列表

详情


2462. 雇佣 K 位工人的总代价

给你一个下标从 0 开始的整数数组 costs ,其中 costs[i] 是雇佣第 i 位工人的代价。

同时给你两个整数 k 和 candidates 。我们想根据以下规则恰好雇佣 k 位工人:

返回雇佣恰好 k 位工人的总代价。

 

示例 1:

输入:costs = [17,12,10,2,7,2,11,20,8], k = 3, candidates = 4
输出:11
解释:我们总共雇佣 3 位工人。总代价一开始为 0 。
- 第一轮雇佣,我们从 [17,12,10,2,7,2,11,20,8] 中选择。最小代价是 2 ,有两位工人,我们选择下标更小的一位工人,即第 3 位工人。总代价是 0 + 2 = 2 。
- 第二轮雇佣,我们从 [17,12,10,7,2,11,20,8] 中选择。最小代价是 2 ,下标为 4 ,总代价是 2 + 2 = 4 。
- 第三轮雇佣,我们从 [17,12,10,7,11,20,8] 中选择,最小代价是 7 ,下标为 3 ,总代价是 4 + 7 = 11 。注意下标为 3 的工人同时在最前面和最后面 4 位工人中。
总雇佣代价是 11 。

示例 2:

输入:costs = [1,2,4,1], k = 3, candidates = 3
输出:4
解释:我们总共雇佣 3 位工人。总代价一开始为 0 。
- 第一轮雇佣,我们从 [1,2,4,1] 中选择。最小代价为 1 ,有两位工人,我们选择下标更小的一位工人,即第 0 位工人,总代价是 0 + 1 = 1 。注意,下标为 1 和 2 的工人同时在最前面和最后面 3 位工人中。
- 第二轮雇佣,我们从 [2,4,1] 中选择。最小代价为 1 ,下标为 2 ,总代价是 1 + 1 = 2 。
- 第三轮雇佣,少于 3 位工人,我们从剩余工人 [2,4] 中选择。最小代价是 2 ,下标为 0 。总代价为 2 + 2 = 4 。
总雇佣代价是 4 。

 

提示:

原站题解

去查看

上次编辑到这里,代码来自缓存 点击恢复默认模板
class Solution { public: long long totalCost(vector<int>& costs, int k, int candidates) { } };

rust 解法, 执行用时: 18 ms, 内存消耗: 3.2 MB, 提交时间: 2024-05-01 10:15:26

use std::collections::BinaryHeap;

impl Solution {
    pub fn total_cost(mut costs: Vec<i32>, k: i32, candidates: i32) -> i64 {
        let n = costs.len();
        let k = k as usize;
        let c = candidates as usize;
        if c * 2 + k > n {
            costs.sort_unstable();
            return costs[..k].iter().map(|&x| x as i64).sum();
        }

        let mut pre = BinaryHeap::new();
        let mut suf = BinaryHeap::new();
        for i in 0..c {
            pre.push(-costs[i]); // 加负号,变成最小堆
            suf.push(-costs[n - 1 - i]);
        }

        let mut ans = 0;
        let mut i = c;
        let mut j = n - 1 - c;
        for _ in 0..k {
            if pre.peek().unwrap() >= suf.peek().unwrap() {
                ans -= pre.pop().unwrap() as i64;
                pre.push(-costs[i]);
                i += 1;
            } else {
                ans -= suf.pop().unwrap() as i64;
                suf.push(-costs[j]);
                j -= 1;
            }
        }
        ans
    }
}

javascript 解法, 执行用时: 190 ms, 内存消耗: 68.8 MB, 提交时间: 2024-05-01 10:15:10

/**
 * @param {number[]} costs
 * @param {number} k
 * @param {number} candidates
 * @return {number}
 */
var totalCost = function(costs, k, candidates) {
    const n = costs.length;
    let ans = 0;
    if (candidates * 2 + k > n) {
        costs.sort((a, b) => a - b);
        for (let i = 0; i < k; i++) {
            ans += costs[i];
        }
        return ans;
    }

    const pre = new MinPriorityQueue();
    const suf = new MinPriorityQueue();
    for (let i = 0; i < candidates; i++) {
        pre.enqueue(costs[i]);
        suf.enqueue(costs[n - 1 - i]);
    }

    let i = candidates;
    let j = n - 1 - candidates;
    while (k--) {
        if (pre.front().element <= suf.front().element) {
            ans += pre.dequeue().element;
            pre.enqueue(costs[i++]);
        } else {
            ans += suf.dequeue().element;
            suf.enqueue(costs[j--]);
        }
    }
    return ans;
};

golang 解法, 执行用时: 71 ms, 内存消耗: 9 MB, 提交时间: 2024-05-01 10:14:49

func totalCost(costs []int, k, candidates int) (ans int64) {
    n := len(costs)
    if candidates*2+k > n {
        slices.Sort(costs)
        for _, x := range costs[:k] {
            ans += int64(x)
        }
        return
    }

    pre := hp{costs[:candidates]}
    suf := hp{costs[len(costs)-candidates:]}
    heap.Init(&pre)
    heap.Init(&suf)
    for i, j := candidates, n-1-candidates; k > 0; k-- {
        if pre.IntSlice[0] <= suf.IntSlice[0] {
            ans += int64(pre.replace(costs[i]))
            i++
        } else {
            ans += int64(suf.replace(costs[j]))
            j--
        }
    }
    return
}

type hp struct{ sort.IntSlice }
func (h *hp) Push(v any)        { h.IntSlice = append(h.IntSlice, v.(int)) }
func (h *hp) Pop() any          { a := h.IntSlice; v := a[len(a)-1]; h.IntSlice = a[:len(a)-1]; return v }
func (h *hp) replace(v int) int { top := h.IntSlice[0]; h.IntSlice[0] = v; heap.Fix(h, 0); return top }

cpp 解法, 执行用时: 89 ms, 内存消耗: 72.5 MB, 提交时间: 2024-05-01 10:14:34

class Solution {
public:
    long long totalCost(vector<int>& costs, int k, int candidates) {
        int n = costs.size();
        if (candidates * 2 + k > n) {
            ranges::nth_element(costs, costs.begin() + k);
            return accumulate(costs.begin(), costs.begin() + k, 0LL);
        }

        priority_queue<int, vector<int>, greater<>> pre, suf;
        for (int i = 0; i < candidates; i++) {
            pre.push(costs[i]);
            suf.push(costs[n - 1 - i]);
        }

        long long ans = 0;
        int i = candidates, j = n - 1 - candidates;
        while (k--) {
            if (pre.top() <= suf.top()) {
                ans += pre.top();
                pre.pop();
                pre.push(costs[i++]);
            } else {
                ans += suf.top();
                suf.pop();
                suf.push(costs[j--]);
            }
        }
        return ans;
    }
};

java 解法, 执行用时: 22 ms, 内存消耗: 53.9 MB, 提交时间: 2024-05-01 10:14:18

class Solution {
    public long totalCost(int[] costs, int k, int candidates) {
        int n = costs.length;
        long ans = 0;
        if (candidates * 2 + k > n) {
            Arrays.sort(costs);
            for (int i = 0; i < k; i++) {
                ans += costs[i];
            }
            return ans;
        }

        PriorityQueue<Integer> pre = new PriorityQueue<>();
        PriorityQueue<Integer> suf = new PriorityQueue<>();
        for (int i = 0; i < candidates; i++) {
            pre.offer(costs[i]);
            suf.offer(costs[n - 1 - i]);
        }

        int i = candidates;
        int j = n - 1 - candidates;
        while (k-- > 0) {
            if (pre.peek() <= suf.peek()) {
                ans += pre.poll();
                pre.offer(costs[i++]);
            } else {
                ans += suf.poll();
                suf.offer(costs[j--]);
            }
        }
        return ans;
    }
}

python3 解法, 执行用时: 156 ms, 内存消耗: 21.6 MB, 提交时间: 2022-11-09 10:56:05

class Solution:
    def totalCost(self, costs: List[int], k: int, candidates: int) -> int:
        ans, n = 0, len(costs)
        if candidates * 2 < n:
            pre = costs[:candidates]
            heapify(pre)
            suf = costs[-candidates:]
            heapify(suf)
            i, j = candidates, n - 1 - candidates
            while k and i <= j:
                if pre[0] <= suf[0]:
                    ans += heapreplace(pre, costs[i])
                    i += 1
                else:
                    ans += heapreplace(suf, costs[j])
                    j -= 1
                k -= 1
            costs = pre + suf
        costs.sort()
        return ans + sum(costs[:k])  # 也可以用快速选择算法求前 k 小

上一题