题解 | #小 Q 与异或#

小 Q 与树

https://ac.nowcoder.com/acm/contest/11171/D

题目大意

给定一棵包含 个节点的树,每个节点有个权值

解题思路

对于节点

  • 记权值小于 的节点有
  • 记权值大于等于 的节点有

那么节点 对答案的贡献为:

即:

定义 为当前子树的根,那么

开四棵权值树状数组,分别用来维护

然后跑一遍 即可

AC_Code

#include<bits/stdc++.h>
#define int long long
using namespace std;
template<typename T>void read(T &res)
{
    bool flag=false;
    char ch;
    while(!isdigit(ch=getchar()))(ch=='-')&&(flag=true);
    for(res=ch-48; isdigit(ch=getchar()); res=(res<<1)+(res<<3)+ch - 48);
    flag&&(res=-res);
}
template<typename T>void Out(T x)
{
    if(x<0)putchar('-'),x=-x;
    if(x>9)Out(x/10);
    putchar(x%10+'0');
}
const int N = 2e5 + 10 , mod = 998244353;
int n , ans , a[N] , dep[N] , sz[N] , HH , hson[N] , M;
struct Edge{
    int nex , to;
} edge[N << 1];
int head[N] , TOT;
void add_edge(int u , int v)
{
    edge[++ TOT].nex = head[u];
    edge[TOT].to = v;
    head[u] = TOT;
}
struct TR{
    int tr[N];
    int lowbit(int x){
        return x & (-x);
    }
    void add(int pos , int val)
    {
        while(pos <= M)
        {
            tr[pos] = (tr[pos] + val + mod) % mod;
            pos += lowbit(pos);
        }
    }
    int query(int pos)
    {
        int res = 0;
        while(pos)
        {
            res += tr[pos];
            res %= mod;
            pos -= lowbit(pos);
        }
        return res;
    }
    int get_sum(int L , int R){
        return (query(R) - query(L - 1) + mod) % mod;
    }
} tree1 , tree2 , tr1 , tr2;
vector<int>vec;
int get_id(int x){
    return lower_bound(vec.begin() , vec.end() , x) - vec.begin() + 1;
}
void dfs(int u , int far)
{
    dep[u] = dep[far] + 1 , sz[u] = 1;
    for(int i = head[u] ; i ; i = edge[i].nex)
    {
        int v = edge[i].to;
        if(v == far) continue ;
        dfs(v , u);
        sz[u] += sz[v];
        if(sz[v] > sz[hson[u]]) hson[u] = v;
    }
}
void change(int u , int far , int val)
{
    tree1.add(a[u] , dep[u] * val);
    tree2.add(a[u] , vec[a[u] - 1] * dep[u] * val);
    tr1.add(a[u] , val);
    tr2.add(a[u] , val * vec[a[u] - 1]);
    for(int i = head[u] ; i ; i = edge[i].nex)
    {
        int v = edge[i].to;
        if(v == far || v == HH) continue ;
        change(v , u , val);
    }
}
void calc(int u , int far , int rt)
{
    int cnt = tr1.get_sum(a[u] , M);
    int sum = tree1.get_sum(a[u] , M);
    int mi = vec[a[u] - 1];
        ans += mi * dep[u] * cnt + mi * sum;
        ans -= mi * cnt * 2 * dep[rt];
        ans = (ans + mod) % mod;
    sum = tree2.get_sum(1 , a[u] - 1);
    cnt = tr2.get_sum(1 , a[u] - 1);
        ans += sum + cnt * dep[u];
        ans -= cnt * 2 * dep[rt];
        ans = (ans + mod) % mod;
    for(int i = head[u] ; i ; i = edge[i].nex)
    {
        int v = edge[i].to;
        if(v == far || v == HH) continue ;
        calc(v , u , rt);
    }
}
void dsu(int u , int far , int op)
{
    for(int i = head[u] ; i ; i = edge[i].nex)
    {
        int v = edge[i].to;
        if(v == far || v == hson[u]) continue ;
        dsu(v , u , 0);
    }
    if(hson[u]) dsu(hson[u] , u , 1) , HH = hson[u];
    for(int i = head[u] ; i ; i = edge[i].nex)
    {
        int v = edge[i].to;
        if(v == far || v == HH) continue;
        calc(v , u , u) , change(v , u , 1);
    }
    int cnt = tr1.get_sum(a[u] , M);
    int sum = tree1.get_sum(a[u] , M);
    int mi = vec[a[u] - 1];
        ans += mi * dep[u] * cnt + mi * sum;
        ans -= mi * cnt * 2 * dep[u];
        ans = (ans + mod) % mod;
    sum = tree2.get_sum(1 , a[u] - 1);
    cnt = tr2.get_sum(1 , a[u] - 1);
        ans += sum + cnt * dep[u];
        ans -= cnt * 2 * dep[u];
        ans = (ans + mod) % mod;
    tree1.add(a[u] , dep[u]);
    tree2.add(a[u] , vec[a[u] - 1] * dep[u]);
    tr1.add(a[u] , 1);
    tr2.add(a[u] , vec[a[u] - 1]);
    HH = 0;
    if(!op) change(u , far , -1);
}
signed main()
{
    read(n);
    for(int i = 1 ; i <= n ; i ++) read(a[i]) , vec.push_back(a[i]);
    for(int i = 1 ; i <  n ; i ++)
    {
        int u , v;
        read(u) , read(v);
        add_edge(u , v) , add_edge(v , u);
    }
    sort(vec.begin() , vec.end());
    vec.erase(unique(vec.begin() , vec.end()) , vec.end());
    for(int i = 1 ; i <= n ; i ++) a[i] = get_id(a[i]);
    M = vec.size();
    dfs(1 , 0);
    dsu(1 , 0 , 1);
    Out(ans * 2 % mod) , puts("");
    return 0;
}
全部评论

相关推荐

10-30 23:23
已编辑
中山大学 Web前端
去B座二楼砸水泥地:这无论是个人素质还是专业素质都👇拉满了吧
点赞 评论 收藏
分享
10-07 23:57
已编辑
电子科技大学 Java
八街九陌:博士?客户端?开发?啊?
点赞 评论 收藏
分享
1 收藏 评论
分享
牛客网
牛客企业服务