列表

详情


834. 树中距离之和

给定一个无向、连通的树。树中有 n 个标记为 0...n-1 的节点以及 n-1 条边 。

给定整数 n 和数组 edges , edges[i] = [ai, bi]表示树中的节点 ai 和 bi 之间有一条边。

返回长度为 n 的数组 answer ,其中 answer[i] 是树中第 i 个节点与所有其他节点之间的距离之和。

 

示例 1:

输入: n = 6, edges = [[0,1],[0,2],[2,3],[2,4],[2,5]]
输出: [8,12,6,10,10,10]
解释: 树如图所示。
我们可以计算出 dist(0,1) + dist(0,2) + dist(0,3) + dist(0,4) + dist(0,5) 
也就是 1 + 1 + 2 + 2 + 2 = 8。 因此,answer[0] = 8,以此类推。

示例 2:

输入: n = 1, edges = []
输出: [0]

示例 3:

输入: n = 2, edges = [[1,0]]
输出: [1,1]

 

提示:

相似题目

在二叉树中分配硬币

原站题解

去查看

上次编辑到这里,代码来自缓存 点击恢复默认模板
class Solution { public: vector<int> sumOfDistancesInTree(int n, vector<vector<int>>& edges) { } };

cpp 解法, 执行用时: 296 ms, 内存消耗: 104.3 MB, 提交时间: 2023-07-16 16:45:02

class Solution {
public:
    vector<int> sumOfDistancesInTree(int n, vector<vector<int>> &edges) {
        vector<vector<int>> g(n); // g[x] 表示 x 的所有邻居
        for (auto &e: edges) {
            int x = e[0], y = e[1];
            g[x].push_back(y);
            g[y].push_back(x);
        }

        vector<int> ans(n);
        vector<int> size(n, 1); // 注意这里初始化成 1 了,下面只需要累加儿子的子树大小
        function<void(int, int, int)> dfs = [&](int x, int fa, int depth) {
            ans[0] += depth; // depth 为 0 到 x 的距离
            for (int y: g[x]) { // 遍历 x 的邻居 y
                if (y != fa) { // 避免访问父节点
                    dfs(y, x, depth + 1); // x 是 y 的父节点
                    size[x] += size[y]; // 累加 x 的儿子 y 的子树大小
                }
            }
        };
        dfs(0, -1, 0); // 0 没有父节点

        function<void(int, int)> reroot = [&](int x, int fa) {
            for (int y: g[x]) { // 遍历 x 的邻居 y
                if (y != fa) { // 避免访问父节点
                    ans[y] = ans[x] + n - 2 * size[y];
                    reroot(y, x); // x 是 y 的父节点
                }
            }
        };
        reroot(0, -1); // 0 没有父节点
        return ans;
    }
};

javascript 解法, 执行用时: 272 ms, 内存消耗: 74.5 MB, 提交时间: 2023-07-16 16:44:24

/**
 * @param {number} n
 * @param {number[][]} edges
 * @return {number[]}
 */
var sumOfDistancesInTree = function (n, edges) {
    let g = Array(n).fill(null).map(() => []); // g[x] 表示 x 的所有邻居
    for (const [x, y] of edges) {
        g[x].push(y);
        g[y].push(x);
    }

    let ans = Array(n).fill(0);
    let size = Array(n).fill(1); // 注意这里初始化成 1 了,下面只需要累加儿子的子树大小
    function dfs(x, fa, depth) {
        ans[0] += depth; // depth 为 0 到 x 的距离
        for (const y of g[x]) { // 遍历 x 的邻居 y
            if (y !== fa) { // 避免访问父节点
                dfs(y, x, depth + 1); // x 是 y 的父节点
                size[x] += size[y]; // 累加 x 的儿子 y 的子树大小
            }
        }
    }
    dfs(0, -1, 0); // 0 没有父节点

    function reroot(x, fa) {
        for (const y of g[x]) { // 遍历 x 的邻居 y
            if (y !== fa) { // 避免访问父节点
                ans[y] = ans[x] + n - 2 * size[y];
                reroot(y, x); // x 是 y 的父节点
            }
        }
    }
    reroot(0, -1); // 0 没有父节点
    return ans;
};

golang 解法, 执行用时: 132 ms, 内存消耗: 15.3 MB, 提交时间: 2023-07-16 16:44:08

func sumOfDistancesInTree(n int, edges [][]int) []int {
    g := make([][]int, n) // g[x] 表示 x 的所有邻居
    for _, e := range edges {
        x, y := e[0], e[1]
        g[x] = append(g[x], y)
        g[y] = append(g[y], x)
    }

    ans := make([]int, n)
    size := make([]int, n)
    var dfs func(int, int, int)
    dfs = func(x, fa, depth int) {
        ans[0] += depth // depth 为 0 到 x 的距离
        size[x] = 1
        for _, y := range g[x] { // 遍历 x 的邻居 y
            if y != fa { // 避免访问父节点
                dfs(y, x, depth+1) // x 是 y 的父节点
                size[x] += size[y] // 累加 x 的儿子 y 的子树大小
            }
        }
    }
    dfs(0, -1, 0) // 0 没有父节点

    var reroot func(int, int)
    reroot = func(x, fa int) {
        for _, y := range g[x] { // 遍历 x 的邻居 y
            if y != fa { // 避免访问父节点
                ans[y] = ans[x] + n - 2*size[y]
                reroot(y, x) // x 是 y 的父节点
            }
        }
    }
    reroot(0, -1) // 0 没有父节点
    return ans
}

java 解法, 执行用时: 36 ms, 内存消耗: 59 MB, 提交时间: 2023-07-16 16:43:54

class Solution {
    private List<Integer>[] g;
    private int[] ans, size;

    public int[] sumOfDistancesInTree(int n, int[][] edges) {
        g = new ArrayList[n]; // g[x] 表示 x 的所有邻居
        Arrays.setAll(g, e -> new ArrayList<>());
        for (var e : edges) {
            int x = e[0], y = e[1];
            g[x].add(y);
            g[y].add(x);
        }
        ans = new int[n];
        size = new int[n];
        dfs(0, -1, 0); // 0 没有父节点
        reroot(0, -1); // 0 没有父节点
        return ans;
    }

    private void dfs(int x, int fa, int depth) {
        ans[0] += depth; // depth 为 0 到 x 的距离
        size[x] = 1;
        for (int y : g[x]) { // 遍历 x 的邻居 y
            if (y != fa) { // 避免访问父节点
                dfs(y, x, depth + 1); // x 是 y 的父节点
                size[x] += size[y]; // 累加 x 的儿子 y 的子树大小
            }
        }
    }

    private void reroot(int x, int fa) {
        for (int y : g[x]) { // 遍历 x 的邻居 y
            if (y != fa) { // 避免访问父节点
                ans[y] = ans[x] + g.length - 2 * size[y];
                reroot(y, x); // x 是 y 的父节点
            }
        }
    }
}

python3 解法, 执行用时: 264 ms, 内存消耗: 63 MB, 提交时间: 2023-07-16 16:42:46

'''
换根dp
从0出发dfs,累加0到每个节点的距离,得到ans[0]
dfs的同时,计算每棵子树的大小size[i].
然后从0出发再次dfs,设y是x的儿子,那么:
ans[y] = ans[x] + n -2*size[y]
利用该公式可以自顶向下递推得到每个ans[i]
'''
class Solution:
    def sumOfDistancesInTree(self, n: int, edges: List[List[int]]) -> List[int]:
        g = [[] for _ in range(n)]  # g[x] 表示 x 的所有邻居
        for x, y in edges:
            g[x].append(y)
            g[y].append(x)

        ans = [0] * n
        size = [1] * n  # 注意这里初始化成 1 了,下面只需要累加儿子的子树大小
        def dfs(x: int, fa: int, depth: int) -> None:
            ans[0] += depth  # depth 为 0 到 x 的距离
            for y in g[x]:  # 遍历 x 的邻居 y
                if y != fa:  # 避免访问父节点
                    dfs(y, x, depth + 1)  # x 是 y 的父节点
                    size[x] += size[y]  # 累加 x 的儿子 y 的子树大小
        dfs(0, -1, 0)  # 0 没有父节点

        def reroot(x: int, fa: int) -> None:
            for y in g[x]:  # 遍历 x 的邻居 y
                if y != fa:  # 避免访问父节点
                    ans[y] = ans[x] + n - 2 * size[y]
                    reroot(y, x)  # x 是 y 的父节点
        reroot(0, -1)  # 0 没有父节点
        return ans

上一题