列表

详情


100178. 将数组分成最小总代价的子数组 II

给你一个下标从 0 开始长度为 n 的整数数组 nums 和两个  整数 k 和 dist 。

一个数组的 代价 是数组中的 第一个 元素。比方说,[1,2,3] 的代价为 1 ,[3,4,1] 的代价为 3 。

你需要将 nums 分割成 k 个 连续且互不相交 的子数组,满足 第二 个子数组与第 k 个子数组中第一个元素的下标距离 不超过 dist 。换句话说,如果你将 nums 分割成子数组 nums[0..(i1 - 1)], nums[i1..(i2 - 1)], ..., nums[ik-1..(n - 1)] ,那么它需要满足 ik-1 - i1 <= dist 。

请你返回这些子数组的 最小 总代价。

 

示例 1:

输入:nums = [1,3,2,6,4,2], k = 3, dist = 3
输出:5
解释:将数组分割成 3 个子数组的最优方案是:[1,3] ,[2,6,4] 和 [2] 。这是一个合法分割,因为 ik-1 - i1 等于 5 - 2 = 3 ,等于 dist 。总代价为 nums[0] + nums[2] + nums[5] ,也就是 1 + 2 + 2 = 5 。
5 是分割成 3 个子数组的最小总代价。

示例 2:

输入:nums = [10,1,2,2,2,1], k = 4, dist = 3
输出:15
解释:将数组分割成 4 个子数组的最优方案是:[10] ,[1] ,[2] 和 [2,2,1] 。这是一个合法分割,因为 ik-1 - i1 等于 3 - 1 = 2 ,小于 dist 。总代价为 nums[0] + nums[1] + nums[2] + nums[3] ,也就是 10 + 1 + 2 + 2 = 15 。
分割 [10] ,[1] ,[2,2,2] 和 [1] 不是一个合法分割,因为 ik-1 和 i1 的差为 5 - 1 = 4 ,大于 dist 。
15 是分割成 4 个子数组的最小总代价。

示例 3:

输入:nums = [10,8,18,9], k = 3, dist = 1
输出:36
解释:将数组分割成 4 个子数组的最优方案是:[10] ,[8] 和 [18,9] 。这是一个合法分割,因为 ik-1 - i1 等于 2 - 1 = 1 ,等于 dist 。总代价为 nums[0] + nums[1] + nums[2] ,也就是 10 + 8 + 18 = 36 。
分割 [10] ,[8,18] 和 [9] 不是一个合法分割,因为 ik-1 和 i1 的差为 3 - 1 = 2 ,大于 dist 。
36 是分割成 3 个子数组的最小总代价。

 

提示:

原站题解

去查看

上次编辑到这里,代码来自缓存 点击恢复默认模板
class Solution { public: long long minimumCost(vector<int>& nums, int k, int dist) { } };

golang 解法, 执行用时: 585 ms, 内存消耗: 17.3 MB, 提交时间: 2024-01-22 10:09:21

import "github.com/emirpasic/gods/trees/redblacktree"

func minimumCost(nums []int, k, dist int) int64 {
	k--
	L := redblacktree.NewWithIntComparator()
	R := redblacktree.NewWithIntComparator()
	add := func(t *redblacktree.Tree, x int) {
		c, ok := t.Get(x)
		if ok {
			t.Put(x, c.(int)+1)
		} else {
			t.Put(x, 1)
		}
	}
	del := func(t *redblacktree.Tree, x int) {
		c, _ := t.Get(x)
		if c.(int) > 1 {
			t.Put(x, c.(int)-1)
		} else {
			t.Remove(x)
		}
	}

	sumL := nums[0]
	for _, x := range nums[1 : dist+2] {
		sumL += x
		add(L, x)
	}
	sizeL := dist + 1

	l2r := func() {
		x := L.Right().Key.(int)
		sumL -= x
		sizeL--
		del(L, x)
		add(R, x)
	}
	r2l := func() {
		x := R.Left().Key.(int)
		sumL += x
		sizeL++
		del(R, x)
		add(L, x)
	}
	for sizeL > k {
		l2r()
	}

	ans := sumL
	for i := dist + 2; i < len(nums); i++ {
		// 移除 out
		out := nums[i-dist-1]
		if _, ok := L.Get(out); ok {
			sumL -= out
			sizeL--
			del(L, out)
		} else {
			del(R, out)
		}

		// 添加 in
		in := nums[i]
		if in < L.Right().Key.(int) {
			sumL += in
			sizeL++
			add(L, in)
		} else {
			add(R, in)
		}

		// 维护大小
		if sizeL == k-1 {
			r2l()
		} else if sizeL == k+1 {
			l2r()
		}

		ans = min(ans, sumL)
	}
	return int64(ans)
}

java 解法, 执行用时: 273 ms, 内存消耗: 55.9 MB, 提交时间: 2024-01-22 10:09:05

public class Solution {
    public long minimumCost(int[] nums, int k, int dist) {
        k--;
        sumL = nums[0];
        for (int i = 1; i < dist + 2; i++) {
            sumL += nums[i];
            L.merge(nums[i], 1, Integer::sum);
        }
        sizeL = dist + 1;
        while (sizeL > k) {
            l2r();
        }

        long ans = sumL;
        for (int i = dist + 2; i < nums.length; i++) {
            // 移除 out
            int out = nums[i - dist - 1];
            if (L.containsKey(out)) {
                sumL -= out;
                sizeL--;
                removeOne(L, out);
            } else {
                removeOne(R, out);
            }

            // 添加 in
            int in = nums[i];
            if (in < L.lastKey()) {
                sumL += in;
                sizeL++;
                L.merge(in, 1, Integer::sum);
            } else {
                R.merge(in, 1, Integer::sum);
            }

            // 维护大小
            if (sizeL == k - 1) {
                r2l();
            } else if (sizeL == k + 1) {
                l2r();
            }

            ans = Math.min(ans, sumL);
        }
        return ans;
    }

    private long sumL;
    private int sizeL;
    private final TreeMap<Integer, Integer> L = new TreeMap<>();
    private final TreeMap<Integer, Integer> R = new TreeMap<>();

    private void l2r() {
        int x = L.lastKey();
        removeOne(L, x);
        sumL -= x;
        sizeL--;
        R.merge(x, 1, Integer::sum);
    }

    private void r2l() {
        int x = R.firstKey();
        removeOne(R, x);
        sumL += x;
        sizeL++;
        L.merge(x, 1, Integer::sum);
    }

    private void removeOne(Map<Integer, Integer> m, int x) {
        int cnt = m.get(x);
        if (cnt > 1) {
            m.put(x, cnt - 1);
        } else {
            m.remove(x);
        }
    }
}

cpp 解法, 执行用时: 537 ms, 内存消耗: 152.6 MB, 提交时间: 2024-01-22 10:08:45

class Solution {
public:
    long long minimumCost(vector<int> &nums, int k, int dist) {
        k--;
        long long sum = accumulate(nums.begin(), nums.begin() + dist + 2, 0LL);
        multiset<int> L(nums.begin() + 1, nums.begin() + dist + 2), R;
        auto L2R = [&]() {
            int x = *L.rbegin();
            sum -= x;
            L.erase(L.find(x));
            R.insert(x);
        };
        auto R2L = [&]() {
            int x = *R.begin();
            sum += x;
            R.erase(R.find(x));
            L.insert(x);
        };
        while (L.size() > k) {
            L2R();
        }

        long long ans = sum;
        for (int i = dist + 2; i < nums.size(); i++) {
            // 移除 out
            int out = nums[i - dist - 1];
            auto it = L.find(out);
            if (it != L.end()) {
                sum -= out;
                L.erase(it);
            } else {
                R.erase(R.find(out));
            }

            // 添加 in
            int in = nums[i];
            if (in < *L.rbegin()) {
                sum += in;
                L.insert(in);
            } else {
                R.insert(in);
            }

            // 维护大小
            if (L.size() == k - 1) {
                R2L();
            } else if (L.size() == k + 1) {
                L2R();
            }

            ans = min(ans, sum);
        }
        return ans;
    }
};

python3 解法, 执行用时: 850 ms, 内存消耗: 32 MB, 提交时间: 2024-01-22 10:08:29

# 维护两个有序集合前k-1小
from sortedcontainers import SortedList

class Solution:
    def minimumCost(self, nums: List[int], k: int, dist: int) -> int:
        k -= 1
        sum_left = sum(nums[:dist + 2])
        L = SortedList(nums[1:dist + 2])
        R = SortedList()

        def L2R() -> None:
            x = L.pop()
            nonlocal sum_left
            sum_left -= x
            R.add(x)

        def R2L() -> None:
            x = R.pop(0)
            nonlocal sum_left
            sum_left += x
            L.add(x)

        while len(L) > k:
            L2R()

        ans = sum_left
        for i in range(dist + 2, len(nums)):
            # 移除 out
            out = nums[i - dist - 1]
            if out in L:
                sum_left -= out
                L.remove(out)
            else:
                R.remove(out)

            # 添加 in
            in_val = nums[i]
            if in_val < L[-1]:
                sum_left += in_val
                L.add(in_val)
            else:
                R.add(in_val)

            # 维护大小
            if len(L) == k - 1:
                R2L()
            elif len(L) == k + 1:
                L2R()

            ans = min(ans, sum_left)

        return ans

上一题