3428. 最多 K 个元素的子序列的最值之和
给你一个整数数组 nums
和一个正整数 k
,返回所有长度最多为 k
的 子序列 中 最大值 与 最小值 之和的总和。
非空子序列 是指从另一个数组中删除一些或不删除任何元素(且不改变剩余元素的顺序)得到的数组。
由于答案可能非常大,请返回对 109 + 7
取余数的结果。
示例 1:
输入: nums = [1,2,3], k = 2
输出: 24
解释:
数组 nums
中所有长度最多为 2 的子序列如下:
子序列 | 最小值 | 最大值 | 和 |
---|---|---|---|
[1] |
1 | 1 | 2 |
[2] |
2 | 2 | 4 |
[3] |
3 | 3 | 6 |
[1, 2] |
1 | 2 | 3 |
[1, 3] |
1 | 3 | 4 |
[2, 3] |
2 | 3 | 5 |
总和 | 24 |
因此,输出为 24。
示例 2:
输入: nums = [5,0,6], k = 1
输出: 22
解释:
对于长度恰好为 1 的子序列,最小值和最大值均为元素本身。因此,总和为 5 + 5 + 0 + 0 + 6 + 6 = 22
。
示例 3:
输入: nums = [1,1,1], k = 2
输出: 12
解释:
子序列 [1, 1]
和 [1]
各出现 3 次。对于所有这些子序列,最小值和最大值均为 1。因此,总和为 12。
提示:
1 <= nums.length <= 105
0 <= nums[i] <= 109
1 <= k <= min(100, nums.length)
原站题解
cpp 解法, 执行用时: 25 ms, 内存消耗: 80.3 MB, 提交时间: 2025-02-01 10:48:23
const int MOD = 1'000'000'007; const int MX = 100'000; long long F[MX]; // F[i] = i! long long INV_F[MX]; // INV_F[i] = i!^-1 long long pow(long long x, int n) { long long res = 1; for (; n; n /= 2) { if (n % 2) { res = res * x % MOD; } x = x * x % MOD; } return res; } auto init = [] { F[0] = 1; for (int i = 1; i < MX; i++) { F[i] = F[i - 1] * i % MOD; } INV_F[MX - 1] = pow(F[MX - 1], MOD - 2); for (int i = MX - 1; i; i--) { INV_F[i - 1] = INV_F[i] * i % MOD; } return 0; }(); long long comb(int n, int m) { return m > n ? 0 : F[n] * INV_F[m] % MOD * INV_F[n - m] % MOD; } class Solution { public: int minMaxSums(vector<int>& nums, int k) { ranges::sort(nums); int n = nums.size(); long long ans = 0, s = 1; for (int i = 0; i < n; i++) { ans = (ans + s * (nums[i] + nums[n - 1 - i])) % MOD; s = (s * 2 - comb(i, k - 1) + MOD) % MOD; } return ans; } };
cpp 解法, 执行用时: 91 ms, 内存消耗: 80.4 MB, 提交时间: 2025-02-01 10:48:12
const int MOD = 1'000'000'007; const int MX = 100'000; long long F[MX]; // F[i] = i! long long INV_F[MX]; // INV_F[i] = i!^-1 long long pow(long long x, int n) { long long res = 1; for (; n; n /= 2) { if (n % 2) { res = res * x % MOD; } x = x * x % MOD; } return res; } auto init = [] { F[0] = 1; for (int i = 1; i < MX; i++) { F[i] = F[i - 1] * i % MOD; } INV_F[MX - 1] = pow(F[MX - 1], MOD - 2); for (int i = MX - 1; i; i--) { INV_F[i - 1] = INV_F[i] * i % MOD; } return 0; }(); long long comb(int n, int m) { return F[n] * INV_F[m] % MOD * INV_F[n - m] % MOD; } class Solution { public: int minMaxSums(vector<int>& nums, int k) { ranges::sort(nums); int n = nums.size(); long long ans = 0; for (int i = 0; i < n; i++) { long long s = 0; for (int j = 0; j < min(k, i + 1); j++) { s += comb(i, j); } ans = (ans + s % MOD * (nums[i] + nums[n - 1 - i])) % MOD; } return ans; } };
java 解法, 执行用时: 50 ms, 内存消耗: 55.2 MB, 提交时间: 2025-02-01 10:47:49
class Solution { private static final int MOD = 1_000_000_007; private static final int MX = 100_000; private static final long[] F = new long[MX]; // f[i] = i! private static final long[] INV_F = new long[MX]; // inv_f[i] = i!^-1 static { F[0] = 1; for (int i = 1; i < MX; i++) { F[i] = F[i - 1] * i % MOD; } INV_F[MX - 1] = pow(F[MX - 1], MOD - 2); for (int i = MX - 1; i > 0; i--) { INV_F[i - 1] = INV_F[i] * i % MOD; } } public int minMaxSums(int[] nums, int k) { Arrays.sort(nums); int n = nums.length; long ans = 0; long s = 1; for (int i = 0; i < n; i++) { ans = (ans + s * (nums[i] + nums[n - 1 - i])) % MOD; s = (s * 2 - comb(i, k - 1) + MOD) % MOD; } return (int) ans; } private long comb(int n, int m) { return m > n ? 0 : F[n] * INV_F[m] % MOD * INV_F[n - m] % MOD; } private static long pow(long x, int n) { long res = 1; for (; n > 0; n /= 2) { if (n % 2 > 0) { res = res * x % MOD; } x = x * x % MOD; } return res; } }
java 解法, 执行用时: 144 ms, 内存消耗: 55.1 MB, 提交时间: 2025-02-01 10:47:36
class Solution { private static final int MOD = 1_000_000_007; private static final int MX = 100_000; private static final long[] F = new long[MX]; // f[i] = i! private static final long[] INV_F = new long[MX]; // inv_f[i] = i!^-1 static { F[0] = 1; for (int i = 1; i < MX; i++) { F[i] = F[i - 1] * i % MOD; } INV_F[MX - 1] = pow(F[MX - 1], MOD - 2); for (int i = MX - 1; i > 0; i--) { INV_F[i - 1] = INV_F[i] * i % MOD; } } public int minMaxSums(int[] nums, int k) { Arrays.sort(nums); int n = nums.length; long ans = 0; for (int i = 0; i < n; i++) { long s = 0; for (int j = 0; j < Math.min(k, i + 1); j++) { s += comb(i, j); } ans = (ans + s % MOD * (nums[i] + nums[n - 1 - i])) % MOD; } return (int) ans; } private long comb(int n, int m) { return F[n] * INV_F[m] % MOD * INV_F[n - m] % MOD; } private static long pow(long x, int n) { long res = 1; for (; n > 0; n /= 2) { if (n % 2 > 0) { res = res * x % MOD; } x = x * x % MOD; } return res; } }
golang 解法, 执行用时: 57 ms, 内存消耗: 12.2 MB, 提交时间: 2025-02-01 10:47:20
const mod = 1_000_000_007 const mx = 100_000 var f [mx]int // f[i] = i! var invF [mx]int // invF[i] = i!^-1 func init() { f[0] = 1 for i := 1; i < mx; i++ { f[i] = f[i-1] * i % mod } invF[mx-1] = pow(f[mx-1], mod-2) for i := mx - 1; i > 0; i-- { invF[i-1] = invF[i] * i % mod } } func pow(x, n int) int { res := 1 for ; n > 0; n /= 2 { if n%2 > 0 { res = res * x % mod } x = x * x % mod } return res } func comb(n, m int) int { return f[n] * invF[m] % mod * invF[n-m] % mod } func minMaxSums(nums []int, k int) (ans int) { slices.Sort(nums) for i, x := range nums { s := 0 for j := range min(k, i+1) { s += comb(i, j) } ans = (ans + s%mod*(x+nums[len(nums)-1-i])) % mod } return }
golang 解法, 执行用时: 14 ms, 内存消耗: 12.4 MB, 提交时间: 2025-02-01 10:47:05
const mod = 1_000_000_007 const mx = 100_000 var f [mx]int // f[i] = i! var invF [mx]int // invF[i] = i!^-1 func init() { f[0] = 1 for i := 1; i < mx; i++ { f[i] = f[i-1] * i % mod } invF[mx-1] = pow(f[mx-1], mod-2) for i := mx - 1; i > 0; i-- { invF[i-1] = invF[i] * i % mod } } func pow(x, n int) int { res := 1 for ; n > 0; n /= 2 { if n%2 > 0 { res = res * x % mod } x = x * x % mod } return res } func comb(n, m int) int { if m > n { return 0 } return f[n] * invF[m] % mod * invF[n-m] % mod } func minMaxSums(nums []int, k int) (ans int) { slices.Sort(nums) s := 1 for i, x := range nums { ans = (ans + s*(x+nums[len(nums)-1-i])) % mod s = (s*2 - comb(i, k-1) + mod) % mod } return }
python3 解法, 执行用时: 180 ms, 内存消耗: 37.1 MB, 提交时间: 2025-02-01 10:46:51
MOD = 1_000_000_007 MX = 100_000 fac = [0] * MX # f[i] = i! fac[0] = 1 for i in range(1, MX): fac[i] = fac[i - 1] * i % MOD inv_f = [0] * MX # inv_f[i] = i!^-1 inv_f[-1] = pow(fac[-1], -1, MOD) for i in range(MX - 1, 0, -1): inv_f[i - 1] = inv_f[i] * i % MOD def comb(n: int, m: int) -> int: return fac[n] * inv_f[m] * inv_f[n - m] % MOD if m <= n else 0 class Solution: def minMaxSums(self, nums: List[int], k: int) -> int: nums.sort() ans = 0 s = 1 for i, x in enumerate(nums): ans += (x + nums[-1 - i]) * s s = (s * 2 - comb(i, k - 1)) % MOD return ans % MOD
python3 解法, 执行用时: 623 ms, 内存消耗: 29.6 MB, 提交时间: 2025-02-01 10:46:40
# 更快的写法见【预处理】 class Solution: def minMaxSums(self, nums: List[int], k: int) -> int: MOD = 1_000_000_007 nums.sort() ans = 0 s = 1 for i, x in enumerate(nums): ans += (x + nums[-1 - i]) * s s = (s * 2 - comb(i, k - 1)) % MOD return ans % MOD
python3 解法, 执行用时: 2431 ms, 内存消耗: 37.3 MB, 提交时间: 2025-02-01 10:46:24
MOD = 1_000_000_007 MX = 100_000 fac = [0] * MX # f[i] = i! fac[0] = 1 for i in range(1, MX): fac[i] = fac[i - 1] * i % MOD inv_f = [0] * MX # inv_f[i] = i!^-1 inv_f[-1] = pow(fac[-1], -1, MOD) for i in range(MX - 1, 0, -1): inv_f[i - 1] = inv_f[i] * i % MOD def comb(n: int, m: int) -> int: return fac[n] * inv_f[m] * inv_f[n - m] % MOD class Solution: def minMaxSums(self, nums: List[int], k: int) -> int: nums.sort() ans = 0 for i, x in enumerate(nums): s = sum(comb(i, j) for j in range(min(k, i + 1))) % MOD ans += (x + nums[-1 - i]) * s return ans % MOD