BZOJ 3196 洛谷 P3380 二逼平衡树(线段树套伸展树)
Description:
您需要写一种数据结构(可参考题目标题),来维护一个有序数列,其中需要提供以下操作:
- 查询k在区间内的排名
- 查询区间内排名为k的值
- 修改某一位值上的数值
- 查询k在区间内的前驱(前驱定义为严格小于x,且最大的数,若不存在输出-2147483647)
- 查询k在区间内的后继(后继定义为严格大于x,且最小的数,若不存在输出2147483647)
Input:
第一行两个数 n,m 表示长度为n的有序序列和m个操作
第二行有n个数,表示有序序列
下面有m行,opt表示操作标号
若opt=1 则为操作1,之后有三个数l,r,k 表示查询k在区间[l,r]的排名
若opt=2 则为操作2,之后有三个数l,r,k 表示查询区间[l,r]内排名为k的数
若opt=3 则为操作3,之后有两个数pos,k 表示将pos位置的数修改为k
若opt=4 则为操作4,之后有三个数l,r,k 表示查询区间[l,r]内k的前驱
若opt=5 则为操作5,之后有三个数l,r,k 表示查询区间[l,r]内k的后继
n,m≤5×104 ,题目保证序列中每个数的数据范围: [0,108]
Output:
对于操作1,2,4,5各输出一行,表示查询结果
Sample Input:
9 6
4 2 2 1 9 4 0 1 1
2 1 4 3
3 4 10
2 1 4 3
1 2 5 9
4 3 9 5
5 2 8 5
Sample Output:
2
4
3
4
9
题目链接 题目链接
题目要求使一个有序序列支持查询区间第 k 小、区间内前驱、区间内后继、区间内排名以及单点修改
首先序列第 k 小、前驱、后继以及排名显然可以用伸展树来实现,但是题目要求实现的操作是区间内查询,对于这种情况伸展树就无能为力了,而区间操作又可以用线段树来实现,所以此题最经典的做法就是线段树套伸展树
这是我写的第一道树套树,刚开始学习线段树套伸展树的时候看到有这么一句话
线段树上每个节点都是一颗伸展树
这很好理解,所以我就把 Splay Tree 封装成一个类并将其作为线段树的节点开到 4 倍伸展树数组
显然这样的做***造成内存爆炸,后来才发现这里树套树确实是线段树上每个节点都是一颗伸展树,但是线段树所维护的信息只是伸展树的根节点信息,而伸展树只需要 1 个(不是 n<<2 个)然后对线段树所维护的每个根节点动态分配伸展树节点即可
这道题目的 5 个操作:
- 在查询 [l,r] 对应的 Splay 节点中查询比 k 小的数的个数求和即可(最后结果要 +1 )
- 二分 [l,r] 排名为 k 的值,然后用 1 操作来进行验证
- 在线段树包含此值的所有节点( Splay Tree )上先删除旧值再插入修改值
- 在线段树 [l,r] 中查找前驱并取最大值
- 在线段树 [l,r] 中查找后继并取最小值
AC代码:
#include <bits/stdc++.h>
using namespace std;
const int inf = 2147483647;
const int maxn = 5e4 + 5;
const int maxm = maxn * 25;
int n;
int arr[maxn];
namespace splay_tree {
int rt[maxm], tot;
int fa[maxm], son[maxm][2];
int val[maxm], cnt[maxm];
int sz[maxm];
void Push(int o) {
sz[o] = sz[son[o][0]] + sz[son[o][1]] + cnt[o];
}
bool Get(int o) {
return o == son[fa[o]][1];
}
void Clear(int o) {
son[o][0] = son[o][1] = fa[o] = val[o] = sz[o] = cnt[o] = 0;
}
void Rotate(int o) {
int p = fa[o], q = fa[p], ck = Get(o);
son[p][ck] = son[o][ck ^ 1];
fa[son[o][ck ^ 1]] = p;
son[o][ck ^ 1] = p;
fa[p] = o; fa[o] = q;
if (q) son[q][p == son[q][1]] = o;
Push(p); Push(o);
}
void Splay(int &root, int o) {
for (int f = fa[o]; (f = fa[o]); Rotate(o))
if (fa[f]) Rotate(Get(o) == Get(f) ? f : o);
root = o;
}
void Insert(int &root, int x) {
if (!root) {
val[++tot] = x;
cnt[tot]++;
root = tot;
Push(root);
return;
}
int cur = root, f = 0;
while (true) {
if (val[cur] == x) {
cnt[cur]++;
Push(cur); Push(f);
Splay(root, cur);
break;
}
f = cur;
cur = son[cur][val[cur] < x];
if (!cur) {
val[++tot] = x;
cnt[tot]++;
fa[tot] = f;
son[f][val[f] < x] = tot;
Push(tot); Push(f);
Splay(root, tot);
break;
}
}
}
int GetRank(int &root, int x) {
int ans = 0, cur = root;
while (cur) {
if (x < val[cur]) {
cur = son[cur][0];
continue;
}
ans += sz[son[cur][0]];
if (x == val[cur]) {
Splay(root, cur);
return ans;
}
if (x > val[cur]) {
ans += cnt[cur];
cur = son[cur][1];
}
}
return ans;
}
int GetKth(int &root, int k) {
int cur = root;
while (true) {
if (son[cur][0] && k <= sz[son[cur][0]]) cur = son[cur][0];
else {
k -= cnt[cur] + sz[son[cur][0]];
if (k <= 0) return cur;
cur = son[cur][1];
}
}
}
int Find(int &root, int x) {
int ans = 0, cur = root;
while (cur) {
if (x < val[cur]) {
cur = son[cur][0];
continue;
}
ans += sz[son[cur][0]];
if (x == val[cur]) {
Splay(root, cur);
return ans + 1;
}
ans += cnt[cur];
cur = son[cur][1];
}
}
int GetPrev(int &root) {
int cur = son[root][0];
while (son[cur][1]) cur = son[cur][1];
return cur;
}
int GetPrevVal(int &root, int x) {
int ans = -inf, cur = root;
while (cur) {
if (x > val[cur]) {
ans = max(ans, val[cur]);
cur = son[cur][1];
continue;
}
cur = son[cur][0];
}
return ans;
}
int GetNext(int &root) {
int cur = son[root][1];
while (son[cur][0]) cur = son[cur][0];
return cur;
}
int GetNextVal(int &root, int x) {
int ans = inf, cur = root;
while (cur) {
if (x < val[cur]) {
ans = min(ans, val[cur]);
cur = son[cur][0];
continue;
}
cur = son[cur][1];
}
return ans;
}
void Delete(int &root, int x) {
Find(root, x);
if (cnt[root] > 1) {
cnt[root]--;
Push(root);
return;
}
if (!son[root][0] && !son[root][1]) {
Clear(root);
root = 0;
return;
}
if (!son[root][0]) {
int cur = root;
root = son[root][1];
fa[root] = 0;
Clear(cur);
return;
}
if (!son[root][1]) {
int cur = root;
root = son[root][0];
fa[root] = 0;
Clear(cur);
return;
}
int p = GetPrev(root), cur = root;
Splay(root, p);
fa[son[cur][1]] = p;
son[p][1] = son[cur][1];
Clear(cur);
Push(root);
}
};
namespace seg_tree {
int tree[maxn << 2];
void Build(int o, int l, int r) {
for (int i = l; i <= r; ++i) splay_tree::Insert(tree[o], arr[i - 1]);
if (l == r) return;
int m = (l + r) >> 1;
Build(o << 1, l, m);
Build(o << 1 | 1, m + 1, r);
}
void Modify(int o, int l, int r, int ll, int rr, int u, int v) {
splay_tree::Delete(tree[o], u); splay_tree::Insert(tree[o], v);
if (l == r) return;
int m = (l + r) >> 1;
if (ll <= m) Modify(o << 1, l, m, ll, rr, u, v);
if (rr > m) Modify(o << 1 | 1, m + 1, r, ll, rr, u, v);
}
int QueryRank(int o, int l, int r, int ll, int rr, int v) {
if (ll <= l && rr >= r) return splay_tree::GetRank(tree[o], v);
int m = (l + r) >> 1, ans = 0;
if (ll <= m) ans += QueryRank(o << 1, l, m, ll, rr, v);
if (rr > m) ans += QueryRank(o << 1 | 1, m + 1, r, ll, rr, v);
return ans;
}
int QueryPrev(int o, int l, int r, int ll, int rr, int v) {
if (ll <= l && rr >= r) return splay_tree::GetPrevVal(tree[o], v);
int m = (l + r) >> 1, ans = -inf;
if (ll <= m) ans = max(ans, QueryPrev(o << 1, l, m, ll, rr, v));
if (rr > m) ans = max(ans, QueryPrev(o << 1 | 1, m + 1, r, ll, rr, v));
return ans;
}
int QueryNext(int o, int l, int r, int ll, int rr, int v) {
if (ll <= l && rr >= r) return splay_tree::GetNextVal(tree[o], v);
int m = (l + r) >> 1, ans = inf;
if (ll <= m) ans = min(ans, QueryNext(o << 1, l, m, ll, rr, v));
if (rr > m) ans = min(ans, QueryNext(o << 1 | 1, m + 1, r, ll, rr, v));
return ans;
}
int QueryKth(int ll, int rr, int v) {
int l = 0, r = 1e8 + 10;
while (l < r) {
int m = ((l + r) >> 1) + 1;
if (QueryRank(1, 1, n, ll, rr, m) < v) l = m;
else r = m - 1;
}
return l;
}
};
int main() {
ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
int m; cin >> n >> m;
for (int i = 0; i < n; ++i) cin >> arr[i];
splay_tree::tot = 0;
seg_tree::Build(1, 1, n);
for (int i = 0, op, l, r, pos, k; i < m; ++i) {
cin >> op;
if (op == 1) {
cin >> l >> r >> k;
cout << seg_tree::QueryRank(1, 1, n, l, r, k) + 1 << endl;
}
else if (op == 2) {
cin >> l >> r >> k;
cout << seg_tree::QueryKth(l, r, k) << endl;
}
else if (op == 3) {
cin >> pos >> k;
seg_tree::Modify(1, 1, n, pos, pos, arr[pos - 1], k);
arr[pos - 1] = k;
}
else if (op == 4) {
cin >> l >> r >> k;
cout << seg_tree::QueryPrev(1, 1, n, l, r, k) << endl;
}
else if (op == 5) {
cin >> l >> r >> k;
cout << seg_tree::QueryNext(1, 1, n, l, r, k) << endl;
}
}
return 0;
}