7-5 1E. 树与路径(巧妙的树上差分)
在一棵有根树 T 上,任何两点间的最短路径都能够分为两个阶段:
从起点出发,沿着向根的方向走若干条边。
向着终点,沿着离开根的方向走若干条边。
定义一条路径的权值为向上走的边数乘上向下走的边数。特殊地,当起点等于终点的时候,两阶段的边数都是 0;当起点是终点的祖先的时候,第一阶段的边数是 0;当终点是起点的祖先的时候,第二阶段的边数是 0------这三种情况下,路径的权值都是 0。
现在给出一棵 n 个节点的无根树 T 和 m 条路径 (ai ,bi )。对于每一个 r∈[1,n],你需要计算当 r 是根节点的时候,所有路径的权值和是多少。
输入格式:
第一行输入两个整数 n,m(1≤n,m≤3×105 )。
接下来 n−1 行每行输入两个整数 ui ,vi (1≤ui ,vi ≤n),表示树上的一条边。
接下来 m 行每行输入两个整数 ai ,bi (1≤ai ,bi ≤n),表示一条路径。
输出格式
输出 n 行每行一个整数,第 i 行表示以 i 为根时,所有路径的权值和。
思路:树上差分一个等差序列,可以化成常数和已知数的形式,真棒
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<string>
#include<cmath>
#include<queue>
#include<cstdlib>
#include<map>
#include<set>
#define ll long long
#define llu unsigned ll
#define ld long double
#define pr make_pair
#define pb push_back
#define x first
#define y second
#define int ll
using namespace std;
const int inf=0x3f3f3f3f;
const ll lnf=0x3f3f3f3f3f3f3f3f;
const int mod=1e9+7;
const int maxn=300100;
int head[maxn],ver[maxn<<1],nt[maxn<<1];
int d[maxn],f[maxn][22],lca[maxn];
int a[maxn],b[maxn],l[maxn];
int rk1[maxn],rk2[maxn],rk3[maxn];
int ans1[maxn];
int n,m,tot,x,y,t;
void add(int x,int y)
{
ver[++tot]=y,nt[tot]=head[x],head[x]=tot;
}
void dfs1(int x,int fa)
{
for(int i=head[x];i;i=nt[i])
{
int y=ver[i];
if(y==fa) continue;
d[y]=d[x]+1;
f[y][0]=x;
for(int j=1;j<=t;j++)
f[y][j]=f[f[y][j-1]][j-1];
dfs1(y,x);
}
}
int lc(int x,int y)
{
if(d[x]>d[y]) swap(x,y);
for(int i=t;i>=0;i--)
if(d[f[y][i]]>=d[x]) y=f[y][i];
if(x==y) return x;
for(int i=t;i>=0;i--)
if(f[y][i]!=f[x][i]) y=f[y][i],x=f[x][i];
return f[x][0];
}
void dfs2(int x,int fa)
{
for(int i=head[x];i;i=nt[i])
{
int y=ver[i];
if(y==fa) continue;
dfs2(y,x);
rk1[x]+=rk1[y];
rk2[x]+=rk2[y];
rk3[x]+=rk3[y];
}
}
void dfs3(int x,int fa)
{
for(int i=head[x];i;i=nt[i])
{
int y=ver[i];
if(y==fa) continue;
ans1[y]=ans1[x]-(rk1[y]+rk2[y]*d[y]+rk3[y]);
dfs3(y,x);
//cout<<"y: "<<y<<" ans1: "<<ans1[y]<<endl;
}
}
signed main(void)
{
scanf("%lld%lld",&n,&m);
t=log(n)/log(2)+1;
for(int i=1;i<n;i++)
{
scanf("%lld%lld",&x,&y);
add(x,y),add(y,x);
}
d[0]=-1;
dfs1(1,0);
int ans=0;
for(int i=1;i<=m;i++)
{
scanf("%lld%lld",&a[i],&b[i]);
lca[i]=lc(a[i],b[i]);
l[i]=d[a[i]]+d[b[i]]-2*d[lca[i]];
ans+=(d[a[i]]-d[lca[i]])*(d[b[i]]-d[lca[i]]);
rk1[a[i]]+=l[i]-2*d[a[i]];
rk1[b[i]]+=l[i]-2*d[b[i]];
rk1[lca[i]]-=l[i]-2*d[a[i]]+l[i]-2*d[b[i]];
rk2[a[i]]+=2;
rk2[b[i]]+=2;
rk2[lca[i]]-=4;
rk3[a[i]]-=1;
rk3[b[i]]-=1;
rk3[lca[i]]+=2;
}
dfs2(1,0);
ans1[1]=ans;
dfs3(1,0);
//cout<<ans<<endl;
//cout<<d[2]<<endl;
//for(int i=1;i<=n;i++)
// printf("rk1:%lld rk2:%lld rk3:%lld\n",rk1[i],rk2[i],rk3[i]);
for(int i=1;i<=n;i++)
printf("%lld\n",ans1[i]);
return 0;
}