"蔚来杯"2022牛客暑期多校训练营3 A题
英文题面:
NIO is playing a game about trees.
The game has two trees A, BA,B each with NN vertices. The vertices in each tree are numbered from 11 to NN and the ii-th vertex has the weight v_ivi. The root of each tree is vertex 1. Given KK key numbers x_1,\dots,x_kx1,…,xk, find the number of solutions that remove exactly one number so that the weight of the lowest common ancestor of the vertices in A with the remaining key numbers is greater than the weight of the lowest common ancestor of the vertices in B with the remaining key numbers.
中文题面:
NIO 正在玩一个关于树木的游戏。
游戏有两棵树甲,乙一个,乙每个与ññ顶点。每棵树中的顶点编号从11至ññ和一世一世-th 顶点有权重v_iv一世. 每棵树的根是顶点 1。给定ķķ关键数字x_1,\点,x_kX1,…,Xķ, 找到恰好去除一个数字的解的数量,使得 A 中具有剩余键号的顶点的最低共同祖先的权重大于 B 中具有剩余键号的顶点的最低共同祖先的权重.
题解:
首先我们考虑如何去求k个节点的lca,即我们可以求出k个点的时间戳序列,按照时间戳放入数组K中,这个时候的lca就是lca(K[0],K[k-1]),之后我们枚举删除哪个数字即可,如果等于边界值,那么0变成1或者k-1变成k-2这样,每次判断lcaa和lcab的大小关系,符合条件加入答案即可,最后的时间复杂度O(nlogn)。
代码:
#include<bits/stdc++.h>
using namespace std;
const int N=2e5+50;
int n,k,la,lb;
int va[N],vb[N],dpa[N],dpb[N],dfa[N],dfb[N],lg[N];
int fa[N][100],fb[N][100];
vector<int>ea[N],eb[N];
set<int>s;
void dfsa(int u,int fath)
{
dpa[u]=dpa[fath]+1,fa[u][0]=fath;
if(s.count(u)) dfa[++la]=u;
for(int i=1;i<=lg[dpa[u]];i++) fa[u][i]=fa[fa[u][i-1]][i-1];
for(int i=0;i<ea[u].size();i++){
int v=ea[u][i];
dfsa(v,u);
}
}
void dfsb(int u,int fath)
{
dpb[u]=dpb[fath]+1,fb[u][0]=fath;
if(s.count(u)) dfb[++lb]=u;
for(int i=1;i<=lg[dpb[u]];i++) fb[u][i]=fb[fb[u][i-1]][i-1];
for(int i=0;i<eb[u].size();i++){
int v=eb[u][i];
dfsb(v,u);
}
}
int lcaa(int x,int y)
{
if(dpa[x]<dpa[y]) swap(x,y);
while(dpa[x]>dpa[y]) x=fa[x][lg[dpa[x]-dpa[y]]-1];
if(x==y) return x;
for(int i=lg[dpa[x]]-1;i>=0;i--){
if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
}
return fa[x][0];
}
int lcab(int x,int y)
{
if(dpb[x]<dpb[y]) swap(x,y);
while(dpb[x]>dpb[y]) x=fb[x][lg[dpb[x]-dpb[y]]-1];
if(x==y) return x;
for(int i=lg[dpb[x]]-1;i>=0;i--){
if(fb[x][i]!=fb[y][i]) x=fb[x][i],y=fb[y][i];
}
return fb[x][0];
}
int main()
{
scanf("%d%d",&n,&k);
for(int i=1;i<=k;i++){
int x;
scanf("%d",&x);
s.insert(x);
}
for(int i=1;i<=n;i++) scanf("%d",&va[i]);
for(int i=2;i<=n;i++){
int u;
scanf("%d",&u);
ea[u].push_back(i);
}
for(int i=1;i<=n;i++) scanf("%d",&vb[i]);
for(int i=2;i<=n;i++){
int u;
scanf("%d",&u);
eb[u].push_back(i);
}
for(int i=1;i<=n;i++){
lg[i]=lg[i-1]+(1<<lg[i-1]==i);
}
dfsa(1,0);
dfsb(1,0);
int ans=0;
for(auto it:s){
int aa,bb;
if(dfa[1]==it) aa=lcaa(dfa[2],dfa[k]);
else if(dfa[k]==it) aa=lcaa(dfa[1],dfa[k-1]);
else aa=lcaa(dfa[1],dfa[k]);
if(dfb[1]==it) bb=lcab(dfb[2],dfb[k]);
else if(dfb[k]==it) bb=lcab(dfb[1],dfb[k-1]);
else bb=lcab(dfb[1],dfb[k]);
if(va[aa]>vb[bb]) ans++;
}
printf("%d\n",ans);
return 0;
}