线段树

线段树是一种支持区间修改、区间查询的数据结构。

区间修改:比如区间加减或者乘除一个值(或者高级一些的线段树,支持同时加减与乘除)

区间查询:求解一个区间的最大值/最小值/和等

以下以区间修改(加减)和区间查询(求和)为例

题目描述

(可以参考这个题目进行测试)

洛谷 P3372 【模板】线段树 1

查询

比如这张图

image.png

红字是线段树数组(即存储这棵树用的数组,代码中的 node tree[MX << 2];
黑字是实际的区间,即当前节点表示了哪段区间。
可以发现,一个点 x 的左节点是 x * 2,而他的右节点是 x * 2 + 1。最下面的一排的叶子(实际可能有倒数第二层的,或更高层的)对应了原始的数组

叶子节点是原始数组的值。而其他的不是叶子的节点,则是区间

对于 1-3 这个点,记录了 a[1-2] + a[3]a[1] + a[2] + a[3] 的值

而 6-8 这个点,则记录了 a[6-7] + a[8]a[6] + a[7] + a[8] 的值

这样对于求和的查询,可以直接拆分成若干个子区间的求和,从而加速求和操作

(而且范围越大的话,使用的子区间范围也越大,这样节省的运算次数也越多,比如 [3, 10] 只需要 [3], [4, 5], [6-10] 即可,而原来需要 8 个元素)

修改

比如要给 [4, 5] 的每个数字加上 3

按照朴素的做法,就是 a[4] += 3, a[5] += 3

但是求和查询的时候,如果查询区间包含 [4, 5],那么,只要保证 4-5 这个点的值(包括全部的父节点)加了 6 即可,至于 4, 5 这两个数字本身是怎么变化的,没有影响。(1)

只有当求和涉及到了 4/5 但是不涉及另一个数字的时候(比如 [1, 4] 或者 [5, 7]),才需要考虑到底是如何加上去的(即把这个加法落实到直到被覆盖的子区间上)(2)

也就是说,在步骤 (1) 的时候,可以直接给 4-5 加上 6,而这个加法不需要落实到 4, 5 数字本身

但是从 (2) 可以看出,在必要的时候(需要拆分区间),还是需要将这个改动落实到子区间(即 pushdown 函数),因此需要有一个变量记录当前区间下每个元素变化的值(只记录,需要的时候才会更新)

而这个修改是可以累积的,比如

add [3, 6] 6
add [4, 5] 3
add [5, 6] 2

在第一行和第二行,对于 [4, 5] 区间的操作是覆盖到每个元素的,因此这里不需要 pushdown(即此时满足 add 函数的第一个 if

而到了第三行,因为更新的区间需要拆开 [4, 5],而 [4, 5] 的标记只能记录全部元素的相同变化,因此首先将之前的操作落实(pushdown),然后,按照递归的方式,进一步的更新 5

代码

#include <iostream>
#include <cstdio>
#include <algorithm>

#define cnt tree[root] // 当前节点
#define lr (root << 1) // 左节点的索引(root * 2)
#define rr ((root << 1) | 1) // 右节点的索引(root * 2 + 1)
#define MX (100010) // 最大容量(实际使用的空间是最大容量的四倍)
#define lt tree[lr] // 左节点
#define rt tree[rr] // right tree(node)
#define LL long long // just short for long long

using namespace std;

struct node { // tree node
    int left, right; // 当前节点覆盖的的子节点的左右索引
    LL add, val; // 标记(记录了区间内每个元素加的值),当前节点的值(和)

    node() {
        add = 0; // 求和用的标记,设置为 0,这个是线段树的核心
    }
} tree[MX << 2]; // 实际需要的是四倍大小
// 为啥四倍,参考这里吧,https://www.cnblogs.com/FengZeng666/p/11446827.html 我之前是直接记住的,没想过这个问题(

LL a[MX]; // 用来存储输入的节点
int n; // 节点数目

void pushup(int root) { // 维护这个节点的值(下推之后,往上更新。根据两个子区间重新算出当前区间的值)
    /*
     * 有两处使用,一种是递归建树(build)的时候,从数组里生成这棵树非叶节点的值(区间和)
     * 另外就是,修改函数(add)内,对于非完全覆盖的区间,需要深入到子区间修改,修改完成后,维护当前节点的性质
     */
    cnt.val = lt.val + rt.val;
    // 这里如果维护的不是和,而是其他的值,那么就这样 pushup 的时候更新
    // 比如,cnt.val = max(lt.val + rt.val)
}

void pushdown(int root) { // 下推当前节点的记录(即原来不需要的时候就存到了根节点,现在需要把这个值往下推送给子节点)
    /*
     * 当当前区间的标记不为 0,即有未下放的标记(累加值),但是需要深入到子区间进行计算的时候
     * 因为之前带有标记的时候不会更新子区间,因此需要把标记下方到子区间(即本来父节点记录了一个 +5,然后,下方后,两个子节点分别记录一个 +5,然后父节点重置为 0)
     */
    if (cnt.add != 0) {
        // cnt.add 是当前节点所记录的,累积起来的数值变动(加法)
        lt.val = lt.val + cnt.add * (lt.right - lt.left + 1); // 左节点 += cnt.add * 左节点元素数目
        lt.add = lt.add + cnt.add; // 左节点标记 += 当前节点标记
        rt.val = rt.val + cnt.add * (rt.right - rt.left + 1);
        rt.add = rt.add + cnt.add;
        cnt.add = 0;
    }
}

void add(int root, int left, int right, int c) { // 区间加一个固定值
    if (left <= cnt.left && right >= cnt.right) { // 如果当前节点所表示的区间位于原始求和区间内
        cnt.add = cnt.add + c; // 当前节点更新一下标记
        cnt.val = (cnt.val + c * (cnt.right - cnt.left + 1)); // 当前节点加上 标记 * 长度
        return;
    }
    pushdown(root); // 因为这次更新覆盖不均匀(即有一部分子节点需要变化,一部分不需要变化),所以把当前的标记下放给子节点
    // 下放后当前节点的标记是 0
    int mid = (cnt.left + cnt.right) >> 1; // build 保证了当前节点的子节点的区间分别为 [left, mid] 和 (mid, right]
    if (left <= mid) add(lr, left, right, c); // 更新左节点
    if (right > mid) add(rr, left, right, c); // 更新右节点
    pushup(root); // 重新获取当前节点的值
    return;
}

LL query(int root, int left, int right) { // 查询(求和)
    if (left <= cnt.left && right >= cnt.right) { // 区间内,直接返回当前值
        return cnt.val;
    }
    pushdown(root); // 因为查询覆盖一部分,所以需要下推全部记录
    int mid = (cnt.left + cnt.right) >> 1;
    LL ans = 0;
    if (left <= mid) ans += query(lr, left, right); // 非全覆盖的则会递归的求下去。
    if (right > mid) ans += query(rr, left, right); // 全覆盖的区间直接返回当前值(最底层单个节点一定全部覆盖)
    return ans;
}

void build(int root, int left, int right) { // 构造线段树(初始化)
    cnt.left = left; // 设置
    cnt.right = right;
    if (left == right) { // 如果当前节点是叶子节点
        cnt.val = a[left];
        return;
    }
    int mid = (cnt.left + cnt.right) >> 1;
    build(lr, left, mid);
    build(rr, mid + 1, right);
    pushup(root);
    return;
}

int main() {
    int m;
    cin >> n >> m;
    for (int i = 1; i <= n; i++) {
        cin >> a[i];
    }
    build(1, 1, n);
    int left, right, c;
    for (int i = 1; i <= m; i++) {
        cin >> c;
        if (c == 1) {
            cin >> left >> right >> c;
            add(1, left, right, c);
        } else {
            cin >> left >> right;
            cout << query(1, left, right) << endl;
        }
    }
    return 0;
}
最后修改:2021 年 11 月 09 日
如果觉得我的文章对你有用,请随意赞赏