首页 > 试题广场 >

小美的蛋糕切割

[编程题]小美的蛋糕切割
  • 热度指数:2177 时间限制:C/C++ 1秒,其他语言2秒 空间限制:C/C++ 256M,其他语言512M
  • 算法知识视频讲解
小美有一个矩形的蛋糕,共分成了 nm 列,共 n \times m 个区域,每个区域是一个小正方形,已知蛋糕每个区域都有一个美味度。她想切一刀把蛋糕切成两部分,自己吃一部分,小团吃另一部分。

小美希望两个人吃的部分的美味度之和尽可能接近,请你输出|s_1-s_2|的最小值。(其中s_1代表小美吃的美味度,s_2代表小团吃的美味度)。

请务必保证,切下来的区域都是完整的,即不能把某个小正方形切成两个小区域。



输入描述:
第一行输出两个正整数 nm ,代表蛋糕区域的行数和列数。
接下来的 n 行,每行输入 m 个正整数 a_{ij} ,用来表示每个区域的美味度。
1\leq n,m \leq 10^3
1\leq a_{ij} \leq 10^4


输出描述:
一个整数,代表 |s_1-s_2| 的最小值。
示例1

输入

2 3
1 1 4
5 1 4

输出

0

说明

把蛋糕像这样切开:
1 1 | 4
5 1 | 4
左边蛋糕美味度之和是8
右边蛋糕美味度之和是8
所以答案是0。

import java.util.Scanner;

// 注意类名必须为 Main, 不要有任何 package xxx 信息
public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int m = sc.nextInt();
        int[][] arr = new int[n][m];

        long[] sumRow = new long[n]; // 每行所有元素相加的和
        long[] sumCol = new long[m]; // 每列所有元素相加的和

        for (int i = 0; i < n; i ++) {
            for (int j = 0; j < m; j ++) {
                arr[i][j] = sc.nextInt();
                sumRow[i] += arr[i][j];
                sumCol[j] += arr[i][j];
            }
        }

        if (m == 1 && n == 1) {
            System.out.println(arr[0][0]);
            return;
        }

        /*
        因为只能一刀,且不能切开每个元素,所以只能横一刀或者竖一刀

        以行的最小差值为例,可以把分成上下两部分,分别从最上面和最小面开始累加,哪个小,就让哪个多加一次
         */
        // 先计算行最小差值绝对值
        int i = 0;
        int j = n - 1;
        long sumUp = sumRow[i];
        long sumDown = sumRow[j];
        while (i < j - 1) {

            if (sumUp > sumDown) {
                j--;
                sumDown += sumRow[j];
            } else {
                i++;
                sumUp += sumRow[i];
            }
        }

        // 再计算列最小差值绝对值
        i = 0;
        j = m - 1;
        long sumLeft = sumCol[i];
        long sumRight = sumCol[j];
        while (i < j - 1) {
            if (sumLeft > sumRight) {
                j --;
                sumRight += sumCol[j];
            } else {
                i ++;
                sumLeft += sumCol[i];
            }
        }

        System.out.println(Math.min(Math.abs(sumDown - sumUp), Math.abs(sumLeft - sumRight)));
    }
}


发表于 2023-08-27 16:33:01 回复(0)
#include <iostream>
#include <vector>

using namespace std;


long long fun (vector<vector<int>>& arr, int n, int m) {
    // 求二维矩阵前缀和 之后的每一刀相当于求每个区间的美味度
    vector<vector<long long>> arr_b(n + 1, vector<long long>(m + 1, 0));
    for (int i = 1; i <= n; i++) {
        for (int j = 1; j <= m; j++) {
            arr_b[i][j] = arr_b[i - 1][j] + arr_b[i][j - 1] - arr_b[i - 1][j - 1] + arr[i][j];
        }
    }

    long long s1, s2;
    long long min_result = arr_b[n][m];
    // 横切
    for (int i = 1; i < n; i++) {
        s1 = arr_b[i][m];
        s2 = arr_b[n][m] - s1;
        if (abs(s1 - s2) < min_result) min_result = abs(s1 - s2);
    }

    // 竖切
    for (int j = 1; j < m; j++) {
        s1 = arr_b[n][j];
        s2 = arr_b[n][m] - s1;
        if (abs(s1 - s2) < min_result) min_result = abs(s1 - s2);
    }

    return min_result;
}

int main () {

    int n, m;

    scanf("%d %d", &n, &m);

    vector<vector<int>> arr(n + 1, vector<int>(m + 1, 0));

    for (int i = 1; i <= n; i++) {
        for (int j = 1; j <= m; j++) {
            scanf("%d", &arr[i][j]);
        }
    }

    long long result = fun(arr, n, m);

    printf("%lld", result);

    return 0;
}
发表于 2023-08-18 14:58:15 回复(0)
import java.util.Scanner;
public class Main {
    public static void main(String[] args) {
        Scanner in = new Scanner(System.in);
        // 注意 hasNext 和 hasNextLine 的区别
        int n=in.nextInt();
        long[]row=new long[n];
        int m=in.nextInt();
        long[]col=new long[m];
        int[][]matrix=new int[n][m];
        for (int i = 0; i < n; i++) {
            if (i>0) row[i]=row[i-1];
            for (int j = 0; j < m; j++) {
                matrix[i][j]=in.nextInt();
                row[i]+=matrix[i][j];
            }
        }
        long sum=row[n-1];
        for (int i = 0; i < m; i++) {
            if (i>0) col[i]=col[i-1];
            for (int j = 0; j < n; j++) {
                col[i]+=matrix[j][i];
            }
        }
        long ans=Long.MAX_VALUE;
        for (long a:row){
            ans=Math.min(ans,Math.abs(sum-a*2));
        }
        for (long a:col){
            ans=Math.min(ans,Math.abs(sum-a*2));
        }
        System.out.println(ans);
    }
}

编辑于 2024-03-21 12:28:03 回复(0)
114514 1919 810
答案是0
发表于 2023-08-25 14:15:03 回复(0)
方法1:二分法
#include <iostream>
#include <bits/stdc++.h>
using namespace std;

int main() {
    ios::sync_with_stdio(false);
    int n, m, tmp;
    cin >> n >> m;
    vector<vector<int64_t>> grid(n, vector<int64_t>(m));
    int64_t sum_of_all = 0;
    for (int i = 0; i < n; ++i) {
        for (int j = 0; j < m; ++j) {
            cin >> tmp;
            grid[i][j] = tmp;
            sum_of_all += tmp;
        }
    }

    vector<int64_t> col_sum(m), row_sum(n);
    for (int i = 0; i < n; ++i) {
        auto& row = grid[i];
        row_sum[i] = std::accumulate(row.begin(), row.end(), 0LL);
    }
    for (int j = 0; j < m; ++j) {
        int64_t local_sum = 0;
        for (int i = 0; i < n; ++i) {
            local_sum += grid[i][j];
        }
        col_sum[j] = local_sum;
    }
    // 目标是找到最接近target的

    int64_t ans = INT64_MAX;

    // 首先竖着切分
    {
        auto get_left_sum = [&] (int mid) {
            int64_t left_sum = 0;
            for (int i = 0; i < mid; ++i) {
                left_sum += col_sum[i];
            }
            return left_sum;
        };
        int left = 0, right = m - 1;
        while (left <= right) {
            int mid = left + (right - left) / 2;
            int64_t left_sum = get_left_sum(mid);
            ans = min(ans, abs(sum_of_all - 2 * left_sum));
            if (left_sum * 2 == sum_of_all) {
                cout << 0 << endl;
                return 0;
            } else if (left_sum * 2 > sum_of_all) {
                right = mid - 1;
            } else {
                left = mid + 1;
            }
        }
    }

    if (ans == 0) {
        cout << 0 << endl;
        return 0;
    }

    // 然后横着切分
    {
        auto get_up_sum = [&] (int mid) {
            int64_t up_sum = 0;
            for (int i = 0; i < mid; ++i) {
                up_sum += row_sum[i];
            }
            return up_sum;
        };
        int up = 0, down = n - 1;
        while (up <= down) {
            int mid = up + (down - up) / 2;
            int64_t up_sum = get_up_sum(mid);
            ans = min(ans, abs(sum_of_all - 2 * up_sum));
            if (up_sum * 2 == sum_of_all) {
                cout << 0 << endl;
                return 0;
            } else if (up_sum * 2 > sum_of_all) {
                down = mid - 1;
            } else {
                up = mid + 1;
            }
        }
    }

    cout << ans << endl;
    return 0;
}
方法2:前缀和
#include <iostream>
#include <bits/stdc++.h>
using namespace std;
 
int main()
{
    ios::sync_with_stdio(false);
    int n, m;
    int64_t sum = 0;
    cin >> n >> m;
    vector<vector<int64_t>> grid(n, vector<int64_t>(m));
    vector<vector<int64_t>> presum(n, vector<int64_t>(m));
 
    for (int i = 0; i < n; ++i)
        for (int j = 0; j < m; ++j)
        {
            cin >> grid[i][j];
            sum += grid[i][j];
        }
 
    for (int i = 0; i < n; ++i)
        for (int j = 0; j < m; ++j)
        {
            if (i == 0 && j == 0)
                presum[0][0] = grid[0][0];
            else if (i == 0)
                presum[0][j] = grid[0][j] + presum[0][j - 1];
            else if (j == 0)
                presum[i][0] = grid[i][0] + presum[i - 1][0];
            else
                presum[i][j] = grid[i][j] + presum[i][j - 1] + presum[i - 1][j] - presum[i - 1][j - 1];
        }
 
    assert(presum[n - 1][m - 1] == sum);
 
    int64_t ans = INT64_MAX;
    for (int i = 0; i < n; ++i)
        ans = min<int64_t>(ans, abs(sum - 2 * presum[i][m - 1]));
    for (int j = 0; j < m; ++j)
        ans = min<int64_t>(ans, abs(sum - 2 * presum[n - 1][j]));
 
    cout << ans << endl;
}



发表于 2023-08-19 04:57:29 回复(0)
h, w = [int(n) for n in input().split()]
m = []
for i in range(h):
    m.append([int(n) for n in input().split()])

h_sum = [sum([m_[i] for m_ in m]) for i in range(w)]
w_sum = [sum(m[i]) for i in range(h)]
max = sum(w_sum)

for i in range(1, w):
    part1_sum = sum(h_sum[:i])
    part2_sum = sum(h_sum[i:])
    sub = abs(part1_sum - part2_sum)
    if sub < max:
        max = sub
    
for i in range(1, h):
    part1_sum = sum(w_sum[:i])
    part2_sum = sum(w_sum[i:])
    sub = abs(part1_sum - part2_sum)
    if sub < max:
        max = sub

print(max)

编辑于 2024-04-21 07:58:27 回复(0)

一开始想用golang结果超时,同样的方法放到C++就通过了……还得是C++

用C++需要注意:输入的数据用int确实没问题,但求和的时候如果还用int的话会溢出导致WA,建议除了代表地址的n,m,i,j之外无脑long long。

#include <iostream>
#include <algorithm>

long long matrix[1000][1000];
long long sumByLine[1000];
long long sumByColumn[1000];
long long prefixByLine[1000];
long long prefixByColumn[1000];
long long distanceByLine[1000];
long long distanceByColumn[1000];

int main(int argc, char* argv[]){
    std::cin.tie(nullptr); std::cout.tie(nullptr);std::ios::sync_with_stdio(false);
    int n, m;
    std::cin >> n >> m;
    for(int i = 0; i < n; i++){
        for(int j = 0; j < m; j++){
            std::cin >> matrix[i][j];
        }
    }
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < m; j++) {
            sumByLine[i] += matrix[i][j];
            sumByColumn[j] += matrix[i][j];
        }
    }
    prefixByLine[0] = sumByLine[0];
    prefixByColumn[0] = sumByColumn[0];
    for (int i = 1; i < n; i++) {
        prefixByLine[i] = prefixByLine[i-1] + sumByLine[i];
    }
    for (int i = 1; i < m; i++) {
        prefixByColumn[i] = prefixByColumn[i-1] + sumByColumn[i];
    }

    for (int i = 0; i < n; i++) {
        distanceByLine[i] = std::abs(prefixByLine[n-1] - 2*prefixByLine[i]);
    }
    for (int i = 0; i < m; i++) {
        distanceByColumn[i] = std::abs(prefixByColumn[m-1] - 2*prefixByColumn[i]);
    }
    long long minDistanceByLine = *(std::min_element(distanceByLine, distanceByLine+n));
    long long minDistanceByColumn = *(std::min_element(distanceByColumn, distanceByColumn+m));
    std::cout << std::min(minDistanceByLine, minDistanceByColumn);
    return 0;
}
编辑于 2024-04-13 14:36:46 回复(0)
import java.util.Scanner;

import static java.lang.Math.abs;

public class Main {
    public static  void main(String[] args) {
        Scanner sc=new Scanner(System.in);
        int x=sc.nextInt();
        int y=sc.nextInt();
        long[][] arr = new long[x][y];
        long sum=0;
        long []line=new long[x];
        long []row=new long[y];
       
        for (int i = 0; i < x; i++) {
            for (int j = 0; j < y; j++) {
                arr[i][j]=sc.nextInt();
                sum+=arr[i][j];
            }
        }
        for (int a = 0; a < x; a++) {
            line[a]=0;
            for (int b = 0; b < y; b++) {
                line[a]+=arr[a][b];
            }


        }
        for (int c = 0; c < y; c++) {
            row[c] = 0;
            for (int d = 0; d < x; d++) {
                row[c] += arr[d][c];
            }
        }
        long lines=line[0];
        long rows=row[0];
        long linemin=abs(sum-2*line[0]);
        long rowmin=abs(sum-2*row[0]);
        for (int p = 1; p < x-1; p++) {
            lines+=line[p];
            if(abs(sum-2*lines)<linemin)linemin=abs(sum-2*lines);
        }
        for(int q=1;q<y-1;q++){
            rows+=row[q];
            if(abs(sum-2*rows)<rowmin)rowmin=abs(sum-2*rows);
        }
        System.out.println(linemin<=rowmin?linemin:rowmin);
    }
}

编辑于 2024-03-31 15:30:50 回复(0)
import java.util.Scanner;
 
// 注意类名必须为 Main, 不要有任何 package xxx 信息
public class Main {
    public static void main(String[] args) {
        Scanner in = new Scanner(System.in);
        int n = in.nextInt();
        int m = in.nextInt();
        long[][] w = new long[n][m];
        // 注意 hasNext 和 hasNextLine 的区别
        for(int i = 0; i < n; i++)
        {
            for(int j = 0; j < m; j++)
            {
                w[i][j] = in.nextLong();
                if(i == 0 && j == 0) continue;
                if(i == 0) {w[i][j] += w[i][j - 1];continue;}
                if(j == 0) {w[i][j] += w[i - 1][j];continue;}
                w[i][j] += w[i - 1][j] + w[i][j - 1] - w[i - 1][j - 1];
            }
        }
        long s1 = 0;
        long s2 = 0;
        long res = Long.MAX_VALUE;
        for(int i = 0; i < n ; i++)
        {
            s1 = w[i][m-1];
            s2 = w[n-1][m-1] - s1;
            res =Math.min(res,Math.abs(s1-s2));
        }
        for(int i = 0; i < m ; i++)
        {
            s1 = w[n-1][i];
            s2 = w[n-1][m-1] - s1;
            res =Math.min(res,Math.abs(s1-s2));
        }
        System.out.println(res);
        return ;
    }
}

前缀和

编辑于 2024-03-04 16:28:11 回复(0)
import java.util.Scanner;

// 注意类名必须为 Main, 不要有任何 package xxx 信息
public class Main {
    public static void main(String[] args) {
        Scanner in = new Scanner(System.in);
        // 注意 hasNext 和 hasNextLine 的区别
        int n = in.nextInt();
        int m = in.nextInt();
        long[][] arr = new long[n][m];
        long[] hang = new long[n];
        long[] lie = new long[m];
        long sum =0 ;
        for (int i = 0 ; i < n; i++) {
            for (int j = 0 ; j < m; j++) {
                arr[i][j] = in.nextLong();
                hang[i] += arr[i][j];
                lie[j] += arr[i][j];
                sum+= arr[i][j];
            }
        }
        if (m == 1 && n == 1) {
            System.out.println(arr[0][0]);
            return;
        }
        long res = Long.MAX_VALUE;

        for (int dao = 1; dao < Math.max(m, n); dao++) {
            long shang = 0;
            long xia = 0;
            long zuo = 0;
            long you = 0;

            if (dao < n) {
                for(int i =0;i<dao;i++){
                    shang += hang[i];
                }
                xia = sum-shang;
            }
            if (dao < m) {
                for(int i =0;i<dao;i++){
                    zuo += lie[i];
                }
                you = sum-zuo;
            }
            if (dao < n) {
                res = Math.min(res, Math.abs(shang - xia));
            }
            if (dao < m) {
                res = Math.min(res, Math.abs(zuo - you));
            }
        }

        System.out.println(res);
    }
}
编辑于 2024-01-11 14:53:01 回复(0)
有没有大佬帮我看下为啥我这个不行?
#include <stdio.h>


long long Min_arr(long long arr[], int n)
{

    long long MID = arr[n - 1] / 2;
    int left = 0;
    int right = n - 2;
    int mid = 0;

    long long min_arr = 0;
    int i=0;
      while(arr[i]<MID)
      {
        i++;
      }
   if(i==0)
   {
    min_arr=2*arr[i]-arr[n-1];
   }
   else{
    min_arr=(arr[n-1]-2*arr[i-1]<2*arr[i]-arr[n-1]?arr[n-1]-2*arr[i-1]:2*arr[i]-arr[n-1]);
   }

    return min_arr;
}


int main() {

    int n = 0;
    int m = 0;
    scanf("%d %d", &n, &m);
    int arr[2][22];
    long long ROW[2]; //记录每行的叠加和
    long long COL[22]; //记录每列的叠加和
    int k = 0;
    for (k = 0; k < m; k++)
    {
        COL[k] = 0;
    }
    for (k = 0; k < n; k++)
    {
        ROW[k] = 0;
    }
    int i = 0;
    long long row = 0;
    for (i = 0; i < n; i++)
    {
        int j = 0;
        for (j = 0; j < m; j++)
        {
            scanf("%d", &arr[i][j]);
            row += arr[i][j];
            COL[j] += arr[i][j];
        }
        ROW[i] += row;
    }
    long long s = 0;
    for (k = 0; k < m; k++)
    {
        COL[k] += s;
        s = COL[k];
    }
    long long  min_ROW = 0;
    long long  min_COL = 0;
    //按行寻找 |s_1-s_2| 的最小值;
    min_ROW = Min_arr(ROW, n);
    //按列寻找 |s_1-s_2| 的最小值;
    min_COL = Min_arr(COL, m);
    printf("%lld", min_ROW < min_COL ? min_ROW : min_COL);
    return 0;
}

发表于 2023-11-09 13:41:56 回复(0)

二维数组前缀和

import java.util.Scanner;

public class Main {
    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in);
        int h = scanner.nextInt(), w = scanner.nextInt();
        int[][] cake = new int[h][w];
        long[][] presum = new long[h][w];

        for (int i = 0; i < h; i++) {
            for (int j = 0; j < w; j++) {
                cake[i][j] = scanner.nextInt();

                if (i == 0 && j == 0) {
                    presum[i][j] = cake[i][j];
                } else if (j == 0) {
                    presum[i][j] = presum[i - 1][0] + cake[i][j];
                } else if (i == 0) {
                    presum[i][j] = presum[i][j - 1] + cake[i][j];
                } else {
                    presum[i][j] = presum[i - 1][j] + presum[i][j - 1] - presum[i - 1][j - 1] + cake[i][j];
                }
            }
        }

        long res = Integer.MAX_VALUE;

        for (int i = 0; i < h; i++) {
            long top = presum[i][w - 1];
            long bottom = presum[h - 1][w - 1] - top;
            res = Math.min(res, Math.abs(top - bottom));
        }
        for (int j = 0; j < w; j++) {
            long left = presum[h - 1][j];
            long right = presum[h - 1][w - 1] - left;
            res = Math.min(res, Math.abs(left - right));
        }
        System.out.println(res);
    }
}
发表于 2023-10-26 13:41:41 回复(0)
import java.util.Scanner;

// 注意类名必须为 Main, 不要有任何 package xxx 信息
public class Main {
    public static void main(String[] args) {
        Scanner in = new Scanner(System.in);
        // 注意 hasNext 和 hasNextLine 的区别
        // while (in.hasNextInt()) { // 注意 while 处理多个 case
        //     int a = in.nextInt();
        //     int b = in.nextInt();
        //     System.out.println(a + b);
        // }
        int n = in.nextInt();
        int m = in.nextInt();
        int [][]nums = new int[n][m];
        long sum = 0;//总和
        long row[] = new long[n];//统计每行的和
        long col[] = new long[m];//统计每列的和
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < m; j++) {
                nums[i][j] = in.nextInt();
                sum += nums[i][j];
                row[i] += nums[i][j];
                col[j] += nums[i][j];
            }
        }
        long sum1 = 0;
        long min1 = 0;
        long sum1_0 = 0;
        long min1_0 = 0;
        for (int i = 0; i < n; i++) {
            sum1 += row[i];
            if (sum1 >= sum / 2) {
                min1 = Math.abs((sum - 2 * sum1));
                sum1_0 = sum1-row[i];
                min1_0 = Math.abs((sum - 2 * sum1_0));
                min1 = Math.min(min1,min1_0);
                break;
            }
        }
        long sum2 = 0;
        long min2 = 0;
        long sum2_0 = 0;
        long min2_0 = 0;
        for (int i = 0; i < m; i++) {
            sum2 += col[i];
            if (sum2 >= sum / 2) {
                min2 = Math.abs((sum - 2 * sum2));
                sum2_0 = sum2-col[i];
                min2_0 = Math.abs((sum - 2 * sum2_0));
                min2 = Math.min(min2,min2_0);
                break;
            }
        }
        System.out.println(Math.min(min1,min2));
    }
}
发表于 2023-10-09 15:40:28 回复(0)
#include <iostream>
#include <bits/stdc++.h>
using namespace std;

int main() {
    ios::sync_with_stdio(false);
    int n, m,tmp;
    long res = INT_MAX;
    cin>>n>>m;

    vector<vector<long>> sum(n+1,vector<long>(m+1,0));
    
    for(int i=1;i<=n;++i){   
       for(int j=1;j<=m;++j){
           cin>>tmp;
           sum[i][j] = tmp+sum[i-1][j]+sum[i][j-1]-sum[i-1][j-1];       
       }
    }
    for(int j=1;j<=m;++j) res = min(res,abs(sum[n][m]-2*sum[n][j]));
    for(int i=1;i<=n;++i) res = min(res,abs(sum[n][m]-2*sum[i][m]));
    cout<<res;
    return 0;
}

发表于 2023-08-19 06:33:27 回复(0)

可以把每行每列的求和算一遍存下来供后续二分时重复使用,用 O(n) 的空间把总求和时间压到 O(nlogn)(其实是 O(n),因为是等比数列求和)

进一步还可以考虑维护以行或列为粒度的前缀和数组,把二分的时间压到 O(logn)

n, m = [int(x) for x in input().strip().split()]

a = [[0] * m for i in range(n)]
row_sum = [0] * n
col_sum = [0] * m
total_sum = 0

for i in range(n):
    a[i] = [int(x) for x in input().strip().split()]
    row_sum[i] = sum(a[i])
    total_sum += row_sum[i]

for j in range(m):
    col_sum[j] = sum([a[i][j] for i in range(n)])

half = total_sum >> 1

# row_idx
l = 0; r = n - 1
if l == r:
    min_row = total_sum
else:
    while (l < r):
        mid = l + r + 1 >> 1
        half_sum = sum(row_sum[:mid])
        if half_sum <= half:
            l = mid
        else:
            r = mid - 1
    if half_sum > half:
        half_sum -= row_sum[mid-1]
    min_row = min(total_sum - 2 * half_sum, 2 * (half_sum + row_sum[r]) - total_sum)

# col_idx
l = 0; r = m - 1
if l == r:
    min_col = total_sum
else:
    while (l < r):
        mid = l + r + 1 >> 1
        half_sum = sum(col_sum[:mid])
        if half_sum <= half:
            l = mid
        else:
            r = mid - 1
    if half_sum > half:
        half_sum -= col_sum[mid-1]
    min_col = min(total_sum - 2 * half_sum, 2 * (half_sum + col_sum[r]) - total_sum)

print(min(min_row, min_col))
编辑于 2023-08-18 05:00:55 回复(0)
直接暴力求解,力求AC,之前直接用的int,结果有一组用例死活过不了,还以为不能暴力求解。后面改为long就好了
public static void main(String[] args) throws IOException {
        BufferedReader reader = new BufferedReader(new InputStreamReader(System.in));
        String[] strings = reader.readLine().split(" ");
        int n = Integer.parseInt(strings[0]);
        int m = Integer.parseInt(strings[1]);

        int[][] cake = new int[n][m];
        for (int i = 0; i < n; i++) {
            String[] s = reader.readLine().split(" ");
            for (int j = 0; j < m; j++) {
                cake[i][j] = Integer.parseInt(s[j]);
            }
        }

        long minDiff = Long.MAX_VALUE;
        long totalSum = 0;
        for (int[] ints : cake) {
            for (long anInt : ints) {
                totalSum += anInt;
            }
        }
        long sum  = 0;
        if (cake[0].length == 1) {

            System.out.println(totalSum);
        }
        else {

            //先想象把蛋糕横着分开
            for (int i = 1; i < n; i++) {
                for (int j = 0; j < n - i; j++) {
                    for (int k = 0; k < m; k++) {
                        sum += cake[j][k];
                    }
                }
                minDiff = Math.min(minDiff, Math.abs(sum - (totalSum - sum)));
                sum = 0;
            }
            //竖着分开的请况
            for (int i = 1; i < m; i++) {
                for (int[] ints : cake) {
                    for (int k = 0; k < m - i; k++) {
                        sum += ints[k];
                    }
                }
                minDiff = Math.min(minDiff, Math.abs(sum - (totalSum - sum)));
                sum = 0;
            }
            System.out.println(minDiff);
        }

    }


发表于 2023-08-18 00:31:29 回复(0)