class Solution {
public:
long long minimumCost(vector<int>& nums, int k, int dist) {
}
};
100178. 将数组分成最小总代价的子数组 II
给你一个下标从 0 开始长度为 n
的整数数组 nums
和两个 正 整数 k
和 dist
。
一个数组的 代价 是数组中的 第一个 元素。比方说,[1,2,3]
的代价为 1
,[3,4,1]
的代价为 3
。
你需要将 nums
分割成 k
个 连续且互不相交 的子数组,满足 第二 个子数组与第 k
个子数组中第一个元素的下标距离 不超过 dist
。换句话说,如果你将 nums
分割成子数组 nums[0..(i1 - 1)], nums[i1..(i2 - 1)], ..., nums[ik-1..(n - 1)]
,那么它需要满足 ik-1 - i1 <= dist
。
请你返回这些子数组的 最小 总代价。
示例 1:
输入:nums = [1,3,2,6,4,2], k = 3, dist = 3 输出:5 解释:将数组分割成 3 个子数组的最优方案是:[1,3] ,[2,6,4] 和 [2] 。这是一个合法分割,因为 ik-1 - i1 等于 5 - 2 = 3 ,等于 dist 。总代价为 nums[0] + nums[2] + nums[5] ,也就是 1 + 2 + 2 = 5 。 5 是分割成 3 个子数组的最小总代价。
示例 2:
输入:nums = [10,1,2,2,2,1], k = 4, dist = 3 输出:15 解释:将数组分割成 4 个子数组的最优方案是:[10] ,[1] ,[2] 和 [2,2,1] 。这是一个合法分割,因为 ik-1 - i1 等于 3 - 1 = 2 ,小于 dist 。总代价为 nums[0] + nums[1] + nums[2] + nums[3] ,也就是 10 + 1 + 2 + 2 = 15 。 分割 [10] ,[1] ,[2,2,2] 和 [1] 不是一个合法分割,因为 ik-1 和 i1 的差为 5 - 1 = 4 ,大于 dist 。 15 是分割成 4 个子数组的最小总代价。
示例 3:
输入:nums = [10,8,18,9], k = 3, dist = 1 输出:36 解释:将数组分割成 4 个子数组的最优方案是:[10] ,[8] 和 [18,9] 。这是一个合法分割,因为 ik-1 - i1 等于 2 - 1 = 1 ,等于 dist 。总代价为 nums[0] + nums[1] + nums[2] ,也就是 10 + 8 + 18 = 36 。 分割 [10] ,[8,18] 和 [9] 不是一个合法分割,因为 ik-1 和 i1 的差为 3 - 1 = 2 ,大于 dist 。 36 是分割成 3 个子数组的最小总代价。
提示:
3 <= n <= 105
1 <= nums[i] <= 109
3 <= k <= n
k - 2 <= dist <= n - 2
原站题解
golang 解法, 执行用时: 585 ms, 内存消耗: 17.3 MB, 提交时间: 2024-01-22 10:09:21
import "github.com/emirpasic/gods/trees/redblacktree" func minimumCost(nums []int, k, dist int) int64 { k-- L := redblacktree.NewWithIntComparator() R := redblacktree.NewWithIntComparator() add := func(t *redblacktree.Tree, x int) { c, ok := t.Get(x) if ok { t.Put(x, c.(int)+1) } else { t.Put(x, 1) } } del := func(t *redblacktree.Tree, x int) { c, _ := t.Get(x) if c.(int) > 1 { t.Put(x, c.(int)-1) } else { t.Remove(x) } } sumL := nums[0] for _, x := range nums[1 : dist+2] { sumL += x add(L, x) } sizeL := dist + 1 l2r := func() { x := L.Right().Key.(int) sumL -= x sizeL-- del(L, x) add(R, x) } r2l := func() { x := R.Left().Key.(int) sumL += x sizeL++ del(R, x) add(L, x) } for sizeL > k { l2r() } ans := sumL for i := dist + 2; i < len(nums); i++ { // 移除 out out := nums[i-dist-1] if _, ok := L.Get(out); ok { sumL -= out sizeL-- del(L, out) } else { del(R, out) } // 添加 in in := nums[i] if in < L.Right().Key.(int) { sumL += in sizeL++ add(L, in) } else { add(R, in) } // 维护大小 if sizeL == k-1 { r2l() } else if sizeL == k+1 { l2r() } ans = min(ans, sumL) } return int64(ans) }
java 解法, 执行用时: 273 ms, 内存消耗: 55.9 MB, 提交时间: 2024-01-22 10:09:05
public class Solution { public long minimumCost(int[] nums, int k, int dist) { k--; sumL = nums[0]; for (int i = 1; i < dist + 2; i++) { sumL += nums[i]; L.merge(nums[i], 1, Integer::sum); } sizeL = dist + 1; while (sizeL > k) { l2r(); } long ans = sumL; for (int i = dist + 2; i < nums.length; i++) { // 移除 out int out = nums[i - dist - 1]; if (L.containsKey(out)) { sumL -= out; sizeL--; removeOne(L, out); } else { removeOne(R, out); } // 添加 in int in = nums[i]; if (in < L.lastKey()) { sumL += in; sizeL++; L.merge(in, 1, Integer::sum); } else { R.merge(in, 1, Integer::sum); } // 维护大小 if (sizeL == k - 1) { r2l(); } else if (sizeL == k + 1) { l2r(); } ans = Math.min(ans, sumL); } return ans; } private long sumL; private int sizeL; private final TreeMap<Integer, Integer> L = new TreeMap<>(); private final TreeMap<Integer, Integer> R = new TreeMap<>(); private void l2r() { int x = L.lastKey(); removeOne(L, x); sumL -= x; sizeL--; R.merge(x, 1, Integer::sum); } private void r2l() { int x = R.firstKey(); removeOne(R, x); sumL += x; sizeL++; L.merge(x, 1, Integer::sum); } private void removeOne(Map<Integer, Integer> m, int x) { int cnt = m.get(x); if (cnt > 1) { m.put(x, cnt - 1); } else { m.remove(x); } } }
cpp 解法, 执行用时: 537 ms, 内存消耗: 152.6 MB, 提交时间: 2024-01-22 10:08:45
class Solution { public: long long minimumCost(vector<int> &nums, int k, int dist) { k--; long long sum = accumulate(nums.begin(), nums.begin() + dist + 2, 0LL); multiset<int> L(nums.begin() + 1, nums.begin() + dist + 2), R; auto L2R = [&]() { int x = *L.rbegin(); sum -= x; L.erase(L.find(x)); R.insert(x); }; auto R2L = [&]() { int x = *R.begin(); sum += x; R.erase(R.find(x)); L.insert(x); }; while (L.size() > k) { L2R(); } long long ans = sum; for (int i = dist + 2; i < nums.size(); i++) { // 移除 out int out = nums[i - dist - 1]; auto it = L.find(out); if (it != L.end()) { sum -= out; L.erase(it); } else { R.erase(R.find(out)); } // 添加 in int in = nums[i]; if (in < *L.rbegin()) { sum += in; L.insert(in); } else { R.insert(in); } // 维护大小 if (L.size() == k - 1) { R2L(); } else if (L.size() == k + 1) { L2R(); } ans = min(ans, sum); } return ans; } };
python3 解法, 执行用时: 850 ms, 内存消耗: 32 MB, 提交时间: 2024-01-22 10:08:29
# 维护两个有序集合前k-1小 from sortedcontainers import SortedList class Solution: def minimumCost(self, nums: List[int], k: int, dist: int) -> int: k -= 1 sum_left = sum(nums[:dist + 2]) L = SortedList(nums[1:dist + 2]) R = SortedList() def L2R() -> None: x = L.pop() nonlocal sum_left sum_left -= x R.add(x) def R2L() -> None: x = R.pop(0) nonlocal sum_left sum_left += x L.add(x) while len(L) > k: L2R() ans = sum_left for i in range(dist + 2, len(nums)): # 移除 out out = nums[i - dist - 1] if out in L: sum_left -= out L.remove(out) else: R.remove(out) # 添加 in in_val = nums[i] if in_val < L[-1]: sum_left += in_val L.add(in_val) else: R.add(in_val) # 维护大小 if len(L) == k - 1: R2L() elif len(L) == k + 1: L2R() ans = min(ans, sum_left) return ans