树上路径
#include <bits/stdc++.h> using namespace std; typedef long long ll; const int mod=1e9+7; const int N=1e5+5; const int iv2=(mod+1)/2; ll w[N]; vector<int>g[N]; ll add(ll a,ll b) { return (a+b)%mod; } ll add(ll a,ll b,ll c) { return add(add(a,b),c); } ll mul(ll a,ll b) { return a*b%mod; } ll mul(ll a,ll b,ll c) { return mul(mul(a,b),c); } int sz[N],son[N],dep[N],f[N]; void dfs(int u,int fa) { dep[u]=dep[fa]+1; f[u]=fa; sz[u]=1; for(int v:g[u]) { if(v==fa) continue; dfs(v,u); sz[u]+=sz[v]; if(sz[son[u]]<sz[v]) { son[u]=v; } } } int idx[N],top[N],id; ll val[N]; void DFS(int u,int tp) { idx[u]=++id; top[u]=tp; val[id]=w[u]; if(!son[u]) return; DFS(son[u],tp); for(int v:g[u]) { if(!idx[v]) { DFS(v,v); } } } struct SegTree{ int l,r,len; ll lazy,sum,ans; }Tr[N<<2]; void change(int u,ll k) { Tr[u].lazy=add(Tr[u].lazy,k); Tr[u].ans=add(Tr[u].ans,mul(mul(Tr[u].len,iv2,Tr[u].len-1),(k*k%mod)),mul(Tr[u].sum,k,Tr[u].len-1)); Tr[u].sum=add(Tr[u].sum,(Tr[u].len)*k%mod); } void pushup(int u) { Tr[u].sum=add(Tr[u<<1].sum,Tr[u<<1|1].sum); Tr[u].ans=add(Tr[u<<1].ans,Tr[u<<1|1].ans,mul(Tr[u<<1].sum,Tr[u<<1|1].sum)); } void pushdown(int u) { if(Tr[u].lazy) { change(u<<1,Tr[u].lazy); change(u<<1|1,Tr[u].lazy); Tr[u].lazy=0; } } void build(int u,int l,int r) { Tr[u].l=l,Tr[u].r=r; Tr[u].len=(r-l+1); if(l==r) { Tr[u].sum=val[l]; return; } int mid=(l+r)>>1; build(u<<1,l,mid); build(u<<1|1,mid+1,r); pushup(u); } void add(int u,int l,int r,ll k) { //if(l>Tr[u].r||r<Tr[u].l) return; if(Tr[u].l>=l&&Tr[u].r<=r) { change(u,k); return; } pushdown(u); int mid=(Tr[u].l+Tr[u].r)/2; if(l<=mid) add(u<<1,l,r,k); if(r>mid) add(u<<1|1,l,r,k); pushup(u); } SegTree merge(SegTree l,SegTree r) { SegTree res; res.ans=add(l.ans,r.ans,mul(l.sum,r.sum)); res.sum=add(l.sum,r.sum); return res; } SegTree query(int u,int l,int r) { if(Tr[u].l>=l&&Tr[u].r<=r) { return Tr[u]; } pushdown(u); int mid=(Tr[u].l+Tr[u].r)>>1; if(r<=mid) return query(u<<1,l,r); else if(l>mid) return query(u<<1|1,l,r); else return merge(query(u<<1,l,r),query(u<<1|1,l,r)); } void Treeadd(int u,int v,ll k) { while(top[u]!=top[v]) { if(dep[top[u]]<dep[top[v]]) swap(u,v); add(1,idx[top[u]],idx[u],k); u=f[top[u]]; } if(dep[v]>dep[u]) swap(u,v); add(1,idx[v],idx[u],k); } ll Treequery(int u,int v) { SegTree res; res.sum=0,res.ans=0; while(top[u]!=top[v]) { if(dep[top[u]]<dep[top[v]]) swap(u,v); res=merge(res,query(1,idx[top[u]],idx[u])); u=f[top[u]]; } if(dep[v]>dep[u]) swap(u,v); res=merge(res,query(1,idx[v],idx[u])); return res.ans; } int main() { int n,m; scanf("%d%d",&n,&m); for(int i=1;i<=n;i++) { scanf("%lld",&w[i]); } for(int i=1;i<n;i++) { int u,v; scanf("%d%d",&u,&v); g[u].push_back(v); g[v].push_back(u); } dfs(1,1); DFS(1,1); build(1,1,n); while(m--) { int opt,v,u;ll k; scanf("%d",&opt); if(opt==1) { scanf("%d%lld",&u,&k); add(1,idx[u],idx[u]+sz[u]-1,k); } else if(opt==2) { scanf("%d%d%lld",&u,&v,&k); Treeadd(u,v,k); } else { scanf("%d%d",&u,&v); printf("%lld\n",Treequery(u,v)); } } return 0; }
这份代码交这个题,为什么有时候段错误,有时候ac...