1782. 统计点对的数目
给你一个无向图,无向图由整数 n
,表示图中节点的数目,和 edges
组成,其中 edges[i] = [ui, vi]
表示 ui
和 vi
之间有一条无向边。同时给你一个代表查询的整数数组 queries
。
第 j
个查询的答案是满足如下条件的点对 (a, b)
的数目:
a < b
cnt
是与 a
或者 b
相连的边的数目,且 cnt
严格大于 queries[j]
。请你返回一个数组 answers
,其中 answers.length == queries.length
且 answers[j]
是第 j
个查询的答案。
请注意,图中可能会有 重复边 。
示例 1:
输入:n = 4, edges = [[1,2],[2,4],[1,3],[2,3],[2,1]], queries = [2,3] 输出:[6,5] 解释:每个点对中,与至少一个点相连的边的数目如上图所示。
示例 2:
输入:n = 5, edges = [[1,5],[1,5],[3,4],[2,5],[1,3],[5,1],[2,3],[2,5]], queries = [1,2,3,4,5] 输出:[10,10,9,8,6]
提示:
2 <= n <= 2 * 104
1 <= edges.length <= 105
1 <= ui, vi <= n
ui != vi
1 <= queries.length <= 20
0 <= queries[j] < edges.length
原站题解
cpp 解法, 执行用时: 468 ms, 内存消耗: 169.6 MB, 提交时间: 2023-08-23 09:37:56
class Solution { public: vector<int> countPairs(int n, vector<vector<int>> &edges, vector<int> &queries) { vector<int> deg(n + 1); unordered_map<int, int> cnt_e; for (auto &e: edges) { int x = e[0], y = e[1]; if (x > y) swap(x, y); deg[x]++; deg[y]++; cnt_e[x << 16 | y]++; } // 统计 deg 中元素的出现次数 unordered_map<int, int> cnt_deg; for (int i = 1; i <= n; i++) cnt_deg[deg[i]]++; // 2) int max_deg = *max_element(deg.begin() + 1, deg.end()); int k = max_deg * 2 + 2; vector<int> cnts(k); for (auto [deg1, c1]: cnt_deg) { for (auto [deg2, c2]: cnt_deg) { if (deg1 < deg2) { cnts[deg1 + deg2] += c1 * c2; } else if (deg1 == deg2) { cnts[deg1 + deg2] += c1 * (c1 - 1) / 2; } } } // 3) for (auto [key, c]: cnt_e) { int s = deg[key >> 16] + deg[key & 0xffff]; cnts[s]--; cnts[s - c]++; } // 4) 计算 cnts 的后缀和 for (int i = k - 1; i > 0; i--) cnts[i - 1] += cnts[i]; for (int &q: queries) q = cnts[min(q + 1, k - 1)]; return queries; } };
javascript 解法, 执行用时: 332 ms, 内存消耗: 80.8 MB, 提交时间: 2023-08-23 09:37:17
/** * @param {number} n * @param {number[][]} edges * @param {number[]} queries * @return {number[]} */ var countPairs = function (n, edges, queries) { const deg = new Array(n + 1).fill(0); const cntE = new Map(); for (let [x, y] of edges) { if (x > y) [x, y] = [y, x]; deg[x]++; deg[y]++; cntE.set(x << 16 | y, (cntE.get(x << 16 | y) ?? 0) + 1); } // 统计 deg 中元素的出现次数 const cntDeg = new Map(); for (let i = 1; i <= n; i++) cntDeg.set(deg[i], (cntDeg.get(deg[i]) ?? 0) + 1); // 2) const cnts = new Array(_.max(deg) * 2 + 2).fill(0); for (const [deg1, c1] of cntDeg.entries()) { for (const [deg2, c2] of cntDeg.entries()) { if (deg1 < deg2) { cnts[deg1 + deg2] += c1 * c2; } else if (deg1 === deg2) { cnts[deg1 + deg2] += c1 * (c1 - 1) >> 1; } } } // 3) for (const [k, c] of cntE) { const s = deg[k >> 16] + deg[k & 0xffff]; cnts[s]--; cnts[s - c]++; } // 4) 计算 cnts 的后缀和 for (let i = cnts.length - 1; i > 0; i--) cnts[i - 1] += cnts[i]; for (let i = 0; i < queries.length; i++) queries[i] = cnts[Math.min(queries[i] + 1, cnts.length - 1)]; return queries; };
java 解法, 执行用时: 56 ms, 内存消耗: 104.8 MB, 提交时间: 2023-08-23 09:36:21
class Solution { public int[] countPairs(int n, int[][] edges, int[] queries) { var deg = new int[n + 1]; var cntE = new HashMap<Integer, Integer>(); for (var e : edges) { int x = e[0], y = e[1]; if (x > y) { int tmp = x; x = y; y = tmp; } deg[x]++; deg[y]++; cntE.merge(x << 16 | y, 1, Integer::sum); } // 统计 deg 中元素的出现次数 var cntDeg = new HashMap<Integer, Integer>(); int maxDeg = 0; for (int i = 1; i <= n; i++) { cntDeg.merge(deg[i], 1, Integer::sum); // cntDeg[deg[i]]++ maxDeg = Math.max(maxDeg, deg[i]); } // 2) var cnts = new int[maxDeg * 2 + 2]; for (var e1 : cntDeg.entrySet()) { int deg1 = e1.getKey(), c1 = e1.getValue(); for (var e2 : cntDeg.entrySet()) { int deg2 = e2.getKey(), c2 = e2.getValue(); if (deg1 < deg2) cnts[deg1 + deg2] += c1 * c2; else if (deg1 == deg2) cnts[deg1 + deg2] += c1 * (c1 - 1) / 2; } } // 3) for (var e : cntE.entrySet()) { int k = e.getKey(), c = e.getValue(); int s = deg[k >> 16] + deg[k & 0xffff]; cnts[s]--; cnts[s - c]++; } // 4) 计算 cnts 的后缀和 for (int i = cnts.length - 1; i > 0; i--) cnts[i - 1] += cnts[i]; for (int i = 0; i < queries.length; i++) queries[i] = cnts[Math.min(queries[i] + 1, cnts.length - 1)]; return queries; } }
golang 解法, 执行用时: 328 ms, 内存消耗: 24.8 MB, 提交时间: 2023-08-23 09:35:28
func countPairs(n int, edges [][]int, queries []int) []int { deg := make([]int, n+1) type edge struct{ x, y int } cntE := map[edge]int{} for _, e := range edges { x, y := e[0], e[1] if x > y { x, y = y, x } deg[x]++ deg[y]++ cntE[edge{x, y}]++ } // 统计 deg 中元素的出现次数 cntDeg := map[int]int{} maxDeg := 0 for _, d := range deg[1:] { cntDeg[d]++ maxDeg = max(maxDeg, d) } // 2) k := maxDeg*2 + 2 cnts := make([]int, k) for deg1, c1 := range cntDeg { for deg2, c2 := range cntDeg { if deg1 < deg2 { cnts[deg1+deg2] += c1 * c2 } else if deg1 == deg2 { cnts[deg1+deg2] += c1 * (c1 - 1) / 2 } } } // 3) for e, c := range cntE { s := deg[e.x] + deg[e.y] cnts[s]-- cnts[s-c]++ } // 4) 计算 cnts 的后缀和 for i := k - 1; i > 0; i-- { cnts[i-1] += cnts[i] } for i, q := range queries { queries[i] = cnts[min(q+1, k-1)] } return queries } func min(a, b int) int { if b < a { return b }; return a } func max(a, b int) int { if b > a { return b }; return a }
python3 解法, 执行用时: 424 ms, 内存消耗: 51.5 MB, 提交时间: 2023-08-23 09:34:51
class Solution: def countPairs(self, n: int, edges: List[List[int]], queries: List[int]) -> List[int]: deg = [0] * (n + 1) cnt_e = defaultdict(int) # 比 Counter 快一点 for x, y in edges: if x > y: x, y = y, x deg[x] += 1 deg[y] += 1 cnt_e[(x, y)] += 1 cnt_deg = Counter(deg[1:]) # 2) cnts = [0] * (max(deg) * 2 + 2) for deg1, c1 in cnt_deg.items(): for deg2, c2 in cnt_deg.items(): if deg1 < deg2: cnts[deg1 + deg2] += c1 * c2 elif deg1 == deg2: cnts[deg1 + deg2] += c1 * (c1 - 1) // 2 # 3) for (x, y), c in cnt_e.items(): s = deg[x] + deg[y] cnts[s] -= 1 cnts[s - c] += 1 # 4) 计算 cnts 的后缀和 for i in range(len(cnts) - 1, 0, -1): cnts[i - 1] += cnts[i] for i, q in enumerate(queries): queries[i] = cnts[min(q + 1, len(cnts) - 1)] return queries
javascript 解法, 执行用时: 324 ms, 内存消耗: 80.8 MB, 提交时间: 2023-08-23 09:29:09
/** * @param {number} n * @param {number[][]} edges * @param {number[]} queries * @return {number[]} */ var countPairs = function (n, edges, queries) { // deg[i] 表示与点 i 相连的边的数目 const deg = new Array(n + 1).fill(0); // 节点编号从 1 到 n const cntE = new Map(); for (let [x, y] of edges) { if (x > y) [x, y] = [y, x]; // 注意 1-2 和 2-1 算同一条边 deg[x]++; deg[y]++; // 统计每条边的出现次数 cntE.set(x << 16 | y, (cntE.get(x << 16 | y) ?? 0) + 1); } const ans = new Array(queries.length).fill(0); const sortedDeg = deg.slice().sort((a, b) => a - b); // 排序,为了双指针 for (let j = 0; j < queries.length; j++) { const q = queries[j]; let left = 1, right = n; // 相向双指针 while (left < right) { if (sortedDeg[left] + sortedDeg[right] <= q) { left++; } else { ans[j] += right - left; right--; } } for (const [k, c] of cntE.entries()) { const s = deg[k >> 16] + deg[k & 0xffff]; // 取出 k 的高 16 位和低 16 位 if (s > q && s - c <= q) { ans[j]--; } } } return ans; };
golang 解法, 执行用时: 384 ms, 内存消耗: 21.4 MB, 提交时间: 2023-08-23 09:28:45
func countPairs(n int, edges [][]int, queries []int) []int { // deg[i] 表示与点 i 相连的边的数目 deg := make([]int, n+1) // 节点编号从 1 到 n type edge struct{ x, y int } cntE := map[edge]int{} for _, e := range edges { x, y := e[0], e[1] if x > y { x, y = y, x } deg[x]++ deg[y]++ // 统计每条边的出现次数,注意 1-2 和 2-1 算同一条边 cntE[edge{x, y}]++ } ans := make([]int, len(queries)) sortedDeg := append([]int(nil), deg...) sort.Ints(sortedDeg) // 排序,为了双指针 for j, q := range queries { left, right := 1, n // 相向双指针 for left < right { if sortedDeg[left]+sortedDeg[right] <= q { left++ } else { ans[j] += right - left right-- } } for e, c := range cntE { s := deg[e.x] + deg[e.y] if s > q && s-c <= q { ans[j]-- } } } return ans }
java 解法, 执行用时: 193 ms, 内存消耗: 104.9 MB, 提交时间: 2023-08-23 09:28:16
class Solution { public int[] countPairs(int n, int[][] edges, int[] queries) { // deg[i] 表示与点 i 相连的边的数目 var deg = new int[n + 1]; // 节点编号从 1 到 n var cntE = new HashMap<Integer, Integer>(); for (var e : edges) { int x = e[0], y = e[1]; if (x > y) { // 交换 x 和 y,因为 1-2 和 2-1 算同一条边 int tmp = x; x = y; y = tmp; } deg[x]++; deg[y]++; // 统计每条边的出现次数 // 用一个 int 存储两个不超过 65535 的数 cntE.merge(x << 16 | y, 1, Integer::sum); // cntE[x<<16|y]++ } var ans = new int[queries.length]; var sortedDeg = deg.clone(); Arrays.sort(sortedDeg); // 排序,为了双指针 for (int j = 0; j < queries.length; j++) { int q = queries[j]; int left = 1, right = n; // 相向双指针 while (left < right) { if (sortedDeg[left] + sortedDeg[right] <= q) { left++; } else { ans[j] += right - left; right--; } } for (var e : cntE.entrySet()) { int k = e.getKey(), c = e.getValue(); int s = deg[k >> 16] + deg[k & 0xffff]; // 取出 k 的高 16 位和低 16 位 if (s > q && s - c <= q) { ans[j]--; } } } return ans; } }
python3 解法, 执行用时: 940 ms, 内存消耗: 51.4 MB, 提交时间: 2023-08-23 09:27:53
''' 点对的degree相加大于cnt, 把degree排序后,用相向双指针。 ''' class Solution: def countPairs(self, n: int, edges: List[List[int]], queries: List[int]) -> List[int]: # deg[i] 表示与点 i 相连的边的数目 deg = [0] * (n + 1) # 节点编号从 1 到 n for x, y in edges: deg[x] += 1 deg[y] += 1 # 统计每条边的出现次数,注意 1-2 和 2-1 算同一条边 cnt_e = Counter(tuple(sorted(e)) for e in edges) ans = [0] * len(queries) sorted_deg = sorted(deg) # 排序,为了双指针 for j, q in enumerate(queries): left, right = 1, n # 相向双指针 while left < right: if sorted_deg[left] + sorted_deg[right] <= q: left += 1 else: ans[j] += right - left right -= 1 # 去重 for (x, y), c in cnt_e.items(): if q < deg[x] + deg[y] <= q + c: ans[j] -= 1 return ans