牛客练习赛84 E 牛客推荐系统开发之标签重复度
牛客推荐系统开发之标签重复度
https://ac.nowcoder.com/acm/contest/11174/E
考虑使用点分树维护,每次维护经过点分中心的权值。我们考虑计算每个点到点分中心的最小值和最大值,然后合并两个到点分中心的路径。将这样的最大最小值对按最小值排序后,计算贡献。
考虑到后计算到的点最小值一定更大,则只有两种情况。
1.最大值比之前的大,则贡献是之前的最小值乘现在的最大值。
2.最大值比之前的小,则贡献是之前的最小值乘之前的最大值。
因此我们只要按照最大值为下标插入权值即可。
因为维护的是过点分中心的路径,要删除掉同一子树内点的贡献。
复杂度。
#include <cstdio> #include <algorithm> #include <cstring> #include <vector> using namespace std; #define lson rt * 2 #define rson rt * 2 + 1 #define MP make_pair typedef long long ll; void read(int &x) { x = 0; char c = getchar(); while (c < '0' || c > '9') c = getchar(); while (c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar(); } const int N = 1e5 + 100; const int MOD = 998244353; int add(int a, int b) { return a + b >= MOD ? a + b - MOD : a + b; } int mul(int a, int b) { return 1LL * a * b % MOD; } void upd(int &a, int b) { a += b; if (a >= MOD) a -= MOD; } int tp; struct node { int tree[N]; void insert(int id, int x) { for (; id <= tp; id += id & -id) upd(tree[id], x); } int query(int id) { int sum = 0; for (; id > 0; id -= id & -id) upd(sum, tree[id]); return sum; } int query(int l, int r) { return add(query(r), MOD - query(l - 1)); } }t1, t2; int R, Mn, SZ, ans; int siz[N], mn[N]; bool vis[N]; vector<int> V[N]; void getroot(int u, int fa) { siz[u] = 1; mn[u] = 0; for (int v : V[u]) { if (vis[v] || v == fa) continue; getroot(v, u); siz[u] += siz[v]; mn[u] = max(mn[u], siz[v]); } mn[u] = max(mn[u], SZ - siz[u]); if (mn[u] < Mn) R = u, Mn = mn[u]; } int n, tot; int sa[N], has[N]; pair<int, int> res[N]; void dfs(int u, int fa, int mi, int ma) { mi = min(mi, sa[u]); ma = max(ma, sa[u]); res[++tot] = MP(mi, ma); siz[u] = 1; for (int v : V[u]) { if (vis[v] || v == fa) continue; dfs(v, u, mi, ma); siz[u] += siz[v]; } } void cal(int op) { int sum = 0; sort(res + 1, res + tot + 1); for (int i = 1; i <= tot; i++) { int mi = res[i].first, ma = res[i].second; if (ma > 1) upd(sum, mul(has[ma], t1.query(1, ma - 1))); upd(sum, t2.query(ma, n)); t1.insert(ma, has[mi]); t2.insert(ma, mul(has[mi], has[ma])); } for (int i = 1; i <= tot; i++) { int mi = res[i].first, ma = res[i].second; t1.insert(ma, MOD - has[mi]); t2.insert(ma, MOD - mul(has[mi], has[ma])); } if (op > 0) upd(ans, sum); else upd(ans, MOD - sum); } void solve(int u) { tot = 0; upd(ans, mul(has[sa[u]], has[sa[u]])); res[++tot] = MP(sa[u], sa[u]); for (int v : V[u]) if (!vis[v]) dfs(v, u, sa[u], sa[u]); cal(1); for (int v : V[u]) { if (vis[v]) continue; tot = 0; dfs(v, u, sa[u], sa[u]); cal(-1); } vis[u] = true; for (int v : V[u]) { if (vis[v]) continue; SZ = siz[v]; Mn = 1e9; getroot(v, u); solve(R); } } int main() { //freopen("0.txt", "r", stdin); read(n); for (int i = 1; i <= n; i++) read(sa[i]), has[i] = sa[i]; for (int i = 1, a, b; i < n; i++) { read(a); read(b); V[a].push_back(b); V[b].push_back(a); } sort(has + 1, has + n + 1); tp = unique(has + 1, has + n + 1) - has - 1; for (int i = 1; i <= n; i++) sa[i] = lower_bound(has + 1, has + tp + 1, sa[i]) - has; SZ = n; Mn = 1e9; getroot(1, 0); solve(R); printf("%d\n", ans); return 0; }
</int,>