树链剖分

2024-08-15

通过两次DFS函数剖链,借助线段树结构实现轻重链剖分

#include <bits/stdc++.h>
#define io cin.tie(0), cout.tie(0), ios::sync_with_stdio(false)
#define LL long long
#define ULL unsigned long long
#define EPS 1e-8
#define INF 0x7fffffff
#define SUB -INF - 1
using namespace std;
const int N = 100010;
int n, m, r, mod;
struct Edge
{
    int to, next;
} edge[2 * N];
int head[2 * N], cnt;
void init()
{
    for (int i = 0; i < 2 * N; i++)
    {
        edge[i].next = -1;
        head[i] = -1;
    }
    cnt = 0;
}
void addedge(int u, int v)
{
    edge[cnt].to = v;
    edge[cnt].next = head[u];
    head[u] = cnt++;
}
int w[N], w_new[N];
int tree[N << 2], tag[N << 2];
void addtag(int p, int pl, int pr, int d)
{
    tag[p] += d;
    tree[p] += d * (pr - pl + 1);
    tree[p] %= mod;
}
void push_up(int p)
{
    tree[p] = tree[p << 1] + tree[p << 1 | 1];
    tree[p] %= mod;
}
void push_down(int p, int pl, int pr)
{
    if (tag[p])
    {
        int mid = (pl + pr) >> 1;
        addtag(p << 1, pl, mid, tag[p]);
        addtag(p << 1 | 1, mid + 1, pr, tag[p]);
        tag[p] = 0;
    }
}
void build(int p, int pl, int pr)
{
    tag[p] = 0;
    if (pl == pr)
    {
        tree[p] = w_new[pl];
        tree[p] %= mod;
        return;
    }
    int mid = (pl + pr) >> 1;
    build(p << 1, pl, mid);
    build(p << 1 | 1, mid + 1, pr);
    push_up(p);
}
void update(int l, int r, int p, int pl, int pr, int d)
{
    if (l <= pl && pr <= r)
    {
        addtag(p, pl, pr, d);
        return;
    }
    push_down(p, pl, pr);
    int mid = (pl + pr) >> 1;
    if (l <= mid)
        update(l, r, p << 1, pl, mid, d);
    if (r > mid)
        update(l, r, p << 1 | 1, mid + 1, pr, d);
    push_up(p);
}
int query(int l, int r, int p, int pl, int pr)
{
    if (l <= pl && pr <= r)
        return tree[p] % mod;
    push_down(p, pl, pr);
    int res = 0;
    int mid = (pl + pr) >> 1;
    if (l <= mid)
        res += query(l, r, p << 1, pl, mid);
    if (r > mid)
        res += query(l, r, p << 1 | 1, mid + 1, pr);
    return res;
}
int son[N], id[N], fa[N], deep[N], siz[N], top[N];
void dfs1(int x, int father)
{
    deep[x] = deep[father] + 1;
    fa[x] = father;
    siz[x] = 1;
    for (int i = head[x]; ~i; i = edge[i].next)
    {
        int y = edge[i].to;
        if (y != father)
        {
            fa[y] = x;
            dfs1(y, x);
            siz[x] += siz[y];
            if (!son[x] || siz[son[x]] < siz[y])
                son[x] = y;
        }
    }
}
int num = 0;
void dfs2(int x, int topx)
{
    id[x] = ++num;
    w_new[num] = w[x];
    top[x] = topx;
    if (!son[x])
        return;
    dfs2(son[x], topx);
    for (int i = head[x]; ~i; i = edge[i].next)
    {
        int y = edge[i].to;
        if (y != fa[x] && y != son[x])
            dfs2(y, y);
    }
}
void update_range(int x, int y, int z)
{
    while (top[x] != top[y])
    {
        if (deep[top[x]] < deep[top[y]])
            swap(x, y);
        update(id[top[x]], id[x], 1, 1, n, z);
        x = fa[top[x]];
    }
    if (deep[x] > deep[y])
        swap(x, y);
    update(id[x], id[y], 1, 1, n, z);
}
int query_range(int x, int y)
{
    int ans = 0;
    while (top[x] != top[y])
    {
        if (deep[top[x]] < deep[top[y]])
            swap(x, y);
        ans += query(id[top[x]], id[x], 1, 1, n);
        ans %= mod;
        x = fa[top[x]];
    }
    if (deep[x] > deep[y])
        swap(x, y);
    ans += query(id[x], id[y], 1, 1, n);
    return ans % mod;
}
void update_tree(int x, int k) { update(id[x], id[x] + siz[x] - 1, 1, 1, n, k); }
int query_tree(int x) { return query(id[x], id[x] + siz[x] - 1, 1, 1, n) % mod; }
int main()
{
    init();
    scanf("%d%d%d%d", &n, &m, &r, &mod);
    for (int i = 1; i <= n; i++)
        scanf("%d", &w[i]);
    for (int i = 1; i < n; i++)
    {
        int u, v;
        scanf("%d%d", &u, &v);
        addedge(u, v);
        addedge(v, u);
    }
    dfs1(r, 0);
    dfs2(r, r);
    build(1, 1, n);
    while (m--)
    {
        int k, x, y, z;
        scanf("%d", &k);
        switch (k)
        {
        case 1:
            scanf("%d%d%d", &x, &y, &z);
            update_range(x, y, z);
            break;
        case 2:
            scanf("%d%d", &x, &y);
            printf("%d\n", query_range(x, y));
            break;
        case 3:
            scanf("%d%d", &x, &y);
            update_tree(x, y);
            break;
        case 4:
            scanf("%d", &x);
            printf("%d\n", query_tree(x));
            break;
        }
    }
    return 0;
}