树上倍增板子
树上倍增比起树链剖分代码短,容易查错,时空优,但是广度不如树链剖分.
具体实现
首先开一个n×logn的数组,比如fa[n][logn],其中fa[i][j]表示i节点的第2^j个父亲是谁。
那么就有: fa[i][j]=fa[fa[i][j-1]][j-1]
用文字叙述为:i的第2^j 个父亲 是i的第2^(j-1) 个父亲的第2^(j-1)个父亲。
下面是求i的第k个父亲的代码段:
int father(int i,int k)
{
for(int x=0;x<=int(log2(k));x++)
if((1<<x)&k) //(1<<x)&k可以判断k的二进制表示中,第(x-1)位上是否为1
i=fa[i][x]; //把i往上提
return i;
}
我们可以通过一次dfs处理出fa数组:(dep[i]表示i的深度,这个可以一起处理出来,以后要用)
void dfs(int x)
{
for(int i=1;i<=max0;i++)
if(fa[x][i-1]) //在dfs(x)之前,x的父辈们的fa数组都已经计算完毕,所以可以用来计算x
fa[x][i]=fa[fa[x][i-1]][i-1];
else break; //如果x已经没有第2^(i-1)个父亲了,那么也不会有更远的父亲,直接break
for(/*每一个与x相连的节点i*/)
if(i!=fa[x][0]) //如果i不是x的父亲就是x的儿子
{
fa[i][0]=x; //记录儿子的第一个父亲是x
dep[i]=dep[x]+1; //处理深度
dfs(i);
}
}
树上倍增常用来求最近公共祖先
int LCA(int u,int v)
{
if(dep[u]<dep[v])swap(u,v); //我们默认u的深度一开始大于v,那么如果u的深度小就交换u和v
int delta=dep[u]-dep[v]; //计算深度差
for(int x=0;x<=max0;x++) //此循环用于提到深度相同。
if((1<<x)&delta)
u=fa[u][x];
if(u==v)return u;
for(int x=max0;x>=0;x--) //注意!此处循环必须是从大到小!因为我们应该越提越“精确”,
if(fa[u][x]!=fa[v][x]) //如果从小到大的话就有可能无法提到正确位置,自己可以多想一下
{
u=fa[u][x];
v=fa[v][x];
}
return fa[u][0]; //此时u、v的第一个父亲就是LCA。
}
倍增还可以有很多变化,这让倍增法可以优更多的变化。比如用data[i][j]记录i到他的第2^j 个父亲的路径长度,就可以边求LCA边求出两点距离,因为data[i][j]满足倍增的递推式:data[i][j]=data[i][j-1]+data[fa[i][j-1]][j-1]。或者用maxlen[i][j]记录i到第2^j个父亲的路径上最长边的边权,它满足maxlen[i][j]=max{maxlen[i][j-1],maxlen[fa[i][j-1]][j-1]},这样就可以快速求出两点路径上最长边的边权……
最后附上一道LCA的模板题代码
#include<cstdio>
#include<vector>
#include<cstring>
using namespace std;
const int maxn=10010;
vector<int> v[maxn<<1];
int deep[maxn],fa[maxn][22];
int sum[maxn];
void dfs(int x){
for(int i=1;i<=21;i++){
if(fa[x][i-1]) fa[x][i]=fa[fa[x][i-1]][i-1];
else break;
}
for(int i=0;i<v[x].size();i++){
int t=v[x][i];
if(t!=fa[x][0]){
fa[t][0]=x;
deep[t]=deep[x]+1;
dfs(t);
}
}
}
int LCA(int u,int v){
if(deep[u]<deep[v]) swap(u,v);
int del=deep[u]-deep[v];
for(int i=0;i<=21;i++)
if((1<<i)&del)
u=fa[u][i];
if(u==v) return u;
for(int i=21;i>=0;i--){
if(fa[u][i]!=fa[v][i]){
u=fa[u][i];
v=fa[v][i];
}
}
return fa[u][0];
}
int main(){
int t;
scanf("%d",&t);
while(t--){
memset(v,0,sizeof(v));
memset(deep,0,sizeof(deep));
memset(fa,0,sizeof(fa));
memset(sum,0,sizeof(sum));
int n;
scanf("%d",&n);
int x,y;
for(int i=1;i<n;i++){
scanf("%d%d",&x,&y);
v[x].push_back(y);
v[y].push_back(x);
sum[y]++;
}
scanf("%d%d",&x,&y);
for(int i=1;i<=n;i++)
if(!sum[i])
dfs(i);
printf("%d\n",LCA(x,y));
}
return 0;
}