在上一篇文章中,我介绍了解决RMQ问题的ST算法,这篇文章我再介绍一种名为“线段树”的数据结构。所谓线段树,其实就是在求解一个区间所有元素的最大值、最小值或所有元素的和之类问题的时候,用一棵树来维护这段区间及其子区间对应的值。在输入数组经常变化的情况下,线段树比起ST算法更加高效。每当原数组变化后,线段树可以在O(logn)
的时间内进行更新(ST算法要进行更新,相当于重新生成一遍Sparse Table,其时间复杂度为O(nlogn)
),并且对于每次查询,也可以在O(logn)
的时间内给出答案。
一、概述
本文还是以求一段区间中元素的最小值问题为例,建立线段树这种数据结构。
线段树首先是一棵普通的二叉树,每个结点都有指向自己的左右儿子的指针。但在此基础上,线段树的每个结点还表示一个区间[l,r]
。如果l
等于r
,该结点就是一个叶子结点;否则该结点就是一个中间结点,其左儿子表示这个区间的前半部分[l,(l+r)/2]
,其右儿子表示这个区间的后半部分[(l+r)/2+1,r]
。线段树的根结点表示的是原数组对应的整个区间[1,n]
。
于是我们可以在脑海中大致勾勒出线段树的结点对应的数据结构Node
的轮廓:
1
2
3
4
5
6
7
8
struct Node {
Node * left ; //指向左儿子的指针
Node * right ; //指向右儿子的指针
int leftIndex ; //所表示区间的左边界
int rightIndex ; //所表示区间的右边界
int minVal ; //所表示区间中元素的最小值
Node ( int l = 0 , int r = 0 , int m = 0 ) : leftIndex ( l ), rightIndex ( r ), minVal ( m ), left ( NULL ), right ( NULL ){}
};
二、线段树的初始化
对于一个数组arr
,我们可以按照上面给出的线段树的定义,构建出线段树的基本结构,确定每个结点代表的是哪一段区间。假设原数组有10个元素,构建好的线段树中每个结点对应的区间如下图所示(图片来自hihoCoder):
构建完基本结构之后,下一步操作是计算每个结点对应的区间中元素的最小值。对于叶子结点来讲,因为其对应的区间只包含一个元素,所以最小值就是该元素;而对于中间结点,其最小值则是它的两个儿子所记录的最小值中较小的那一个。这种先计算子结点的值,再根据子结点的值计算父结点的值的过程,与后序遍历的过程类似,可以通过一次后序遍历来实现 ,其时间复杂度为O(n)
。构建线段树并且初始化最小值的代码如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
class SegmentTree {
private :
Node * root ;
Node * constructTree ( const vector < int >& arr , int left , int right );
void initializeMin ( Node * node );
void updateNode ( Node * node , int index , int val );
int queryNode ( Node * node , int left , int right );
SegmentTree ( const SegmentTree & s ){} //Do not copy it!
public :
SegmentTree ();
SegmentTree ( const vector < int >& arr );
void update ( int index , int val );
int query ( int left , int right );
};
Node * SegmentTree :: constructTree ( const vector < int >& arr , int left , int right ){
if ( left == right )
return new Node ( left , right , arr [ left ]);
else {
Node * res = new Node ( left , right );
int middle = ( left + right ) / 2 ;
res -> left = constructTree ( arr , left , middle );
res -> right = constructTree ( arr , middle + 1 , right );
return res ;
}
}
void SegmentTree :: initializeMin ( Node * node ){
if ( ! node -> left && ! node -> right )
return ;
int minVal = 2147483647 ;
if ( node -> left ){
initializeMin ( node -> left );
minVal = min ( minVal , node -> left -> minVal );
}
if ( node -> right ){
initializeMin ( node -> right );
minVal = min ( minVal , node -> right -> minVal );
}
node -> minVal = minVal ;
return ;
}
SegmentTree :: SegmentTree () : root ( NULL ){
}
//由一个数组构造一个线段树:
SegmentTree :: SegmentTree ( const vector < int >& arr ){
if ( arr . size () == 0 ){
root = NULL ;
}
root = constructTree ( arr , 0 , arr . size () - 1 ); //构建线段树,顺便初始化叶子结点的最小值
initializeMin ( root ); //利用后序遍历,初始化非叶子结点的最小值
}
三、线段树的更新
所谓更新,就是在对应的数组arr
中某个元素的值发生变化时,更新该线段树上的结点所记录的最小值。我们可以想象一下受影响的结点包括哪些:首先这个发生变化的元素对应的叶子结点的值肯定要更新,然后是叶子结点的父结点、叶子结点的祖父结点……最后是根结点。由于其他结点对应的区间没有包含该元素,因此其值不受影响。
也就是说,我只需要找到从根结点到该叶子结点的路径,然后倒序更新该路径上结点的值就可以了。由于根结点到叶子结点的路径上结点数目为O(logn)
,所以更新线段树的时间复杂度为O(logn)
。代码如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
//在更新一个元素的值的时候,用类似二分查找的方法更新查找路径上结点的值。
//具体更新的方法是利用递归,先更新子结点,再根据子结点更新之后的值更新父结点。
//这一步操作的复杂度是O(nlogn)。
void SegmentTree :: updateNode ( Node * node , int index , int val ){
if ( ! node )
return ;
if ( node -> leftIndex == index && node -> rightIndex == index ){
node -> minVal = val ;
return ;
}
int left = node -> leftIndex ;
int right = node -> rightIndex ;
int middle = ( left + right ) / 2 ;
if ( index <= middle ){
//查找左子树:
updateNode ( node -> left , index , val );
} else {
//查找右子树:
updateNode ( node -> right , index , val );
}
//更新该结点的最小值:
int minVal = 2147483647 ;
if ( node -> left )
minVal = min ( minVal , node -> left -> minVal );
if ( node -> right )
minVal = min ( minVal , node -> right -> minVal );
node -> minVal = minVal ;
return ;
}
void SegmentTree :: update ( int index , int val ){
updateNode ( root , index , val );
}
四、线段树的查询:
还是以上面所建立的10个元素的线段树为例:如果要查询的区间就是[1,10]
,那么我们直接取出根结点记录的最小值就可以了。但如果不是[1,10]
呢?又分为三种情况:
首先,要查询的区间完全位于左子树,这个时候我们对左子树进行递归调用,获取查询结果即可;其次,要查询的区间完全位于右子树,类似地,我们对右子树进行递归调用;最后,要查询的区间横跨左右子树,这个时候,我们只能把这个区间[l,r]
分为两部分,第一部分[l,middle]
完全位于左子树中,第二部分[middle+1,r]
完全位于右子树中,对两部分都进行递归调用,然后取两部分结果中较小的那一个。这一步的时间复杂度仍然是O(logn)
,代码如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
int SegmentTree :: queryNode ( Node * node , int left , int right ){
if ( ! node )
return 0 ;
if ( node -> leftIndex == left && node -> rightIndex == right )
return node -> minVal ;
int l = node -> leftIndex ;
int r = node -> rightIndex ;
int middle = ( l + r ) / 2 ;
if ( left > middle && right > middle )
return queryNode ( node -> right , left , right );
else if ( left <= middle && right <= middle )
return queryNode ( node -> left , left , right );
else {
int leftMin = queryNode ( node -> left , left , middle );
int rightMin = queryNode ( node -> right , middle + 1 , right );
return min ( leftMin , rightMin );
}
}
int SegmentTree :: query ( int left , int right ){
return queryNode ( root , left , right );
}
五、总结
相比于RMQ-ST算法处理每次查询的O(1)
时间复杂度,线段树在这方面要略逊一筹。但是,在原数组发生变化时,线段树可以做到更快地更新;并且线段树适用的场合也比RMQ-ST算法要广(求区间元素之和也可以用到线段树)。在求区间元素之和的问题上,还有一种数据结构也可以达到初始化是O(n)
、更新和查询都是O(logn)
复杂度的效率,那就是树状数组 。并且这种数据结构实现起来比较小巧轻便,不像线段树这么庞大,动辄就递归调用。关于树状数组,之后我会在博客中介绍。