长链剖分优化树上dp

长链剖分可以把维护子树中只与深度有关的信息做到线性的时间复杂度。

例题cf1009f

给一棵树,定义每个点的dom值为,以该点为根的子树重中,节点数最多的那一层的层数,即那一层距离这个根节点有几条边,如果多层节点数相同,取最小的层数。例如单独叶节点的dom值是0,一条垂直的链的dom值也是0(每层数量都是1,取最前面的)

定义dp[i][j]为i树第j层的节点数,如果暴力遍历复杂度是n^2。我们用类似启发式合并的思想,对于长链(最深而非最重的链)处理后的结果直接拿给父节点用,然后再去依次解决轻节点,并把这些轻节点向父节点暴力合并。由于所有的轻链合并到深链上之后就“失效”了,之后再也用不到,所有整体时间复杂度只有O(n)。(太nb了这个算法……

由于剖分的性质,一条链上的节点在一个区间,子树的节点也都在一个区间,所以dp数组不必为难开n^2大小,我们考虑用指针实现,这样也实现了重儿子O(1)传递dp数组给父亲,实际上就是用了同一块内存,重儿子跑完了直接父亲就顺着用(看代码吧。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#include<bits/stdc++.h>
using namespace std;

const int maxn=1e6+4;

struct edge{
    int to,next;
}es[maxn*2];
int head[maxn],cnt;
void add(int u,int v){
    es[++cnt]=(edge){ v,head[u]};
    head[u]=cnt;
}

int h[maxn],son[maxn];
void dfs1(int x,int fa){
    h[x]=1;
    for(int i=head[x];i;i=es[i].next){
        int v=es[i].to;
        if(v==fa)
            continue;
        dfs1(v,x);
        h[x]=max(h[x],h[v]+1);
        if(h[v]>h[son[x]])
            son[x]=v;
    }
}

int *dp[maxn],g[maxn],dfn[maxn];    //dp[i][j]表示和根节点i相距j条边的子孙数量(树的第j层节点个数)
void dfs2(int x,int fa){
    //先搜重儿子,这样重儿子和父节点的dp数组地址差1,刚好可以公用
    dfn[x]=++dfn[0];
    dp[x]=g+dfn[x];
    if(son[x])
        dfs2(son[x],x);
    for(int i=head[x];i;i=es[i].next){
        int v=es[i].to;
        if(v==fa||v==son[x])
            continue;
        dfs2(v,x);
    }
}

int ans[maxn];
void solve_dp(int x,int fa){
    if(son[x]>0){
        solve_dp(son[x],x);
        //这时son[x]的dp数组已经构建完成,地址就比x的大1,刚好深度也大1,直接就成为了x的dp数组
        ans[x]=ans[son[x]]+1;
    }
    dp[x][0]=1;
    if(dp[x][ans[x]]==dp[x][0]) //如果x只有一个子节点,且是一条垂直的链,那么它的ans可以再向前提
        ans[x]=0;
    for(int i=head[x];i;i=es[i].next){
        int v=es[i].to;
        if(v==son[x]||fa==v)
            continue;
        solve_dp(v,x);
        int len=h[v];
        for(int j=0;j<len;j++){
            dp[x][j+1]+=dp[v][j];
            if(dp[x][j+1]>dp[x][ans[x]])
                ans[x]=j+1;
            if(dp[x][j+1]==dp[x][ans[x]]&&ans[x]>j+1)
                ans[x]=j+1;
        }
    }
}
int n;
int main(){
    cin>>n;
    for(int i=1,x,y;i<n;i++){
        cin>>x>>y;
        add(x,y);add(y,x);
    }
    dfs1(1,0);
    dfs2(1,0);
    solve_dp(1,0);
    for(int i=1;i<=n;i++)
        printf("%d\n",ans[i]);
    return 0;
}
全部评论

相关推荐

点赞 评论 收藏
分享
10-17 12:16
同济大学 Java
7182oat:快快放弃了然后发给我,然后让我也泡他七天最后再拒掉,狠狠羞辱他一把😋
点赞 评论 收藏
分享
点赞 收藏 评论
分享
牛客网
牛客企业服务