联合权值
联合权值
https://ac.nowcoder.com/acm/problem/16495
题意:
一棵树
存在两个点x,y,dis(x,y)=2,然后他们的贡献就是a[x]*a[y]
求所有这样点对的贡献的和,还有贡献值最大值
solve:
1.首先明确是一棵树,那么我们只要枚举连接(x,y)的点,然后直接算即可。
#include <bits/stdc++.h> #include <tr1/unordered_map> using namespace std; #define ll long long //#define ll __int128 #define re register #define pb push_back #define fi first #define se second const int N=1e6+10; const int mod7=1e9+7; const int mod=1e9+7; void read(int &a) { a=0;int d=1;char ch; while(ch=getchar(),ch>'9'||ch<'0') if(ch=='-') d=-1; a=ch^48; while(ch=getchar(),ch>='0'&&ch<='9') a=(a<<3)+(a<<1)+(ch^48); a*=d; } void read(ll &a) { a=0;int d=1;char ch; while(ch=getchar(),ch>'9'||ch<'0') if(ch=='-') d=-1; a=ch^48; while(ch=getchar(),ch>='0'&&ch<='9') a=(a<<3)+(a<<1)+(ch^48); a*=d; } void write(ll x) { if(x<0) putchar(45); if(x>9) write(x/10); putchar(x%10+'0'); } int a[N],maxn=0; ll ans=0; vector <int> v[N]; int main() { int n;read(n);a[0]=0; for(int i=1;i<n;i++) { int x,y; read(x),read(y); v[x].pb(y),v[y].pb(x); } for(int i=1;i<=n;i++) read(a[i]); for(int j=1;j<=n;j++) { priority_queue <int,vector<int>,greater<int> > q; int now=0; for(auto i:v[j]) { ans+=1ll*now*a[i]; q.push(a[i]); if(q.size()>2) q.pop(); now+=a[i]; } if(q.size()==2) { int b=q.top();q.pop(); int c=q.top();q.pop(); ans%=10007;maxn=max(maxn,b*c); } } printf("%d %lld",maxn,ans*2%10007); return 0; }
2.在这里我主要讲的是假如不是找dis(x,y)=2,是找dis(x,y)=k (但是我也不知道对不对,欢迎大佬吐槽)
当距离是k的时候,我采用dsu on tree做法
根据公式dis(x,y)=dep[x]+dep[y]-2dep[lca]=k
那么我们遍历 x然后就可以通过公式求得dep[y]
dep[y]=k+2dep[lca]-dep[x]
我们只要用vis[dep[y]]统计深度是dep[y]的权值和即可,用ans+=a[x]*vis[dep[y]]即可
之后还有一个问题,怎么处理最大值?
我们再用一个数组去记录就可以了,mx[dep[y]]=max(mx[dep[y]],a[y]);
这样我们就能得到我们要求得所有东西
最后把贡献和×2即可
#include <bits/stdc++.h> #include <tr1/unordered_map> using namespace std; #define ll long long //#define ll __int128 #define re register #define pb push_back #define fi first #define se second const int N=1e6+10; const int mod7=1e9+7; const int mod=1e9+7; void read(int &a) { a=0;int d=1;char ch; while(ch=getchar(),ch>'9'||ch<'0') if(ch=='-') d=-1; a=ch^48; while(ch=getchar(),ch>='0'&&ch<='9') a=(a<<3)+(a<<1)+(ch^48); a*=d; } void read(ll &a) { a=0;int d=1;char ch; while(ch=getchar(),ch>'9'||ch<'0') if(ch=='-') d=-1; a=ch^48; while(ch=getchar(),ch>='0'&&ch<='9') a=(a<<3)+(a<<1)+(ch^48); a*=d; } int a[N],maxn=0,dep[N],son[N],siz[N],id[N],rk[N],cnt,mx[N],k; ll ans=0,vis[N]; vector <int> v[N]; void dfs(int x,int fa) { dep[x]=dep[fa]+1;siz[x]=1; id[x]=++cnt,rk[cnt]=x; for(auto i:v[x]) { if(i==fa) continue; dfs(i,x);siz[x]+=siz[i]; if(siz[son[x]]<siz[i]) son[x]=i; } } void solve(int x,int val) { if(val==1) mx[dep[x]]=max(mx[dep[x]],a[x]),vis[dep[x]]+=a[x]; else mx[dep[x]]=0,vis[dep[x]]=0; } void dfs(int x,int fa,bool keep) { for(auto i:v[x]) { if(i==fa||i==son[x]) continue; dfs(i,x,0); } if(son[x]) dfs(son[x],x,1); for(auto i:v[x]) { if(i==fa||i==son[x]) continue; for(int j=0;j<siz[i];j++) { int nxt=rk[id[i]+j]; int rep=k+2*dep[x]-dep[nxt]; if(vis[rep]) { ans+=1ll*a[nxt]*vis[rep]; maxn=max(maxn,a[nxt]*mx[rep]); } } for(int j=0;j<siz[i];j++) solve(rk[id[i]+j],1); } solve(x,1); if(vis[k+dep[x]]) ans+=1ll*a[x]*vis[k+dep[x]],maxn=max(maxn,a[x]*mx[k+dep[x]]); if(!keep) for(int i=0;i<siz[x];i++) solve(rk[id[x]+i],-1); } int main() { int n;read(n);k=2; for(int i=1;i<n;i++) { int x,y; read(x),read(y); v[x].pb(y),v[y].pb(x); } for(int i=1;i<=n;i++) read(a[i]); dfs(1,0);dfs(1,0,0); printf("%d %lld",maxn,ans*2%10007); return 0; }
这个代码要在c++14才能过, c++11只能过60%,很玄学!!! (点分治貌似也能搞)