列表

详情


3428. 最多 K 个元素的子序列的最值之和

给你一个整数数组 nums 和一个正整数 k,返回所有长度最多为 k子序列 中 最大值 与 最小值 之和的总和。

非空子序列 是指从另一个数组中删除一些或不删除任何元素(且不改变剩余元素的顺序)得到的数组。

由于答案可能非常大,请返回对 109 + 7 取余数的结果。

 

示例 1:

输入: nums = [1,2,3], k = 2

输出: 24

解释:

数组 nums 中所有长度最多为 2 的子序列如下:

子序列 最小值 最大值
[1] 1 1 2
[2] 2 2 4
[3] 3 3 6
[1, 2] 1 2 3
[1, 3] 1 3 4
[2, 3] 2 3 5
总和     24

因此,输出为 24。

示例 2:

输入: nums = [5,0,6], k = 1

输出: 22

解释:

对于长度恰好为 1 的子序列,最小值和最大值均为元素本身。因此,总和为 5 + 5 + 0 + 0 + 6 + 6 = 22

示例 3:

输入: nums = [1,1,1], k = 2

输出: 12

解释:

子序列 [1, 1][1] 各出现 3 次。对于所有这些子序列,最小值和最大值均为 1。因此,总和为 12。

 

提示:

原站题解

去查看

上次编辑到这里,代码来自缓存 点击恢复默认模板
class Solution { public: int minMaxSums(vector<int>& nums, int k) { } };

cpp 解法, 执行用时: 25 ms, 内存消耗: 80.3 MB, 提交时间: 2025-02-01 10:48:23

const int MOD = 1'000'000'007;
const int MX = 100'000;

long long F[MX]; // F[i] = i!
long long INV_F[MX]; // INV_F[i] = i!^-1

long long pow(long long x, int n) {
    long long res = 1;
    for (; n; n /= 2) {
        if (n % 2) {
            res = res * x % MOD;
        }
        x = x * x % MOD;
    }
    return res;
}

auto init = [] {
    F[0] = 1;
    for (int i = 1; i < MX; i++) {
        F[i] = F[i - 1] * i % MOD;
    }

    INV_F[MX - 1] = pow(F[MX - 1], MOD - 2);
    for (int i = MX - 1; i; i--) {
        INV_F[i - 1] = INV_F[i] * i % MOD;
    }
    return 0;
}();

long long comb(int n, int m) {
    return m > n ? 0 : F[n] * INV_F[m] % MOD * INV_F[n - m] % MOD;
}

class Solution {
public:
    int minMaxSums(vector<int>& nums, int k) {
        ranges::sort(nums);
        int n = nums.size();
        long long ans = 0, s = 1;
        for (int i = 0; i < n; i++) {
            ans = (ans + s * (nums[i] + nums[n - 1 - i])) % MOD;
            s = (s * 2 - comb(i, k - 1) + MOD) % MOD;
        }
        return ans;
    }
};

cpp 解法, 执行用时: 91 ms, 内存消耗: 80.4 MB, 提交时间: 2025-02-01 10:48:12

const int MOD = 1'000'000'007;
const int MX = 100'000;

long long F[MX]; // F[i] = i!
long long INV_F[MX]; // INV_F[i] = i!^-1

long long pow(long long x, int n) {
    long long res = 1;
    for (; n; n /= 2) {
        if (n % 2) {
            res = res * x % MOD;
        }
        x = x * x % MOD;
    }
    return res;
}

auto init = [] {
    F[0] = 1;
    for (int i = 1; i < MX; i++) {
        F[i] = F[i - 1] * i % MOD;
    }

    INV_F[MX - 1] = pow(F[MX - 1], MOD - 2);
    for (int i = MX - 1; i; i--) {
        INV_F[i - 1] = INV_F[i] * i % MOD;
    }
    return 0;
}();

long long comb(int n, int m) {
    return F[n] * INV_F[m] % MOD * INV_F[n - m] % MOD;
}

class Solution {
public:
    int minMaxSums(vector<int>& nums, int k) {
        ranges::sort(nums);
        int n = nums.size();
        long long ans = 0;
        for (int i = 0; i < n; i++) {
            long long s = 0;
            for (int j = 0; j < min(k, i + 1); j++) {
                s += comb(i, j);
            }
            ans = (ans + s % MOD * (nums[i] + nums[n - 1 - i])) % MOD;
        }
        return ans;
    }
};

java 解法, 执行用时: 50 ms, 内存消耗: 55.2 MB, 提交时间: 2025-02-01 10:47:49

class Solution {
    private static final int MOD = 1_000_000_007;
    private static final int MX = 100_000;

    private static final long[] F = new long[MX]; // f[i] = i!
    private static final long[] INV_F = new long[MX]; // inv_f[i] = i!^-1

    static {
        F[0] = 1;
        for (int i = 1; i < MX; i++) {
            F[i] = F[i - 1] * i % MOD;
        }

        INV_F[MX - 1] = pow(F[MX - 1], MOD - 2);
        for (int i = MX - 1; i > 0; i--) {
            INV_F[i - 1] = INV_F[i] * i % MOD;
        }
    }

    public int minMaxSums(int[] nums, int k) {
        Arrays.sort(nums);
        int n = nums.length;
        long ans = 0;
        long s = 1;
        for (int i = 0; i < n; i++) {
            ans = (ans + s * (nums[i] + nums[n - 1 - i])) % MOD;
            s = (s * 2 - comb(i, k - 1) + MOD) % MOD;
        }
        return (int) ans;
    }

    private long comb(int n, int m) {
        return m > n ? 0 : F[n] * INV_F[m] % MOD * INV_F[n - m] % MOD;
    }

    private static long pow(long x, int n) {
        long res = 1;
        for (; n > 0; n /= 2) {
            if (n % 2 > 0) {
                res = res * x % MOD;
            }
            x = x * x % MOD;
        }
        return res;
    }
}

java 解法, 执行用时: 144 ms, 内存消耗: 55.1 MB, 提交时间: 2025-02-01 10:47:36

class Solution {
    private static final int MOD = 1_000_000_007;
    private static final int MX = 100_000;

    private static final long[] F = new long[MX]; // f[i] = i!
    private static final long[] INV_F = new long[MX]; // inv_f[i] = i!^-1

    static {
        F[0] = 1;
        for (int i = 1; i < MX; i++) {
            F[i] = F[i - 1] * i % MOD;
        }

        INV_F[MX - 1] = pow(F[MX - 1], MOD - 2);
        for (int i = MX - 1; i > 0; i--) {
            INV_F[i - 1] = INV_F[i] * i % MOD;
        }
    }

    public int minMaxSums(int[] nums, int k) {
        Arrays.sort(nums);
        int n = nums.length;
        long ans = 0;
        for (int i = 0; i < n; i++) {
            long s = 0;
            for (int j = 0; j < Math.min(k, i + 1); j++) {
                s += comb(i, j);
            }
            ans = (ans + s % MOD * (nums[i] + nums[n - 1 - i])) % MOD;
        }
        return (int) ans;
    }

    private long comb(int n, int m) {
        return F[n] * INV_F[m] % MOD * INV_F[n - m] % MOD;
    }

    private static long pow(long x, int n) {
        long res = 1;
        for (; n > 0; n /= 2) {
            if (n % 2 > 0) {
                res = res * x % MOD;
            }
            x = x * x % MOD;
        }
        return res;
    }
}

golang 解法, 执行用时: 57 ms, 内存消耗: 12.2 MB, 提交时间: 2025-02-01 10:47:20

const mod = 1_000_000_007
const mx = 100_000

var f [mx]int    // f[i] = i!
var invF [mx]int // invF[i] = i!^-1

func init() {
	f[0] = 1
	for i := 1; i < mx; i++ {
		f[i] = f[i-1] * i % mod
	}

	invF[mx-1] = pow(f[mx-1], mod-2)
	for i := mx - 1; i > 0; i-- {
		invF[i-1] = invF[i] * i % mod
	}
}

func pow(x, n int) int {
	res := 1
	for ; n > 0; n /= 2 {
		if n%2 > 0 {
			res = res * x % mod
		}
		x = x * x % mod
	}
	return res
}

func comb(n, m int) int {
	return f[n] * invF[m] % mod * invF[n-m] % mod
}

func minMaxSums(nums []int, k int) (ans int) {
	slices.Sort(nums)
	for i, x := range nums {
		s := 0
		for j := range min(k, i+1) {
			s += comb(i, j)
		}
		ans = (ans + s%mod*(x+nums[len(nums)-1-i])) % mod
	}
	return
}

golang 解法, 执行用时: 14 ms, 内存消耗: 12.4 MB, 提交时间: 2025-02-01 10:47:05

const mod = 1_000_000_007
const mx = 100_000

var f [mx]int    // f[i] = i!
var invF [mx]int // invF[i] = i!^-1

func init() {
	f[0] = 1
	for i := 1; i < mx; i++ {
		f[i] = f[i-1] * i % mod
	}

	invF[mx-1] = pow(f[mx-1], mod-2)
	for i := mx - 1; i > 0; i-- {
		invF[i-1] = invF[i] * i % mod
	}
}

func pow(x, n int) int {
	res := 1
	for ; n > 0; n /= 2 {
		if n%2 > 0 {
			res = res * x % mod
		}
		x = x * x % mod
	}
	return res
}

func comb(n, m int) int {
	if m > n {
		return 0
	}
	return f[n] * invF[m] % mod * invF[n-m] % mod
}

func minMaxSums(nums []int, k int) (ans int) {
	slices.Sort(nums)
	s := 1
	for i, x := range nums {
		ans = (ans + s*(x+nums[len(nums)-1-i])) % mod
		s = (s*2 - comb(i, k-1) + mod) % mod
	}
	return
}

python3 解法, 执行用时: 180 ms, 内存消耗: 37.1 MB, 提交时间: 2025-02-01 10:46:51

MOD = 1_000_000_007
MX = 100_000

fac = [0] * MX  # f[i] = i!
fac[0] = 1
for i in range(1, MX):
    fac[i] = fac[i - 1] * i % MOD

inv_f = [0] * MX  # inv_f[i] = i!^-1
inv_f[-1] = pow(fac[-1], -1, MOD)
for i in range(MX - 1, 0, -1):
    inv_f[i - 1] = inv_f[i] * i % MOD

def comb(n: int, m: int) -> int:
    return fac[n] * inv_f[m] * inv_f[n - m] % MOD if m <= n else 0

class Solution:
    def minMaxSums(self, nums: List[int], k: int) -> int:
        nums.sort()
        ans = 0
        s = 1
        for i, x in enumerate(nums):
            ans += (x + nums[-1 - i]) * s
            s = (s * 2 - comb(i, k - 1)) % MOD
        return ans % MOD

python3 解法, 执行用时: 623 ms, 内存消耗: 29.6 MB, 提交时间: 2025-02-01 10:46:40

# 更快的写法见【预处理】
class Solution:
    def minMaxSums(self, nums: List[int], k: int) -> int:
        MOD = 1_000_000_007
        nums.sort()
        ans = 0
        s = 1
        for i, x in enumerate(nums):
            ans += (x + nums[-1 - i]) * s
            s = (s * 2 - comb(i, k - 1)) % MOD
        return ans % MOD

python3 解法, 执行用时: 2431 ms, 内存消耗: 37.3 MB, 提交时间: 2025-02-01 10:46:24

MOD = 1_000_000_007
MX = 100_000

fac = [0] * MX  # f[i] = i!
fac[0] = 1
for i in range(1, MX):
    fac[i] = fac[i - 1] * i % MOD

inv_f = [0] * MX  # inv_f[i] = i!^-1
inv_f[-1] = pow(fac[-1], -1, MOD)
for i in range(MX - 1, 0, -1):
    inv_f[i - 1] = inv_f[i] * i % MOD

def comb(n: int, m: int) -> int:
    return fac[n] * inv_f[m] * inv_f[n - m] % MOD

class Solution:
    def minMaxSums(self, nums: List[int], k: int) -> int:
        nums.sort()
        ans = 0
        for i, x in enumerate(nums):
            s = sum(comb(i, j) for j in range(min(k, i + 1))) % MOD
            ans += (x + nums[-1 - i]) * s
        return ans % MOD

上一题