树状数组,了解三点就够了

简介

何为树状数组?首先引用百度百科的定义:树状数组二叉索引树(英语:Binary Indexed Tree),又以其发明者命名为Fenwick树,最早由Peter M. Fenwick于1994年以A New Data Structure for Cumulative Frequency Tables为题发表在SOFTWARE PRACTICE AND EXPERIENCE。其初衷是解决数据压缩里的累积频率(Cumulative Frequency)的计算问题,现多用于高效计算数列的前缀和, 区间和。

说白了树状数组就是用来高效计算并修改数列前缀和的一种高级数据结构,和线段树类似,但是实现起来比线段树复杂,能够实现的功能相较于线段树也会较少一些。

普通的前缀和数组求取前缀和的时间复杂度为\(O(1)\),而更新前缀和的时间复杂度为\(O(n)\)​​,而树状数组求取和更新前缀和的时间复杂度均为\(O(log(n))\)。因此,当对数组的更新较为频繁时,可以使用树状数组来进行维护。而更新较少时使用普通的前缀和数组维护可得到更好的时间复制度。

要点

一图概括树状数组。

上图中,arr表示原始数组,treeArr表示树状数组,注意,建立树状数组的时候,索引从1开始,将索引0处留出来,这是为了方便进行位寻址操作。所以对一个长度为n的原始数组,建立成树状数组后长度将变为n+1.

树状数组的三个核心操作

  • 取二进制中的最低位1:lowbit(x)

根据补码知识,一个数与上它的相反数就可以取出其的最低位1.

1
2
def lowbit(x):
return x & (-x)
  • 更新操作(往父节点传播),索引变化:x+lowbit(x)
1
2
3
4
5
def update(self, index, value):
self.arr[index] += value
while index <= self.length:
self.treeArr[index] += value
index += lowbit(index)
  • 求和操作(下级传播),索引变化:x-lowbit(x)
1
2
3
4
5
6
def prefix_sum(self, index):
ans = 0
while index:
ans += self.treeArr[index]
index -= lowbit(index)
return ans

树状数组的每一个节点管理的数据数量为其最低位1的值,节点的值表示所管理的所有数之和。例如,索引为4(100)的节点最低位1为(100),所管理的数据量为4,表示\(arr_1+arr_2+arr_3+arr_4\);索引为5(101)的节点管理的数据量为1(最低位1为1),于是t5值所表示的数为arr5,对于任何一个\(t_i\),其值为\(sum(arr(i-lowbit(i),i])\)​​.

树状数组求前缀和时,只需将所有当前索引不断去掉低位1的数据进行求和即可。以t7求和为例,7的二进制为111,去掉低位一后为110,继续得到100,最终得到000(000没用到,其值也定位0,不予理会)。所以前7个数的前缀和就可以使用t7(111)+t6(110)+t4(100)求得。求区间和时可以使用两前缀和相减:

1
2
def interval_sum(self, start, end):
return self.prefix_sum(end - 1) - self.prefix_sum(start - 1)

树状数组进行更新时需要更新当前节点及其所有的父节点,如对arr3加上4,需要对t3(11)+4,再对t4(100)+4,再对t8(1000)+4,这个过程进行的操作为x+lowbit(x).

树状数组的初始化

树状数组在初始化时,可直接使用更新操作来建立,时间复制度为\(O(nlog(n))\)​.

1
2
3
4
5
6
def __init__(self, origin):
self.length = len(origin)
self.arr = [0] * (self.length + 1)
self.treeArr = [0] * (self.length + 1)
for i in range(1, self.length + 1):
self.update(i, origin[i - 1])

也可以使用直接求和来建立,时间复杂度为\(O(n^2)\)​​​,(但是此方法在python中比上面的快很多,或许python的sum有特殊机制,迷惑~~)。

1
2
3
4
5
6
def __init__(self, origin):
self.length = len(origin)
self.arr = [0] + origin[:]
self.treeArr = [0] * (self.length + 1)
for i in range(1, self.length + 1):
self.treeArr[i] = sum(self.arr[i - lowbit(i) + 1:i + 1])

End

最后放个完整示例代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
class TreeArray:
def __init__(self, origin):
self.length = len(origin)
# self.arr = [0] + origin[:]
self.arr = [0] * (self.length + 1)
self.treeArr = [0] * (self.length + 1)
for i in range(1, self.length + 1):
# self.treeArr[i] = sum(self.arr[i - lowbit(i) + 1:i + 1])
self.update(i, origin[i - 1])

def update(self, index, value): # 索引index处加上value
self.arr[index] += value
while index <= self.length:
self.treeArr[index] += value
index += lowbit(index)

def prefix_sum(self, index):
ans = 0
while index:
ans += self.treeArr[index]
index -= lowbit(index)
return ans

def interval_sum(self, start, end): # [start,end)左闭右开
return self.prefix_sum(end - 1) - self.prefix_sum(start - 1)

def get_current(self):
print('arr:', self.arr)
print('Tree:', self.treeArr)


def lowbit(x):
return x & (-x)


if __name__ == '__main__':
origin = [i for i in range(1, 10)]
s = TreeArray(origin)
s.get_current()
print(s.interval_sum(3, 8))
s.update(4, 6)
s.get_current()
print(s.interval_sum(3, 8))

[out]:
arr: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Tree: [0, 1, 3, 3, 10, 5, 11, 7, 36, 9]
25
arr: [0, 1, 2, 3, 10, 5, 6, 7, 8, 9]
Tree: [0, 1, 3, 3, 16, 5, 11, 7, 42, 9]
31

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
package main

import "fmt"

func main() {
var origin []int
n := 8
for i := 1; i <= n; i++ {
origin = append(origin, i)
}
var tree *treeArr
tree = new(treeArr)
tree.init(origin)
tree.print()
fmt.Printf("sum[3:8]=%d\n", tree.interval_sum(3, 8))
tree.update(4,6)
tree.print()
fmt.Printf("sum[3:8]=%d\n", tree.interval_sum(3, 8))

}

type treeArr struct {
arr []int
treeArr []int
length int
}

func (the *treeArr) init(arr []int) {
the.length = len(arr)
the.arr = make([]int, the.length+1)
the.treeArr = make([]int, the.length+1)
for i := 1; i <= the.length; i++ {
the.update(i, arr[i-1])
}
}
func (the *treeArr) update(index, value int) {
the.arr[index] = value
for index <= the.length {
the.treeArr[index] += value
index += lowbit(index)
}

}
func (the *treeArr) prefix_sum(index int) int {
ans := 0
for index > 0 {
ans += the.treeArr[index]
index -= lowbit(index)
}
return ans
}
func (the *treeArr) print() {
fmt.Println(the.arr)
fmt.Println(the.treeArr)
}
func (the *treeArr) interval_sum(start, end int) int {
return the.prefix_sum(end-1) - the.prefix_sum(start-1)
}
func lowbit(x int) int {
return x & (-x)
}

[out]
[0 1 2 3 4 5 6 7 8]
[0 1 3 3 10 5 11 7 36]
sum[3:8]=25
[0 1 2 3 6 5 6 7 8]
[0 1 3 3 16 5 11 7 42]
sum[3:8]=31