<span>2019icpc沈阳网络赛 D Fish eating fruit 树形dp</span>
题意
分别算一个树中所有简单路径长度模3为0,1,2的距离和乘2。
分析
记录两个数组,
\(dp[i][k]\)为距i模3为k的子节点到i的距离和
\(f[i][k]\)为距i模3为k的子节点的个数
\(ans[k]\)为所有简单路径长度模3为k的距离和
\(v\)是\(u\)的儿子,\(c\)是u到v的边长度,\(0<i,j<3,k=(j-c\%3+3)\%3\)
-
\(dp[u][(i+c\%3)\%3]+=dp[v][i]+f[v][i]*c\)
-
\(f[u][(i+c\%3)\%3]+=f[v][i]\)
-
\(ans[(i+j+c\%3)\%3]+=f[v][i]*(d[u][j]-d[v][k]-f[v][k]*c)\),算u的子节点跨越u的答案
-
\(ans[i]+=dp[u][i]\)
Code
#include<bits/stdc++.h>
#define fi first
#define se second
#define pb push_back
#define lson l,mid,p<<1
#define rson mid+1,r,p<<1|1
#define ll long long
using namespace std;
const int inf=1e9;
const int mod=1e9+7;
const int maxn=1e5+10;
int n;
typedef pair<int,int>pii;
vector<pii>g[maxn];
ll d[maxn][3],f[maxn][3],ans[3];
void dfs(int u,int fa){
f[u][0]=1;
for(pii x:g[u]){
if(x.fi==fa) continue;
int v=x.fi;ll c=x.se;
dfs(x.fi,u);
for(int i=0;i<3;i++){
(d[u][(i+c%3)%3]+=(d[v][i]+f[v][i]*c%mod)%mod)%=mod;
(f[u][(i+c%3)%3]+=f[v][i])%=mod;
}
}
}
void dfs1(int u,int fa){
ans[0]=(ans[0]+d[u][0])%mod;
ans[1]=(ans[1]+d[u][1])%mod;
ans[2]=(ans[2]+d[u][2])%mod;
for(pii x:g[u]){
if(x.fi==fa) continue;
int v=x.fi;ll c=x.se;
for(int i=0;i<3;i++){
for(int j=0;j<3;j++){
int k=(j-c%3+3)%3;
(ans[(i+j+c%3)%3]+=f[v][i]*(d[u][j]-d[v][k]-f[v][k]*c%mod)%mod)%=mod;
}
}
dfs1(x.fi,u);
}
}
int main(){
//ios::sync_with_stdio(false);
//freopen("in","r",stdin);
while(~scanf("%d",&n)){
for(int i=1,a,b,c;i<n;i++){
scanf("%d%d%d",&a,&b,&c);
++a;++b;
g[a].pb(pii(b,c));
g[b].pb(pii(a,c));
}
ans[0]=ans[1]=ans[2]=0;
dfs(1,0);dfs1(1,0);
printf("%lld %lld %lld\n",ans[0]*2%mod,ans[1]*2%mod,ans[2]*2%mod);
for(int i=1;i<=n;i++) g[i].clear(),d[i][0]=d[i][1]=d[i][2]=f[i][0]=f[i][1]=f[i][2]=0;
}
return 0;
}