初遇 - 树链剖分
树链剖分
前言:树链剖分,久闻大名。从半年前了解到这个知识点,却一直拖到寒假才来学...emm,确实拖了挺久的。
小小总结一下,每个父节点都有一个重儿子和若干轻儿子,从一个轻结点出发,一直沿着重儿子延伸下去,就是一条重链。
因此,一颗树就被剖分成了若干重链与若干叶子结点。并且,每一条重链的头都是轻结点(根也是轻结点)。
我的一些心得是,可以把一条重链看作是一个并查集,并查集的头是轻结点,同时若干并查集之间相互关联。
具体的一些实现过程:
- 跑一遍dfs1,标记以下内容:
- 结点的父亲fa
- 结点的深度dep
- 结点的大小siz
- 结点的重儿子son(取决于结点的大小)
- 再跑一遍dfs2,标记以下内容:
- 结点的时间戳dfn
- 当前结点所在重链的头top
- 结点权值的dfs序
我喜欢用结构体封装 (当然也可以不用)
void dfs1(ll x,ll ac){
tr[x].fa=ac;
tr[x].dep=++dep;
tr[x].siz=1;
ll k=head[x];
while(k){
if(eg[k].to!=ac)dfs1(eg[k].to,x);
k=eg[k].nxt;
}
if(!tr[ac].son||tr[tr[ac].son].siz<tr[x].siz)tr[ac].son=x;
tr[ac].siz+=tr[x].siz;
dep--;
}
void dfs2(ll x){
tr[x].dfn=++num;
tr[x].top=pos;
a[num]=tr[x].w;
if(!tr[x].son)return;
dfs2(tr[x].son);
ll k=head[x];
while(k){
if(eg[k].to!=tr[x].fa&&eg[k].to!=tr[x].son)pos=eg[k].to,dfs2(eg[k].to);
k=eg[k].nxt;
}
}
至此,树链剖分预处理部分结束,我们可以用来对树进行操作。
比如取两点lca,操作以x为根的子树,或者操作x到y的最短路上的点权等等。
E.g.
具体拿洛谷P3384举例,是一道标准的模板题
链接:https://www.luogu.com.cn/problem/P3384
题目要求四个操作:
1 x y z,表示将树从 x 到 y 结点最短路径上所有节点的值都加上 z。
2 x y,表示求树从 x 到 y 结点最短路径上所有节点的值之和。
3 x z,表示将以 x 为根节点的子树内所有节点值都加上 z。
4 x 表示求以 x 为根节点的子树内所有节点值之和
很简单,依据dfs序的特性,对于3和4操作,我们可以轻松知道以 x 为根的子树的dfn区间是 [dfn(x),dfn(x)+siz(x)-1] ,即x的dfn到加上x结点的大小-1为止。用线段树处理dfs序后的区间权值即可。
其次,对于1和2的操作,我们需要找出他们的最短路。
- 设想,如果 x 和 y 在同一条重链上,那么他们的dfn肯定是连续的,只需在线段树上区间操作就好了。
- 那么不在一条重链上呢?我们需要让两结点中其top结点深度较大的一端往上跳,跳到top结点的父亲,即 x 跳到 fa(top(x)) 上,同时,因top(x) 到 x 的dfn一定是连续的,所以在跳的过程中,对线段树进行区间操作。该过程持续到两结点处于同一条重链上。
void mchain(ll x,ll y,ll val){
while(tr[x].top!=tr[y].top){
if(tr[tr[x].top].dep<tr[tr[y].top].dep)modify(1,tr[tr[y].top].dfn,tr[y].dfn,val),y=tr[tr[y].top].fa;
else modify(1,tr[tr[x].top].dfn,tr[x].dfn,val),x=tr[tr[x].top].fa;
}
if(tr[x].dep>tr[y].dep)swap(x,y);
modify(1,tr[x].dfn,tr[y].dfn,val);
}
ll qchain(ll x,ll y){
ll res=0;
while(tr[x].top!=tr[y].top){
if(tr[tr[x].top].dep<tr[tr[y].top].dep)res=(res+query(1,tr[tr[y].top].dfn,tr[y].dfn))%pp,y=tr[tr[y].top].fa;
else res=(res+query(1,tr[tr[x].top].dfn,tr[x].dfn))%pp,x=tr[tr[x].top].fa;
}
if(tr[x].dep>tr[y].dep)swap(x,y);
res=(res+query(1,tr[x].dfn,tr[y].dfn))%pp;
return res;
}
全代码:
#include<bits/stdc++.h>
#define ll long long
#define L(i,j,k) for(ll i=(j);i<=(k);i++)
#define R(i,j,k) for(ll i=(j);i>=(k);i--)
#define inf 0x3f3f3f3f3f3f3f3f
#define fi first
#define se second
#define MS(i,j) memset(i,j,sizeof (i))
const ll N=1e6+10,M=5,mod=998244353,mmod=mod-1;
const double pi=acos(-1);
using namespace std;
ll gcd(ll x,ll y){if(y==0) return x;return gcd(y,x%y);}
ll fksm(ll a,ll b){ll r=1;if(b<0)b+=mod-1;for(a%=mod;b;b>>=1){if(b&1)r=r*a%mod;a=a*a%mod;}return r;}//a 分母; b MOD-2
ll lowbit(ll x){return x&(-x);}
ll m,n,t,x,y,z,l,r,k,p,pp,nx,ny,ansx,ansy,lim,num,sum,pos,tot,dep,root,block,key,cnt,minn,maxx,ans;
ll a[N],head[N],dx[5]={0,0,-1,1},dy[5]={-1,1,0,0};
double dans;
bool vis[1010][1010],flag;
char mapp[1010][1010],zz;
struct qq{ll x,y,z;}q;
struct tree{ll l,r,tag,sum;}trs[N*4];
struct Tree{ll fa,dep,dfn,siz,son,top,w;}tr[N];
struct Trp{ll l,r,fat,dep,n,w;}trp;
struct E{ll to,nxt,w;}eg[N*2];
struct matrix{ll n,m[M][M];};
struct complx{
double r,i;
complx(){}
complx(double r,double i):r(r),i(i){}
complx operator+(const complx& rhs)const{return complx (r+rhs.r,i+rhs.i);}
complx operator-(const complx& rhs)const{return complx (r-rhs.r,i-rhs.i);}
complx operator*(const complx& rhs)const{return complx (r*rhs.r-i*rhs.i,i*rhs.r+r*rhs.i);}
void operator+=(const complx& rhs){r+=rhs.r,i+=rhs.i;}
void operator*=(const complx& rhs){r=r*rhs.r-i*rhs.i,i=r*rhs.i+i*rhs.r;}
void operator/=(const double& x){r/=x,i/=x;}
complx conj(){return complx(r,-i);}
};
bool cmp(qq u,qq v){
return u.x*v.y>v.x*u.y;
}
bool cmp1(qq u,qq v){
return u.x<v.x;
}
bool cmpl(ll u,ll v){return u>v;}
struct cmps{bool operator()(ll u,ll v){
return u>v;
}};//shun序
pair<ll,ll>pre[1010][1010];
vector<ll>v;//v.assign(m,vector<ll>(n));
//priority_queue<ll,vector<ll>,cmps>sp;
deque<qq>sq;
map<ll,ll>mp;
bitset<M>bi;
void add(ll u,ll v,ll w){
eg[++cnt].to=v;
eg[cnt].nxt=head[u];
eg[cnt].w=w;
head[u]=cnt;
}
void push_up(ll k){
trs[k].sum=(trs[k*2].sum+trs[k*2+1].sum)%pp;
}
void push_down(ll k){
if(trs[k].tag>0){
ll l=k*2,r=k*2+1;
trs[l].tag=(trs[l].tag+trs[k].tag)%pp;
trs[r].tag=(trs[r].tag+trs[k].tag)%pp;
trs[l].sum=(trs[l].sum+(trs[l].r-trs[l].l+1)*trs[k].tag)%pp;
trs[r].sum=(trs[r].sum+(trs[r].r-trs[r].l+1)*trs[k].tag)%pp;
trs[k].tag=0;
}
}
void bd_tree(ll k,ll l,ll r){
trs[k].tag=0;
trs[k].l=l,trs[k].r=r;
if(l==r){
trs[k].sum=a[l]%pp;
return;
}
ll mid=(l+r)/2;
bd_tree(k*2,l,mid);
bd_tree(k*2+1,mid+1,r);
push_up(k);
}
ll query(ll k,ll pl,ll pr){
ll ml=0,mr=0;
if(trs[k].l>=pl&&trs[k].r<=pr){
return trs[k].sum;
}
push_down(k);
ll mid=(trs[k].l+trs[k].r)/2;
if(mid>=pl)ml=query(k*2,pl,pr);
if(mid+1<=pr)mr=query(k*2+1,pl,pr);
return (ml+mr)%pp;
}
void modify(ll k,ll pl,ll pr,ll val){//[pl,pr]改为val
if(trs[k].l>=pl&&trs[k].r<=pr){
trs[k].sum=(trs[k].sum+(trs[k].r-trs[k].l+1)*val)%pp;
trs[k].tag=(trs[k].tag+val)%pp;
return;
}
push_down(k);
ll mid=(trs[k].l+trs[k].r)/2;
if(mid>=pl)modify(k*2,pl,pr,val);
if(mid+1<=pr)modify(k*2+1,pl,pr,val);
push_up(k);
}
void dfs1(ll x,ll ac){
tr[x].fa=ac;
tr[x].dep=++dep;
tr[x].siz=1;
ll k=head[x];
while(k){
if(eg[k].to!=ac)dfs1(eg[k].to,x);
k=eg[k].nxt;
}
if(!tr[ac].son||tr[tr[ac].son].siz<tr[x].siz)tr[ac].son=x;
tr[ac].siz+=tr[x].siz;
dep--;
}
void dfs2(ll x){
tr[x].dfn=++num;
tr[x].top=pos;
a[num]=tr[x].w;
if(!tr[x].son)return;
dfs2(tr[x].son);
ll k=head[x];
while(k){
if(eg[k].to!=tr[x].fa&&eg[k].to!=tr[x].son)pos=eg[k].to,dfs2(eg[k].to);
k=eg[k].nxt;
}
}
void mchain(ll x,ll y,ll val){
while(tr[x].top!=tr[y].top){
if(tr[tr[x].top].dep<tr[tr[y].top].dep)modify(1,tr[tr[y].top].dfn,tr[y].dfn,val),y=tr[tr[y].top].fa;
else modify(1,tr[tr[x].top].dfn,tr[x].dfn,val),x=tr[tr[x].top].fa;
}
if(tr[x].dep>tr[y].dep)swap(x,y);
modify(1,tr[x].dfn,tr[y].dfn,val);
}
ll qchain(ll x,ll y){
ll res=0;
while(tr[x].top!=tr[y].top){
if(tr[tr[x].top].dep<tr[tr[y].top].dep)res=(res+query(1,tr[tr[y].top].dfn,tr[y].dfn))%pp,y=tr[tr[y].top].fa;
else res=(res+query(1,tr[tr[x].top].dfn,tr[x].dfn))%pp,x=tr[tr[x].top].fa;
}
if(tr[x].dep>tr[y].dep)swap(x,y);
res=(res+query(1,tr[x].dfn,tr[y].dfn))%pp;
return res;
}
int main(){
scanf("%lld%lld%lld%lld",&n,&m,&p,&pp);
L(i,1,n)scanf("%lld",&tr[i].w);
cnt=0;
L(i,1,n-1){
scanf("%lld%lld",&x,&y);
add(x,y,0);
add(y,x,0);
}
num=0;dep=0,pos=p;
dfs1(p,0);
dfs2(p);
bd_tree(1,1,n);
//L(i,1,n)printf("%lld ",tr[i].fa);printf("\n");
L(i,1,m){
scanf("%lld",&k);
if(k==1){
scanf("%lld%lld%lld",&x,&y,&z);
mchain(x,y,z%pp);
}
else if(k==2){
scanf("%lld%lld",&x,&y);
printf("%lld\n",qchain(x,y));
}
else if(k==3){
scanf("%lld%lld",&x,&y);
modify(1,tr[x].dfn,tr[x].dfn+tr[x].siz-1,y%pp);
}
else{
scanf("%lld",&x);
printf("%lld\n",query(1,tr[x].dfn,tr[x].dfn+tr[x].siz-1));
}
//L(i,1,n)printf("%lld ",query(1,i,i));printf("\n");
}
}
E.g.2有难度
最近做到的牛客挑战赛57的C题,我拿来讲一下
链接:https://ac.nowcoder.com/acm/contest/11197/C
题目要求两个操作:
-
给定 x,y,令 x→y 的最短路上的点构成的点序列为 p,对于所有 i>1,令 b[ p[i] ] 增加 a[ p[i-1] ]
-
给定 x,输出 b[x] 的值。
很明显的树链操作,对于 x 到 y 的最短路径上的点的计数器 b 依次加上前一个点的点权 a。 一样的先树链剖分,对于一个结点,我们可以发现只需统计三个值:
- 由父亲加过来的点权
- 由重儿子加过来的点权
- 由其他轻儿子加过来的点权
在操作1中,因为 x 到 y 的重链数量为logn量级,所以可以发现,第3个值的统计也为logn量级,而由于一条链上结点的dfn值连续性,可以区间处理统计第1、2的值,因此同样可以借助两个线段树或树状数组来维护1、2值加过来的个数。
对于操作2,因为只询问单个结点,因此只需将3个值加起来即可。
代码有些冗余,因为使用了两个相同的线段树……
全代码:
#include<bits/stdc++.h>
#define ll long long
#define L(i,j,k) for(ll i=(j);i<=(k);i++)
#define R(i,j,k) for(ll i=(j);i>=(k);i--)
#define inf 0x3f3f3f3f3f3f3f3f
#define fi first
#define se second
#define MS(i,j) memset(i,j,sizeof (i))
const ll N=1e6+10,M=5,mod=998244353,mmod=mod-1;
const double pi=acos(-1);
using namespace std;
ll gcd(ll x,ll y){if(y==0) return x;return gcd(y,x%y);}
ll fksm(ll a,ll b){ll r=1;if(b<0)b+=mod-1;for(a%=mod;b;b>>=1){if(b&1)r=r*a%mod;a=a*a%mod;}return r;}//a 分母; b MOD-2
ll lowbit(ll x){return x&(-x);}
ll m,n,t,x,y,z,l,r,k,p,pp,nx,ny,ansx,ansy,lim,num,sum,pos,tot,dep,root,block,key,cnt,minn,maxx,ans;
ll a[N],b[N],head[N],dx[5]={0,0,-1,1},dy[5]={-1,1,0,0};
double dans;
bool vis,flag;
char mapp,zz;
struct qq{ll x,y,z;}q;
struct tree{ll l,r,tag,sum;}tra[N*4],trb[N*4];
struct Tree{ll fa,dep,dfn,siz,son,top,w;}tr[N];
struct Trp{ll l,r,fat,dep,n,w;}trp;
struct E{ll to,nxt,w;}eg[N*2];
struct matrix{ll n,m[M][M];};
struct complx{
double r,i;
complx(){}
complx(double r,double i):r(r),i(i){}
complx operator+(const complx& rhs)const{return complx (r+rhs.r,i+rhs.i);}
complx operator-(const complx& rhs)const{return complx (r-rhs.r,i-rhs.i);}
complx operator*(const complx& rhs)const{return complx (r*rhs.r-i*rhs.i,i*rhs.r+r*rhs.i);}
void operator+=(const complx& rhs){r+=rhs.r,i+=rhs.i;}
void operator*=(const complx& rhs){r=r*rhs.r-i*rhs.i,i=r*rhs.i+i*rhs.r;}
void operator/=(const double& x){r/=x,i/=x;}
complx conj(){return complx(r,-i);}
};
bool cmp(qq u,qq v){
return u.x*v.y>v.x*u.y;
}
bool cmp1(qq u,qq v){
return u.x<v.x;
}
bool cmpl(ll u,ll v){return u>v;}
struct cmps{bool operator()(ll u,ll v){
return u>v;
}};//shun序
pair<ll,ll>pre[1010][1010];
vector<ll>v;//v.assign(m,vector<ll>(n));
//priority_queue<ll,vector<ll>,cmps>sp;
deque<qq>sq;
map<ll,ll>mp;
bitset<M>bi;
void add(ll u,ll v,ll w){
eg[++cnt].to=v;
eg[cnt].nxt=head[u];
eg[cnt].w=w;
head[u]=cnt;
}
void push_upA(ll k){
tra[k].sum=tra[k*2].sum+tra[k*2+1].sum;
}
void push_upB(ll k){
trb[k].sum=trb[k*2].sum+trb[k*2+1].sum;
}
void push_downA(ll k){
if(tra[k].tag>0){
ll l=k*2,r=k*2+1;
tra[l].tag+=tra[k].tag;
tra[r].tag+=tra[k].tag;
tra[l].sum+=(tra[l].r-tra[l].l+1)*tra[k].tag;
tra[r].sum+=(tra[r].r-tra[r].l+1)*tra[k].tag;
tra[k].tag=0;
}
}
void push_downB(ll k){
if(trb[k].tag>0){
ll l=k*2,r=k*2+1;
trb[l].tag+=trb[k].tag;
trb[r].tag+=trb[k].tag;
trb[l].sum+=(trb[l].r-trb[l].l+1)*trb[k].tag;
trb[r].sum+=(trb[r].r-trb[r].l+1)*trb[k].tag;
trb[k].tag=0;
}
}
void bd_tree(ll k,ll l,ll r){
tra[k].tag=0;trb[k].tag=0;
tra[k].l=l,tra[k].r=r;
trb[k].l=l,trb[k].r=r;
if(l==r){
tra[k].sum=0;
trb[k].sum=0;
return;
}
ll mid=(l+r)/2;
bd_tree(k*2,l,mid);
bd_tree(k*2+1,mid+1,r);
push_upA(k);push_upB(k);
}
ll queryA(ll k,ll pl,ll pr){
ll ml=0,mr=0;
if(tra[k].l>=pl&&tra[k].r<=pr){
return tra[k].sum;
}
push_downA(k);
ll mid=(tra[k].l+tra[k].r)/2;
if(mid>=pl)ml=queryA(k*2,pl,pr);
if(mid+1<=pr)mr=queryA(k*2+1,pl,pr);
return ml+mr;
}
ll queryB(ll k,ll pl,ll pr){
ll ml=0,mr=0;
if(trb[k].l>=pl&&trb[k].r<=pr){
return trb[k].sum;
}
push_downB(k);
ll mid=(trb[k].l+trb[k].r)/2;
if(mid>=pl)ml=queryB(k*2,pl,pr);
if(mid+1<=pr)mr=queryB(k*2+1,pl,pr);
return ml+mr;
}
void modifyA(ll k,ll pl,ll pr,ll val){//后效
if(pl>pr)return;
if(tra[k].l>=pl&&tra[k].r<=pr){
tra[k].sum+=(tra[k].r-tra[k].l+1)*val;
tra[k].tag+=val;
return;
}
push_downA(k);
ll mid=(tra[k].l+tra[k].r)/2;
if(mid>=pl)modifyA(k*2,pl,pr,val);
if(mid+1<=pr)modifyA(k*2+1,pl,pr,val);
push_upA(k);
}
void modifyB(ll k,ll pl,ll pr,ll val){//前效
if(pl>pr)return;
if(trb[k].l>=pl&&trb[k].r<=pr){
trb[k].sum+=(trb[k].r-trb[k].l+1)*val;
trb[k].tag+=val;
return;
}
push_downB(k);
ll mid=(trb[k].l+trb[k].r)/2;
if(mid>=pl)modifyB(k*2,pl,pr,val);
if(mid+1<=pr)modifyB(k*2+1,pl,pr,val);
push_upB(k);
}
void dfs1(ll x,ll ac){
tr[x].fa=ac;
tr[x].dep=++dep;
tr[x].siz=1;
ll k=head[x];
while(k){
if(eg[k].to!=ac)dfs1(eg[k].to,x);
k=eg[k].nxt;
}
if(!tr[ac].son||tr[tr[ac].son].siz<tr[x].siz)tr[ac].son=x;
tr[ac].siz+=tr[x].siz;
dep--;
}
void dfs2(ll x){
tr[x].dfn=++num;
tr[x].top=pos;
a[num]=x;
if(!tr[x].son)return;
dfs2(tr[x].son);
ll k=head[x];
while(k){
if(eg[k].to!=tr[x].fa&&eg[k].to!=tr[x].son)pos=eg[k].to,dfs2(eg[k].to);
k=eg[k].nxt;
}
}
void mchain(ll x,ll y){
while(tr[x].top!=tr[y].top){
if(tr[tr[x].top].dep<tr[tr[y].top].dep){
modifyB(1,tr[tr[y].top].dfn+1,tr[y].dfn,1);
y=tr[y].top;
b[y]+=tr[tr[y].fa].w;
y=tr[y].fa;
}
else{
modifyA(1,tr[tr[x].top].dfn,tr[x].dfn-1,1);
x=tr[x].top;
b[tr[x].fa]+=tr[x].w;
x=tr[x].fa;
}
}
if(tr[x].dep>tr[y].dep)modifyA(1,tr[y].dfn,tr[x].dfn-1,1);
else modifyB(1,tr[x].dfn+1,tr[y].dfn,1);
}
ll qchain(ll x){
ll res=b[x];
res+=queryA(1,tr[x].dfn,tr[x].dfn)*tr[tr[x].son].w;
res+=queryB(1,tr[x].dfn,tr[x].dfn)*tr[tr[x].fa].w;
return res;
}
int main(){
scanf("%lld%lld",&n,&m);
L(i,1,n)scanf("%lld",&tr[i].w);
cnt=0;
L(i,1,n-1){
scanf("%lld%lld",&x,&y);
add(x,y,0);
add(y,x,0);
}
num=0;dep=0,pos=p;
dfs1(1,0);
dfs2(1);
bd_tree(1,1,n);
// L(i,1,n)printf("%lld ",a[i]);printf("\n");
L(i,1,m){
scanf("%lld",&k);
if(k==1){
scanf("%lld%lld",&x,&y);
mchain(x,y);
}
else if(k==2){
scanf("%lld",&x);
printf("%lld\n",qchain(x));
}
}
}