列表

详情


2277. 树中最接近路径的节点

给定一个正整数 n,表示树中的节点数,编号从 0n - 1 (含边界)。还给定一个长度为 n - 1 的二维整数数组 edges,其中 edges[i] = [node1i, node2i] 表示有一条 双向 边连接树中的 node1inode2i

给定一个长度为 m ,下标从 0 开始 的整数数组 query,其中 query[i] = [starti, endi, nodei] 意味着对于第 i 个查询,您的任务是从 startiendi 的路径上找到 最接近 nodei 的节点。

返回长度为 m 的整数数组 answer,其中 answer[i] 是第 i 个查询的答案。

 

示例 1:

输入: n = 7, edges = [[0,1],[0,2],[0,3],[1,4],[2,5],[2,6]], query = [[5,3,4],[5,3,6]]
输出: [0,2]
解释:
节点 5 到节点 3 的路径由节点 5、2、0、3 组成。
节点 4 到节点 0 的距离为 2。
节点 0 是距离节点 4 最近的路径上的节点,因此第一个查询的答案是 0。
节点 6 到节点 2 的距离为 1。
节点 2 是距离节点 6 最近的路径上的节点,因此第二个查询的答案是 2。

示例 2:

输入: n = 3, edges = [[0,1],[1,2]], query = [[0,1,2]]
输出: [1]
解释:
从节点 0 到节点 1 的路径由节点 0,1 组成。
节点 2 到节点 1 的距离为 1。
节点 1 是距离节点 2 最近的路径上的节点,因此第一个查询的答案是 1。

示例 3:

输入: n = 3, edges = [[0,1],[1,2]], query = [[0,0,0]]
输出: [0]
解释:
节点 0 到节点 0 的路径由节点 0 组成。
因为 0 是路径上唯一的节点,所以第一个查询的答案是0。

 

提示:

原站题解

去查看

上次编辑到这里,代码来自缓存 点击恢复默认模板
class Solution { public: vector<int> closestNode(int n, vector<vector<int>>& edges, vector<vector<int>>& query) { } };

golang 解法, 执行用时: 424 ms, 内存消耗: 7.5 MB, 提交时间: 2023-10-21 19:46:36

func closestNode(n int, edges [][]int, query [][]int) []int {
	path := make([][]int, n)
	for _, edge := range edges {
		path[edge[0]] = append(path[edge[0]], edge[1])
		path[edge[1]] = append(path[edge[1]], edge[0])
	}
	ans := make([]int, len(query))
	for i := 0; i < len(query); i++ {
		s, e, node := query[i][0], query[i][1], query[i][2]
		visited := map[int]bool{s: true}
		ends := map[int]bool{}
		var dfs func(i int) bool
		dfs = func(i int) bool {
			if i == e {
				ends[i] = true
				return true
			}
			for _, j := range path[i] {
				if visited[j] {
					continue
				}
				visited[j] = true
				if dfs(j) {
					ends[i] = true
					return true
				}
			}
			return false
		}
        dfs(s)
		visited = map[int]bool{node: true}
		var list = []int{node}
	out:
		for len(list) > 0 {
			var temp []int
			for _, j := range list {
				if ends[j] {
					ans[i] = j
					break out
				}
				for _, k := range path[j] {
					if visited[k] {
						continue
					}
					visited[k] = true
					temp = append(temp, k)
				}
			}
			list = temp
		}
	}
	return ans
}

java 解法, 执行用时: 14 ms, 内存消耗: 42.8 MB, 提交时间: 2023-10-21 19:46:14

class Solution {
        public int[] closestNode(int n, int[][] edges, int[][] query) {
            //1.建图
            Set<Integer>[] graph = new Set[n];
            Arrays.setAll(graph,o->new HashSet<>());
            for(int[] edge:edges){
                int u = edge[0];
                int v = edge[1];
                graph[u].add(v);
                graph[v].add(u);
            }
            //2.删边,定深
            int[] depth = new int[n];
            int[] parent = new int[n];
            parent[0]=-1;
            Arrays.fill(depth,-1);
            depth[0]=0;
            Queue<Integer> queue = new LinkedList<>();
            queue.offer(0);
            while(!queue.isEmpty()){
                int node = queue.poll();
                for(int next:graph[node]){
                    parent[next] = node;
                    depth[next]=depth[node]+1;
                    graph[next].remove(node);
                    queue.offer(next);
                }
            }

            int m = query.length;
            int[] ans = new int[m];
            for(int i = 0; i < m; i++){
                int[] q = query[i];
                int a = q[0];
                int b = q[1];
                int c = q[2];
                
                int d = getPublicParent(a,b,parent,depth);
                int e = getPublicParent(a,c,parent,depth);
                int f = getPublicParent(b,c,parent,depth);

                int res = Math.min(depth[c]-depth[d],Math.min(depth[c]-depth[e],depth[c]-depth[f]));
                int num = depth[c]-depth[d]==res?d:(depth[c]-depth[e] == res?e:f);

                ans[i] = num;
            }
            return ans;
        }

        private int getPublicParent(int a, int b,int[] parent,int[] depth){
            while(depth[a]>depth[b]) {
                a = parent[a];
            }

            while(depth[b]>depth[a]){
                b = parent[b];
            }

            while (a!=b){
                a = parent[a];
                b = parent[b];
            }
            return a;
        }
    }

cpp 解法, 执行用时: 20 ms, 内存消耗: 14.8 MB, 提交时间: 2023-10-21 19:45:50

class Solution {
public:
    vector<int> closestNode(int n, vector<vector<int>>& edges, vector<vector<int>>& query) {
        vector<vector<int>> g(n);
        for (auto &it: edges) {
            int u = it[0], v = it[1];
            g[u].emplace_back(v);
            g[v].emplace_back(u);
        }
        vector<vector<int>> fa(11, vector<int>(n, -1));
        vector<int> dep(n, 1);
        function<void(int, int)> dfs = [&](int u, int p) {
            fa[0][u] = p;
            if (p >= 0) dep[u] = dep[p] + 1;
            for (auto v: g[u]) {
                if (v != p) {
                    dfs(v, u);
                }
            }
        };
        dfs(0, -1);

        for (int k = 0; k < 10; ++k) {
            for (int i = 0; i < n; ++i) {
                if (fa[k][i] >= 0) {
                    fa[k + 1][i] = fa[k][fa[k][i]];
                }
            }
        }

        function<int(int, int)> lca = [&](int a, int b) -> int {
            if(dep[a] > dep[b])
                swap(a, b);
            while (dep[b] > dep[a])
                b = fa[__lg(dep[b] - dep[a])][b];
            if (a == b) return a;
            for (int k = 10; k >= 0; --k) {
                if (fa[k][a] != fa[k][b]) {
                    a = fa[k][a], b = fa[k][b];
                }
            }
            return fa[0][a];
        };
        vector<int> ans(query.size());
        for (int i = 0; i < query.size(); ++i) {
            int u = query[i][0], v = query[i][1], t = query[i][2];
            int uv = lca(u, v), ut = lca(u, t), vt = lca(v, t);
            if (ut == uv) {
                ans[i] = vt;
            } else if (vt == uv) {
                ans[i] = ut;
            } else {
                ans[i] = uv;
            }
        }
        return ans;
    }
};

python3 解法, 执行用时: 2504 ms, 内存消耗: 41.4 MB, 提交时间: 2023-10-21 19:45:11

class Solution:
    def closestNode(self, n: int, edges: List[List[int]], query: List[List[int]]) -> List[int]:
        G = defaultdict(list)
        for a,b in edges:
            G[a].append(b)
            G[b].append(a)
        D = [[math.inf]*n for _ in range(n)]
        for a in range(n):
            que,D[a][a] = [a],0
            while que:
                tmp = []
                for q in que:
                    for b in G[q]:
                        if D[a][b]!=math.inf: continue
                        D[a][b]=D[a][q]+1
                        tmp.append(b)
                que = tmp
        return [min(range(n), key=lambda x: D[x][a]+D[x][b]+D[x][q]) for a,b,q in query]

python3 解法, 执行用时: 76 ms, 内存消耗: 18 MB, 提交时间: 2023-10-21 19:44:40

class Solution:
    def closestNode(self, n: int, edges: List[List[int]], query: List[List[int]]) -> List[int]:
        adjMap = defaultdict(set)
        for u, v in edges:
            adjMap[u].add(v)
            adjMap[v].add(u)
        LCA = LCAManager(n, adjMap)

        res = []  # 答案是最(靠下)深的LCA
        for root1, root2, root3 in query:
            res.append(
                max(
                    LCA.queryLCA(root1, root3),
                    LCA.queryLCA(root2, root3),
                    LCA.queryLCA(root1, root2),
                    key=lambda lca: LCA.depth[lca],
                )
            )
        return res


class LCAManager:
    def __init__(self, n: int, adjMap: DefaultDict[int, Set[int]]) -> None:
        """查询 LCA

        `nlogn` 预处理
        `logn`查询两点的LCA

        Args:
            n (int): 树节点编号 默认 0 ~ n-1 根节点为 0
            adjMap (DefaultDict[int, Set[int]]): 树
        """
        self.depth = defaultdict(lambda: -1)
        self.parent = defaultdict(lambda: -1)
        self._BITLEN = floor(log2(n)) + 1
        self._MAX = n
        self._adjMap = adjMap
        self._dfs(0, -1, 0)
        self._dp = self._initDp(self.parent)

    def queryLCA(self, root1: int, root2: int) -> int:
        """ `logn` 查询 """
        if self.depth[root1] < self.depth[root2]:
            root1, root2 = root2, root1

        for i in range(self._BITLEN - 1, -1, -1):
            if self.depth[self._dp[root1][i]] >= self.depth[root2]:
                root1 = self._dp[root1][i]

        if root1 == root2:
            return root1

        for i in range(self._BITLEN - 1, -1, -1):
            if self._dp[root1][i] != self._dp[root2][i]:
                root1 = self._dp[root1][i]
                root2 = self._dp[root2][i]

        return self._dp[root1][0]

    def _dfs(self, cur: int, pre: int, dep: int) -> None:
        """处理高度、父节点信息"""
        self.depth[cur], self.parent[cur] = dep, pre
        for next in self._adjMap[cur]:
            if next == pre:
                continue
            self._dfs(next, cur, dep + 1)

    def _initDp(self, parent: DefaultDict[int, int]) -> List[List[int]]:
        """nlogn预处理"""
        dp = [[0] * self._BITLEN for _ in range(self._MAX)]
        for i in range(self._MAX):
            dp[i][0] = parent[i]
        for j in range(self._BITLEN - 1):
            for i in range(self._MAX):
                if dp[i][j] == -1:
                    dp[i][j + 1] = -1
                else:
                    dp[i][j + 1] = dp[dp[i][j]][j]
        return dp

python3 解法, 执行用时: 348 ms, 内存消耗: 18.8 MB, 提交时间: 2023-10-21 19:44:24

class Solution:
    def closestNode(self, n: int, edges: List[List[int]], query: List[List[int]]) -> List[int]:
        def dfs(cur: int, pre: int, dep: int) -> None:
            """处理高度、父节点信息"""
            depth[cur], parent[cur] = dep, pre
            for next in adjMap[cur]:
                if next == pre:
                    continue
                dfs(next, cur, dep + 1)

        def getPath(
            root1: int, root2: int, level: DefaultDict[int, int], parent: DefaultDict[int, int]
        ) -> Set[int]:
            """求两个结点间的路径,不断上跳到LCA并记录经过的结点"""
            res = {root1, root2}
            if level[root1] < level[root2]:
                root1, root2 = root2, root1
            diff = level[root1] - level[root2]
            for _ in range(diff):
                root1 = parent[root1]
                res |= {root1}
            while root1 != root2:
                root1 = parent[root1]
                root2 = parent[root2]
                res |= {root1, root2}
            return res

        def bfs(start: int, hit: Set[int]) -> int:
            """求到目标路径的最近交点"""
            visited, queue = set([start]), deque([start])
            while queue:
                cur = queue.popleft()
                if cur in hit:
                    return cur
                for next in adjMap[cur]:
                    if next not in visited:
                        visited.add(next)
                        queue.append(next)
            raise Exception("impossible")

        adjMap = defaultdict(set)
        for u, v in edges:
            adjMap[u].add(v)
            adjMap[v].add(u)

        depth, parent = defaultdict(int), defaultdict(lambda: -1)
        dfs(0, -1, 0)
        return [bfs(root3, getPath(root1, root2, depth, parent)) for root1, root2, root3 in query]

上一题