求和(dfs序+线段树)

求和

https://ac.nowcoder.com/acm/problem/204871

题意:
已知有n个节点,有n−1条边,形成一个树的结构。
给定一个根节点k,每个节点都有一个权值,节点i的权值为vi​。
给m个操作,操作有两种类型:
1 a x :表示将节点a的权值加上x
2 a :表示求a节点的子树上所有节点的和(包括a节点本身)
题解:
dfs序+线段树
用dfs序确定in[x]和out[x]的位置,in是当前结点开始的时间戳,out是回溯到当前结点的时间戳,当查询某结点时,可以快速确定他子树的范围。
代码:

/*Keep on going Never give up*/
//#pragma GCC optimize(3,"Ofast","inline")
#include<bits/stdc++.h>
//#define int long long

#define endl '\n'
#define Accepted 0
#define AK main()
#define I_can signed
using namespace std;
const int maxn =1e6+10;
const int MaxN = 0x3f3f3f3f;
const int MinN = 0xc0c0c00c;
typedef long long ll;
const int inf=0x3f3f3f3f;
const ll mod=1e9+7;

vector<int> edge[maxn];
int a[maxn];
int in[maxn],out[maxn],id[maxn],cnt;
int tree[maxn<<2];
void dfs(int u,int fa){
    in[u]=++cnt;
    id[cnt]=u;
    for(auto i:edge[u]){
        if(i==fa) continue;
        dfs(i,u);
    }
    out[u]=cnt;
}

void build(int node,int l,int r){
    if(l==r){
        tree[node]=a[id[l]];
        return ;
    }
    int mid=(l+r)/2;
    build(2*node,l,mid);
    build(2*node+1,mid+1,r);
    tree[node]=tree[2*node]+tree[2*node+1];

}

void update(int node,int l,int r,int pos,int val){
    if(l==r){
       tree[node]+=val;
       return ;
    }
    int mid=(l+r)/2;
    if(pos<=mid) update(node*2,l,mid,pos,val);
    else update(node*2+1,mid+1,r,pos,val);
    tree[node]=tree[node*2]+tree[node*2+1];
}

int query(int node,int l,int r,int start,int ends){
    if(l>=start&&r<=ends){
        return tree[node];
    }
    int ans=0;
    int mid=(l+r)/2;
    if(start<=mid) ans+=query(node*2,l,mid,start,ends);
    if(ends>mid) ans+=query(node*2+1,mid+1,r,start,ends);
    return ans;
}

int main(){

    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cout.tie(nullptr);
    int n,m,root;
    cin>>n>>m>>root;
    for(int i=1;i<=n;i++) cin>>a[i];
    for(int i=0;i<n-1;i++){
        int x,y;
        cin>>x>>y;
        edge[x].push_back(y);
        edge[y].push_back(x);
    }
    dfs(root,-1);
    //for(int i=1;i<=n;i++) cout<<out[i]<<" ";
    build(1,1,n);
    for(int i=0;i<m;i++){
        int opt,x,y;
        cin>>opt;
        if(opt==1){
            cin>>x>>y;
            update(1,1,n,in[x],y);
        }
        else{
            cin>>x;
            //cout<<in[x]<<" "<<out[x]<<endl;
            cout<<query(1,1,n,in[x],out[x])<<endl;
        }
    }
    return 0;
}

题解 文章被收录于专栏

主要写一些题目的题解

全部评论

相关推荐

点赞 评论 收藏
分享
2 收藏 评论
分享
牛客网
牛客企业服务