题解 | #I题,标准的LCA+树上差分边源码#

你也喜欢数学吗

https://ac.nowcoder.com/acm/contest/61132/A

I题,标准的LCA+树上差分边源码

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
#include <iostream>
using namespace std;
int fa[100001][20];
int deep[100001];
long long diff[100001];
//倍增lca
int lca(int x, int y){
    int i=0;
    while(1){
        if(x==y) return x;
        if (fa[x][0] == fa[y][0]) return fa[x][0];
        if (deep[x] == deep[y]){
            i=0;
            while(fa[x][i] != fa[y][i]) ++i;
            if (i>0) --i;
            x= fa[x][i];y=fa[y][i];
        }else{
            if (deep[x]<deep[y]) swap(x,y);
            i=0;
            while(deep[fa[x][i]] > deep[y]) ++i;
            if (deep[fa[x][i]]<deep[y]) --i;
            x= fa[x][i];
        }
        
    }
    return x;
}
void my_ans(){
    int i,j,ij,n,m,q,x,y,z;
    cin>>n>>m>>q;
    vector<vector<int>> vc[n+1];
    for(i=1;i<=n;++i) {
        diff[i] = 0;deep[i] = 0;
    }
    memset(fa,0,sizeof(fa));
    for(i=1;i<n;++i){
        cin>>x>>y>>z;
        vc[x].push_back({y,z});vc[y].push_back({x,z});
    }
    //建树
    vector<int> point;
    point.push_back(1);deep[1] = 1;i=0;
    while(i<point.size()){
        x = point[i];++i;
        for(j=0;j<vc[x].size();++j) if (deep[vc[x][j][0]] ==0){
            y = vc[x][j][0];z= vc[x][j][1];
            diff[y]+=z;diff[x]-=z;
            fa[y][0] = x;deep[y] = deep[x]+1;point.push_back(y);
            for(ij=1;ij<20;++ij) fa[y][ij] = fa[fa[y][ij-1]][ij-1];
        }
    }
    //差分边加权
    for(i=0;i<m;++i){
        cin>>x>>y>>z;
        if (x!=y){
            diff[x]+=z;diff[y]+=z;diff[lca(x,y)]-=2*z;            
        }
    }
    //计算边权值
    for(i=n-1;i>=0;--i){
        x= point[i];
        if (fa[x][0] !=0) diff[fa[x][0]] += diff[x];
    }
    //计算点到树根的边权总和
    for(i=0;i<n;++i) diff[point[i]] += diff[fa[point[i]][0]]; 
    //计算点到点的边权总和
    for(i=0;i<q;++i){
        cin>>x>>y;
        cout<<diff[x]+diff[y] - 2*diff[lca(x,y)]<<endl;
    }
    return;
}
int main() {
    ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
    int t=1,i,j;
    //cin>>t;
    while(t>0){
        --t;my_ans();
    }
    return 0;
}
// 64 位输出请用 printf("%lld")
全部评论

相关推荐

11-03 14:38
重庆大学 Java
AAA求offer教程:我手都抬起来了又揣裤兜了
点赞 评论 收藏
分享
1 收藏 评论
分享
牛客网
牛客企业服务