列表

详情


1439. 有序矩阵中的第 k 个最小数组和

给你一个 m * n 的矩阵 mat,以及一个整数 k ,矩阵中的每一行都以非递减的顺序排列。

你可以从每一行中选出 1 个元素形成一个数组。返回所有可能数组中的第 k 个 最小 数组和。

 

示例 1:

输入:mat = [[1,3,11],[2,4,6]], k = 5
输出:7
解释:从每一行中选出一个元素,前 k 个和最小的数组分别是:
[1,2], [1,4], [3,2], [3,4], [1,6]。其中第 5 个的和是 7 。  

示例 2:

输入:mat = [[1,3,11],[2,4,6]], k = 9
输出:17

示例 3:

输入:mat = [[1,10,10],[1,4,5],[2,3,6]], k = 7
输出:9
解释:从每一行中选出一个元素,前 k 个和最小的数组分别是:
[1,1,2], [1,1,3], [1,4,2], [1,4,3], [1,1,6], [1,5,2], [1,5,3]。其中第 7 个的和是 9 。 

示例 4:

输入:mat = [[1,1,10],[2,2,9]], k = 7
输出:12

 

提示:

原站题解

去查看

上次编辑到这里,代码来自缓存 点击恢复默认模板
class Solution { public: int kthSmallest(vector<vector<int>>& mat, int k) { } };

golang 解法, 执行用时: 4 ms, 内存消耗: 6.5 MB, 提交时间: 2023-05-28 10:22:16

// 373. 查找和最小的 K 对数字
func kSmallestPairs(nums1, nums2 []int, k int) []int {
    n, m := len(nums1), len(nums2)
    ans := make([]int, 0, min(k, n*m)) // 预分配空间
    h := hp{{nums1[0] + nums2[0], 0, 0}}
    for len(h) > 0 && len(ans) < k {
        p := heap.Pop(&h).(tuple)
        i, j := p.i, p.j
        ans = append(ans, nums1[i]+nums2[j]) // 数对和
        if j == 0 && i+1 < n {
            heap.Push(&h, tuple{nums1[i+1] + nums2[0], i + 1, 0})
        }
        if j+1 < m {
            heap.Push(&h, tuple{nums1[i] + nums2[j+1], i, j + 1})
        }
    }
    return ans
}

func kthSmallest(mat [][]int, k int) int {
    a := []int{0}
    for _, row := range mat {
        a = kSmallestPairs(row, a, k)
    }
    return a[k-1]
}

type tuple struct{ s, i, j int }
type hp []tuple
func (h hp) Len() int           { return len(h) }
func (h hp) Less(i, j int) bool { return h[i].s < h[j].s }
func (h hp) Swap(i, j int)      { h[i], h[j] = h[j], h[i] }
func (h *hp) Push(v any)        { *h = append(*h, v.(tuple)) }
func (h *hp) Pop() any          { a := *h; v := a[len(a)-1]; *h = a[:len(a)-1]; return v }
func min(a, b int) int { if b < a { return b }; return a }

java 解法, 执行用时: 4 ms, 内存消耗: 42.1 MB, 提交时间: 2023-05-28 10:21:54

class Solution {
    public int kthSmallest(int[][] mat, int k) {
        var a = new int[]{0};
        for (var row : mat)
            a = kSmallestPairs(row, a, k);
        return a[k - 1];
    }

    // 373. 查找和最小的 K 对数字
    private int[] kSmallestPairs(int[] nums1, int[] nums2, int k) {
        int n = nums1.length, m = nums2.length, sz = 0;
        var ans = new int[Math.min(k, n * m)];
        var pq = new PriorityQueue<int[]>((a, b) -> a[0] - b[0]);
        pq.add(new int[]{nums1[0] + nums2[0], 0, 0});
        while (!pq.isEmpty() && sz < k) {
            var p = pq.poll();
            int i = p[1], j = p[2];
            ans[sz++] = nums1[i] + nums2[j]; // 数对和
            if (j == 0 && i + 1 < n)
                pq.add(new int[]{nums1[i + 1] + nums2[0], i + 1, 0});
            if (j + 1 < m)
                pq.add(new int[]{nums1[i] + nums2[j + 1], i, j + 1});
        }
        return ans;
    }
}

python3 解法, 执行用时: 72 ms, 内存消耗: 16.3 MB, 提交时间: 2023-05-28 10:21:22

# 最小堆
class Solution:
    # 373. 查找和最小的 K 对数字
    def kSmallestPairs(self, nums1: List[int], nums2: List[int], k: int) -> List[int]:
        ans = []
        h = [(nums1[0] + nums2[0], 0, 0)]
        while h and len(ans) < k:
            _, i, j = heappop(h)
            ans.append(nums1[i] + nums2[j])  # 数对和
            if j == 0 and i + 1 < len(nums1):
                heappush(h, (nums1[i + 1] + nums2[0], i + 1, 0))
            if j + 1 < len(nums2):
                heappush(h, (nums1[i] + nums2[j + 1], i, j + 1))
        return ans

    def kthSmallest(self, mat: List[List[int]], k: int) -> int:
        a = mat[0][:k]
        for row in mat[1:]:
            a = self.kSmallestPairs(row, a, k)
        return a[-1]

python3 解法, 执行用时: 84 ms, 内存消耗: 16.6 MB, 提交时间: 2023-05-28 10:20:56

# 二分
class Solution:
    def kthSmallest(self, mat: List[List[int]], k: int) -> int:
        def check(s: int) -> bool:
            left_k = k
            # 返回是否找到 k 个子数组和
            def dfs(i: int, s: int) -> bool:
                if i < 0:  # 能递归到这里,说明数组和不超过二分的 s
                    nonlocal left_k
                    left_k -= 1
                    return left_k == 0  # 是否找到 k 个
                for x in mat[i]:  # 「枚举选哪个」,注意 mat[i] 是有序的
                    if x - mat[i][0] > s:  # 选 x 不选 mat[i][0]
                        break  # 剪枝:后面的元素更大,无需枚举
                    if dfs(i - 1, s - (x - mat[i][0])):  # 选 x 不选 mat[i][0]
                        return True  # 找到 k 个就一直返回 True,不再递归
                return False
            return dfs(len(mat) - 1, s - sl)  # 先把第一列的所有数都选上

        sl = sum(row[0] for row in mat)
        sr = sum(row[-1] for row in mat)
        return sl + bisect_left(range(sl, sr), True, key=check)

golang 解法, 执行用时: 4 ms, 内存消耗: 3 MB, 提交时间: 2023-05-28 10:20:25

func kthSmallest(mat [][]int, k int) int {
    sl, sr := 0, 0
    for _, row := range mat {
        sl += row[0]
        sr += row[len(row)-1]
    }
    return sl + sort.Search(sr-sl, func(s int) bool {
        leftK := k
        // 返回是否找到 k 个子数组和
        var dfs func(int, int) bool
        dfs = func(i, s int) bool {
            if i < 0 { // 能递归到这里,说明数组和满足要求
                leftK--
                return leftK == 0 // 是否找到 k 个
            }
            for _, x := range mat[i] { // 「枚举选哪个」,注意 mat[i] 是有序的
                if x-mat[i][0] > s { // 选 x 不选 mat[i][0]
                    break // 剪枝:后面的元素更大,无需枚举
                }
                if dfs(i-1, s-(x-mat[i][0])) { // 选 x 不选 mat[i][0]
                    return true // 找到 k 个就一直返回 true,不再递归
                }
            }
            return false
        }
        return dfs(len(mat)-1, s) // 这里的 s 已经把 sl 减掉了
    })
}

java 解法, 执行用时: 1 ms, 内存消耗: 40.5 MB, 提交时间: 2023-05-28 10:20:13

class Solution {
    private int leftK;

    public int kthSmallest(int[][] mat, int k) {
        int sl = 0, sr = 0;
        for (var row : mat) {
            sl += row[0];
            sr += row[row.length - 1];
        }
        // 二分模板 https://www.bilibili.com/video/BV1AP41137w7/
        int left = sl - 1, right = sr; // 开区间 (sl-1,sr)
        while (left + 1 < right) { // 开区间不为空
            // 循环不变量:
            // f(left) < k
            // f(right) >= k
            int mid = (left + right) >>> 1;
            leftK = k;
            if (dfs(mat, mat.length - 1, mid - sl)) // 先把第一列的所有数都选上
                right = mid; // 二分范围缩小至开区间 (left, mid)
            else // f(mid) < k
                left = mid; // 二分范围缩小至开区间 (mid, right)
        }
        return right;
    }

    // 返回是否找到 k 个子数组和
    private boolean dfs(int[][] mat, int i, int s) {
        if (i < 0) // 能递归到这里,说明数组和不超过二分的 mid
            return --leftK == 0; // 是否找到 k 个
        for (int x : mat[i]) { // 「枚举选哪个」,注意 mat[i] 是有序的
            if (x - mat[i][0] > s) // 选 x 不选 mat[i][0]
                break; // 剪枝:后面的元素更大,无需枚举
            if (dfs(mat, i - 1, s - (x - mat[i][0]))) // 选 x 不选 mat[i][0]
                return true; // 找到 k 个就一直返回 true,不再递归
        }
        return false;
    }
}

java 解法, 执行用时: 75 ms, 内存消耗: 42.6 MB, 提交时间: 2023-05-28 10:20:03

class Solution {
    public int kthSmallest(int[][] mat, int k) {
        var a = new int[]{0};
        for (var row : mat) {
            var b = new int[a.length * row.length];
            int i = 0;
            for (int x : a)
                for (int y : row)
                    b[i++] = x + y;
            Arrays.sort(b);
            if (b.length > k) // 保留最小的 k 个
                b = Arrays.copyOfRange(b, 0, k);
            a = b;
        }
        return a[k - 1];
    }
}

golang 解法, 执行用时: 160 ms, 内存消耗: 6.7 MB, 提交时间: 2023-05-28 10:19:39

func kthSmallest(mat [][]int, k int) int {
    a := []int{0}
    for _, row := range mat {
        b := make([]int, 0, len(a)*len(row)) // 预分配空间
        for _, x := range a {
            for _, y := range row {
                b = append(b, x+y)
            }
        }
        sort.Ints(b)
        if len(b) > k { // 保留最小的 k 个
            b = b[:k]
        }
        a = b
    }
    return a[k-1]
}

python3 解法, 执行用时: 176 ms, 内存消耗: 16.6 MB, 提交时间: 2023-05-28 10:19:06

# 暴力解法
class Solution:
    def kthSmallest(self, mat: List[List[int]], k: int) -> int:
        a = mat[0][:k]
        for row in mat[1:]:
            a = sorted(x + y for x in a for y in row)[:k]
        return a[-1]

上一题