Rotate Columns(CodeForces Round #584)(状压DP)
Rotate Columns
题意:给定一个矩阵,可以对矩阵的任意列进行上下滑动(或称旋转),使最大化每一行的最大值 之和。
Easy version: n≤4, m≤100
Hard version: n≤12, m≤2000
思路:直接讲 Hard吧(后面附上 Easy的解法)
- 按列进行 dp,目标是得到 dp[(1<<n)−1]
- 假设当前枚举到第 j列,则考虑当前列所有可能的状态(状态被定义为当前列的哪些行是这一行的最大值,状压)的贡献,而每一种状态会被更新 n次,因为要考虑当前列的滑动(或称旋转)
- 得到每一列所有状态的贡献值后,用其与之前 j−1列所得到的状态互补更新(妙呀!)
- 而这样对于每一组数据( 40组)得到的时间复杂度是: O(m∗(n2∗2n+4n)),显然是过不了的
- 进一步思考:对 n行的最大值有贡献的列数肯定是不超过 n的,并且如果按照每一列的最大值对列进行排序,是不是第 n列以后的列就不会对答案有贡献呢?如果后面某一列对答案有贡献,那说明前 n列至少有一列对答案没有贡献(毕竟只有 n行),那完全可以用前面那一列的最大值来代替这一行呀(旋转一下就行了)。
- 因此,如果 m>n,则我们可以将 m行排序后只使用前 n行,这样的时间复杂度降为: O(n∗(n2∗2n+4n))
题面描述
Hard
#include "bits/stdc++.h"
#define hhh printf("hhh\n")
#define see(x) (cerr<<(#x)<<'='<<(x)<<endl)
using namespace std;
typedef long long ll;
typedef pair<int,int> pr;
inline int read() {int x=0;char c=getchar();while(c<'0'||c>'9')c=getchar();while(c>='0'&&c<='9')x=x*10+c-'0',c=getchar();return x;}
const int maxn = 1e5+10;
const int inf = 0x3f3f3f3f;
const int mod = 1e9+7;
const double eps = 1e-7;
int n, m;
int a[12][2000], dp[4096], odp[4096], mx[2000], dd[4096];
int id[2000];
bool cmp(int a, int b) { return mx[a]>mx[b]; }
void solve() {
memset(dp,0,sizeof(dp));
memset(mx,0,sizeof(mx));
memset(dd,0,sizeof(dd));
n=read(), m=read();
for(int i=0; i<m; ++i) id[i]=i;
for(int i=0; i<n; ++i) for(int j=0; j<m; ++j)
if((a[i][j]=read())>mx[j]) mx[j]=a[i][j];
sort(id,id+m,cmp);
if(m>n) m=n;
for(int j=0; j<m; ++j) {
memcpy(odp,dp,sizeof(dp));
memset(dd,0,sizeof(dd));
for(int s=0; s<1<<n; ++s) {
for(int times=0; times<n; ++times) { //旋转次数
int d=0;
for(int i=0; i<n; ++i)
if(s>>i&1) d+=a[(i+times)%n][id[j]];
dd[s]=max(dd[s],d); //记录这个状态的贡献值
}
for(int ss=s; ss<1<<n; ss=(ss+1)|s)
dp[ss]=max(dp[ss],odp[ss^s]+dd[s]); //与前面的状态互补更新
}
}
printf("%d\n", dp[(1<<n)-1]);
}
int main() {
//ios::sync_with_stdio(false); cin.tie(0);
int T=read();
while(T--) solve();
}
Easy(排序优化后为 O(n2∗4n))(这个写法复杂度差一点,没卡过去)
#include "bits/stdc++.h"
#define hhh printf("hhh\n")
#define see(x) (cerr<<(#x)<<'='<<(x)<<endl)
using namespace std;
typedef long long ll;
typedef pair<int,int> pr;
inline int read() {int x=0;char c=getchar();while(c<'0'||c>'9')c=getchar();while(c>='0'&&c<='9')x=x*10+c-'0',c=getchar();return x;}
const int maxn = 1e5+10;
const int inf = 0x3f3f3f3f;
const int mod = 1e9+7;
const double eps = 1e-7;
int n, m;
int a[4][100], dp[100][16];
void solve() {
memset(dp,0,sizeof(dp));
n=read(), m=read();
for(int i=0; i<n; ++i) for(int j=0; j<m; ++j) a[i][j]=read();
for(int j=0; j<m; ++j) { //这里并没有进行排序处理,复杂度为O(m*n*4^n)
for(int s=0; s<1<<n; ++s) {
int d=0, ns=s;
for(int i=0; i<n; ++i) if(s>>i&1) d+=a[i][j];
for(int times=0; times<n; ++times) {
for(int i=0; i<1<<n; ++i)
if(!(i&ns)) dp[j][ns|i]=max(dp[j][ns|i],d+(j?dp[j-1][i]:0));
ns=(ns<<1)+((ns&1<<(n-1))?1-(1<<n):0); //旋转的另外一种写法,稍麻烦
}
}
}
printf("%d\n", dp[m-1][(1<<n)-1]);
}
int main() {
//ios::sync_with_stdio(false); cin.tie(0);
int T=read();
while(T--) solve();
}