<span>BZOJ 4034 [HAOI2015]树上操作 线段树+树剖或dfs</span>
题意
直接照搬原题面
有一棵点数为 N 的树,以点 1 为根,且树点有边权。然后有 M 个
操作,分为三种:
操作 1 :把某个节点 x 的点权增加 a 。
操作 2 :把某个节点 x 为根的子树中所有点的点权都增加 a 。
操作 3 :询问某个节点 x 到根的路径中所有点的点权和。
分析
先树剖一下,按重新编号的点建线段树
-
操作1:直接单点修改
-
操作2:一个子树里的点的编号是连在一起的,直接区间修改
-
操作3:该点的\(top\)不为1时,即该点跟根结点不在一条链上,加上这条链的贡献(线段树的区间求和),
再跳到\(top\)的父节点所在链,直到\(top\)为1再加上\(top\)为1这条链的贡献,就能求出1到x的答案了
其实还有另一种不用树剖的做法,用线段树维护前缀和,\(a[x]\)为从\(1\)到\(x\)的点权和,操作1就等于区间修改\(x\)的子树中所有节点,
操作2就等于对\(x\)的子树中每个节点进行一次操作1,这肯定不行,考虑单个节点的贡献,每个节点总共增加的值为它在\(x\)的子树中的深度\(p\)
乘上增加量\(k\),区间贡献和即为区间深度之和乘\(k\).
线段树要多记录区间结点的深度和\(w[p]\),区间修改的式子为\(val[p]+=w[p]*k-(r-l+1)*dep*k\),\(dep\)为\(x\)的父节点的深度
加个lazy标记记录\(dep*k\)就行了
Code 1
#include<bits/stdc++.h>
#define fi first
#define se second
#define pb push_back
#define lson l,mid,p<<1
#define rson mid+1,r,p<<1|1
using namespace std;
typedef long long ll;
const int inf=1e9;
const int maxn=3e5+10;
int n,q;
ll a[maxn];
vector<int>g[maxn];
int top[maxn],in[maxn],out[maxn],sz[maxn],f[maxn],son[maxn],id[maxn],tot;
ll val[maxn<<2],tag[maxn<<2];
void pp(int p){val[p]=val[p<<1]+val[p<<1|1];}
void pd(int l,int r,int p,ll k){val[p]+=(r-l+1)*k,tag[p]+=k;}
void bd(int l,int r,int p){
if(l==r) return val[p]=a[id[l]],void();
int mid=l+r>>1;
bd(lson);bd(rson);pp(p);
}
void up(int dl,int dr,int l,int r,int p,ll k){
if(l>=dl&&r<=dr){
val[p]+=(r-l+1)*k;tag[p]+=k;
return;
}int mid=l+r>>1;
pd(lson,tag[p]);pd(rson,tag[p]);tag[p]=0;
if(dl<=mid) up(dl,dr,lson,k);
if(dr>mid) up(dl,dr,rson,k);
pp(p);
}
ll qy(int dl,int dr,int l,int r,int p){
if(l>=dl&&r<=dr) return val[p];
int mid=l+r>>1;ll ret=0;
pd(lson,tag[p]);pd(rson,tag[p]);tag[p]=0;
if(dl<=mid) ret+=qy(dl,dr,lson);
if(dr>mid) ret+=qy(dl,dr,rson);
return ret;
}
void dfs1(int u){
sz[u]=1;
for(int i=0;i<g[u].size();i++){
int x=g[u][i];
if(x==f[u]) continue;
f[x]=u;dfs1(x);
sz[u]+=sz[x];
if(sz[son[u]]<sz[x]) son[u]=x;
}
}
void dfs2(int u,int t){
top[u]=t;in[u]=++tot;id[tot]=u;
if(son[u]) dfs2(son[u],t);
for(int i=0;i<g[u].size();i++){
int x=g[u][i];
if(x==f[u]||x==son[u]) continue;
dfs2(x,x);
}
out[u]=tot;
}
ll cal(int x){
ll res=0;
while(top[x]!=1){
res+=qy(in[top[x]],in[x],1,n,1);
x=f[top[x]];
}
res+=qy(1,in[x],1,n,1);return res;
}
int main(){
scanf("%d%d",&n,&q);
for(int i=1;i<=n;i++){
scanf("%lld",&a[i]);
}
for(int i=1,a,b;i<n;i++){
scanf("%d%d",&a,&b);
g[a].pb(b);g[b].pb(a);
}
dfs1(1);dfs2(1,1);bd(1,n,1);
while(q--){
int op,x;ll a;
scanf("%d%d",&op,&x);
if(op==1){
scanf("%lld",&a);
up(in[x],in[x],1,n,1,a);
}else if(op==2){
scanf("%lld",&a);
up(in[x],out[x],1,n,1,a);
}else{
printf("%lld\n",cal(x));
}
}
return 0;
}
Code 2
#include<bits/stdc++.h>
#define fi first
#define se second
#define pb push_back
#define lson l,mid,p<<1
#define rson mid+1,r,p<<1|1
using namespace std;
typedef long long ll;
const int inf=1e9;
const int maxn=3e5+10;
int n,q;
int d[maxn];
ll a[maxn],dep[maxn];
vector<int>g[maxn];
int f[maxn],in[maxn],out[maxn],tot;
ll val[maxn<<2],tag[maxn<<2],w[maxn<<2],tw[maxn<<2],qw[maxn<<2];
void pushup(int p){
val[p]=val[p<<1]+val[p<<1|1];
}
void tag1(int l,int r,int p,ll k,ll tk,ll qk){
val[p]+=w[p]*k-(r-l+1)*tk+(r-l+1)*qk;tag[p]+=k;
tw[p]+=tk;qw[p]+=qk;
}
void bd(int l,int r,int p){
if(l==r){
val[p]=a[d[l]];
w[p]=dep[d[l]];
return;
}
int mid=l+r>>1;
bd(lson);bd(rson);
w[p]=w[p<<1]+w[p<<1|1];
pushup(p);
}
void up(int dl,int dr,int l,int r,int p,ll k,ll dep){
if(l>=dl&&r<=dr){
val[p]+=(w[p]-(r-l+1)*dep)*k;tag[p]+=k;
tw[p]+=dep*k;
return;
}int mid=l+r>>1;
tag1(lson,tag[p],tw[p],qw[p]);tag1(rson,tag[p],tw[p],qw[p]);tag[p]=0;tw[p]=0;qw[p]=0;
if(dl<=mid) up(dl,dr,lson,k,dep);
if(dr>mid) up(dl,dr,rson,k,dep);
pushup(p);
}
void upd(int dl,int dr,int l,int r,int p,ll k){
if(l>=dl&&r<=dr){
val[p]+=(r-l+1)*k;qw[p]+=k;
return;
}int mid=l+r>>1;
tag1(lson,tag[p],tw[p],qw[p]);tag1(rson,tag[p],tw[p],qw[p]);tag[p]=0;tw[p]=0;qw[p]=0;
if(dl<=mid) upd(dl,dr,lson,k);
if(dr>mid) upd(dl,dr,rson,k);
pushup(p);
}
ll qy(int dl,int dr,int l,int r,int p){
ll ret=0;
if(l>=dl&&r<=dr) return val[p];int mid=l+r>>1;
tag1(lson,tag[p],tw[p],qw[p]);tag1(rson,tag[p],tw[p],qw[p]);tag[p]=0;tw[p]=0;qw[p]=0;
if(dl<=mid) ret+=qy(dl,dr,lson);
if(dr>mid) ret+=qy(dl,dr,rson);
return ret;
}
void dfs(int u){
in[u]=++tot;d[tot]=u;dep[u]=dep[f[u]]+1;
for(int i=0;i<g[u].size();i++){
int x=g[u][i];
if(x==f[u]) continue;
f[x]=u;a[x]+=a[u];
dfs(x);
}
out[u]=tot;
}
int main(){
scanf("%d%d",&n,&q);
for(int i=1;i<=n;i++){
scanf("%lld",&a[i]);
}
for(int i=1,a,b;i<n;i++){
scanf("%d%d",&a,&b);
g[a].pb(b);g[b].pb(a);
}
dfs(1);
bd(1,n,1);
while(q--){
int op,x;
ll a;
scanf("%d%d",&op,&x);
if(op!=3) scanf("%lld",&a);
if(op==1) upd(in[x],out[x],1,n,1,a);
else if(op==2) up(in[x],out[x],1,n,1,a,dep[f[x]]);
else printf("%lld\n",qy(in[x],in[x],1,n,1));
}
return 0;
}