线段树底层逻辑探讨-P3373-P1438
目录
功能
建树
过程模拟
代码实现
内存空间
区间查询
过程模拟
代码实现(无须修改、不带标记版)
区间修改-加(懒惰标记)
过程模拟
代码实现
建树
懒标记下传
区间加
查询区间和
区间修改-改为某个值
注意
区间查询-极值版(改合并操作)
区间双修改:加 & 乘(下传时先乘后加)
取模版-P3373 (每次运算后 %=MOD)
P1438-差分加-dTag
注:本文图文以oiwiki为基础
功能
单点、区间
修改、查询
建树
过程模拟
线段树会将数组中任一长不为1的数组进行二分,把线段划分为一个树形结构
需要求信息时将左右子进行合并
例如我们有一个数组 a=[ 10 , 11 , 12 , 13 , 14 ],节点索引从1开始
结构中红色表示数组d管辖数组a节点的区间
可以观察到:
设d[ i ]的区间为 [ s,t ]
左儿子是d[ 2*i ] [ s,(s+t)/2 ]
右儿子是d[ 2*i+1 ] [ (s+t)/2+1,t ]
代码实现
在建树的过程中,假设我们当前根节点为p:
1.如果根节点管辖的区间为1那么直接可以从数组a中映射过来
2.如果长度大于1,那么要从区间中点进行分割,然后再进入左右子节点进行建树,最后合并两节点信息
不难看出,这是一个递归过程,情况1是终止条件
def build(s, t, p):# 对 [s,t] 区间建立线段树,当前根的编号为 pif s == t:tree[p] = a[s]returnmid = s + ((t - s) >> 1)# 移位运算符的优先级小于加减法,所以加上括号# 如果写成 (s + t) >> 1 可能会超出 int 范围build(s, mid, p * 2)build(mid + 1, t, p * 2 + 1)# 递归对左右区间建树tree[p] = tree[p * 2] + tree[p * 2 + 1]
内存空间
如果采用堆式存储(2p是p的左儿子,2p+1是p的右儿子),若有n个叶子结点,则 d 数组的范围最最大为 2**(logn +1)
由于线段树是由二分而来,所以易得深度为logn,在堆式存储情况下叶子结点的数量为2**logn,又由于其为一颗完全二叉树(除了最后一层其他全满的,而且最后一层也只有靠左有),所以总结点为2**(logn+1) -1个,当然可以直接设数组长度为4n(因为前面表达式的最大值为4n-5)
而堆式存储存在无用的叶子节点,可以考虑使用内存池管理线段树节点,每当需要新建节点从池中获取。自底向上考虑,必有每两个底层节点合并为一个上层节点,因此可以类似哈夫曼树地证明,如果有 n 个叶子节点,这样的线段树总共有2n-1个节点。其空间效率优于堆式存储,并且是可能的最优情况。
这样的线段树可以自底向上维护
区间查询
区间和,区间极值等等
过程模拟
仍然以这张图为例,查询:
a[ 1:5 ] 直接d[ 1 ]
a[ 3:5 ] 需要合并a[ 3:3 ] + a[ 4:5 ]
等等,一般情况: 要查询[ l,r ]可以将其拆成最多 O(logn) 个 极大 区间,合并他们即可获得信息
代码实现(无须修改、不带标记版)
def getsum(l, r, s, t, p):# [l, r] 为查询区间, [s, t] 为当前节点包含的区间, p 为当前节点的编号if l <= s and t <= r:return tree[p] # 当前区间为询问区间的子集时直接返回当前区间的和mid = s + ((t - s) >> 1)sum = 0if l <= mid:sum = sum + getsum(l, r, s, mid, p * 2)# 如果左儿子代表的区间 [s, m] 与询问区间有交集, 则递归查询左儿子if r > mid:sum = sum + getsum(l, r, mid + 1, t, p * 2 + 1)# 如果右儿子代表的区间 [m + 1, t] 与询问区间有交集, 则递归查询右儿子return sum
区间修改-加(懒惰标记)
过程模拟
如果需要修改一个区间时进行遍历再修改,那是无法接受的,于是乎,懒惰标记诞生了
懒惰标记,简单来说,就是通过延迟对节点信息的更改,减少没必要的操作。在每次执行修改时,我们通过打标记的方式表明该节点对应区间在某一次操作中被更改,但不更新该节点的子节点信息。实质性的修改则在下一次访问带有标记的节点时才进行
仍以数组 a=[ 10 , 11 , 12 , 13 , 14 ]为例,现在我们新增一个信息表示标记值,设为 ti
最开始情况是这样的
现在我们给[ 3,5 ]每个数加上5,和之前区间查询一样,我们找到两个极大区间[ 3,3 ]和[ 4,5 ],分别对应线段树d[ 3 ]和d[ 5 ],那么我们现在需要修改他们的标记值
可以看到,d[ 3 ]节点 t 值修改了,但是他的 两个子节点却没有更新
不过不用担心,虽然现在没有修改,但是我们需要查询这两个子节点信息时我们会利用标记修改这两个子节点信息,使结果依旧准确,比如现在我们需要查询[ 4,4 ]的信息,我们通过递归找到[ 4,5 ]区间,这时候发现该区间并非我们的目标,而且该区间还存在标记,那么这时候就到下放标记的时间了。我们将该区间的两个子区间信息更新,并清除该区间的标记
清零该节点的 t[ 3 ],更新两个子节点的 t[ 6 ]、t[ 7 ]
代码实现
建树
def build(s, t, p):# 建立线段树,[s,t] 当前根为pif s == t:tree[p] = a[s]returnmid = s + ((t - s) >> 1)build(s, mid, p * 2)build(mid + 1, t, p * 2 + 1)tree[p] = tree[p * 2] + tree[p * 2 + 1]
懒标记下传
def pushdown(p, s, t):mid = s + ((t - s) >> 1)if addTag[p]:tree[p * 2] = (tree[p * 2] + addTag[p] * (mid - s + 1))tree[p * 2 + 1] = (tree[p * 2 + 1] + addTag[p] * (t - mid))addTag[p * 2] = (addTag[p * 2] + addTag[p])addTag[p * 2 + 1] = (addTag[p * 2 + 1] + addTag[p])addTag[p] = 0
区间加
def update(l, r, c, s, t, p):# [l, r] 为修改区间, c 为被修改的元素的变化量, [s, t] 为当前节点包含的区间# p为当前节点的编号if l <= s and t <= r:tree[p] += (t - s + 1) * caddTag[p] += creturnpushdown(p,s,t)# 当前区间为修改区间的子集时直接修改当前节点的值, 然后打标记, 结束修改mid = s + ((t - s) >> 1)if l <= mid:addUpdate(l, r, c, s, mid, p * 2)if r > mid:addUpdate(l, r, c, mid + 1, t, p * 2 + 1)tree[p] = tree[p * 2] + tree[p * 2 + 1]
查询区间和
def getsum(l, r, s, t, p):# [l, r] 为查询区间, [s, t] 为当前节点包含的区间, p 为当前节点的编号if l <= s and t <= r:return tree[p] # 当前区间为询问区间的子集时直接返回当前区间的和pushdown(p,s,t)mid = s + ((t - s) >> 1)sum = 0if l <= mid:sum = sum + getsum(l, r, s, mid, p * 2)# 如果左儿子代表的区间 [s, m] 与询问区间有交集, 则递归查询左儿子if r > mid:sum = sum + getsum(l, r, mid + 1, t, p * 2 + 1)# 如果右儿子代表的区间 [m + 1, t] 与询问区间有交集, 则递归查询右儿子return sum
区间修改-改为某个值
如果不是想将区间加值,而是修改为某个值
def update(l, r, c, s, t, p):if l <= s and t <= r:d[p] = (t - s + 1) * cb[p] = cv[p] = 1returnm = s + ((t - s) >> 1)if v[p]:d[p * 2] = b[p] * (m - s + 1)d[p * 2 + 1] = b[p] * (t - m)b[p * 2] = b[p * 2 + 1] = b[p]v[p * 2] = v[p * 2 + 1] = 1v[p] = 0if l <= m:update(l, r, c, s, m, p * 2)if r > m:update(l, r, c, m + 1, t, p * 2 + 1)d[p] = d[p * 2] + d[p * 2 + 1]def getsum(l, r, s, t, p):if l <= s and t <= r:return d[p]m = s + ((t - s) >> 1)if v[p]:d[p * 2] = b[p] * (m - s + 1)d[p * 2 + 1] = b[p] * (t - m)b[p * 2] = b[p * 2 + 1] = b[p]v[p * 2] = v[p * 2 + 1] = 1v[p] = 0sum = 0if l <= m:sum = getsum(l, r, s, m, p * 2)if r > m:sum = sum + getsum(l, r, m + 1, t, p * 2 + 1)return sum
注意
# 注意数组a下标从1开始,所以得在前面价个[0]
# 那么N记得-1
# 别忘记建树 build(1,N,1)
# 形参传入( ,1,N,1):在修改函数中一般需要 ( 操作参数[比如 l,r,v ],s,t,p ),其中s,t,p分别表示当前节点负责的左右区间以及当前节点编号,那么我们调用函数的时候就需要前面写操作参数,后面的s,t,p代入 1,N,1
以上的代码都是基于求区间和,现在需求是求区间极值,那么只需要将合并操作改 + 为 max/min
区间查询-极值版(改合并操作)
# 数组a是需要操作的数据,N是数组a的长度d = [0] * (N * 4) # 线段树数组
b = [0] * (N * 4) # 懒标记数组def build(s, t, p):"""建立线段树,s和t是当前节点p管理的区间"""if s == t:d[p] = a[s]returnm = (s + t) >> 1build(s, m, p * 2)build(m + 1, t, p * 2 + 1)d[p] = max(d[p * 2], d[p * 2 + 1])# 将合并操作改 + 为 max/mindef pushdown(p, s, t):"""把懒标记往下传"""if b[p]:m = (s + t) >> 1d[p * 2] += b[p]d[p * 2 + 1] += b[p]b[p * 2] += b[p]b[p * 2 + 1] += b[p]b[p] = 0def update(l, r, c, s, t, p):"""区间[l, r]加上c当前节点p管理区间[s, t]"""if l <= s and t <= r:d[p] += cb[p] += creturnpushdown(p, s, t)m = (s + t) >> 1if l <= m:update(l, r, c, s, m, p * 2)if r > m:update(l, r, c, m + 1, t, p * 2 + 1)d[p] = max(d[p * 2], d[p * 2 + 1])def query(l, r, s, t, p):"""查询区间[l, r]内的最大值当前节点p管理区间[s, t]"""if l <= s and t <= r:return d[p]pushdown(p, s, t)m = (s + t) >> 1res = -float('inf')if l <= m:res = max(res, query(l, r, s, m, p * 2))if r > m:res = max(res, query(l, r, m + 1, t, p * 2 + 1))return res
既需要区间加,又需要区间乘时,先处理乘法再处理加法
区间双修改:加 & 乘(下传时先乘后加)
# 数组a是需要操作的数据,N是数组a的长度# 注意数组a下标从1开始,所以得在前面价个[0]
# 那么N记得len后-1tree = [0] * (N * 4) # 维护区间和
addTag = [0] * (N * 4) # 区间加法懒标记
mulTag = [1] * (N * 4) # 区间乘法懒标记(初始为1)def build(s, t, p):# 建立线段树,[s,t] 当前根为pif s == t:tree[p] = a[s]returnmid = s + ((t - s) >> 1)build(s, mid, p * 2)build(mid + 1, t, p * 2 + 1)tree[p] = tree[p * 2] + tree[p * 2 + 1]def pushdown(p, s, t):mid = s + ((t - s) >> 1)# 先处理乘法懒标记if mulTag[p] != 1:for child in (p * 2, p * 2 + 1):tree[child] = (tree[child] * mulTag[p]) % mmulTag[child] = (mulTag[child] * mulTag[p]) % maddTag[child] = (addTag[child] * mulTag[p]) % mmulTag[p] = 1 #恢复初始值# 再处理加法懒标记if addTag[p]:tree[p * 2] = (tree[p * 2] + addTag[p] * (mid - s + 1)) % mtree[p * 2 + 1] = (tree[p * 2 + 1] + addTag[p] * (t - mid)) % maddTag[p * 2] = (addTag[p * 2] + addTag[p]) % maddTag[p * 2 + 1] = (addTag[p * 2 + 1] + addTag[p]) % maddTag[p] = 0 #恢复初始值def addUpdate(l, r, c, s, t, p):# [l, r] 区间内加上cif l <= s and t <= r:tree[p] += (t - s + 1) * caddTag[p] += creturnpushdown(p, s, t)mid = s + ((t - s) >> 1)if l <= mid:addUpdate(l, r, c, s, mid, p * 2)if r > mid:addUpdate(l, r, c, mid + 1, t, p * 2 + 1)tree[p] = tree[p * 2] + tree[p * 2 + 1]def mulUpdate(l, r, c, s, t, p):# [l, r] 区间内乘上cif l <= s and t <= r:tree[p] *= cmulTag[p] *= caddTag[p] *= creturnpushdown(p, s, t)mid = s + ((t - s) >> 1)if l <= mid:mulUpdate(l, r, c, s, mid, p * 2)if r > mid:mulUpdate(l, r, c, mid + 1, t, p * 2 + 1)tree[p] = tree[p * 2] + tree[p * 2 + 1]def getSum(l, r, s, t, p):# 查询[l,r]区间和if l <= s and t <= r:return tree[p]pushdown(p, s, t)mid = s + ((t - s) >> 1)res = 0if l <= mid:res += getSum(l, r, s, mid, p * 2)if r > mid:res += getSum(l, r, mid + 1, t, p * 2 + 1)return res
取模版-P3373 (每次运算后 %=MOD)
P3373 【模板】线段树 2 - 洛谷
如果还要求取模,那么需要在每次运算后 %=MOD
n,q,m=map(int,input().split())# 数组a是需要操作的数据,N是数组a的长度
a=[0] + list(map(int,input().split()))
# 注意数组a下标从1开始,所以得在前面价个[0]
# 那么N记得-1
N=len(a)-1tree = [0] * (N * 4) # 维护区间和
addTag = [0] * (N * 4) # 区间加法懒标记
mulTag = [1] * (N * 4) # 区间乘法懒标记(初始为1)def build(s, t, p):# 建立线段树,[s,t] 当前根为pif s == t:tree[p] = a[s]tree[p]%=mreturnmid = s + ((t - s) >> 1)build(s, mid, p * 2)build(mid + 1, t, p * 2 + 1)tree[p] = tree[p * 2] + tree[p * 2 + 1]tree[p]%=mdef pushdown(p, s, t):mid = s + ((t - s) >> 1)# 先处理乘法懒标记if mulTag[p] != 1:for child in (p * 2, p * 2 + 1):tree[child] = (tree[child] * mulTag[p]) % mmulTag[child] = (mulTag[child] * mulTag[p]) % maddTag[child] = (addTag[child] * mulTag[p]) % mmulTag[p] = 1# 再处理加法懒标记if addTag[p]:'''for ch in (p*2,p*2+1):tree[ch]+=addTag[p]*(mid-s+1)tree[ch]%=maddTag[ch]+=addTag[p]addTag%=m'''tree[p * 2] = (tree[p * 2] + addTag[p] * (mid - s + 1)) % mtree[p * 2 + 1] = (tree[p * 2 + 1] + addTag[p] * (t - mid)) % maddTag[p * 2] = (addTag[p * 2] + addTag[p]) % maddTag[p * 2 + 1] = (addTag[p * 2 + 1] + addTag[p]) % maddTag[p] = 0def addUpdate(l, r, c, s, t, p):# [l, r] 区间内加上cif l <= s and t <= r:tree[p] += (t - s + 1) * ctree[p]%=maddTag[p] += caddTag[p]%=mreturnpushdown(p, s, t)mid = s + ((t - s) >> 1)if l <= mid:addUpdate(l, r, c, s, mid, p * 2)if r > mid:addUpdate(l, r, c, mid + 1, t, p * 2 + 1)tree[p] = tree[p * 2] + tree[p * 2 + 1]tree[p]%=mdef mulUpdate(l, r, c, s, t, p):# [l, r] 区间内乘上cif l <= s and t <= r:tree[p] *= ctree[p]%=mmulTag[p] *= cmulTag[p]%=maddTag[p] *= caddTag[p]%=mreturnpushdown(p, s, t)mid = s + ((t - s) >> 1)if l <= mid:mulUpdate(l, r, c, s, mid, p * 2)if r > mid:mulUpdate(l, r, c, mid + 1, t, p * 2 + 1)tree[p] = tree[p * 2] + tree[p * 2 + 1]tree[p]%=mdef getSum(l, r, s, t, p):# 查询[l,r]区间和if l <= s and t <= r:return tree[p]%mpushdown(p, s, t)mid = s + ((t - s) >> 1)res = 0if l <= mid:res += getSum(l, r, s, mid, p * 2)res%=mif r > mid:res += getSum(l, r, mid + 1, t, p * 2 + 1)res%=mreturn resbuild(1,N,1)
#别忘记建树for _ in range(q):te = list(map(int, input().split()))if te[0] == 1:x, y, k = te[1:]mulUpdate(x, y, k, 1, N, 1) # 一开始是整棵树,区间是[1,N],根节点是1号elif te[0] == 2:x, y, k = te[1:]addUpdate(x, y, k, 1, N, 1)else:x, y = te[1:]print(getSum(x, y, 1, N, 1) % m)
P1438-差分加-dTag
P1438 无聊的数列 - 洛谷
n,m=map(int,input().split())a=[0]+list(map(int,input().split()))
N=len(a)-1tree=[0]*(N*4)
addTag=[0]*(N*4)
dTag=[0]*(N*4)
# tree: 线段树数组
# addTag, dTag: 懒标记数组,初始均为0def build(s, t, p):# 建立线段树,[s,t] 当前根为 pif s == t:tree[p] = a[s]returnmid = s + ((t - s) >> 1)build(s, mid, p * 2)build(mid + 1, t, p * 2 + 1)tree[p] = tree[p * 2] + tree[p * 2 + 1]def pushdown(p, s, t):mid = s + ((t - s) >> 1)L = mid - s + 1R = t - mid# 若存在懒标记,则将变更下放到左右子节点if addTag[p] or dTag[p]:# 左子节点更新tree[p * 2] += addTag[p] * L + dTag[p] * (L * (L - 1) // 2)addTag[p * 2] += addTag[p]dTag[p * 2] += dTag[p]# 右子节点,第一个位置对应加值 = addTag[p] + dTag[p] * Ltree[p * 2 + 1] += (addTag[p] + dTag[p] * L) * R + dTag[p] * (R * (R - 1) // 2)addTag[p * 2 + 1] += addTag[p] + dTag[p] * LdTag[p * 2 + 1] += dTag[p]addTag[p] = 0dTag[p] = 0def update(l, r, K, D, s, t, p):# 在区间 [l, r] 内每个位置加上等差数列:首项 K, 公差 Dif l <= s and t <= r:L = t - s + 1# 当前区间第一个位置 s 对应增加 A = K + D * (s - l)A = K + D * (s - l)tree[p] += A * L + D * (L * (L - 1) // 2)addTag[p] += AdTag[p] += Dreturnpushdown(p, s, t)mid = s + ((t - s) >> 1)if l <= mid:update(l, r, K, D, s, mid, p * 2)if r > mid:update(l, r, K, D, mid + 1, t, p * 2 + 1)tree[p] = tree[p * 2] + tree[p * 2 + 1]def getsum(l, r, s, t, p):# 查询区间 [l, r] 的和if l <= s and t <= r:return tree[p]pushdown(p, s, t)mid = s + ((t - s) >> 1)sum_val = 0if l <= mid:sum_val += getsum(l, r, s, mid, p * 2)if r > mid:sum_val += getsum(l, r, mid + 1, t, p * 2 + 1)return sum_val#别忘记build
build(1,N,1)for i in range(m):te=list(map(int,input().split()))if te[0]==1:l,r,K,D=te[1:]update(l,r,K,D,1,N,1)else:p=te[1]print(getsum(p,p,1,N,1))