2019 ICPC 西安邀请赛 J. And And And 树上启发式合并
题目链接:https://nanti.jisuanke.com/t/39277
题目大意:给你一棵树。每条边一边权。我们一共有n*(n+1)/2的点对(u, v)。对每个点对,假设deep[v]<deep[u]那么u-v的最短路径上的所有点属于集合E。问有多少个deep[v']< deep[u']并且v', u'属于E。并且v'到u'的最短路路径边权异或和为0。
思路:这个应该按贡献来算。因为一条路径的异或为0, 这条路一定存在应该LCA,树上启发式合并。
- 1 子树间产生贡献
例如root的子树之间异或为0,那么u的子节点到v的子节点都会包含这条路径。产生贡献ans+=siz[u]+
- 2 根和子树之间产生贡献
那么贡献为ans+=siz[u]*(n-siz[u]+1)
用map[i]维护i的所有的子节点到i的异或值和子节点的个数和
这里我写了应该O(n^2 * log n)的维护。对于一个节点v。往祖先u上传时,把所有的mp[u]^=w(u, v)。如果是一条链,并且所有的异或值都不同就是O(n^2* log n),不过这题可以卡过去。
我们考虑:严格的O(n * log n * log n)的写法。
我们按照上面的思路我们想一想往祖先u上传时能不能不修改mp里面的东西。答案是可以的,我们定义一个数组k[i]:表示mp[i]的点都要^k[i]才是真正的到i的异或值。
那么上传时:
mp[u][k[u]]=(mp[u][k[u]]+s[u])%mod;//u节点自己的贡献 k[u]^=w;
再考虑子树合并时,怎么合并子树。
k[v]^mp[v]=v这个是v这棵子树,真实的到u的的异或值。 因为mp[u]保存的节点真实的到u的的异或值mp[u]^k[u]。 把mp[v]合并到mp[u]只要mp[v]^k[v]^k[u]就可以了
for(auto x: mp[v]){ mp[u][((x.first^k[v])^k[u])]=(mp[u][((x.first^k[v])^k[u])]+(x.second))%mod; }
还有没有更简单的方法?有。
我们甚至不用维护k数组。只要维护根1到每个节点的异或值就可以了。在往根上传递时,不用更新。
1. O(n*n*logn) #include <bits/stdc++.h> #define LL long long using namespace std; struct Node{ LL to; LL w; }; vector<Node> v[100005]; map<LL, LL> mp[100005]; LL s[100005], ans=0; const LL mod=1000000007; LL n; void PUT(LL x){ for(auto pos: mp[x]){ cout<<pos.first<<" "<<pos.second<<endl; }cout<<endl; } void dfs(LL u){ for(auto x: v[u]){ dfs(x.to); s[u]+=s[x.to]; } s[u]++; } void DFS(LL u, LL w){ for(auto to: v[u]){ DFS(to.to, to.w); if(mp[to.to].count(0)){//计算u到子树的贡献 ans+=(mp[to.to][0]*(n-s[to.to]))%mod; ans%=mod; } if(mp[u].size()<mp[to.to].size()){//启发式贪心 swap(mp[u], mp[to.to]); } for(auto x: mp[to.to]){//计算子树间贡献 LL w=x.first; LL siz=x.second; if(mp[u].count(w)){ ans+=(siz*(mp[u][w]))%mod; ans%=mod; } } for(auto x: mp[to.to]){//合并 mp[u][x.first]=(mp[u][x.first]+(x.second))%mod; } } for(auto &x: mp[u]){//暴力更新 mp[0][x.first^w]+=x.second; mp[0][x.first^w]%=mod; } swap(mp[0], mp[u]); mp[0].clear(); mp[u][w]=(mp[u][w]+s[u])%mod; } int main(){ LL x; scanf("%lld", &n); LL w; for(LL i=2; i<=n; i++){ scanf("%lld%lld", &x, &w); v[x].push_back(Node{i, w}); } dfs(1); DFS(1, 0); printf("%lld\n", ans); return 0; }
2.O(n*logn*logn) k[]的写法 #include <bits/stdc++.h> #define LL long long using namespace std; struct Node{ LL to; LL w; }; vector<Node> v[100005]; map<LL, LL> mp[100005]; LL s[100005], k[100005], ans=0; const LL mod=1000000007; LL n; void dfs(LL u){ for(auto x: v[u]){ dfs(x.to); s[u]+=s[x.to]; } s[u]++; } void DFS(LL u, LL w){ for(auto to: v[u]){ DFS(to.to, to.w); if(mp[to.to].count(k[to.to])){计算u到子树的贡献 ans+=(mp[to.to][k[to.to]]*(n-s[to.to]))%mod; ans%=mod; } if(mp[u].size()<mp[to.to].size()){ swap(mp[u], mp[to.to]); swap(k[u], k[to.to]); } for(auto x: mp[to.to]){//计算子树间贡献 LL w=x.first^k[to.to]; LL siz=x.second; if(mp[u].count(w^k[u])){ ans+=(siz*(mp[u][(w^k[u])]))%mod; ans%=mod; } } for(auto x: mp[to.to]){//合并 mp[u][((x.first^k[to.to])^k[u])]=(mp[u][((x.first^k[to.to])^k[u])]+(x.second))%mod; } } mp[u][k[u]]=(mp[u][k[u]]+s[u])%mod; k[u]^=w;//更新 //PUT(u, -1); } int main(){ LL x; scanf("%lld", &n); LL w; for(LL i=2; i<=n; i++){ scanf("%lld%lld", &x, &w); v[x].push_back(Node{i, w}); } dfs(1); DFS(1, 0); printf("%lld\n", ans); return 0; }
//记录根节点到当前节点的记录 #include <bits/stdc++.h> #define LL long long using namespace std; struct Node{ LL to; LL w; }; vector<Node> v[100005]; map<LL, LL> mp[100005]; LL s[100005], ans=0; const LL mod=1000000007; LL n; void dfs(LL u){ for(auto x: v[u]){ dfs(x.to); s[u]+=s[x.to]; } s[u]++; } void DFS(LL u, LL weight){ for(auto to: v[u]){ DFS(to.to, to.w^weight); if(mp[to.to].count(weight)){//计算u到子树的贡献 ans+=(mp[to.to][weight]*(n-s[to.to]))%mod; ans%=mod; } if(mp[u].size()<mp[to.to].size()){ swap(mp[u], mp[to.to]); } for(auto x: mp[to.to]){//计算子树间贡献 LL w=x.first; LL siz=x.second; if(mp[u].count(w)){ ans+=(siz*(mp[u][w]))%mod; ans%=mod; } } for(auto x: mp[to.to]){//合并 mp[u][x.first]=(mp[u][x.first]+(x.second))%mod; } } mp[u][weight]=(mp[u][weight]+s[u])%mod; } int main(){ LL x; scanf("%lld", &n); LL w; for(LL i=2; i<=n; i++){ scanf("%lld%lld", &x, &w); v[x].push_back(Node{i, w}); } dfs(1); DFS(1, 0); printf("%lld\n", ans); return 0; }