题解 | #游游的整数操作#
游游的整数操作
https://ac.nowcoder.com/acm/contest/61571/C
C题也可以用势能线段树解决:
#include<cstdio>
#include<cstring>
#include<algorithm>
#define ls u << 1
#define rs u << 1 | 1
using namespace std;
typedef long long LL;
const int N = 1e5 + 10, mod = 1e9 + 7;
const LL INF = 0x3f3f3f3f3f3f3f3f;
int a[N];
int n, m;
struct Node
{
LL mi, mi2; //最大值、次大值
int cmx, cmi; //最大值个数、次大值个数
LL tmx, tad; //最大值懒标记、加法懒标记
LL sum; //区间和
}tr[N << 2];
void pushup(int u) //根据子区间更新父区间
{
tr[u].sum = (tr[ls].sum + tr[rs].sum) % mod;
if (tr[ls].mi == tr[rs].mi) //如果左区间最小值 == 右区间最小值
{
tr[u].mi = tr[rs].mi, tr[u].cmi = tr[ls].cmi + tr[rs].cmi; //更新最小值,最小值个数
tr[u].mi2 = min(tr[ls].mi2, tr[rs].mi2); //更新次小值
}
else if (tr[ls].mi < tr[rs].mi) //如果左区间最小值 < 右区间最小值
{
tr[u].mi = tr[ls].mi, tr[u].cmi = tr[ls].cmi;
tr[u].mi2 = min(tr[ls].mi2, tr[rs].mi);
}
else
{
tr[u].mi = tr[rs].mi, tr[u].cmi = tr[rs].cmi;
tr[u].mi2 = min(tr[ls].mi, tr[rs].mi2);
}
}
//加法的优先级 > 取最大值的优先级
void push_add(int u, int l, int r, LL v) //传递加法懒标记
{
//更新加法标记的同时,更新min标记
tr[u].sum = (tr[u].sum + (r - l + 1ll) * v % mod + mod) % mod; //注意取模
//整个区间都加上v,最小值,次小值以及最大值懒标记都要 +v (如果存在的话)
tr[u].mi += v;
if (tr[u].mi2 != INF) tr[u].mi2 += v;
if (tr[u].tmx != -INF) tr[u].tmx += v;
tr[u].tad += v; //更新加法懒标记
}
void push_max(int u, LL tag) //传递最大值懒标记
{
if (tr[u].mi >= tag) return;
tr[u].sum = (tr[u].sum + (1ll * tag - tr[u].mi) * tr[u].cmi % mod + mod) % mod;
tr[u].mi = tag = tr[u].tmx = tag;
}
void pushdown(int u, int l, int r) //更新本区间并向子区间传递懒标记
{
int mid = l + r >> 1;
if (tr[u].tad)
{
push_add(ls, l, mid, tr[u].tad);
push_add(rs, mid + 1, r, tr[u].tad);
}
if (tr[u].tmx != -INF) push_max(ls, tr[u].tmx), push_max(rs, tr[u].tmx);
tr[u].tad = 0, tr[u].tmx = -INF; //清除本区间的懒标记
}
void build(int u = 1, int l = 1, int r = n) //初始化势能线段树
{
tr[u].tmx = -INF; //最大值懒标记初始化为负无穷
if (l == r)
{
tr[u].mi = tr[u].sum = a[r];
tr[u].mi2 = INF; //次小值初始化为正无穷
tr[u].cmx = tr[u].cmi = 1; //个数都初始化为1
return;
}
int mid = l + r >> 1;
build(ls, l, mid), build(rs, mid + 1, r);
pushup(u);
}
void add(int L, int R, LL v, int u = 1, int l = 1, int r = n)
{
if (L <= l && r <= R)
{
push_add(u, l, r, v);
return;
}
pushdown(u, l, r);
int mid = l + r >> 1;
if (L <= mid) add(L, R, v, ls, l, mid);
if (R > mid) add(L, R, v, rs, mid + 1, r);
pushup(u);
}
void to_max(int L, int R, LL v, int u = 1, int l = 1, int r = n)
{
if (tr[u].mi >= v) return; //如果区间最小值 >= v,则无需进行任何操作
if (L <= l && r <= R && tr[u].mi2 > v)
{
push_max(u, v);
return;
}
pushdown(u, l, r);
int mid = l + r >> 1;
if (L <= mid) to_max(L, R, v, ls, l, mid);
if (R > mid) to_max(L, R, v, rs, mid + 1, r);
pushup(u);
}
LL query_sum(int L, int R, int u = 1, int l = 1, int r = n)
{
if (L <= l && r <= R) return tr[u].sum % mod;
pushdown(u, l, r);
LL res = 0;
int mid = l + r >> 1;
if (L <= mid) res = query_sum(L, R, ls, l, mid);
if (R > mid) res = (res + query_sum(L, R, rs, mid + 1, r) + mod) % mod;
return res;
}
void solve()
{
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
build();
while (m -- )
{
int opt, x;
scanf("%d%d", &opt, &x);
if (opt == 1) add(1, n, x);
else if (opt == 2) add(1, n, -x), to_max(1, n, 0);
}
printf("%lld\n", query_sum(1, n));
}
int main()
{
int T = 1;
// scanf("%d", &T);
while (T -- ) solve();
return 0;
}