【每日一题】2021年3月26日题目[HAOI2015]树上操作
题号 NC19995
名称 [HAOI2015]树上操作
来源 [HAOI2015]
有一棵点数为 N 的树,以点 1 为根,且树点有边权。
然后有 M 个 操作,分为三种:
操作 1 :把某个节点 x 的点权增加 a 。
操作 2 :把某个节点 x 为根的子树中所有点的点权都增加 a 。
操作 3 :询问某个节点 x 到根的路径中所有点的点权和。
样例
输入 5 5 1 2 3 4 5 1 2 1 4 2 3 2 5 3 3 1 2 1 3 5 2 1 2 3 3 输出 6 9 13
dfs序和欧拉序
记录进出栈的时间戳我记得好像是叫欧拉序,欧拉序和dfs序一直没分清楚,网上对这块的讲解好乱
以前遇到欧拉序和dfs序没有仔细整理过,下面来简单整理一下
DFS序:
定义:对一棵树进行dfs遍历,按照访问顺序,给每个节点按照第一次访问的时间给节点排序,得到的序列
如图:
代码实现:
伪代码:
void dfs(当前节点) { 分配时间戳 往dfs序列加入当前节点 for(遍历子节点节点) { dfs(子节点) } }
C++:
int dfn[N],seq[N],cnt;//dfn[]当前节点时间戳,seq[]记录dfs序 int h[N],e[M],ne[M],idx; void dfs(int u) { dfn[u] = ++ cnt; seq[cnt] = u; for(int i = h[u];~i;i = ne[i]) { int j = e[i]; if(!dfn[j]) dfs(j); } }
欧拉序:
欧拉序有两种:
- 在dfs过程中,每个节点按照dfs遍历的顺序,进栈记录一次出栈记录一次得到的序列
- 在dfs过程中,每个节点按照dfs遍历的顺序,每递归访问一次记录一次得到的序列
第一种:
如图:
代码实现:
伪代码:
void dfs(当前节点) { 记录入栈时间 往dfs序列加入当前节点 for(遍历子节点节点) { dfs(子节点) } 记录出栈时间 往dfs序列加入当前节点 }
C++:
int seq[N],fir[N],last[N],top;//seq记录欧拉序,fir入栈时间,last出栈时间 int h[N],ne[M],e[M],idx; void dfs(int u,int father) { seq[++ top] = u; fir[u] = top; for(int i = h[u];~i;i = ne[i]) { int j = e[i]; if(j == father) continue; dfs(j,u); } seq[++ top] = u; last[u] = top; }
第二种:
如图:
代码实现:
伪代码:
void dfs(当前节点) { 往dfs序列加入当前节点 for(遍历子节点节点) { dfs(子节点) 往dfs序列加入当前节点 } }
c++:
int seq[N],top;//seq记录欧拉序 int h[N],ne[M],e[M],idx; void dfs(int u,int father) { seq[++ top] = u; for(int i = h[u];~i;i = ne[i]) { int j = e[i]; if(j == father) continue; dfs(j,u); seq[++ top] = u; } }
算法1
(线段树 +欧拉序)
思路:
和树链剖分的思路类似,如果想用线段树这类维护序列的数据结构维护树的信息
我们就将转化成序列,常见的方式就是将树转化成欧拉序
记录每个节点第一次访问的时间戳和最后一次访问的时间戳
好处就是,树的路径可以对应到一段连续的区间线段,就可以用线段树维护树中路径上的信息
实现:
我们在欧拉序列的基础上构建线段树
我们定义操作:在计算序列上某一段区间的数值和时,入栈位置做加法,出栈位置做减法
fir[x]表示x在序列中第一次出现的位置,last[x]表示x在序列中最后一次出现的位置
操作一:我们在fir[x]加上a,在last[x]减去一个a
操作二:我们对区间[fir[x],last[x]]中入栈的位置同时加上a,出栈的位置同时减去a
操作三:输出[fir[1],fir[x]]的区间和
落实到代码上我们可以给入栈的位置分配数值1,出栈的位置分配数值-1
然后用线段树维护以下信息:
cnt维护上图区间中1和-1数值相加的结果(这样维护就不用分别记录入栈位置的个数,和出栈位置的个数)
lazy维护对一个区间进行操作二的数值a
sum表示一个区间“入栈位置加上节点的权值,出栈位置减去节点的权值”的区间和的结果
信息更新:
void pushup(int u) { tr[u].cnt = tr[lc].cnt + tr[rc].cnt; tr[u].sum = tr[lc].sum + tr[rc].sum; } void pushdown(int u) { if(tr[u].lazy) { tr[lc].lazy += tr[u].lazy; tr[rc].lazy += tr[u].lazy; tr[lc].sum += 1ll * tr[lc].cnt * tr[u].lazy; tr[rc].sum += 1ll * tr[rc].cnt * tr[u].lazy; tr[u].lazy = 0; } }
时间复杂度
C++ 代码
#include <iostream> #include <cstdio> #include <cstring> #include <algorithm> #include <unordered_map> #include <map> #include <vector> #include <queue> #include <set> #include <bitset> #include <cmath> #define P 131 #define lc u << 1 #define rc u << 1 | 1 using namespace std; typedef long long LL; const int N = 200010; int h[N],ne[N * 2],e[N * 2],idx; int w[N],fir[N],last[N]; int sign[N]; int seq[N],top; struct Node { int l,r; int cnt; LL sum; LL lazy; }tr[N * 4]; int n,m; void add(int a,int b) { e[idx] = b,ne[idx] = h[a],h[a] = idx ++; } void dfs(int u,int father) { seq[++ top] = u; fir[u] = top; sign[top] = 1; for(int i = h[u];~i;i = ne[i]) { int j = e[i]; if(j == father) continue; dfs(j,u); } seq[++ top] = u; last[u] = top; sign[top] = -1; } void pushup(int u) { tr[u].cnt = tr[lc].cnt + tr[rc].cnt; tr[u].sum = tr[lc].sum + tr[rc].sum; } void pushdown(int u) { if(tr[u].lazy) { tr[lc].lazy += tr[u].lazy; tr[rc].lazy += tr[u].lazy; tr[lc].sum += 1ll * tr[lc].cnt * tr[u].lazy; tr[rc].sum += 1ll * tr[rc].cnt * tr[u].lazy; tr[u].lazy = 0; } } void build(int u,int l,int r) { if(l == r) { tr[u] = {l,r,sign[l],sign[l] * w[seq[l]],0}; return; } int mid = l + r >> 1; tr[u] = {l,r}; build(lc,l,mid); build(rc,mid + 1,r); pushup(u); } void modify(int u,int l,int r,int k) { if(tr[u].l >= l && tr[u].r <= r) { tr[u].lazy += k; tr[u].sum += 1ll * tr[u].cnt * k; return; } pushdown(u); int mid = (tr[u].l + tr[u].r) >> 1; if(l <= mid) modify(lc,l,r,k); if(r > mid) modify(rc,l,r,k); pushup(u); } LL query(int u,int l,int r) { if(tr[u].l >= l && tr[u].r <= r) return tr[u].sum; pushdown(u); int mid = (tr[u].l + tr[u].r) >> 1; LL res = 0; if(l <= mid) res += query(lc,l,r); if(r > mid) res += query(rc,l,r); return res; } void solve() { scanf("%d%d",&n,&m); memset(h,-1,sizeof h); for(int i = 1;i <= n;i ++) scanf("%d",&w[i]); for(int i = 1;i <= n - 1;i ++) { int a,b; scanf("%d%d",&a,&b); add(a,b); add(b,a); } dfs(1,-1); build(1,1,top); while(m --) { int op,x,a; scanf("%d",&op); if(op == 1) { scanf("%d%d",&x,&a); modify(1,fir[x],fir[x],a); modify(1,last[x],last[x],a); }else if(op == 2) { scanf("%d%d",&x,&a); modify(1,fir[x],last[x],a); }else { scanf("%d",&x); printf("%lld\n",query(1,fir[1],fir[x])); } } } int main() { #ifdef LOCAL freopen("in.txt", "r", stdin); freopen("out.txt", "w", stdout); #else #endif // LOCAL int T = 1; // init(500); // scanf("%d",&T); while(T --) { // scanf("%lld%lld",&n,&m); solve(); // test(); } return 0; }
算法2
(树链剖分)
留个坑。。。