矩阵取数游戏 | 区间dp
矩阵取数游戏
https://ac.nowcoder.com/acm/problem/16645
链接:https://ac.nowcoder.com/acm/problem/16645
来源:牛客网
题目描述
帅帅经常跟同学玩一个矩阵取数游戏:对于一个给定的n*m的矩阵,矩阵中的每个元素aij均为非负整数。游戏规则如下:
1.每次取数时须从每行各取走一个元素,共n个。m次后取完矩阵所有元素;
2.每次取走的各个元素只能是该元素所在行的行首或行尾;
3.每次取数都有一个得分值,为每行取数的得分之和,每行取数的得分 = 被取走的元素值 * 2i,其中i表示第i次取数(从1开始编号);
4.游戏结束总得分为m次取数得分之和。
1.每次取数时须从每行各取走一个元素,共n个。m次后取完矩阵所有元素;
2.每次取走的各个元素只能是该元素所在行的行首或行尾;
3.每次取数都有一个得分值,为每行取数的得分之和,每行取数的得分 = 被取走的元素值 * 2i,其中i表示第i次取数(从1开始编号);
4.游戏结束总得分为m次取数得分之和。
帅帅想请你帮忙写一个程序,对于任意矩阵,可以求出取数后的最大得分。
题目思路:
经典区间dp,考虑初状态不确定,从末状态开始(最后一定是取了单独的最后一个数)
之后对每一行进行区间dp就可以了
/*** keep hungry and calm CoolGuang!***/ #pragma GCC optimize(3) #include <bits/stdc++.h> #include<stdio.h> #include<queue> #include<algorithm> #include<string.h> #include<iostream> #define debug(x) cout<<#x<<":"<<x<<endl; #define ls k<<1 #define rs k<<1|1 #define _CRT_SECURE_NO_WARNINGS #pragma GCC optimize("Ofast","unroll-loops","omit-frame-pointer","inline") #pragma GCC option("arch=native","tune=native","no-zero-upper") #pragma GCC target("avx2") using namespace std; typedef long long ll; typedef unsigned long long ull; typedef pair<int,int> pp; const ll INF=1e17; const int maxn=2e5+6; const int mod=998244353; const double eps=1e-3; inline bool read(ll &num) {char in;bool IsN=false; in=getchar();if(in==EOF) return false;while(in!='-'&&(in<'0'||in>'9')) in=getchar();if(in=='-'){ IsN=true;num=0;}else num=in-'0';while(in=getchar(),in>='0'&&in<='9'){num*=10,num+=in-'0';}if(IsN) num=-num;return true;} ll n,m,p; __int128 read(){ __int128 x=0,f=1; char ch=getchar(); while(!isdigit(ch)&&ch!='-')ch=getchar(); if(ch=='-')f=-1; while(isdigit(ch))x=x*10+ch-'0',ch=getchar(); return f*x; } void print(__int128 x){ if(x<0)putchar('-'),x=-x; if(x>9)print(x/10); putchar(x%10+'0'); } __int128 a[1005][1005],dp[1005][1005],f[1005]; int main(){ read(n);read(m); f[0] = 1; for(int i=1;i<=m;i++) f[i] = f[i-1]*2; for(int i=1;i<=n;i++) for(int k=1;k<=m;k++) a[i][k] = read(); __int128 ans = 0; for(int i=1;i<=n;i++){ for(int k=1;k<=m;k++) dp[k][k] = a[i][k]*f[m]; for(int len=1;len<=m;len++){ for(int s=1;s+len<=m;s++){ int t = s+len; dp[s][t] = max(dp[s+1][t]+a[i][s]*f[m-len],dp[s][t-1]+a[i][t]*f[m-len]); } } ans += dp[1][m]; } print(ans); return 0; } /** 10 10 1 2 3 4 5 6 7 8 7 8 7 8 7 8 7 8 7 8 7 8 **/