题解 | #小 Q 与树#
小 Q 与树
https://ac.nowcoder.com/acm/contest/11171/D
给出一颗树,结点权值为 求:
思路
本题为点分治模板题
以重心为根,用 solve(x)
解决 子树内贡献
每次 solve(x)
时,首先得到 经过该点 和 不经过该点 的贡献总和 calc(x, fa, 0)
这个过程首先利用 dfs_dis(x, fa, 0)
得到以 为根的链信息再将链两两合并,得到 的路径贡献
排除 不经过该点, 即排除 的情况,只需要 先向下走一步,然后统计答案,及 calc(x, fa, 1)
注意到以 为根后会将 删去,则每条路径有且仅会被统计一次,故答案正确。
对于本题,由于求 时若暴力枚举,calc
复杂度会为 总复杂度为
需要先排序处理,calc
复杂度为 总复杂度为
注意: 求重心时 S = sz[x]
每次都要更新,否则复杂度不对
代码
#include <bits/stdc++.h> using namespace std; const int N = 2e5 + 10; #define rep(i, s, t) for (int i = (int)(s); i <= (int)(t); ++i) const int mod = 998244353; int root, mx[N], sz[N]; int n, m, S, tot; bool vis[N]; vector<int> G[N]; bool chkmax(int &x, int y) { if (x < y) return x = y, 1; return false; } int read() { int x = 0, ch = getchar(); while (ch < '0' || ch > '9') ch = getchar(); while (ch >='0' && ch <='9') x = x * 10 + (ch ^ 48), ch = getchar(); return x; } #define int long long int a[N], ans; pair<int, int> q[(int)(1e6) + 10]; void dfs_root(int x, int fa) { mx[x] = 0, sz[x] = 1; for (int y : G[x]) { if (y == fa || vis[y]) continue; dfs_root(y, x); sz[x] += sz[y]; chkmax(mx[x], sz[y]); } chkmax(mx[x], S - sz[x]); if (mx[x] < mx[root]) root = x; } void dfs_dis(int x, int fa, int dis) { q[++ tot] = make_pair(a[x], dis); for (int y : G[x]) { if (y == fa || vis[y]) continue; dfs_dis(y, x, dis + 1); } } int Mod(int x) { return (x % mod + mod) % mod; } int calc(int x, int fa, int org) { int res = 0; tot = 0; dfs_dis(x, fa, org); sort(q + 1, q + tot + 1); int dis_sum = 0, dis_prefix = 0; rep(i, 1, tot) { dis_sum += q[i].second; } rep(i, 1, tot) { dis_prefix += q[i].second; res = Mod(res + q[i].first * (dis_sum - dis_prefix)); res = Mod(res + (tot - i) * q[i].second * q[i].first); } return Mod(res + res); } void solve(int x) { tot = 0; vis[x] = 1; ans += calc(x, 0, 0); for (int y : G[x]) { if (vis[y]) continue; ans = Mod(ans - calc(y, x, 1)); mx[root = 0] = S = n; S = sz[x]; dfs_root(y, x); solve(root); } } signed main() { n = read(); rep(i, 1, n) a[i] = read(); rep(i, 1, n - 1) { int x = read(), y = read(); G[x].push_back(y); G[y].push_back(x); } mx[root = 0] = S = n; dfs_root(1, 0); solve(root); printf("%lld", ans); }