树上DP例题及思路总结
树上dp整理
大体思路
以树为载体的最优值问题是建立在遍历的基础上的。不同状态的表示往往按照节点分类,状态转移往往发生在父节点和子节点之间。建树的方式还是使用建图的方式。
两个基本模型
- 树的最长路径:即所有节点的子路径的最大值加次大值的最大值
int dfs(int u,int father){ int dist = 0; int d1 = 0,d2 = 0; for (int i = h[u];i != -1;i = ne[i]){ int j = e[i]; if (j == father) continue; int d = dfs(j,u) + w[i]; dist = max(dist,d); if (d > d1) d2 = d1,d1 = d; else if (d > d2) d2 = d; } ans = max(ans,d1 + d2); return dist; }
- 树的中心:该节点到树中其他节点的最远距离最近。考虑该点向上走的距离和向下的距离。其中向上走的路径中不能碰到自己。
int dfs_d(int u,int father){ d1[u] = d2[u] = -INF; for (int i = h[u];i != -1;i = ne[i]){ int j = e[i]; if (j == father) continue; int d = dfs_d(j,u) + w[i]; if (d >= d1[u]){ d2[u] = d1[u],d1[u] = d; p2[u] = p1[u],p1[u] = j; } else if (d > d2[u]) d2[u] = d,p2[u] = j; } if (d1[u] == -INF) d1[u] = d2[u] = 0; return d1[u]; } void dfs_u(int u,int father){ for (int i = h[u];i != -1;i = ne[i]){ int j = e[i]; if (j == father) continue; if (p1[u] == j) up[j] = max(up[u],d2[u]) + w[i]; else up[j] = max(up[u],d1[u]) + w[i]; dfs_u(j,u); } }
例题
数字转换:x转换成y可以看成x到y的有向无权边,以此建树,求该树(其实是森林)的最长路径。
由于是森林,需要用标记数组求一下根节点。
#include using namespace std; typedef long long LL; typedef unsigned long long uLL; typedef pair PII; typedef pair PLL; typedef pair PDD; #define fi first #define se second #define mp make_pair #define IOS ios::sync_with_stdio(false),cin.tie(0) const int N = 5e4 + 10; const int MOD = 998244353; const int INF = 0x3f3f3f3f; mt19937 mrand(random_device{}()); int rnd(int x) { return mrand() % x;} LL powmod(LL a,LL b,LL mod) {LL res=1;a%=mod; assert(b>=0); for(;b;b>>=1){if(b&1)res=res*a%mod;a=a*a%mod;}return res;} LL gcd(LL a,LL b) { return b?gcd(b,a%b):a;} LL lcm(LL a,LL b) { return a*b/gcd(a,b);} int n; int h[N],e[N],ne[N],idx; int w[N]; bool vis[N]; int ans; int dfs(int u){ int d1 = 0,d2 = 0; for (int i = h[u];i != -1;i = ne[i]){ int j = e[i]; int d = dfs(j) + 1; if (d > d1) d2 = d1,d1 = d; else if(d > d2) d2 = d; } //cout << u << " " << d1 << endl; ans = max(ans,d1 + d2); return d1; } void add(int a,int b){ e[idx] = b,ne[idx] = h[a],h[a] = idx++; } int main(){ cin >> n; for (int i = 1;i <= n;++i){ for (int j = 2;j * i <= n;++j){ w[i * j] += i; } } memset(h,-1,sizeof h); for (int i = 2;i <= n;++i){ if (i > w[i]){ add(w[i],i); vis[i] = 1; } } for (int i = 1;i <= n;++i){ if (!vis[i]) dfs(i); } cout << ans << endl; return 0; }
- 二叉苹果树:分组背包问题和树上dp的结合。
#include using namespace std; typedef long long LL; typedef unsigned long long uLL; typedef pair PII; typedef pair PLL; typedef pair PDD; #define fi first #define se second #define mp make_pair #define IOS ios::sync_with_stdio(false),cin.tie(0) const int N = 220 + 10; const int MOD = 998244353; const int INF = 0x3f3f3f3f; mt19937 mrand(random_device{}()); int rnd(int x) { return mrand() % x;} LL powmod(LL a,LL b,LL mod) {LL res=1;a%=mod; assert(b>=0); for(;b;b>>=1){if(b&1)res=res*a%mod;a=a*a%mod;}return res;} LL gcd(LL a,LL b) { return b?gcd(b,a%b):a;} LL lcm(LL a,LL b) { return a*b/gcd(a,b);} int n,m; int h[N],e[N],w[N],ne[N],idx; int f[N][N]; void add(int a,int b,int c){ e[idx] = b,w[idx] = c,ne[idx] = h[a],h[a] = idx++; } void dfs(int u,int father){ for (int i = h[u];~i;i = ne[i]){ if (e[i] == father) continue; dfs(e[i],u); for (int j = m;j >= 0;--j){ for (int k = 0;k < j;++k){ f[u][j] = max(f[u][j],f[u][j - k -1] + f[e[i]][k] + w[i]); } } } } int main(){ cin >> n >> m; memset(h,-1,sizeof h); for (int i = 0;i < n-1;++i){ int a,b,c;cin >> a >> b >> c; add(a,b,c),add(b,a,c); } dfs(1,-1); cout << f[1][m] << endl; return 0; }
- 战略游戏:遍历整棵树,对于每个节点遍历它的子节点,用f[i,j]表示第i个节点放置状态为j的所有士兵数;如果当前点不放,那么所有的子节点必须放置;如果当前点放置,子节点可放可不放,取min即可。
#include using namespace std; typedef long long LL; typedef unsigned long long uLL; typedef pair PII; typedef pair PLL; typedef pair PDD; #define fi first #define se second #define mp make_pair #define IOS ios::sync_with_stdio(false),cin.tie(0) const int N = 1500 + 10; const int MOD = 998244353; const int INF = 0x3f3f3f3f; mt19937 mrand(random_device{}()); int rnd(int x) { return mrand() % x;} LL powmod(LL a,LL b,LL mod) {LL res=1;a%=mod; assert(b>=0); for(;b;b>>=1){if(b&1)res=res*a%mod;a=a*a%mod;}return res;} LL gcd(LL a,LL b) { return b?gcd(b,a%b):a;} LL lcm(LL a,LL b) { return a*b/gcd(a,b);} int n; int h[N],e[N],w[N],ne[N],idx; int f[N][2]; bool vis[N]; void add(int a,int b){ e[idx] = b,ne[idx] = h[a],h[a] = idx++; } void dfs(int u){ f[u][0] = 0; f[u][1] = 1; for (int i = h[u];~i;i = ne[i]){ int j = e[i]; dfs(j); f[u][0] += f[j][1]; f[u][1] += min(f[j][0],f[j][1]); } } int main(){ cin >> n; memset(h,-1,sizeof h); for (int i = 0;i < n;++i){ int a;cin >> a; int k;cin >> k; while (k--){ int b;cin >> b; add(a,b); vis[b] = 1; } } int root = 0; while (vis[root]) root++; dfs(root); cout << min(f[root][0],f[root][1]) << endl; return 0; }
- 皇宫守卫:和战略游戏相似,但是当前点的状态有三种,分别是当前点自己有守卫,当前点被父节点望到,当前点被子节点望到。当前点被父节点看到的最小花费就是其子节点上有警卫和子节点被其子节点看到的min;同理,当前点放置警卫子节点三种状态都可以,取min即可。而当前点被子节点看到的情况,必须至少有一个子节点得放置棋子,其他的就是被子节点和放置守卫这两种状态的min。
#include using namespace std; typedef long long LL; typedef unsigned long long uLL; typedef pair PII; typedef pair PLL; typedef pair PDD; #define fi first #define se second #define mp make_pair #define IOS ios::sync_with_stdio(false),cin.tie(0) const int N = 1500 + 10; const int MOD = 998244353; const int INF = 0x3f3f3f3f; mt19937 mrand(random_device{}()); int rnd(int x) { return mrand() % x;} LL powmod(LL a,LL b,LL mod) {LL res=1;a%=mod; assert(b>=0); for(;b;b>>=1){if(b&1)res=res*a%mod;a=a*a%mod;}return res;} LL gcd(LL a,LL b) { return b?gcd(b,a%b):a;} LL lcm(LL a,LL b) { return a*b/gcd(a,b);} int n; int h[N],e[N],ne[N],idx; int w[N]; int f[N][3]; bool vis[N]; void add(int a,int b){ e[idx] = b,ne[idx] = h[a],h[a] = idx++; } void dfs(int u){ f[u][2] = w[u]; //cout << w[u] << endl; for (int i = h[u];~i;i = ne[i]){ int j = e[i]; dfs(j); f[u][0] += min(f[j][1],f[j][2]); f[u][2] += min({f[j][0],f[j][1],f[j][2]}); } //cout << f[u][0] << " " << f[u][2] << endl; f[u][1] = INF; for (int i = h[u];~i;i = ne[i]){ int j = e[i]; f[u][1] = min(f[u][1],f[j][2] + f[u][0] - min(f[j][1],f[j][2])); } } int main(){ cin >> n; memset(h,-1,sizeof h); for (int i = 0;i < n;++i){ int a,b,c;cin >> a >> b >> c; w[a] = b; while (c--){ int x;cin >> x; add(a,x); vis[x] = 1; } } int root = 1; while (vis[root]) root++; dfs(root); cout << min({f[root][1],f[root][2]}) << endl; return 0; }