美柑の部屋

涙は見せないと誓った。

Loading…

随机查询(二):线段树

在上一篇文章中,我介绍了解决RMQ问题的ST算法,这篇文章我再介绍一种名为“线段树”的数据结构。所谓线段树,其实就是在求解一个区间所有元素的最大值、最小值或所有元素的和之类问题的时候,用一棵树来维护这段区间及其子区间对应的值。在输入数组经常变化的情况下,线段树比起ST算法更加高效。每当原数组变化后,线段树可以在O(logn)的时间内进行更新(ST算法要进行更新,相当于重新生成一遍Sparse Table,其时间复杂度为O(nlogn)),并且对于每次查询,也可以在O(logn)的时间内给出答案。

一、概述

本文还是以求一段区间中元素的最小值问题为例,建立线段树这种数据结构。
线段树首先是一棵普通的二叉树,每个结点都有指向自己的左右儿子的指针。但在此基础上,线段树的每个结点还表示一个区间[l,r]。如果l等于r,该结点就是一个叶子结点;否则该结点就是一个中间结点,其左儿子表示这个区间的前半部分[l,(l+r)/2],其右儿子表示这个区间的后半部分[(l+r)/2+1,r]。线段树的根结点表示的是原数组对应的整个区间[1,n]
于是我们可以在脑海中大致勾勒出线段树的结点对应的数据结构Node的轮廓:

1
2
3
4
5
6
7
8
struct Node{
    Node* left; //指向左儿子的指针
    Node* right; //指向右儿子的指针
    int leftIndex; //所表示区间的左边界
    int rightIndex; //所表示区间的右边界
    int minVal; //所表示区间中元素的最小值
    Node(int l = 0, int r = 0, int m = 0): leftIndex(l), rightIndex(r), minVal(m), left(NULL), right(NULL){}
};

二、线段树的初始化

对于一个数组arr,我们可以按照上面给出的线段树的定义,构建出线段树的基本结构,确定每个结点代表的是哪一段区间。假设原数组有10个元素,构建好的线段树中每个结点对应的区间如下图所示(图片来自hihoCoder):

构建完基本结构之后,下一步操作是计算每个结点对应的区间中元素的最小值。对于叶子结点来讲,因为其对应的区间只包含一个元素,所以最小值就是该元素;而对于中间结点,其最小值则是它的两个儿子所记录的最小值中较小的那一个。这种先计算子结点的值,再根据子结点的值计算父结点的值的过程,与后序遍历的过程类似,可以通过一次后序遍历来实现,其时间复杂度为O(n)。构建线段树并且初始化最小值的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
class SegmentTree{
private:
    Node* root;
    Node* constructTree(const vector<int>& arr, int left, int right);
    void initializeMin(Node* node);
    void updateNode(Node* node, int index, int val);
    int queryNode(Node* node, int left, int right);
    SegmentTree(const SegmentTree& s){} //Do not copy it!
public:
    SegmentTree();
    SegmentTree(const vector<int>& arr);
    void update(int index, int val);
    int query(int left, int right);
};

Node* SegmentTree::constructTree(const vector<int>& arr, int left, int right){
    if(left == right)
        return new Node(left, right, arr[left]);
    else{
        Node* res = new Node(left, right);
        int middle = (left + right) / 2;
        res->left = constructTree(arr, left, middle);
        res->right = constructTree(arr, middle+1, right);
        return res;
    }
}

void SegmentTree::initializeMin(Node* node){
    if(!node->left && !node->right)
        return;

    int minVal = 2147483647;
    if(node->left){
        initializeMin(node->left);
        minVal = min(minVal, node->left->minVal);
    }
    if(node->right){
        initializeMin(node->right);
        minVal = min(minVal, node->right->minVal);
    }
    node->minVal = minVal;
    return;
}

SegmentTree::SegmentTree(): root(NULL){
}

//由一个数组构造一个线段树:
SegmentTree::SegmentTree(const vector<int>& arr){
    if(arr.size() == 0){
        root = NULL;
    }
    root = constructTree(arr, 0, arr.size() - 1); //构建线段树,顺便初始化叶子结点的最小值
    initializeMin(root); //利用后序遍历,初始化非叶子结点的最小值
}

三、线段树的更新

所谓更新,就是在对应的数组arr中某个元素的值发生变化时,更新该线段树上的结点所记录的最小值。我们可以想象一下受影响的结点包括哪些:首先这个发生变化的元素对应的叶子结点的值肯定要更新,然后是叶子结点的父结点、叶子结点的祖父结点……最后是根结点。由于其他结点对应的区间没有包含该元素,因此其值不受影响。
也就是说,我只需要找到从根结点到该叶子结点的路径,然后倒序更新该路径上结点的值就可以了。由于根结点到叶子结点的路径上结点数目为O(logn),所以更新线段树的时间复杂度为O(logn)。代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
//在更新一个元素的值的时候,用类似二分查找的方法更新查找路径上结点的值。
//具体更新的方法是利用递归,先更新子结点,再根据子结点更新之后的值更新父结点。
//这一步操作的复杂度是O(nlogn)。
void SegmentTree::updateNode(Node* node, int index, int val){
    if(!node)
        return;
    if(node->leftIndex == index && node->rightIndex == index){
        node->minVal = val;
        return;
    }
    int left = node->leftIndex;
    int right = node->rightIndex;
    int middle = (left + right) / 2;
    if(index <= middle){
        //查找左子树:
        updateNode(node->left, index, val);
    }else{
        //查找右子树:
        updateNode(node->right, index, val);
    }
    //更新该结点的最小值:
    int minVal = 2147483647;
    if(node->left)
        minVal = min(minVal, node->left->minVal);
    if(node->right)
        minVal = min(minVal, node->right->minVal);
    node->minVal = minVal;
    return;
}

void SegmentTree::update(int index, int val){
    updateNode(root, index, val);
}

四、线段树的查询:

还是以上面所建立的10个元素的线段树为例:如果要查询的区间就是[1,10],那么我们直接取出根结点记录的最小值就可以了。但如果不是[1,10]呢?又分为三种情况:
首先,要查询的区间完全位于左子树,这个时候我们对左子树进行递归调用,获取查询结果即可;其次,要查询的区间完全位于右子树,类似地,我们对右子树进行递归调用;最后,要查询的区间横跨左右子树,这个时候,我们只能把这个区间[l,r]分为两部分,第一部分[l,middle]完全位于左子树中,第二部分[middle+1,r]完全位于右子树中,对两部分都进行递归调用,然后取两部分结果中较小的那一个。这一步的时间复杂度仍然是O(logn),代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
int SegmentTree::queryNode(Node* node, int left, int right){
    if(!node)
        return 0;
    if(node->leftIndex == left && node->rightIndex == right)
        return node->minVal;

    int l = node->leftIndex;
    int r = node->rightIndex;
    int middle = (l + r) / 2;
    if(left > middle && right > middle)
        return queryNode(node->right, left, right);
    else if(left <= middle && right <= middle)
        return queryNode(node->left, left, right);
    else{
        int leftMin = queryNode(node->left, left, middle);
        int rightMin = queryNode(node->right, middle+1, right);
        return min(leftMin, rightMin);
    }
}

int SegmentTree::query(int left, int right){
    return queryNode(root, left, right);
}

五、总结

相比于RMQ-ST算法处理每次查询的O(1)时间复杂度,线段树在这方面要略逊一筹。但是,在原数组发生变化时,线段树可以做到更快地更新;并且线段树适用的场合也比RMQ-ST算法要广(求区间元素之和也可以用到线段树)。在求区间元素之和的问题上,还有一种数据结构也可以达到初始化是O(n)、更新和查询都是O(logn)复杂度的效率,那就是树状数组。并且这种数据结构实现起来比较小巧轻便,不像线段树这么庞大,动辄就递归调用。关于树状数组,之后我会在博客中介绍。

评论