浙农大第十九届程序设计竞赛 J-Tree & 树上启发式合并的理解
Tree
https://ac.nowcoder.com/acm/contest/7872/J
前言
树上启发式合并好啊,可惜还不太会啊
分析
1.凡事始于朴素:根据题目,我得在每一个子树内部去寻找答案。
假设是这样的一棵树。首先在3节点的子树中去找,首先遍历4节点,
统计一下,这时的lca是3,记录一个k并求出 表示另一个数的出现次数,同时 。然后去到5节点,
同样的,首先找到满足 的数的个数,同时将 。
然后进入10号节点....
这个时候我们发现已经统计完了3节点的子树,向上走一步同时vis数组清零,到达2号,根据统计3节点子树的方法一样,先把3节点所有子树加入贡献(注意,是把vis清零后再重新跑一遍3节点的子树)
然后再进入6节点,统计答案....
2.会发现,有许多不必要的重复操作,我为什么要把3节点子树的vis清零呢?我一定要把所有的都清零吗?于是树上启发式合并就来了。名字听着挺nb的,其实就是最大减少重复操作。就比如,在这棵树中,我如果要减少重复操作,明显就得最小化进入节点较多的子树的次数。也就是说,尽可能多的保留重儿子的信息(不清零),在统计时,只需要搜索轻儿子。
先给出代码
inline void dsu(int u,int v,bool w) { for (int i=h[u];~i;i=nex[i]) { int j=ver[i]; if(j==v||j==son[u]) continue; dsu(j,u,0); } if(son[u]) dsu(son[u],u,1); for (int i=h[u];~i;i=nex[i]) { int j=ver[i]; if(j==v||j==son[u]) continue; cal(j,u,u),upd(j,u,1); } k[val[u]]++; if(!w) upd(u,v,-1); }
然后看图模拟一遍(以1节点为lca统计答案):首先,重儿子(节点数最多的那个)为2,先跑入7,9节点统计完这两个子树对答案产生的贡献之后,再加入重儿子
会发现,此时2节点的子树内部信息不会被清零,
然后开始与轻儿子统计对答案的贡献,避免了重新进入2号节点统计每一个值的出现个数的问题。
代码
/*树上启发式合并*/ #include<bits/stdc++.h> #define R register #define ll long long #define inf INT_MAX using namespace std; const int N=1e5+10; int n,tot;ll ans; int h[N],nex[N<<1],ver[N<<1]; int son[N],val[N],siz[N]; map<int,int>k; inline void add(int x,int y) { nex[tot]=h[x]; ver[tot]=y; h[x]=tot++; } inline void dfs(int u,int v) { siz[u]=1; for (int i=h[u];~i;i=nex[i]) { int j=ver[i]; if(j==v) continue; dfs(j,u); siz[u]+=siz[j]; if(siz[son[u]]<siz[j]) son[u]=j; } } inline void cal(int u,int v,int lca) { ans+=(ll)k[2*val[lca]-val[u]]; for (int i=h[u];~i;i=nex[i]) if(ver[i]!=v) cal(ver[i],u,lca); } inline void upd(int u,int v,int va) { k[val[u]]+=va; for (int i=h[u];~i;i=nex[i]) if(ver[i]!=v) upd(ver[i],u,va); } inline void dsu(int u,int v,bool w) { for (int i=h[u];~i;i=nex[i]) { int j=ver[i]; if(j==v||j==son[u]) continue; dsu(j,u,0); } if(son[u]) dsu(son[u],u,1); for (int i=h[u];~i;i=nex[i]) { int j=ver[i]; if(j==v||j==son[u]) continue; cal(j,u,u),upd(j,u,1); } k[val[u]]++; if(!w) upd(u,v,-1); } int main() { memset(h,-1,sizeof(h)); scanf("%d",&n); for (int i=1;i<=n;i++) scanf("%d",&val[i]); for (int i=1;i<n;i++) { int x,y; scanf("%d%d",&x,&y); add(x,y),add(y,x); } dfs(1,0); dsu(1,1,0); printf("%lld\n",ans*2ll); return 0; }
后话
篇幅极短,且用词不规范,不知是否会误导其他人(应该也没多少人看)
比赛题解 文章被收录于专栏
牛客IOI周赛,团队赛,练习赛,挑战赛,各种模拟赛的部分题解