2019icpc沈阳网络赛 D. Fish eating fruit (点分治)
大意:让你求树上%3后分别为0,1,2的所有路径之和。
思路:树上路径分为两种,过根节点的,和 在子树的,所以直接点分治,在计算每个根节点的贡献的时候,统计一下%3分别为0 1 2 的路径数,和 路径和,然后每次直接算一下贡献就好了。
细节见代码:
#include<bits/stdc++.h>
#define fi first
#define se second
#define pb push_back
#define mp make_pair
#define LL long long
#define SZ(X) X.size()
#define pii pair<int,int>
#define ALL(X) X.begin(),X.end()
using namespace std;
LL gcd(LL a, LL b) {return b ? gcd(b, a % b) : a;}
LL lcm(LL a, LL b) {return a / gcd(a, b) * b;}
LL powmod(LL a, LL b, LL MOD) {LL ans = 1; while (b) {if (b % 2)ans = ans * a % MOD; a = a * a % MOD; b /= 2;} return ans;}
const int N = 1e4 + 11;
int n, m;
vector<pii>v[N];
int q[N];
int rt, son[N], sz[N], vis[N];
void root(int now, int pre) { //找重心
sz[now] = 1;
son[now] = 0;
for (pii k : v[now]) {
if (vis[k.fi] || k.fi == pre)continue;
root(k.fi, now);
sz[now] += sz[k.fi];
son[now] = max(son[now], sz[k.fi]);
}
son[now] = max(son[now], n - sz[now]);
if (!rt || son[rt] > son[now]) {
rt = now;
}
return ;
}
int a[N];
LL md[N], cnt;
LL A[3],B[5],C[5];
void getdis(int now, int pre, LL di) {
md[++cnt] = di;
A[di%3]+=di;
B[di%3]++;
for (pii k : v[now]) {
if (vis[k.fi] || k.fi == pre)continue;
getdis(k.fi, now, di + k.se);
}
}
LL qa[4], tmp;
LL AA,BB,CC;
LL mod=1e9+7;
void get(int now, int pre) {
for (pii k : v[now]) {
if (k.fi == pre || vis[k.fi])continue;
cnt = 0;
for(int i=0;i<3;i++)A[i]=0,B[i]=0;
getdis(k.fi, now, k.se);
for (int i = 1; i <= cnt; ++i) {
if(md[i]%3==0){
AA+=md[i]*C[0]%mod+qa[0];
BB+=md[i]*C[1]%mod+qa[1];
CC+=md[i]*C[2]%mod+qa[2];
}else if(md[i]%3==1){
AA+=md[i]*C[2]%mod+qa[2];
BB+=md[i]*C[0]%mod+qa[0];
CC+=md[i]*C[1]%mod+qa[1];
}else{
AA+=md[i]*C[1]%mod+qa[1];
BB+=md[i]*C[2]%mod+qa[2];
CC+=md[i]*C[0]%mod+qa[0];
}
}
for(int i=0;i<3;i++){
qa[i]+=A[i];
qa[i]%=mod;
C[i]+=B[i];
}
}
}
void dfs(int now){
vis[now] = 1;
for(int i=0;i<3;i++)qa[i]=0,C[i]=0;
C[0]=1;
get(now, 0);
for (pii k : v[now]) {
if (vis[k.fi])continue;
rt = 0;
n = sz[k.fi];
root(k.fi, 0);
dfs(rt);
}
return ;
}
int main() {
while(scanf("%d", &n)==1){
rt=0;
for(int i=1;i<=n;i++)v[i].clear(),vis[i]=0;
for (int i = 1; i < n; i++) {
int s, t, w;
scanf("%d%d%d", &s, &t, &w);
s++;
t++;
v[s].pb({t, w});
v[t].pb({s, w});
}
root(1, 0);
root(rt, 0);
dfs(rt);
printf("%lld %lld %lld\n",(AA*2ll)%mod,(BB*2ll)%mod,(CC*2ll)%mod);
AA=0;
BB=0;
CC=0;
}
return 0;
}