联合权值

联合权值

https://ac.nowcoder.com/acm/problem/16495

题意:
一棵树
存在两个点x,y,dis(x,y)=2,然后他们的贡献就是a[x]*a[y]
求所有这样点对的贡献的和,还有贡献值最大值

solve:

1.首先明确是一棵树,那么我们只要枚举连接(x,y)的点,然后直接算即可。

#include <bits/stdc++.h>
#include <tr1/unordered_map>
using namespace std;
#define ll long long
//#define ll __int128
#define re register
#define pb push_back
#define fi first
#define se second
const int N=1e6+10;
const int mod7=1e9+7;
const int mod=1e9+7;
void read(int &a)
{
    a=0;int d=1;char ch;
    while(ch=getchar(),ch>'9'||ch<'0')
        if(ch=='-')
            d=-1;
    a=ch^48;
    while(ch=getchar(),ch>='0'&&ch<='9')
        a=(a<<3)+(a<<1)+(ch^48);
    a*=d;
}
void read(ll &a)
{
    a=0;int d=1;char ch;
    while(ch=getchar(),ch>'9'||ch<'0')
        if(ch=='-')
            d=-1;
    a=ch^48;
    while(ch=getchar(),ch>='0'&&ch<='9')
        a=(a<<3)+(a<<1)+(ch^48);
    a*=d;
}
void write(ll x)
{
    if(x<0) putchar(45);
    if(x>9) write(x/10);
    putchar(x%10+'0');
}
int a[N],maxn=0;
ll ans=0;
vector <int> v[N];
int main()
{
    int n;read(n);a[0]=0;
    for(int i=1;i<n;i++)
    {
        int x,y;
        read(x),read(y);
        v[x].pb(y),v[y].pb(x);
    }
    for(int i=1;i<=n;i++) read(a[i]);
    for(int j=1;j<=n;j++)
    {
        priority_queue <int,vector<int>,greater<int> > q;
        int now=0;
        for(auto i:v[j])
        {
            ans+=1ll*now*a[i];
            q.push(a[i]);
            if(q.size()>2) q.pop();
            now+=a[i];
        }
        if(q.size()==2)
        {
            int b=q.top();q.pop();
            int c=q.top();q.pop();
            ans%=10007;maxn=max(maxn,b*c);
        }
    }
    printf("%d %lld",maxn,ans*2%10007);
    return 0;
}

2.在这里我主要讲的是假如不是找dis(x,y)=2,是找dis(x,y)=k (但是我也不知道对不对,欢迎大佬吐槽)

当距离是k的时候,我采用dsu on tree做法
根据公式dis(x,y)=dep[x]+dep[y]-2dep[lca]=k
那么我们遍历 x然后就可以通过公式求得dep[y]
dep[y]=k+2
dep[lca]-dep[x]
我们只要用vis[dep[y]]统计深度是dep[y]的权值和即可,用ans+=a[x]*vis[dep[y]]即可

之后还有一个问题,怎么处理最大值?
我们再用一个数组去记录就可以了,mx[dep[y]]=max(mx[dep[y]],a[y]);
这样我们就能得到我们要求得所有东西

最后把贡献和×2即可

#include <bits/stdc++.h>
#include <tr1/unordered_map>
using namespace std;
#define ll long long
//#define ll __int128
#define re register
#define pb push_back
#define fi first
#define se second
const int N=1e6+10;
const int mod7=1e9+7;
const int mod=1e9+7;
void read(int &a)
{
    a=0;int d=1;char ch;
    while(ch=getchar(),ch>'9'||ch<'0')
        if(ch=='-')
            d=-1;
    a=ch^48;
    while(ch=getchar(),ch>='0'&&ch<='9')
        a=(a<<3)+(a<<1)+(ch^48);
    a*=d;
}
void read(ll &a)
{
    a=0;int d=1;char ch;
    while(ch=getchar(),ch>'9'||ch<'0')
        if(ch=='-')
            d=-1;
    a=ch^48;
    while(ch=getchar(),ch>='0'&&ch<='9')
        a=(a<<3)+(a<<1)+(ch^48);
    a*=d;
}
int a[N],maxn=0,dep[N],son[N],siz[N],id[N],rk[N],cnt,mx[N],k;
ll ans=0,vis[N];
vector <int> v[N];
void dfs(int x,int fa)
{
    dep[x]=dep[fa]+1;siz[x]=1;
    id[x]=++cnt,rk[cnt]=x;
    for(auto i:v[x])
    {
        if(i==fa) continue;
        dfs(i,x);siz[x]+=siz[i];
        if(siz[son[x]]<siz[i]) son[x]=i;
    }
}
void solve(int x,int val)
{
    if(val==1) mx[dep[x]]=max(mx[dep[x]],a[x]),vis[dep[x]]+=a[x];
    else mx[dep[x]]=0,vis[dep[x]]=0;
}
void dfs(int x,int fa,bool keep)
{
    for(auto i:v[x])
    {
        if(i==fa||i==son[x]) continue;
        dfs(i,x,0);
    }
    if(son[x]) dfs(son[x],x,1);
    for(auto i:v[x])
    {
        if(i==fa||i==son[x]) continue;
        for(int j=0;j<siz[i];j++)
        {
            int nxt=rk[id[i]+j];
            int rep=k+2*dep[x]-dep[nxt];
            if(vis[rep]) 
            {
                ans+=1ll*a[nxt]*vis[rep];
                maxn=max(maxn,a[nxt]*mx[rep]);
            }
        }
        for(int j=0;j<siz[i];j++) solve(rk[id[i]+j],1);
    }
    solve(x,1);
    if(vis[k+dep[x]]) ans+=1ll*a[x]*vis[k+dep[x]],maxn=max(maxn,a[x]*mx[k+dep[x]]);
    if(!keep) for(int i=0;i<siz[x];i++) solve(rk[id[x]+i],-1);
}
int main()
{
    int n;read(n);k=2;
    for(int i=1;i<n;i++)
    {
        int x,y;
        read(x),read(y);
        v[x].pb(y),v[y].pb(x);
    }
    for(int i=1;i<=n;i++) read(a[i]);
    dfs(1,0);dfs(1,0,0);
    printf("%d %lld",maxn,ans*2%10007);
    return 0;
}

这个代码要在c++14才能过, c++11只能过60%,很玄学!!! (点分治貌似也能搞)

全部评论

相关推荐

秋国🐮🐴:拿到你简历编号然后让你知道世间险恶
点赞 评论 收藏
分享
评论
点赞
收藏
分享

创作者周榜

更多
牛客网
牛客企业服务