class Solution {
public:
int numberOfGoodPaths(vector<int>& vals, vector<vector<int>>& edges) {
}
};
2421. 好路径的数目
给你一棵 n
个节点的树(连通无向无环的图),节点编号从 0
到 n - 1
且恰好有 n - 1
条边。
给你一个长度为 n
下标从 0 开始的整数数组 vals
,分别表示每个节点的值。同时给你一个二维整数数组 edges
,其中 edges[i] = [ai, bi]
表示节点 ai
和 bi
之间有一条 无向 边。
一条 好路径 需要满足以下条件:
请你返回不同好路径的数目。
注意,一条路径和它反向的路径算作 同一 路径。比方说, 0 -> 1
与 1 -> 0
视为同一条路径。单个节点也视为一条合法路径。
示例 1:
输入:vals = [1,3,2,1,3], edges = [[0,1],[0,2],[2,3],[2,4]] 输出:6 解释:总共有 5 条单个节点的好路径。 还有 1 条好路径:1 -> 0 -> 2 -> 4 。 (反方向的路径 4 -> 2 -> 0 -> 1 视为跟 1 -> 0 -> 2 -> 4 一样的路径) 注意 0 -> 2 -> 3 不是一条好路径,因为 vals[2] > vals[0] 。
示例 2:
输入:vals = [1,1,2,2,3], edges = [[0,1],[1,2],[2,3],[2,4]] 输出:7 解释:总共有 5 条单个节点的好路径。 还有 2 条好路径:0 -> 1 和 2 -> 3 。
示例 3:
输入:vals = [1], edges = [] 输出:1 解释:这棵树只有一个节点,所以只有一条好路径。
提示:
n == vals.length
1 <= n <= 3 * 104
0 <= vals[i] <= 105
edges.length == n - 1
edges[i].length == 2
0 <= ai, bi < n
ai != bi
edges
表示一棵合法的树。原站题解
golang 解法, 执行用时: 248 ms, 内存消耗: 12.1 MB, 提交时间: 2023-09-14 00:43:36
func numberOfGoodPaths(vals []int, edges [][]int) int { n := len(vals) g := make([][]int, n) for _, e := range edges { x, y := e[0], e[1] g[x] = append(g[x], y) g[y] = append(g[y], x) // 建图 } // 并查集模板 fa := make([]int, n) // size[x] 表示节点值等于 vals[x] 的节点个数, // 如果按照节点值从小到大合并,size[x] 也是连通块内的等于最大节点值的节点个数 size := make([]int, n) id := make([]int, n) // 后面排序用 for i := range fa { fa[i] = i size[i] = 1 id[i] = i } var find func(int) int find = func(x int) int { if fa[x] != x { fa[x] = find(fa[x]) } return fa[x] } ans := n // 单个节点的好路径 sort.Slice(id, func(i, j int) bool { return vals[id[i]] < vals[id[j]] }) for _, x := range id { vx := vals[x] fx := find(x) for _, y := range g[x] { y = find(y) if y == fx || vals[y] > vx { continue // 只考虑最大节点值不超过 vx 的连通块 } if vals[y] == vx { // 可以构成好路径 ans += size[fx] * size[y] // 乘法原理 size[fx] += size[y] // 统计连通块内节点值等于 vx 的节点个数 } fa[y] = fx // 把小的节点值合并到大的节点值上 } } return ans }
cpp 解法, 执行用时: 528 ms, 内存消耗: 160.2 MB, 提交时间: 2023-09-14 00:43:20
class Solution { public: int numberOfGoodPaths(vector<int> &vals, vector<vector<int>> &edges) { int n = vals.size(); vector<vector<int>> g(n); for (auto &e : edges) { int x = e[0], y = e[1]; g[x].push_back(y); g[y].push_back(x); // 建图 } // 并查集模板 // size[x] 表示节点值等于 vals[x] 的节点个数, // 如果按照节点值从小到大合并,size[x] 也是连通块内的等于最大节点值的节点个数 int id[n], fa[n], size[n]; // id 后面排序用 iota(id, id + n, 0); iota(fa, fa + n, 0); fill(size, size + n, 1); function<int(int)> find = [&](int x) -> int { return fa[x] == x ? x : fa[x] = find(fa[x]); }; int ans = n; // 单个节点的好路径 sort(id, id + n, [&](int i, int j) { return vals[i] < vals[j]; }); for (int x : id) { int vx = vals[x], fx = find(x); for (int y : g[x]) { y = find(y); if (y == fx || vals[y] > vx) continue; // 只考虑最大节点值不超过 vx 的连通块 if (vals[y] == vx) { // 可以构成好路径 ans += size[fx] * size[y]; // 乘法原理 size[fx] += size[y]; // 统计连通块内节点值等于 vx 的节点个数 } fa[y] = fx; // 把小的节点值合并到大的节点值上 } } return ans; } };
java 解法, 执行用时: 48 ms, 内存消耗: 59.4 MB, 提交时间: 2023-09-14 00:43:04
class Solution { public int numberOfGoodPaths(int[] vals, int[][] edges) { int n = vals.length; List<Integer>[] g = new ArrayList[n]; Arrays.setAll(g, e -> new ArrayList<>()); for (var e : edges) { int x = e[0], y = e[1]; g[x].add(y); g[y].add(x); // 建图 } fa = new int[n]; var id = new Integer[n]; for (int i = 0; i < n; i++) fa[i] = id[i] = i; Arrays.sort(id, (i, j) -> vals[i] - vals[j]); // size[x] 表示节点值等于 vals[x] 的节点个数, // 如果按照节点值从小到大合并,size[x] 也是连通块内的等于最大节点值的节点个数 var size = new int[n]; Arrays.fill(size, 1); int ans = n; // 单个节点的好路径 for (var x : id) { int vx = vals[x], fx = find(x); for (var y : g[x]) { y = find(y); if (y == fx || vals[y] > vx) continue; // 只考虑最大节点值不超过 vx 的连通块 if (vals[y] == vx) { // 可以构成好路径 ans += size[fx] * size[y]; // 乘法原理 size[fx] += size[y]; // 统计连通块内节点值等于 vx 的节点个数 } fa[y] = fx; // 把小的节点值合并到大的节点值上 } } return ans; } private int[] fa; private int find(int x) { if (fa[x] != x) fa[x] = find(fa[x]); return fa[x]; } }
python3 解法, 执行用时: 416 ms, 内存消耗: 54.3 MB, 提交时间: 2023-09-14 00:42:51
class Solution: def numberOfGoodPaths(self, vals: List[int], edges: List[List[int]]) -> int: n = len(vals) g = [[] for _ in range(n)] for x, y in edges: g[x].append(y) g[y].append(x) # 建图 # 并查集模板 fa = list(range(n)) # size[x] 表示节点值等于 vals[x] 的节点个数, # 如果按照节点值从小到大合并,size[x] 也是连通块内的等于最大节点值的节点个数 size = [1] * n def find(x: int) -> int: if fa[x] != x: fa[x] = find(fa[x]) return fa[x] ans = n # 单个节点的好路径 for vx, x in sorted(zip(vals, range(n))): fx = find(x) for y in g[x]: y = find(y) if y == fx or vals[y] > vx: continue # 只考虑最大节点值不超过 vx 的连通块 if vals[y] == vx: # 可以构成好路径 ans += size[fx] * size[y] # 乘法原理 size[fx] += size[y] # 统计连通块内节点值等于 vx 的节点个数 fa[y] = fx # 把小的节点值合并到大的节点值上 return ans