牛客练习赛71 E 神奇的迷宫
神奇的迷宫
https://ac.nowcoder.com/acm/contest/7745/E
题目可以简化为ans[i]表示距离为i的点对个数的概率和,求出这个概率和即可。
考虑使用点分支分解整棵树,然后在子树中选取深度小的进行启发式合并,这里合并用ntt进行加速。
复杂度O(nlognlogn)。
#include <cstdio> #include <algorithm> #include <cstring> #include <vector> using namespace std; typedef long long ll; const int N = 262144 + 100; const int MOD = 998244353; namespace NTT { #define pw(n) (1<<n) const int N = 262144, P = 998244353, g = 3;//或P=1004535809 int n, m, bit, bitnum = 0, a[N + 5], b[N + 5], rev[N + 5]; void getrev(int l) { for (int i = 0; i < pw(l); i++) { rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (l - 1)); } } int fastpow(int a, int b) { int ans = 1; for (; b; b >>= 1, a = 1LL * a*a%P) { if (b & 1)ans = 1LL * ans*a%P; } return ans; } void NTT(int *s, int op) { for (int i = 0; i < bit; i++)if (i < rev[i])swap(s[i], s[rev[i]]); for (int i = 1; i < bit; i <<= 1) { int w = fastpow(g, (P - 1) / (i << 1)); for (int p = i << 1, j = 0; j < bit; j += p) { int wk = 1; for (int k = j; k < i + j; k++, wk = 1LL * wk*w%P) { int x = s[k], y = 1LL * s[k + i] * wk%P; s[k] = (x + y) % P; s[k + i] = (x - y + P) % P; } } } if (op == -1) { reverse(s + 1, s + bit); int inv = fastpow(bit, P - 2); for (int i = 0; i < bit; i++)s[i] = 1LL * s[i] * inv%P; } } int solve(int *aa, int nn, int *bb, int mm, int *c) { n = nn; m = mm; bit = bitnum = 0; for (int i = 0; i <= n; i++) a[i] = aa[i]; for (int i = 0; i <= m; i++) b[i] = bb[i]; m += n; for (bit = 1; bit <= m; bit <<= 1)bitnum++; getrev(bitnum); NTT(a, 1); NTT(b, 1); for (int i = 0; i < bit; i++) a[i] = 1LL * a[i] * b[i] % P; NTT(a, -1); for (int i = 0; i < bit; i++) c[i] = a[i]; for (int i = 0; i < bit; i++) a[i] = b[i] = 0; return bit; } } ll qpow(ll x, ll n) { ll res = 1; while (n > 0) { if (n & 1) res = res * x % MOD; n /= 2; x = x * x % MOD; } return res; } int n, MX, R; int sa[N], ww[N], siz[N], ms[N]; bool vis[N]; vector<int> V[N]; void getroot(int u, int fa) { siz[u] = 1; ms[u] = 0; for (int v : V[u]) { if (vis[v] || v == fa) continue; getroot(v, u); siz[u] += siz[v]; ms[u] = max(ms[u], siz[v]); } ms[u] = max(ms[u], MX - siz[u]); if (ms[u] < ms[R]) R = u; } int dep[N], res[N], now[N], ss[N], md[N], ans[N]; void upd(int &a, int b) { a += b; if (a >= MOD) a -= MOD; } void dfs(int u, int fa) { md[u] = dep[u]; siz[u] = 1; for (int v : V[u]) { if (vis[v] || v == fa) continue; dep[v] = dep[u] + 1; dfs(v, u); siz[u] += siz[v]; md[u] = max(md[u], md[v]); } } void dfs1(int u, int fa) { upd(res[dep[u]], sa[u]); for (int v : V[u]) { if (vis[v] || v == fa) continue; dfs1(v, u); } } int id[N], tp; bool cmp(int a, int b) { return md[a] < md[b]; } void divide(int u) { vis[u] = true; tp = 0; int mm = 0; for (int v : V[u]) { if (vis[v]) continue; dep[v] = 1; dfs(v, u); id[++tp] = v; } sort(id + 1, id + tp + 1, cmp); now[0] = sa[u]; for (int i = 1; i <= tp; i++) { int v = id[i]; dfs1(v, u); int tt = NTT::solve(now, mm, res, md[v], ss); for (int i = 1; i <= tt; i++) upd(ans[i], ss[i]); for (int i = 0; i <= md[v]; i++) upd(now[i], res[i]); for (int i = 0; i <= md[v]; i++) res[i] = 0; for (int i = 0; i <= tt; i++) ss[i] = 0; mm = max(mm, md[v]); } for (int i = 0; i <= mm; i++) now[i] = 0; for (int v : V[u]) { if (vis[v]) continue; R = 0; MX = siz[v]; getroot(v, u); divide(R); } } int main() { //freopen("0.txt", "r", stdin); int a, b; scanf("%d", &n); ll sum = 0; for (int i = 1; i <= n; i++) { scanf("%d", sa + i); sum += sa[i]; if (sum >= MOD) sum -= MOD; } ll RR = qpow(sum, MOD - 2); for (int i = 1; i <= n; i++) { sa[i] = RR * sa[i] % MOD; ans[0] = (ans[0] + 1LL * sa[i] * sa[i]) % MOD; } for (int i = 0; i < n; i++) scanf("%d", ww + i); for (int i = 1; i < n; i++) { scanf("%d%d", &a, &b); V[a].push_back(b); V[b].push_back(a); } ms[0] = 1e9; MX = n; getroot(1, 0); divide(R); ll r = 1LL * ans[0] * ww[0] % MOD; for (int i = 1; i < n; i++) r = (r + 1LL * ans[i] * ww[i] * 2) % MOD; printf("%lld\n", r); return 0; }