洛谷p4178 点分治+线段树
题目链接:洛谷p4178
这道题目是点分治的题目,与模板题不一样的是,这是要统计小于路径长度为k的点对数。我们分析下与模板题的区别,模板要求的是路径等于k的长度的路径是否存在,然后统计的方法是对于每个重心(跟),求出各点与根的路径,然后重心下面的各个子树中,相同的子树不去统计答案,去计算之前非该子树的路径能否凑出k的路径,也就是吧路径分成了(当前子树的路径与非当前子树的路径),然后其实我们也可以利用这个模板题的思想,一旦当前子树的路径与非当前子树的路径能凑出小于k的路径,那么这就是一个就是符合条件的点对,我们可以对于每种路径的数目去记录一下,就是记录有多少条路径可以达到这个路径数,但是答案就是加上这个长度路径的点数。
for (register int x = 1; x <= cnt; x++) {
if (temp[x] > m)continue;
for (register int y = temp[x]; y <= m; y++)
ans += judge[y - temp[x]];
}
temp保存的就是当前子树的路径的长度,然后judge[i]就是保存的存在i这条路径的点的数目。我们去枚举[1,k](代码是有个小优化,因为一些长度是非法的无法更新答案)所有的可能路径,然后去加起来。
这样就足够了吗?并不是,为什么?因为我们算答案只是算了当前子树与非当前子树的路径能不能凑,当前子树的路径我们没去统计,所以其实就是只要在最后,我们把所有judge数组里面[1-k]的点数都加到ans上面去就可以了,因为judge存的就是当前子树的所有路径,然后当前子树只要存在小于k的也就是一个点对。
while (!que.empty())
{
ans += judge[que.front()];
judge[que.front()] = 0;
que.pop();
}
好了如果这样子,他普通的就只有50分,(inline,register,快读乱飞可以达到7,80),因为超时;为什么?我们来看这个统计答案的代码
for (register int x = 1; x <= cnt; x++) {
if (temp[x] > m)continue;
for (register int y = temp[x]; y <= m; y++)
ans += judge[y - temp[x]];
}
这个统计答案是一个平方级别的统计答案,所以爆掉了,怎么优化?我们思考一下,这第二个for算的其实就是
judge[1]+judge[2] + judge[3]+…+judge[m-temp[x]];
这其实就是一个前缀和对不对,然后我们去每次更新judge数组的时候都是单点更新(找到一条路径, 在路径长度的judge下面++);这里就可以用树状数组或者线段树去优化得答案了,就是一个简单的单点更新区间查询。好了这题就做好了。
(用树状数组wa了4个小数据我也不知道为什么,一改线段树就a)
看起来代码很长~~(线段树太长了!)~~ ,其实都是模板套上去然后稍稍改进下(比如judge数组从记录边是否存在变成了存在这条边的数目)
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<queue>
#include<cstring>
using namespace std;
const long long max_ = 100000 + 7;
const int inf = 1e9 + 7;
inline int read() {
int s = 0, f = 1;
char ch = getchar();
while (ch<'0' || ch>'9') {
if (ch == '-')
f = -1;
ch = getchar();
}
while (ch >= '0'&&ch <= '9') {
s = s * 10 + ch - '0';
ch = getchar();
}
return s * f;
}
int head[max_], xiann = 2, n, m, sum, root, vis[max_], size_[max_], maxpart[max_], judge[max_], temp[max_], dis[max_], cnt = 0, ask[max_], ans = 0;
struct k {
int next, to, value;
}xian[max_ * 2];
inline void add_(int a, int b, int c) {
xian[xiann].next = head[a];
xian[xiann].to = b;
xian[xiann].value = c;
head[a] = xiann;
xiann++;
}
inline void getroot(int now, int fa) {
size_[now] = 1, maxpart[now] = 0;
for (register int i = head[now]; i; i = xian[i].next) {
int to = xian[i].to;
if (to == fa || vis[to])continue;
getroot(to, now);
size_[now] += size_[to];
maxpart[now] = max(maxpart[now], size_[to]);
}
maxpart[now] = max(maxpart[now], sum - size_[now]);
if (maxpart[root] > maxpart[now])root = now;
}
inline void getdis(int now, int fa) {
temp[++cnt] = dis[now];
for (register int i = head[now]; i; i = xian[i].next) {
int to = xian[i].to;
if (vis[to] || to == fa)continue;
dis[to] = dis[now] + xian[i].value;
getdis(to, now);
}
}
struct kk {
int lazy, sum, Lson, Rson;
}tree_[max_ * 4];
void build(int L, int R, int node) {
tree_[node].Lson = L;
tree_[node].Rson = R;
if (L == R) {
tree_[node].sum = 0;
tree_[node].lazy = 0;
return;
}
int mid = L + (R - L) / 2, L_tree = node * 2, R_tree = node * 2 + 1;
build(L, mid, L_tree);
build(mid + 1, R, R_tree);
tree_[node].sum = tree_[L_tree].sum + tree_[R_tree].sum;
}
void pushdown_(int node, int vv) {
tree_[node].lazy += vv;
tree_[node].sum += tree_[node].lazy * (tree_[node].Rson - tree_[node].Lson + 1);
tree_[node * 2].lazy += tree_[node].lazy;
tree_[node * 2 + 1].lazy += tree_[node].lazy;
tree_[node].lazy = 0;
}
void update(int L, int R, int node, int aimL, int aimR, int vv) {
if (tree_[node].lazy != 0)pushdown_(node, 0);
if (L > aimR || R < aimL) return;
if (L >= aimL && R <= aimR) {
pushdown_(node, vv);
return;
}
int mid = L + (R - L) / 2, L_tree = node * 2, R_tree = node * 2 + 1;
update(L, mid, L_tree, aimL, aimR, vv);
update(mid + 1, R, R_tree, aimL, aimR, vv);
tree_[node].sum = tree_[L_tree].sum + tree_[R_tree].sum;
}
int askk(int L, int R, int node, int aimL, int aimR) {
if (tree_[node].lazy != 0)pushdown_(node, 0);
if (L > aimR || R < aimL) return 0;
if (L >= aimL && R <= aimR) {
return tree_[node].sum;
}
int mid = L + (R - L) / 2, L_tree = node * 2, R_tree = node * 2 + 1;
int t1 = askk(L, mid, L_tree, aimL, aimR),
t2 = askk(mid + 1, R, R_tree, aimL, aimR);
return t1 + t2;
}
inline void solve(int now) {
queue<int> que;
for (register int i = head[now]; i; i = xian[i].next) {
int to = xian[i].to;
if (vis[to]) continue;
dis[to] = xian[i].value;//dis数组是表示到当前根的距离;
cnt = 0;
getdis(to, now);
for (register int x = 1; x <= cnt; x++) {
if (temp[x] > m)continue;
ans += askk(1, m, 1, 1, m - temp[x]);
// for (register int y = temp[x] + 1; y <= m; y++)
// ans += judge[y - temp[x]];
}
for (register int x = 1; x <= cnt; x++) {
if (temp[x] > m)continue;
que.push(temp[x]);
judge[temp[x]] += 1;
update(1, m, 1, temp[x], temp[x], 1);
}
}
while (!que.empty())
{
if (que.front() <= m)
ans += judge[que.front()];
update(1, m, 1, que.front(), que.front(), -judge[que.front()]);
judge[que.front()] = 0;
que.pop();
}
}
inline void divide(int now) {
vis[now] = 1;
solve(now);
for (register int i = head[now]; i; i = xian[i].next) {
int to = xian[i].to;
if (vis[to])continue;
maxpart[root = 0] = sum = size_[to];
getroot(to, 0);
getroot(root, 0);
divide(root);
}
}
int main() {
n = read();
for (register int i = 2; i <= n; i++) {
int a = read(), b = read(), c = read();
add_(a, b, c);
add_(b, a, c);
}
m = read();
build(1, m, 1);
maxpart[0] = sum = n;
root = 0;
getroot(1, 0);
getroot(root, 0);
divide(root);
printf("%d", ans);
return 0;
}