列表

详情


DP49. 矩阵取数游戏

描述

帅帅经常跟同学玩一个矩阵取数游戏:对于一个给定的 n*m 的矩阵,矩阵中的每个元素 均为非负整数。游戏规则如下:
1.每次取数时须从每行各取走一个元素,共 n 个。m 次后取完矩阵所有元素;
2.每次取走的各个元素只能是该元素所在行的行首或行尾;
3.每次取数都有一个得分值,为每行取数的得分之和,每行取数的得分 = 被取走的元素值 * 2i,其中i表示第 i 次取数(从1开始编号);
4.游戏结束总得分为 m 次取数得分之和。
帅帅想请你帮忙写一个程序,对于任意矩阵,可以求出取数后的最大得分。

数据范围: ,矩阵中的值满足 ,由于得分可能会非常大,所以把值对 取模

输入描述

第一行输入两个正整数 n 和 m ,表示矩阵的长宽。
后续 n 行每行输入 m 个正整数,表示矩阵的元素

输出描述

输出最大得分

示例1

输入:

2 3
1 2 3
3 4 2

输出:

82

说明:

第1次:第1行取行首元素,第2行取行尾元素,本次得分为1 * 2+ 2 * 2= 6
第2次:两行均取行首元素,本次得分为2 * 2+ 3 * 22 = 20
第3次:得分为3 * 2+ 4 * 2= 56。
总得分为6 + 20 + 56 = 82

示例2

输入:

1 4
4 5 0 5

输出:

122

示例3

输入:

2 10
96 56 54 46 86 12 23 88 80 43
16 95 18 29 30 53 88 83 64 67

输出:

316994

原站题解

C 解法, 执行用时: 144ms, 内存消耗: 904KB, 提交时间: 2022-06-22

#include <stdio.h>

int *multi(int *a, int aLen, int b) {
    int *result = (int *)malloc(sizeof(int) * 10);
    
    long carry = 0, v;
    for (int i = 0; i < aLen; i++) {
        v = a[i] * b + carry;
        carry = v / 1000000007;
        result[i] = v % 1000000007;
    }
    
    return result;
}

int *add(int *a, int aLen, int *b, int bLen) {
    int *result = (int *)malloc(sizeof(int) * 10);
    
    long carry = 0, v;
    for (int i = 0; i < aLen; i++) {
        v = a[i] + b[i] + carry;
        carry = v / 1000000007;
        result[i] = v % 1000000007;
    }
    
    return result;
}

int *max(int *a, int aLen, int *b, int bLen) {
    for (int i = aLen - 1; i >= 0; i--) {
        if (a[i] > b[i]) {
            return a;
        } else if (a[i] < b[i]) {
            return b;
        }
    }
    
    return a;
}

int maxRow(int *nums, int len, int **powArray) {
    int ***dp = (int ***)malloc(sizeof(int **) * (len + 1));
    for (int i = 0; i < len + 1; i++) {
        dp[i] = (int **)malloc(sizeof(int *) * (len + 1));
        for (int j = 0; j < len + 1; j++) {
            dp[i][j] = (int *)malloc(sizeof(int) * 10);
            memset(dp[i][j], 0, sizeof(int) * 10);
        }
    }
    
    for (int d = 1; d < len + 1; d++) {
        for (int l = 0; l <= d; l++) {
            int *v = dp[l][d-l];
            if (l != d) {
                int *x = multi(powArray[d-1], 10, nums[len - d + l]);
                int *y = add(dp[l][d-l-1], 10, x, 10);
                free(x);
                int *z = max(y, 10, v, 10);
                if (z != v) {
                    free(v);
                    dp[l][d-l] = z;
                    v = dp[l][d-l];
                } else {
                    free(y);
                }
            }
            
            if (l != 0) {
                int *x = multi(powArray[d-1], 10, nums[l - 1]);
                int *y = add(dp[l-1][d-l], 10, x, 10);
                free(x);
                int *z = max(y, 10, v, 10);
                if (z != v) {
                    free(v);
                    dp[l][d-l] = z;
                    v = dp[l][d-l];
                } else {
                    free(y);
                }
            }
        }
    }
    
    int *maxSum = dp[0][0], result;
    for (int k = 0; k <= len; k++) {
        maxSum = max(dp[k][len-k], 10, maxSum, 10);
    }
    result = maxSum[0];
    
    for (int i = 0; i < len + 1; i++) {
        for (int j = 0; j < len + 1; j++) {
            free(dp[i][j]);
        }
        free(dp[i]);
    }
    free(dp);
    
    return result;
}

int main() {
    int n, m;
    scanf("%d %d", &n, &m);
    
    int nums[101][101];
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < m; j++) {
            scanf("%d", &nums[i][j]);
        }
    }
    
    long sum = 0, x = 1;
    
    int **powArray = (int **)malloc(sizeof(int *) * m);
    for (int i = 0; i < m; i++) {
        powArray[i] = (int *)malloc(sizeof(int) * 10);
        memset(powArray[i], 0, sizeof(int) * 10);
    }
    powArray[0][0] = 2;
    
    int *result;
    for (int i = 1; i < m; i++) {
        result = multi(powArray[i-1], 10, 2);
        for (int j = 0; j < 10; j++) {
            powArray[i][j] = result[j];
        }
        free(result);
    }
    
    for (int i = 0; i < n; i++) {
        sum = (sum + maxRow(nums[i], m, powArray)) % 1000000007;
    }
    
    printf("%d", sum);
    
    for (int i = 0; i < m; i++) {
        free(powArray[i]);
    }
    free(powArray);
    
    return 0;
}


C++ 解法, 执行用时: 286ms, 内存消耗: 1048KB, 提交时间: 2022-08-05

#include <iostream>
#include <algorithm>
#include <cstring>
#include <vector>

using namespace std;
using ll = long long;
using VI = vector<int>;

const int N = 110, mod = 1e9 + 7;
int g[N];
VI power2[N];
VI dp[N][N];

int n, m;

VI add(VI a, VI b){
    static VI c;
    c.clear();
    for(int i = 0, t = 0; i < a.size() || i < b.size() ||  t; i++){
        if(i < a.size()) t += a[i];
        if(i < b.size()) t += b[i];
        c.push_back(t % 10);
        t /= 10;
    }
    return c;
}

VI mul(VI a, int b){
    static VI c;
    c.clear();
    ll t = 0;
    for(int i = 0; i < a.size() || t; i++){
        if(i < a.size()) t += (ll)a[i] * b;
        c.push_back(t % 10);
        t /= 10;
    }
    return c;
}

VI maxv(VI a, VI b){
    if(a.size() > b.size()) return a;
    if(a.size() < b.size()) return b;
    for(int i = a.size() - 1; i >= 0; i--){
        if(a[i] > b[i]) return a;
        if(a[i] < b[i]) return b;
    }
    return a;
}

VI f(){
    for(int len = 1; len <= m; len++){
        for(int l = 1; l + len - 1 <= m; l++){
            int r = l + len - 1;
            if(l == r) dp[l][r] = mul({1}, g[l] * 2);
            else{
                auto left = add(mul(dp[l + 1][r], 2), mul({2}, g[l]));
                auto right = add(mul(dp[l][r - 1], 2), mul({2}, g[r]));
                dp[l][r] = maxv(left, right);
            }
        }
    }
    return dp[1][m];
}

void print(VI a){
    ll t = 0;
    for(int i = a.size() - 1; i >= 0; i--){
        t = t * 10 + a[i];
        t %= mod;
    }
    cout << t << endl;
}

int main(){
    cin >> n >> m;

    power2[0] = {1};
    for(int i = 1; i <= m; i++) power2[i] = mul(power2[i - 1], 2);

    VI res(1, 0);
    for(int i = 0; i < n; i++){
        for(int j = 1; j <= m; j++) cin >> g[j];
        res = add(res, f());
    }

    print(res);

    return 0;
}

C++ 解法, 执行用时: 322ms, 内存消耗: 1052KB, 提交时间: 2022-08-06

#include <iostream>
#include <algorithm>
#include <cstring>
#include <vector>

using namespace std;
using ll = long long;
using VI = vector<int>;

const int N = 110, mod = 1e9 + 7;
int g[N];
VI power2[N];
VI dp[N][N];

int n, m;

VI add(VI a, VI b){
    static VI c;
    c.clear();
    for(int i = 0, t = 0; i < a.size() || i < b.size() ||  t; i++){
        if(i < a.size()) t += a[i];
        if(i < b.size()) t += b[i];
        c.push_back(t % 10);
        t /= 10;
    }
    return c;
}

VI mul(VI a, int b){
    static VI c;
    c.clear();
    ll t = 0;
    for(int i = 0; i < a.size() || t; i++){
        if(i < a.size()) t += (ll)a[i] * b;
        c.push_back(t % 10);
        t /= 10;
    }
    return c;
}

VI maxv(VI a, VI b){
    if(a.size() > b.size()) return a;
    if(a.size() < b.size()) return b;
    for(int i = a.size() - 1; i >= 0; i--){
        if(a[i] > b[i]) return a;
        if(a[i] < b[i]) return b;
    }
    return a;
}

VI f(){
    for(int len = 1; len <= m; len++){
        for(int l = 1; l + len - 1 <= m; l++){
            int r = l + len - 1;
            if(l == r) dp[l][r] = mul({1}, g[l] * 2);
            else{
                auto left = add(mul(dp[l + 1][r], 2), mul({2}, g[l]));
                auto right = add(mul(dp[l][r - 1], 2), mul({2}, g[r]));
                dp[l][r] = maxv(left, right);
            }
        }
    }
    return dp[1][m];
}

void print(VI a){
    ll t = 0;
    for(int i = a.size() - 1; i >= 0; i--){
        t = t * 10 + a[i];
        t %= mod;
    }
    cout << t << endl;
}

int main(){
    cin >> n >> m;

//     power2[0] = {1};
//     for(int i = 1; i <= m; i++) power2[i] = mul(power2[i - 1], 2);

    VI res(1, 0);
    for(int i = 0; i < n; i++){
        for(int j = 1; j <= m; j++) cin >> g[j];
        res = add(res, f());
    }

    print(res);

    return 0;
}

Pypy2 解法, 执行用时: 324ms, 内存消耗: 32108KB, 提交时间: 2022-06-13

import sys

while True:
    firstline = sys.stdin.readline()
    if not firstline:
        break

    (n,m) = [int(field) for field in firstline.rstrip().split(" ")]
    matrix = [[int(field) for field in sys.stdin.readline().rstrip().split(" ")] for i in range(n)]

    answer = 0

    for l in range(n):
        a = [0] + matrix[l]
        f = [[-1] * (m+1)] * (m+1)

        for i in range(0,(m)+1):
            for j in range(0,(m-i)+1):
                if i>0 or j>0:
                    f[i][j] = max((2**(i+j)*a[i]+f[i-1][j] if i>0 else 0), (2**(i+j)*a[m+1-j]+f[i][j-1] if j>0 else 0))
                else:
                    f[i][j] = 0

        answer += max([f[i][m-i] for i in range(0,m+1)])

    mod = int(1e9 + 7)
    print answer % mod

Pypy3 解法, 执行用时: 413ms, 内存消耗: 29416KB, 提交时间: 2022-05-27

n,m = map(int, input().split())
ans = 0
for i in range(n):
    lst = list(map(int, input().split()))
    dp = [[0] * (m+1) for j in range(m+1)]
    for ln in range(1,m+1):
        for l in range (ln+1):
            r = ln - l
            if l > 0:
                dp[l][r] = (dp[l-1][r] + lst[l-1] * 2**ln)
            if r > 0:
                dp[l][r] = max((dp[l][r-1] + lst[m-r] * 2**ln), dp[l][r])
    mx = 0
    for l in range(m+1):
        mx = max(dp[l][m-l], mx)
    ans += mx
    ans %= 1000000007
print(ans)

上一题