通过两次DFS函数剖链,借助线段树结构实现轻重链剖分
#include <bits/stdc++.h>
#define io cin.tie(0), cout.tie(0), ios::sync_with_stdio(false)
#define LL long long
#define ULL unsigned long long
#define EPS 1e-8
#define INF 0x7fffffff
#define SUB -INF - 1
using namespace std;
const int N = 100010;
int n, m, r, mod;
struct Edge
{
int to, next;
} edge[2 * N];
int head[2 * N], cnt;
void init()
{
for (int i = 0; i < 2 * N; i++)
{
edge[i].next = -1;
head[i] = -1;
}
cnt = 0;
}
void addedge(int u, int v)
{
edge[cnt].to = v;
edge[cnt].next = head[u];
head[u] = cnt++;
}
int w[N], w_new[N];
int tree[N << 2], tag[N << 2];
void addtag(int p, int pl, int pr, int d)
{
tag[p] += d;
tree[p] += d * (pr - pl + 1);
tree[p] %= mod;
}
void push_up(int p)
{
tree[p] = tree[p << 1] + tree[p << 1 | 1];
tree[p] %= mod;
}
void push_down(int p, int pl, int pr)
{
if (tag[p])
{
int mid = (pl + pr) >> 1;
addtag(p << 1, pl, mid, tag[p]);
addtag(p << 1 | 1, mid + 1, pr, tag[p]);
tag[p] = 0;
}
}
void build(int p, int pl, int pr)
{
tag[p] = 0;
if (pl == pr)
{
tree[p] = w_new[pl];
tree[p] %= mod;
return;
}
int mid = (pl + pr) >> 1;
build(p << 1, pl, mid);
build(p << 1 | 1, mid + 1, pr);
push_up(p);
}
void update(int l, int r, int p, int pl, int pr, int d)
{
if (l <= pl && pr <= r)
{
addtag(p, pl, pr, d);
return;
}
push_down(p, pl, pr);
int mid = (pl + pr) >> 1;
if (l <= mid)
update(l, r, p << 1, pl, mid, d);
if (r > mid)
update(l, r, p << 1 | 1, mid + 1, pr, d);
push_up(p);
}
int query(int l, int r, int p, int pl, int pr)
{
if (l <= pl && pr <= r)
return tree[p] % mod;
push_down(p, pl, pr);
int res = 0;
int mid = (pl + pr) >> 1;
if (l <= mid)
res += query(l, r, p << 1, pl, mid);
if (r > mid)
res += query(l, r, p << 1 | 1, mid + 1, pr);
return res;
}
int son[N], id[N], fa[N], deep[N], siz[N], top[N];
void dfs1(int x, int father)
{
deep[x] = deep[father] + 1;
fa[x] = father;
siz[x] = 1;
for (int i = head[x]; ~i; i = edge[i].next)
{
int y = edge[i].to;
if (y != father)
{
fa[y] = x;
dfs1(y, x);
siz[x] += siz[y];
if (!son[x] || siz[son[x]] < siz[y])
son[x] = y;
}
}
}
int num = 0;
void dfs2(int x, int topx)
{
id[x] = ++num;
w_new[num] = w[x];
top[x] = topx;
if (!son[x])
return;
dfs2(son[x], topx);
for (int i = head[x]; ~i; i = edge[i].next)
{
int y = edge[i].to;
if (y != fa[x] && y != son[x])
dfs2(y, y);
}
}
void update_range(int x, int y, int z)
{
while (top[x] != top[y])
{
if (deep[top[x]] < deep[top[y]])
swap(x, y);
update(id[top[x]], id[x], 1, 1, n, z);
x = fa[top[x]];
}
if (deep[x] > deep[y])
swap(x, y);
update(id[x], id[y], 1, 1, n, z);
}
int query_range(int x, int y)
{
int ans = 0;
while (top[x] != top[y])
{
if (deep[top[x]] < deep[top[y]])
swap(x, y);
ans += query(id[top[x]], id[x], 1, 1, n);
ans %= mod;
x = fa[top[x]];
}
if (deep[x] > deep[y])
swap(x, y);
ans += query(id[x], id[y], 1, 1, n);
return ans % mod;
}
void update_tree(int x, int k) { update(id[x], id[x] + siz[x] - 1, 1, 1, n, k); }
int query_tree(int x) { return query(id[x], id[x] + siz[x] - 1, 1, 1, n) % mod; }
int main()
{
init();
scanf("%d%d%d%d", &n, &m, &r, &mod);
for (int i = 1; i <= n; i++)
scanf("%d", &w[i]);
for (int i = 1; i < n; i++)
{
int u, v;
scanf("%d%d", &u, &v);
addedge(u, v);
addedge(v, u);
}
dfs1(r, 0);
dfs2(r, r);
build(1, 1, n);
while (m--)
{
int k, x, y, z;
scanf("%d", &k);
switch (k)
{
case 1:
scanf("%d%d%d", &x, &y, &z);
update_range(x, y, z);
break;
case 2:
scanf("%d%d", &x, &y);
printf("%d\n", query_range(x, y));
break;
case 3:
scanf("%d%d", &x, &y);
update_tree(x, y);
break;
case 4:
scanf("%d", &x);
printf("%d\n", query_tree(x));
break;
}
}
return 0;
}