题解 | #小 Q 与异或#
小 Q 与树
https://ac.nowcoder.com/acm/contest/11171/D
题目大意
给定一棵包含 个节点的树,每个节点有个权值
求
解题思路
对于节点
- 记权值小于 的节点有
- 记权值大于等于 的节点有
那么节点 对答案的贡献为:
即:
定义 为当前子树的根,那么
开四棵权值树状数组,分别用来维护 、、 、
然后跑一遍 即可
AC_Code
#include<bits/stdc++.h> #define int long long using namespace std; template<typename T>void read(T &res) { bool flag=false; char ch; while(!isdigit(ch=getchar()))(ch=='-')&&(flag=true); for(res=ch-48; isdigit(ch=getchar()); res=(res<<1)+(res<<3)+ch - 48); flag&&(res=-res); } template<typename T>void Out(T x) { if(x<0)putchar('-'),x=-x; if(x>9)Out(x/10); putchar(x%10+'0'); } const int N = 2e5 + 10 , mod = 998244353; int n , ans , a[N] , dep[N] , sz[N] , HH , hson[N] , M; struct Edge{ int nex , to; } edge[N << 1]; int head[N] , TOT; void add_edge(int u , int v) { edge[++ TOT].nex = head[u]; edge[TOT].to = v; head[u] = TOT; } struct TR{ int tr[N]; int lowbit(int x){ return x & (-x); } void add(int pos , int val) { while(pos <= M) { tr[pos] = (tr[pos] + val + mod) % mod; pos += lowbit(pos); } } int query(int pos) { int res = 0; while(pos) { res += tr[pos]; res %= mod; pos -= lowbit(pos); } return res; } int get_sum(int L , int R){ return (query(R) - query(L - 1) + mod) % mod; } } tree1 , tree2 , tr1 , tr2; vector<int>vec; int get_id(int x){ return lower_bound(vec.begin() , vec.end() , x) - vec.begin() + 1; } void dfs(int u , int far) { dep[u] = dep[far] + 1 , sz[u] = 1; for(int i = head[u] ; i ; i = edge[i].nex) { int v = edge[i].to; if(v == far) continue ; dfs(v , u); sz[u] += sz[v]; if(sz[v] > sz[hson[u]]) hson[u] = v; } } void change(int u , int far , int val) { tree1.add(a[u] , dep[u] * val); tree2.add(a[u] , vec[a[u] - 1] * dep[u] * val); tr1.add(a[u] , val); tr2.add(a[u] , val * vec[a[u] - 1]); for(int i = head[u] ; i ; i = edge[i].nex) { int v = edge[i].to; if(v == far || v == HH) continue ; change(v , u , val); } } void calc(int u , int far , int rt) { int cnt = tr1.get_sum(a[u] , M); int sum = tree1.get_sum(a[u] , M); int mi = vec[a[u] - 1]; ans += mi * dep[u] * cnt + mi * sum; ans -= mi * cnt * 2 * dep[rt]; ans = (ans + mod) % mod; sum = tree2.get_sum(1 , a[u] - 1); cnt = tr2.get_sum(1 , a[u] - 1); ans += sum + cnt * dep[u]; ans -= cnt * 2 * dep[rt]; ans = (ans + mod) % mod; for(int i = head[u] ; i ; i = edge[i].nex) { int v = edge[i].to; if(v == far || v == HH) continue ; calc(v , u , rt); } } void dsu(int u , int far , int op) { for(int i = head[u] ; i ; i = edge[i].nex) { int v = edge[i].to; if(v == far || v == hson[u]) continue ; dsu(v , u , 0); } if(hson[u]) dsu(hson[u] , u , 1) , HH = hson[u]; for(int i = head[u] ; i ; i = edge[i].nex) { int v = edge[i].to; if(v == far || v == HH) continue; calc(v , u , u) , change(v , u , 1); } int cnt = tr1.get_sum(a[u] , M); int sum = tree1.get_sum(a[u] , M); int mi = vec[a[u] - 1]; ans += mi * dep[u] * cnt + mi * sum; ans -= mi * cnt * 2 * dep[u]; ans = (ans + mod) % mod; sum = tree2.get_sum(1 , a[u] - 1); cnt = tr2.get_sum(1 , a[u] - 1); ans += sum + cnt * dep[u]; ans -= cnt * 2 * dep[u]; ans = (ans + mod) % mod; tree1.add(a[u] , dep[u]); tree2.add(a[u] , vec[a[u] - 1] * dep[u]); tr1.add(a[u] , 1); tr2.add(a[u] , vec[a[u] - 1]); HH = 0; if(!op) change(u , far , -1); } signed main() { read(n); for(int i = 1 ; i <= n ; i ++) read(a[i]) , vec.push_back(a[i]); for(int i = 1 ; i < n ; i ++) { int u , v; read(u) , read(v); add_edge(u , v) , add_edge(v , u); } sort(vec.begin() , vec.end()); vec.erase(unique(vec.begin() , vec.end()) , vec.end()); for(int i = 1 ; i <= n ; i ++) a[i] = get_id(a[i]); M = vec.size(); dfs(1 , 0); dsu(1 , 0 , 1); Out(ans * 2 % mod) , puts(""); return 0; }