牛客练习赛59 E 石子搬运 dp+三分法
有n堆石子,第i堆石子的石子数量是ai{a_{i}}ai,作为牛客网的一头领头牛,牛牛决定把这些石子搬回牛客。如果牛牛一次搬运的石子数量是k{k}k,那么这堆石子将对牛牛产生k2{k^{2}}k2的负担值。牛牛最多只能搬运m{m}m次,每次搬运可以从一堆石子中选出一些石子搬回牛客,每次搬运不能同时从两堆石子中选取石子,每次只能搬运整数个石子。牛牛是一只聪明的牛,他想出了一种搬运计划可以最小化他搬运完这些石子的负担值的总和,但是突然牛牛的死敌牛能出现了,牛能每次可以施展以下的魔法:
x v 将第x堆石子的数量变为v
这打乱了牛牛的计划,每次牛能施展一次魔法,牛牛就得重新规划他的搬运方案,但是牛能施展魔法的次数太多了,牛牛根本忙活不过来了,于是他请来了聪明的你帮他写一个程序计算。
分析一下问题,由题意我们可以知道,每个石块堆是相互独立的。也就是每一个石块堆拿走几次的最小权值都是固定的,可以算出来的,(就最平均的样子去拿就可以获得最小消耗),并且堆与堆之间是独立的就是可以无视堆的选取顺序,我们直接从1-n去考虑也完全可以不影响最终结果。
先来思考如果不改变石堆的数值怎么算答案。
我们这么设计dp
f[i][j],前i个堆消耗了j次的最小花费。
g[i][j], 第i个堆消耗j次的最小花费,从上一段可以知道这个g[i][j]是可以预处理出来的。
转移的话很明显就是
f[i][j] = min(f[i - 1][k] + g[i][j - k]);
但是这样子的转移是一个n^3,如果对于每一个修改我们都暴力做一次 n^3 的dp复杂度就是O(q*n^3),肯定超时。
我们从转移的过程入手,考虑对于f[i-1][k] 在k从小走到大的时候,由f[i-1][k] 和g[i][j - k]得到的最终权值也就是从大走到小在走到大,那么这先递减在递增的总权值就可以去做一个三分。
这样的复杂度就是变成了O(q * n ^ 2 * log n)可以通过了。
但是这个三分的思路并不会证明,但是的的确确是过的去的。然后我们对于每一个询问重新计算一次g[i][j],在暴力跑一次dp就好了。由于复杂度看起来可以过但是蛮危险的所以要尽量有点小优化。卡卡常。就k的枚举是有范围的,f[i][j]的j也有范围的,大于某个数就也没有继续做下去的必要。
通过的链接在牛客的通过代码
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<queue>
#include<time.h>
#include<string>
#include<cmath>
#include<stack>
#include<map>
#include<set>
#define int long long
//#define double long double
using namespace std;
#define PI 3.1415926535898
#define eqs 1e-6
const long long max_ = 400 + 7;
int mod = 998244353;
const int inf = 1e9;
const long long INF = 1e18;
int read() {
int s = 0, f = 1;
char ch = getchar();
while (ch<'0' || ch>'9') {
if (ch == '-')
f = -1;
ch = getchar();
}
while (ch >= '0'&&ch <= '9') {
s = s * 10 + ch - '0';
ch = getchar();
}
return s * f;
}
inline int min(int a, int b) {
return a < b ? a : b;
}
inline int max(int a, int b) {
return a > b ? a : b;
}
int f[max_][max_], g[max_][max_], node[max_], n, q, m, sum[max_];
void change(int i) {
g[i][1] = node[i] * node[i];
for (int j = 2; j <= min(node[i], m); j++) {
if (node[i] % j) {
int num = node[i] % j;
g[i][j] = (node[i] / j)*(node[i] / j)*(j - num)
+ ((node[i] / j) + 1)* ((node[i] / j) + 1)*num;
}
else g[i][j] = (node[i] / j)*(node[i] / j)*j;
}
}
signed main() {
n = read(), m = read();
for (int i = 1; i <= n; i++) {
node[i] = read();
}
for (int i = 1; i <= n; i++) change(i);
/*for (int i = 1; i <= n; i++) { for (int j = 1; j <= min(node[i], m); j++) { cout << g[i][j] << " "; } cout << endl; }*/
q = read();
while (q--) {
int a = read(), b = read();
node[a] = b;
change(a);
for (int i = 1; i <= n; i++)sum[i] = sum[i - 1] + node[i];
// memset(f, 127, sizeof(f));
int tt = min(node[1], m - (n - 1));
for (int i = 1; i <= tt; i++) {
f[1][i] = g[1][i];
}
for (int i = 2; i <= n; i++) {
tt = min(sum[i], m - (n - i));
for (int j = i; j <= tt; j++) {
//前i堆处理完了后搬了j次的最小权值
//前i-1堆可以搬的次数是[i-1, min( j - 1,sum[i - 1] )];
//设第i-1堆搬的次数为A
//则第i堆搬的次数是j - A;
int L = i - 1, R = min(j - 1, sum[i - 1]);//A的取值范围
while (L < R) {
int mid = (L + R) >> 1, t1 = mid + 1;
//mid 与 t1去对比
int vmid = f[i - 1][mid] + g[i][j - mid],
vt1 = f[i - 1][t1] + g[i][j - t1];
if (vmid >= vt1)L = mid + 1;//递减舍弃左边
else R = mid;
}
f[i][j] = f[i - 1][R] + g[i][j - R];
}
}
printf("%lld\n", f[n][min(sum[n], m)]);
}
return 0;
}