列表

详情


100047. 统计树中的合法路径数目

给你一棵 n 个节点的无向树,节点编号为 1 到 n 。给你一个整数 n 和一个长度为 n - 1 的二维整数数组 edges ,其中 edges[i] = [ui, vi] 表示节点 ui 和 vi 在树中有一条边。

请你返回树中的 合法路径数目 。

如果在节点 a 到节点 b 之间 恰好有一个 节点的编号是质数,那么我们称路径 (a, b) 是 合法的 。

注意:

 

示例 1:

输入:n = 5, edges = [[1,2],[1,3],[2,4],[2,5]]
输出:4
解释:恰好有一个质数编号的节点路径有:
- (1, 2) 因为路径 1 到 2 只包含一个质数 2 。
- (1, 3) 因为路径 1 到 3 只包含一个质数 3 。
- (1, 4) 因为路径 1 到 4 只包含一个质数 2 。
- (2, 4) 因为路径 2 到 4 只包含一个质数 2 。
只有 4 条合法路径。

示例 2:

输入:n = 6, edges = [[1,2],[1,3],[2,4],[3,5],[3,6]]
输出:6
解释:恰好有一个质数编号的节点路径有:
- (1, 2) 因为路径 1 到 2 只包含一个质数 2 。
- (1, 3) 因为路径 1 到 3 只包含一个质数 3 。
- (1, 4) 因为路径 1 到 4 只包含一个质数 2 。
- (1, 6) 因为路径 1 到 6 只包含一个质数 3 。
- (2, 4) 因为路径 2 到 4 只包含一个质数 2 。
- (3, 6) 因为路径 3 到 6 只包含一个质数 3 。
只有 6 条合法路径。

 

提示:

原站题解

去查看

上次编辑到这里,代码来自缓存 点击恢复默认模板
class Solution { public: long long countPaths(int n, vector<vector<int>>& edges) { } };

rust 解法, 执行用时: 58 ms, 内存消耗: 13.2 MB, 提交时间: 2024-02-27 11:41:55

impl Solution {
  fn dfs(graph: &Vec<Vec<usize>>, nodes: &mut Vec<usize>, parent: usize, node: usize) -> i64 {
  	nodes.push(node);

  	let mut path = 1;
  	for child in graph[node].iter() {
  		if child != &parent {
  			path += Solution::dfs(graph, nodes, node, *child);
  		}
  	}
  	path
  }

  pub fn count_paths(n: i32, edges: Vec<Vec<i32>>) -> i64 {
  	let mut graph: Vec<Vec<usize>> = vec![vec![]; n as usize + 1];
  	let mut is_prime = vec![true; n as usize + 1];
  	let mut ways= vec![0_i64; n as usize + 1];
  	let mut nodes: Vec<usize> = vec![];
  	let (mut ans, mut sum) = (0, 1);

  	// 筛法求质数
  	is_prime[1] = false;
  	for i in 2..=(n as f32).sqrt() as usize {
  		if is_prime[i] {
  			for j in (i.pow(2)..=n as usize).step_by(i) {
  				is_prime[j] = false;
  			}
  		}
  	}

  	for edge in edges.into_iter() {
  		if !is_prime[edge[1] as usize] {
  			graph[edge[0] as usize].push(edge[1] as usize);
  		}
  		if !is_prime[edge[0] as usize] {
  			graph[edge[1] as usize].push(edge[0] as usize);
  		}
  	}

  	for i in 1..=n as usize {
  		if is_prime[i] {
  			sum = 1;
  			for child in graph[i].iter() {
  				if ways[*child] == 0 {
  					nodes.clear();
  					let temp = Solution::dfs(&graph, &mut nodes, i, *child);
  					for node in nodes.iter() {
  						ways[*node] = temp;
  					}
  				}
  				ans += sum * ways[*child];
  				sum += ways[*child];
  			}
  		}
  	}

  	ans
  }
}

python3 解法, 执行用时: 364 ms, 内存消耗: 54.3 MB, 提交时间: 2023-09-24 22:52:05

# 标记 10**5 以内的质数
MX = 10 ** 5 + 1
is_prime = [True] * MX
is_prime[1] = False
for i in range(2, isqrt(MX) + 1):
    if is_prime[i]:
        for j in range(i * i, MX, i):
            is_prime[j] = False

class Solution:
    def countPaths(self, n: int, edges: List[List[int]]) -> int:
        g = [[] for _ in range(n + 1)]
        for x, y in edges:
            g[x].append(y)
            g[y].append(x)

        def dfs(x: int, fa: int) -> None:
            nodes.append(x)
            for y in g[x]:
                if y != fa and not is_prime[y]:
                    dfs(y, x)

        ans = 0
        size = [0] * (n + 1)
        for x in range(1, n + 1):
            if not is_prime[x]:  # 跳过非质数
                continue
            s = 0
            for y in g[x]:  # 质数 x 把这棵树分成了若干个连通块
                if is_prime[y]:
                    continue
                if size[y] == 0:  # 尚未计算过
                    nodes = []
                    dfs(y, -1)  # 遍历 y 所在连通块,在不经过质数的前提下,统计有多少个非质数
                    for z in nodes:
                        size[z] = len(nodes)
                # 这 size[y] 个非质数与之前遍历到的 s 个非质数,两两之间的路径只包含质数 x
                ans += size[y] * s
                s += size[y]
            ans += s  # 从 x 出发的路径
        return ans

java 解法, 执行用时: 46 ms, 内存消耗: 76.9 MB, 提交时间: 2023-09-24 22:51:51

class Solution {
    private final static int MX = (int) 1e5;
    private final static boolean[] np = new boolean[MX + 1]; // 质数=false 非质数=true

    static {
        np[1] = true;
        for (int i = 2; i * i <= MX; i++) {
            if (!np[i]) {
                for (int j = i * i; j <= MX; j += i) {
                    np[j] = true;
                }
            }
        }
    }

    public long countPaths(int n, int[][] edges) {
        List<Integer>[] g = new ArrayList[n + 1];
        Arrays.setAll(g, e -> new ArrayList<>());
        for (var e : edges) {
            int x = e[0], y = e[1];
            g[x].add(y);
            g[y].add(x);
        }

        long ans = 0;
        int[] size = new int[n + 1];
        var nodes = new ArrayList<Integer>();
        for (int x = 1; x <= n; x++) {
            if (np[x]) { // 跳过非质数
                continue;
            }
            int sum = 0;
            for (int y : g[x]) { // 质数 x 把这棵树分成了若干个连通块
                if (!np[y]) {
                    continue;
                }
                if (size[y] == 0) { // 尚未计算过
                    nodes.clear();
                    dfs(y, -1, g, nodes); // 遍历 y 所在连通块,在不经过质数的前提下,统计有多少个非质数
                    for (int z : nodes) {
                        size[z] = nodes.size();
                    }
                }
                // 这 size[y] 个非质数与之前遍历到的 sum 个非质数,两两之间的路径只包含质数 x
                ans += (long) size[y] * sum;
                sum += size[y];
            }
            ans += sum; // 从 x 出发的路径
        }
        return ans;
    }

    private void dfs(int x, int fa, List<Integer>[] g, List<Integer> nodes) {
        nodes.add(x);
        for (int y : g[x]) {
            if (y != fa && np[y]) {
                dfs(y, x, g, nodes);
            }
        }
    }
}

cpp 解法, 执行用时: 512 ms, 内存消耗: 172.7 MB, 提交时间: 2023-09-24 22:51:35

const int MX = 1e5;
bool np[MX + 1]; // 质数=false 非质数=true
int init = []() {
    np[1] = true;
    for (int i = 2; i * i <= MX; i++) {
        if (!np[i]) {
            for (int j = i * i; j <= MX; j += i) {
                np[j] = true;
            }
        }
    }
    return 0;
}();

class Solution {
public:
    long long countPaths(int n, vector<vector<int>> &edges) {
        vector<vector<int>> g(n + 1);
        for (auto &e: edges) {
            int x = e[0], y = e[1];
            g[x].push_back(y);
            g[y].push_back(x);
        }

        vector<int> size(n + 1);
        vector<int> nodes;
        function<void(int, int)> dfs = [&](int x, int fa) {
            nodes.push_back(x);
            for (int y: g[x]) {
                if (y != fa && np[y]) {
                    dfs(y, x);
                }
            }
        };

        long long ans = 0;
        for (int x = 1; x <= n; x++) {
            if (np[x]) continue; // 跳过非质数
            int sum = 0;
            for (int y: g[x]) { // 质数 x 把这棵树分成了若干个连通块
                if (!np[y]) continue;
                if (size[y] == 0) { // 尚未计算过
                    nodes.clear();
                    dfs(y, -1); // 遍历 y 所在连通块,在不经过质数的前提下,统计有多少个非质数
                    for (int z: nodes) {
                        size[z] = nodes.size();
                    }
                }
                // 这 size[y] 个非质数与之前遍历到的 sum 个非质数,两两之间的路径只包含质数 x
                ans += (long long) size[y] * sum;
                sum += size[y];
            }
            ans += sum; // 从 x 出发的路径
        }
        return ans;
    }
};

golang 解法, 执行用时: 240 ms, 内存消耗: 25.6 MB, 提交时间: 2023-09-24 22:51:20

const mx int = 1e5 + 1
var np = [mx]bool{1: true}
func init() { // 质数=false 非质数=true
	for i := 2; i*i < mx; i++ {
		if !np[i] {
			for j := i * i; j < mx; j += i {
				np[j] = true
			}
		}
	}
}

func countPaths(n int, edges [][]int) (ans int64) {
	g := make([][]int, n+1)
	for _, e := range edges {
		x, y := e[0], e[1]
		g[x] = append(g[x], y)
		g[y] = append(g[y], x)
	}

	size := make([]int, n+1)
	var nodes []int
	var dfs func(int, int)
	dfs = func(x, fa int) {
		nodes = append(nodes, x)
		for _, y := range g[x] {
			if y != fa && np[y] {
				dfs(y, x)
			}
		}
	}
	for x := 1; x <= n; x++ {
		if np[x] { // 跳过非质数
			continue
		}
		sum := 0
		for _, y := range g[x] { // 质数 x 把这棵树分成了若干个连通块
			if !np[y] {
				continue
			}
			if size[y] == 0 { // 尚未计算过
				nodes = []int{}
				dfs(y, -1) // 遍历 y 所在连通块,在不经过质数的前提下,统计有多少个非质数
				for _, z := range nodes {
					size[z] = len(nodes)
				}
			}
			// 这 size[y] 个非质数与之前遍历到的 sum 个非质数,两两之间的路径只包含质数 x
			ans += int64(size[y]) * int64(sum)
			sum += size[y]
		}
		ans += int64(sum) // 从 x 出发的路径
	}
	return
}

上一题