列表

详情


1782. 统计点对的数目

给你一个无向图,无向图由整数 n  ,表示图中节点的数目,和 edges 组成,其中 edges[i] = [ui, vi] 表示 ui 和 vi 之间有一条无向边。同时给你一个代表查询的整数数组 queries 。

j 个查询的答案是满足如下条件的点对 (a, b) 的数目:

请你返回一个数组 answers ,其中 answers.length == queries.length 且 answers[j] 是第 j 个查询的答案。

请注意,图中可能会有 重复边 。

 

示例 1:

输入:n = 4, edges = [[1,2],[2,4],[1,3],[2,3],[2,1]], queries = [2,3]
输出:[6,5]
解释:每个点对中,与至少一个点相连的边的数目如上图所示。

示例 2:

输入:n = 5, edges = [[1,5],[1,5],[3,4],[2,5],[1,3],[5,1],[2,3],[2,5]], queries = [1,2,3,4,5]
输出:[10,10,9,8,6]

 

提示:

原站题解

去查看

上次编辑到这里,代码来自缓存 点击恢复默认模板
class Solution { public: vector<int> countPairs(int n, vector<vector<int>>& edges, vector<int>& queries) { } };

cpp 解法, 执行用时: 468 ms, 内存消耗: 169.6 MB, 提交时间: 2023-08-23 09:37:56

class Solution {
public:
    vector<int> countPairs(int n, vector<vector<int>> &edges, vector<int> &queries) {
        vector<int> deg(n + 1);
        unordered_map<int, int> cnt_e;
        for (auto &e: edges) {
            int x = e[0], y = e[1];
            if (x > y) swap(x, y);
            deg[x]++;
            deg[y]++;
            cnt_e[x << 16 | y]++;
        }

        // 统计 deg 中元素的出现次数
        unordered_map<int, int> cnt_deg;
        for (int i = 1; i <= n; i++)
            cnt_deg[deg[i]]++;

        // 2)
        int max_deg = *max_element(deg.begin() + 1, deg.end());
        int k = max_deg * 2 + 2;
        vector<int> cnts(k);
        for (auto [deg1, c1]: cnt_deg) {
            for (auto [deg2, c2]: cnt_deg) {
                if (deg1 < deg2) {
                    cnts[deg1 + deg2] += c1 * c2;
                } else if (deg1 == deg2) {
                    cnts[deg1 + deg2] += c1 * (c1 - 1) / 2;
                }
            }
        }

        // 3)
        for (auto [key, c]: cnt_e) {
            int s = deg[key >> 16] + deg[key & 0xffff];
            cnts[s]--;
            cnts[s - c]++;
        }

        // 4) 计算 cnts 的后缀和
        for (int i = k - 1; i > 0; i--)
            cnts[i - 1] += cnts[i];

        for (int &q: queries)
            q = cnts[min(q + 1, k - 1)];
        return queries;
    }
};

javascript 解法, 执行用时: 332 ms, 内存消耗: 80.8 MB, 提交时间: 2023-08-23 09:37:17

/**
 * @param {number} n
 * @param {number[][]} edges
 * @param {number[]} queries
 * @return {number[]}
 */
var countPairs = function (n, edges, queries) {
    const deg = new Array(n + 1).fill(0);
    const cntE = new Map();
    for (let [x, y] of edges) {
        if (x > y) [x, y] = [y, x];
        deg[x]++;
        deg[y]++;
        cntE.set(x << 16 | y, (cntE.get(x << 16 | y) ?? 0) + 1);
    }

    // 统计 deg 中元素的出现次数
    const cntDeg = new Map();
    for (let i = 1; i <= n; i++)
        cntDeg.set(deg[i], (cntDeg.get(deg[i]) ?? 0) + 1);

    // 2)
    const cnts = new Array(_.max(deg) * 2 + 2).fill(0);
    for (const [deg1, c1] of cntDeg.entries()) {
        for (const [deg2, c2] of cntDeg.entries()) {
            if (deg1 < deg2) {
                cnts[deg1 + deg2] += c1 * c2;
            } else if (deg1 === deg2) {
                cnts[deg1 + deg2] += c1 * (c1 - 1) >> 1;
            }
        }
    }

    // 3)
    for (const [k, c] of cntE) {
        const s = deg[k >> 16] + deg[k & 0xffff];
        cnts[s]--;
        cnts[s - c]++;
    }

    // 4) 计算 cnts 的后缀和
    for (let i = cnts.length - 1; i > 0; i--)
        cnts[i - 1] += cnts[i];

    for (let i = 0; i < queries.length; i++)
        queries[i] = cnts[Math.min(queries[i] + 1, cnts.length - 1)];
    return queries;
};

java 解法, 执行用时: 56 ms, 内存消耗: 104.8 MB, 提交时间: 2023-08-23 09:36:21

class Solution {
    public int[] countPairs(int n, int[][] edges, int[] queries) {
        var deg = new int[n + 1];
        var cntE = new HashMap<Integer, Integer>();
        for (var e : edges) {
            int x = e[0], y = e[1];
            if (x > y) {
                int tmp = x;
                x = y;
                y = tmp;
            }
            deg[x]++;
            deg[y]++;
            cntE.merge(x << 16 | y, 1, Integer::sum);
        }

        // 统计 deg 中元素的出现次数
        var cntDeg = new HashMap<Integer, Integer>();
        int maxDeg = 0;
        for (int i = 1; i <= n; i++) {
            cntDeg.merge(deg[i], 1, Integer::sum); // cntDeg[deg[i]]++
            maxDeg = Math.max(maxDeg, deg[i]);
        }

        // 2)
        var cnts = new int[maxDeg * 2 + 2];
        for (var e1 : cntDeg.entrySet()) {
            int deg1 = e1.getKey(), c1 = e1.getValue();
            for (var e2 : cntDeg.entrySet()) {
                int deg2 = e2.getKey(), c2 = e2.getValue();
                if (deg1 < deg2)
                    cnts[deg1 + deg2] += c1 * c2;
                else if (deg1 == deg2)
                    cnts[deg1 + deg2] += c1 * (c1 - 1) / 2;
            }
        }

        // 3)
        for (var e : cntE.entrySet()) {
            int k = e.getKey(), c = e.getValue();
            int s = deg[k >> 16] + deg[k & 0xffff];
            cnts[s]--;
            cnts[s - c]++;
        }

        // 4) 计算 cnts 的后缀和
        for (int i = cnts.length - 1; i > 0; i--)
            cnts[i - 1] += cnts[i];

        for (int i = 0; i < queries.length; i++)
            queries[i] = cnts[Math.min(queries[i] + 1, cnts.length - 1)];
        return queries;
    }
}

golang 解法, 执行用时: 328 ms, 内存消耗: 24.8 MB, 提交时间: 2023-08-23 09:35:28

func countPairs(n int, edges [][]int, queries []int) []int {
    deg := make([]int, n+1)
    type edge struct{ x, y int }
    cntE := map[edge]int{}
    for _, e := range edges {
        x, y := e[0], e[1]
        if x > y {
            x, y = y, x
        }
        deg[x]++
        deg[y]++
        cntE[edge{x, y}]++
    }

    // 统计 deg 中元素的出现次数
    cntDeg := map[int]int{}
    maxDeg := 0
    for _, d := range deg[1:] {
        cntDeg[d]++
        maxDeg = max(maxDeg, d)
    }

    // 2)
    k := maxDeg*2 + 2
    cnts := make([]int, k)
    for deg1, c1 := range cntDeg {
        for deg2, c2 := range cntDeg {
            if deg1 < deg2 {
                cnts[deg1+deg2] += c1 * c2
            } else if deg1 == deg2 {
                cnts[deg1+deg2] += c1 * (c1 - 1) / 2
            }
        }
    }

    // 3)
    for e, c := range cntE {
        s := deg[e.x] + deg[e.y]
        cnts[s]--
        cnts[s-c]++
    }

    // 4) 计算 cnts 的后缀和
    for i := k - 1; i > 0; i-- {
        cnts[i-1] += cnts[i]
    }

    for i, q := range queries {
        queries[i] = cnts[min(q+1, k-1)]
    }
    return queries
}

func min(a, b int) int { if b < a { return b }; return a }
func max(a, b int) int { if b > a { return b }; return a }

python3 解法, 执行用时: 424 ms, 内存消耗: 51.5 MB, 提交时间: 2023-08-23 09:34:51

class Solution:
    def countPairs(self, n: int, edges: List[List[int]], queries: List[int]) -> List[int]:
        deg = [0] * (n + 1)
        cnt_e = defaultdict(int)  # 比 Counter 快一点
        for x, y in edges:
            if x > y: x, y = y, x
            deg[x] += 1
            deg[y] += 1
            cnt_e[(x, y)] += 1
        cnt_deg = Counter(deg[1:])

        # 2)
        cnts = [0] * (max(deg) * 2 + 2)
        for deg1, c1 in cnt_deg.items():
            for deg2, c2 in cnt_deg.items():
                if deg1 < deg2:
                    cnts[deg1 + deg2] += c1 * c2
                elif deg1 == deg2:
                    cnts[deg1 + deg2] += c1 * (c1 - 1) // 2

        # 3)
        for (x, y), c in cnt_e.items():
            s = deg[x] + deg[y]
            cnts[s] -= 1
            cnts[s - c] += 1

        # 4) 计算 cnts 的后缀和
        for i in range(len(cnts) - 1, 0, -1):
            cnts[i - 1] += cnts[i]

        for i, q in enumerate(queries):
            queries[i] = cnts[min(q + 1, len(cnts) - 1)]
        return queries

javascript 解法, 执行用时: 324 ms, 内存消耗: 80.8 MB, 提交时间: 2023-08-23 09:29:09

/**
 * @param {number} n
 * @param {number[][]} edges
 * @param {number[]} queries
 * @return {number[]}
 */
var countPairs = function (n, edges, queries) {
    // deg[i] 表示与点 i 相连的边的数目
    const deg = new Array(n + 1).fill(0); // 节点编号从 1 到 n
    const cntE = new Map();
    for (let [x, y] of edges) {
        if (x > y) [x, y] = [y, x]; // 注意 1-2 和 2-1 算同一条边
        deg[x]++;
        deg[y]++;
        // 统计每条边的出现次数
        cntE.set(x << 16 | y, (cntE.get(x << 16 | y) ?? 0) + 1);
    }

    const ans = new Array(queries.length).fill(0);
    const sortedDeg = deg.slice().sort((a, b) => a - b); // 排序,为了双指针
    for (let j = 0; j < queries.length; j++) {
        const q = queries[j];
        let left = 1, right = n; // 相向双指针
        while (left < right) {
            if (sortedDeg[left] + sortedDeg[right] <= q) {
                left++;
            } else {
                ans[j] += right - left;
                right--;
            }
        }
        for (const [k, c] of cntE.entries()) {
            const s = deg[k >> 16] + deg[k & 0xffff]; // 取出 k 的高 16 位和低 16 位
            if (s > q && s - c <= q) {
                ans[j]--;
            }
        }
    }
    return ans;
};

golang 解法, 执行用时: 384 ms, 内存消耗: 21.4 MB, 提交时间: 2023-08-23 09:28:45

func countPairs(n int, edges [][]int, queries []int) []int {
    // deg[i] 表示与点 i 相连的边的数目
    deg := make([]int, n+1) // 节点编号从 1 到 n
    type edge struct{ x, y int }
    cntE := map[edge]int{}
    for _, e := range edges {
        x, y := e[0], e[1]
        if x > y {
            x, y = y, x
        }
        deg[x]++
        deg[y]++
        // 统计每条边的出现次数,注意 1-2 和 2-1 算同一条边
        cntE[edge{x, y}]++
    }

    ans := make([]int, len(queries))
    sortedDeg := append([]int(nil), deg...)
    sort.Ints(sortedDeg) // 排序,为了双指针
    for j, q := range queries {
        left, right := 1, n // 相向双指针
        for left < right {
            if sortedDeg[left]+sortedDeg[right] <= q {
                left++
            } else {
                ans[j] += right - left
                right--
            }
        }
        for e, c := range cntE {
            s := deg[e.x] + deg[e.y]
            if s > q && s-c <= q {
                ans[j]--
            }
        }
    }
    return ans
}

java 解法, 执行用时: 193 ms, 内存消耗: 104.9 MB, 提交时间: 2023-08-23 09:28:16

class Solution {
    public int[] countPairs(int n, int[][] edges, int[] queries) {
        // deg[i] 表示与点 i 相连的边的数目
        var deg = new int[n + 1]; // 节点编号从 1 到 n
        var cntE = new HashMap<Integer, Integer>();
        for (var e : edges) {
            int x = e[0], y = e[1];
            if (x > y) {
                // 交换 x 和 y,因为 1-2 和 2-1 算同一条边
                int tmp = x;
                x = y;
                y = tmp;
            }
            deg[x]++;
            deg[y]++;
            // 统计每条边的出现次数
            // 用一个 int 存储两个不超过 65535 的数
            cntE.merge(x << 16 | y, 1, Integer::sum); // cntE[x<<16|y]++
        }

        var ans = new int[queries.length];
        var sortedDeg = deg.clone();
        Arrays.sort(sortedDeg); // 排序,为了双指针
        for (int j = 0; j < queries.length; j++) {
            int q = queries[j];
            int left = 1, right = n; // 相向双指针
            while (left < right) {
                if (sortedDeg[left] + sortedDeg[right] <= q) {
                    left++;
                } else {
                    ans[j] += right - left;
                    right--;
                }
            }
            for (var e : cntE.entrySet()) {
                int k = e.getKey(), c = e.getValue();
                int s = deg[k >> 16] + deg[k & 0xffff]; // 取出 k 的高 16 位和低 16 位
                if (s > q && s - c <= q) {
                    ans[j]--;
                }
            }
        }
        return ans;
    }
}

python3 解法, 执行用时: 940 ms, 内存消耗: 51.4 MB, 提交时间: 2023-08-23 09:27:53

'''
点对的degree相加大于cnt, 把degree排序后,用相向双指针。
'''
class Solution:
    def countPairs(self, n: int, edges: List[List[int]], queries: List[int]) -> List[int]:
        # deg[i] 表示与点 i 相连的边的数目
        deg = [0] * (n + 1)  # 节点编号从 1 到 n
        for x, y in edges:
            deg[x] += 1
            deg[y] += 1
        # 统计每条边的出现次数,注意 1-2 和 2-1 算同一条边
        cnt_e = Counter(tuple(sorted(e)) for e in edges)

        ans = [0] * len(queries)
        sorted_deg = sorted(deg)  # 排序,为了双指针
        for j, q in enumerate(queries):
            left, right = 1, n  # 相向双指针
            while left < right:
                if sorted_deg[left] + sorted_deg[right] <= q:
                    left += 1
                else:
                    ans[j] += right - left
                    right -= 1
            # 去重
            for (x, y), c in cnt_e.items():
                if q < deg[x] + deg[y] <= q + c:
                    ans[j] -= 1
        return ans

上一题