联合权值

联合权值

https://ac.nowcoder.com/acm/problem/16495

题意:
一棵树
存在两个点x,y,dis(x,y)=2,然后他们的贡献就是a[x]*a[y]
求所有这样点对的贡献的和,还有贡献值最大值

solve:

1.首先明确是一棵树,那么我们只要枚举连接(x,y)的点,然后直接算即可。

#include <bits/stdc++.h>
#include <tr1/unordered_map>
using namespace std;
#define ll long long
//#define ll __int128
#define re register
#define pb push_back
#define fi first
#define se second
const int N=1e6+10;
const int mod7=1e9+7;
const int mod=1e9+7;
void read(int &a)
{
    a=0;int d=1;char ch;
    while(ch=getchar(),ch>'9'||ch<'0')
        if(ch=='-')
            d=-1;
    a=ch^48;
    while(ch=getchar(),ch>='0'&&ch<='9')
        a=(a<<3)+(a<<1)+(ch^48);
    a*=d;
}
void read(ll &a)
{
    a=0;int d=1;char ch;
    while(ch=getchar(),ch>'9'||ch<'0')
        if(ch=='-')
            d=-1;
    a=ch^48;
    while(ch=getchar(),ch>='0'&&ch<='9')
        a=(a<<3)+(a<<1)+(ch^48);
    a*=d;
}
void write(ll x)
{
    if(x<0) putchar(45);
    if(x>9) write(x/10);
    putchar(x%10+'0');
}
int a[N],maxn=0;
ll ans=0;
vector <int> v[N];
int main()
{
    int n;read(n);a[0]=0;
    for(int i=1;i<n;i++)
    {
        int x,y;
        read(x),read(y);
        v[x].pb(y),v[y].pb(x);
    }
    for(int i=1;i<=n;i++) read(a[i]);
    for(int j=1;j<=n;j++)
    {
        priority_queue <int,vector<int>,greater<int> > q;
        int now=0;
        for(auto i:v[j])
        {
            ans+=1ll*now*a[i];
            q.push(a[i]);
            if(q.size()>2) q.pop();
            now+=a[i];
        }
        if(q.size()==2)
        {
            int b=q.top();q.pop();
            int c=q.top();q.pop();
            ans%=10007;maxn=max(maxn,b*c);
        }
    }
    printf("%d %lld",maxn,ans*2%10007);
    return 0;
}

2.在这里我主要讲的是假如不是找dis(x,y)=2,是找dis(x,y)=k (但是我也不知道对不对,欢迎大佬吐槽)

当距离是k的时候,我采用dsu on tree做法
根据公式dis(x,y)=dep[x]+dep[y]-2dep[lca]=k
那么我们遍历 x然后就可以通过公式求得dep[y]
dep[y]=k+2
dep[lca]-dep[x]
我们只要用vis[dep[y]]统计深度是dep[y]的权值和即可,用ans+=a[x]*vis[dep[y]]即可

之后还有一个问题,怎么处理最大值?
我们再用一个数组去记录就可以了,mx[dep[y]]=max(mx[dep[y]],a[y]);
这样我们就能得到我们要求得所有东西

最后把贡献和×2即可

#include <bits/stdc++.h>
#include <tr1/unordered_map>
using namespace std;
#define ll long long
//#define ll __int128
#define re register
#define pb push_back
#define fi first
#define se second
const int N=1e6+10;
const int mod7=1e9+7;
const int mod=1e9+7;
void read(int &a)
{
    a=0;int d=1;char ch;
    while(ch=getchar(),ch>'9'||ch<'0')
        if(ch=='-')
            d=-1;
    a=ch^48;
    while(ch=getchar(),ch>='0'&&ch<='9')
        a=(a<<3)+(a<<1)+(ch^48);
    a*=d;
}
void read(ll &a)
{
    a=0;int d=1;char ch;
    while(ch=getchar(),ch>'9'||ch<'0')
        if(ch=='-')
            d=-1;
    a=ch^48;
    while(ch=getchar(),ch>='0'&&ch<='9')
        a=(a<<3)+(a<<1)+(ch^48);
    a*=d;
}
int a[N],maxn=0,dep[N],son[N],siz[N],id[N],rk[N],cnt,mx[N],k;
ll ans=0,vis[N];
vector <int> v[N];
void dfs(int x,int fa)
{
    dep[x]=dep[fa]+1;siz[x]=1;
    id[x]=++cnt,rk[cnt]=x;
    for(auto i:v[x])
    {
        if(i==fa) continue;
        dfs(i,x);siz[x]+=siz[i];
        if(siz[son[x]]<siz[i]) son[x]=i;
    }
}
void solve(int x,int val)
{
    if(val==1) mx[dep[x]]=max(mx[dep[x]],a[x]),vis[dep[x]]+=a[x];
    else mx[dep[x]]=0,vis[dep[x]]=0;
}
void dfs(int x,int fa,bool keep)
{
    for(auto i:v[x])
    {
        if(i==fa||i==son[x]) continue;
        dfs(i,x,0);
    }
    if(son[x]) dfs(son[x],x,1);
    for(auto i:v[x])
    {
        if(i==fa||i==son[x]) continue;
        for(int j=0;j<siz[i];j++)
        {
            int nxt=rk[id[i]+j];
            int rep=k+2*dep[x]-dep[nxt];
            if(vis[rep]) 
            {
                ans+=1ll*a[nxt]*vis[rep];
                maxn=max(maxn,a[nxt]*mx[rep]);
            }
        }
        for(int j=0;j<siz[i];j++) solve(rk[id[i]+j],1);
    }
    solve(x,1);
    if(vis[k+dep[x]]) ans+=1ll*a[x]*vis[k+dep[x]],maxn=max(maxn,a[x]*mx[k+dep[x]]);
    if(!keep) for(int i=0;i<siz[x];i++) solve(rk[id[x]+i],-1);
}
int main()
{
    int n;read(n);k=2;
    for(int i=1;i<n;i++)
    {
        int x,y;
        read(x),read(y);
        v[x].pb(y),v[y].pb(x);
    }
    for(int i=1;i<=n;i++) read(a[i]);
    dfs(1,0);dfs(1,0,0);
    printf("%d %lld",maxn,ans*2%10007);
    return 0;
}

这个代码要在c++14才能过, c++11只能过60%,很玄学!!! (点分治貌似也能搞)

全部评论

相关推荐

不愿透露姓名的神秘牛友
07-09 16:15
我应届生,去年10月份开始在这家公司实习,到今年10月份正好一年想(实习+试用期),在想要不要提前9月份就离职,这样好找工作些,但又差一个月满一年,又怕10月份国庆回来离职,容易错过了下半年的金九银十,到年底容易gap到年后
小破站_程序员YT:说这家公司不好吧,你干了快一年 说这家公司好吧,你刚毕业就想跑路说你不懂行情吧,你怕错过金九银十说 你懂行情吧,校招阶段在实习,毕业社招想换工作 哥们,我该怎么劝你留下来呢
应届生,你找到工作了吗
点赞 评论 收藏
分享
06-26 17:24
已编辑
宁波大学 golang
迷失西雅图:别给,纯kpi,别问我为什么知道
点赞 评论 收藏
分享
07-09 15:55
门头沟学院 Java
点赞 评论 收藏
分享
评论
点赞
收藏
分享

创作者周榜

更多
牛客网
牛客网在线编程
牛客网题解
牛客企业服务