列表

详情


1617. 统计子树中城市之间最大距离

给你 n 个城市,编号为从 1 到 n 。同时给你一个大小为 n-1 的数组 edges ,其中 edges[i] = [ui, vi] 表示城市 ui 和 vi 之间有一条双向边。题目保证任意城市之间只有唯一的一条路径。换句话说,所有城市形成了一棵  。

一棵 子树 是城市的一个子集,且子集中任意城市之间可以通过子集中的其他城市和边到达。两个子树被认为不一样的条件是至少有一个城市在其中一棵子树中存在,但在另一棵子树中不存在。

对于 d 从 1 到 n-1 ,请你找到城市间 最大距离 恰好为 d 的所有子树数目。

请你返回一个大小为 n-1 的数组,其中第 d 个元素(下标从 1 开始)是城市间 最大距离 恰好等于 d 的子树数目。

请注意,两个城市间距离定义为它们之间需要经过的边的数目。

 

示例 1:

输入:n = 4, edges = [[1,2],[2,3],[2,4]]
输出:[3,4,0]
解释:
子树 {1,2}, {2,3} 和 {2,4} 最大距离都是 1 。
子树 {1,2,3}, {1,2,4}, {2,3,4} 和 {1,2,3,4} 最大距离都为 2 。
不存在城市间最大距离为 3 的子树。

示例 2:

输入:n = 2, edges = [[1,2]]
输出:[1]

示例 3:

输入:n = 3, edges = [[1,2],[2,3]]
输出:[2,1]

 

提示:

原站题解

去查看

上次编辑到这里,代码来自缓存 点击恢复默认模板
class Solution { public: vector<int> countSubgraphsForEachDiameter(int n, vector<vector<int>>& edges) { } };

golang 解法, 执行用时: 0 ms, 内存消耗: 1.9 MB, 提交时间: 2023-03-12 09:43:31

func countSubgraphsForEachDiameter(n int, edges [][]int) []int {
    // 建树
    g := make([][]int, n)
    for _, e := range edges {
        x, y := e[0]-1, e[1]-1 // 编号改为从 0 开始
        g[x] = append(g[x], y)
        g[y] = append(g[y], x)
    }

    // 计算树上任意两点的距离
    dis := make([][]int, n)
    for i := range dis {
        // 计算 i 到其余点的距离
        dis[i] = make([]int, n)
        var dfs func(int, int)
        dfs = func(x, fa int) {
            for _, y := range g[x] {
                if y != fa {
                    dis[i][y] = dis[i][x] + 1 // 自顶向下
                    dfs(y, x)
                }
            }
        }
        dfs(i, -1)
    }

    ans := make([]int, n-1)
    for i, di := range dis {
        for j := i + 1; j < n; j++ {
            dj := dis[j]
            d := di[j]
            var dfs func(int, int) int
            dfs = func(x, fa int) int {
                // 能递归到这,说明 x 可以选
                cnt := 1 // 选 x
                for _, y := range g[x] {
                    if y != fa &&
                       (di[y] < d || di[y] == d && y > j) &&
                       (dj[y] < d || dj[y] == d && y > i) { // 满足这些条件就可以选
                        cnt *= dfs(y, x) // 每棵子树互相独立,采用乘法原理
                    }
                }
                if di[x]+dj[x] > d { // x 是可选点
                    cnt++ // 不选 x
                }
                return cnt
            }
            ans[d-1] += dfs(i, -1)
        }
    }
    return ans
}

java 解法, 执行用时: 2 ms, 内存消耗: 40.1 MB, 提交时间: 2023-03-12 09:43:06

class Solution {
    private List<Integer>[] g;
    private int[][] dis;

    public int[] countSubgraphsForEachDiameter(int n, int[][] edges) {
        g = new ArrayList[n];
        Arrays.setAll(g, e -> new ArrayList<>());
        for (var e : edges) {
            int x = e[0] - 1, y = e[1] - 1; // 编号改为从 0 开始
            g[x].add(y);
            g[y].add(x); // 建树
        }

        dis = new int[n][n];
        for (int i = 0; i < n; ++i)
            dfs(i, i, -1); // 计算 i 到其余点的距离

        var ans = new int[n - 1];
        for (int i = 0; i < n; ++i)
            for (int j = i + 1; j < n; ++j)
                ans[dis[i][j] - 1] += dfs2(i, j, dis[i][j], i, -1);
        return ans;
    }

    private void dfs(int i, int x, int fa) {
        for (int y : g[x])
            if (y != fa) {
                dis[i][y] = dis[i][x] + 1; // 自顶向下
                dfs(i, y, x);
            }
    }

    private int dfs2(int i, int j, int d, int x, int fa) {
        // 能递归到这,说明 x 可以选
        int cnt = 1; // 选 x
        for (int y : g[x])
            if (y != fa &&
               (dis[i][y] < d || dis[i][y] == d && y > j) &&
               (dis[j][y] < d || dis[j][y] == d && y > i)) // 满足这些条件就可以选
                cnt *= dfs2(i, j, d, y, x); // 每棵子树互相独立,采用乘法原理
        if (dis[i][x] + dis[j][x] > d)  // x 是可选点
            ++cnt; // 不选 x
        return cnt;
    }
}

python3 解法, 执行用时: 40 ms, 内存消耗: 15 MB, 提交时间: 2023-03-12 09:42:43

'''
枚举直径端点 + 乘法原理
'''
class Solution:
    def countSubgraphsForEachDiameter(self, n: int, edges: List[List[int]]) -> List[int]:
        # 建树
        g = [[] for _ in range(n)]
        for x, y in edges:
            g[x - 1].append(y - 1)
            g[y - 1].append(x - 1)  # 编号改为从 0 开始

        # 计算树上任意两点的距离
        dis = [[0] * n for _ in range(n)]
        def dfs(x: int, fa: int) -> None:
            for y in g[x]:
                if y != fa:
                    dis[i][y] = dis[i][x] + 1  # 自顶向下
                    dfs(y, x)
        for i in range(n):
            dfs(i, -1)  # 计算 i 到其余点的距离

        def dfs2(x: int, fa: int) -> int:
            # 能递归到这,说明 x 可以选
            cnt = 1  # 选 x
            for y in g[x]:
                if y != fa and \
                   (di[y] < d or di[y] == d and y > j) and \
                   (dj[y] < d or dj[y] == d and y > i):  # 满足这些条件就可以选
                    cnt *= dfs2(y, x)  # 每棵子树互相独立,采用乘法原理
            if di[x] + dj[x] > d:  # x 是可选点
                cnt += 1  # 不选 x
            return cnt
        ans = [0] * (n - 1)
        for i, di in enumerate(dis):
            for j in range(i + 1, n):
                dj = dis[j]
                d = di[j]
                ans[d - 1] += dfs2(i, -1)
        return ans

python3 解法, 执行用时: 508 ms, 内存消耗: 14.8 MB, 提交时间: 2023-03-12 09:41:40

class Solution:
    def countSubgraphsForEachDiameter(self, n: int, edges: List[List[int]]) -> List[int]:
        # 建树
        g = [[] for _ in range(n)]
        for x, y in edges:
            g[x - 1].append(y - 1)
            g[y - 1].append(x - 1)  # 编号改为从 0 开始

        ans = [0] * (n - 1)
        #  二进制枚举
        for mask in range(3, 1 << n):
            if (mask & (mask - 1)) == 0:  # 需要至少两个点
                continue
            # 求树的直径
            vis = diameter = 0
            def dfs(x: int) -> int:
                nonlocal vis, diameter
                vis |= 1 << x  # 标记 x 访问过
                max_len = 0
                for y in g[x]:
                    if (vis >> y & 1) == 0 and mask >> y & 1:  # y 没有访问过且在 mask 中
                        ml = dfs(y) + 1
                        diameter = max(diameter, max_len + ml)
                        max_len = max(max_len, ml)
                return max_len
            dfs(mask.bit_length() - 1)  # 从一个在 mask 中的点开始递归
            if vis == mask:
                ans[diameter - 1] += 1
        return ans

golang 解法, 执行用时: 8 ms, 内存消耗: 1.9 MB, 提交时间: 2023-03-12 09:41:13

func countSubgraphsForEachDiameter(n int, edges [][]int) []int {
    // 建树
    g := make([][]int, n)
    for _, e := range edges {
        x, y := e[0]-1, e[1]-1 // 编号改为从 0 开始
        g[x] = append(g[x], y)
        g[y] = append(g[y], x)
    }

    ans := make([]int, n-1)
    // 二进制枚举
    for mask := 3; mask < 1<<n; mask++ {
        if mask&(mask-1) == 0 { // 需要至少两个点
            continue
        }
        // 求树的直径
        vis, diameter := 0, 0
        var dfs func(int) int
        dfs = func(x int) (maxLen int) {
            vis |= 1 << x // 标记 x 访问过
            for _, y := range g[x] {
                if vis>>y&1 == 0 && mask>>y&1 > 0 { // y 没有访问过且在 mask 中
                    ml := dfs(y) + 1
                    diameter = max(diameter, maxLen+ml)
                    maxLen = max(maxLen, ml)
                }
            }
            return
        }
        dfs(bits.TrailingZeros(uint(mask))) // 从一个在 mask 中的点开始递归
        if vis == mask {
            ans[diameter-1]++
        }
    }
    return ans
}

func max(a, b int) int { if a < b { return b }; return a }

java 解法, 执行用时: 26 ms, 内存消耗: 42.1 MB, 提交时间: 2023-03-12 09:40:41

class Solution {
    private List<Integer>[] g;
    private int mask, vis, diameter;

    public int[] countSubgraphsForEachDiameter(int n, int[][] edges) {
        g = new ArrayList[n];
        Arrays.setAll(g, e -> new ArrayList<>());
        for (var e : edges) {
            int x = e[0] - 1, y = e[1] - 1; // 编号改为从 0 开始
            g[x].add(y);
            g[y].add(x); // 建树
        }

        var ans = new int[n - 1];
        // 二进制枚举
        for (mask = 3; mask < 1 << n; ++mask) {
            if ((mask & (mask - 1)) == 0) continue; // 需要至少两个点
            vis = diameter = 0;
            dfs(Integer.numberOfTrailingZeros(mask)); // 从一个在 mask 中的点开始递归
            if (vis == mask)
                ++ans[diameter - 1];
        }
        return ans;
    }

    // 求树的直径
    private int dfs(int x) {
        vis |= 1 << x; // 标记 x 访问过
        int maxLen = 0;
        for (int y : g[x])
            if ((vis >> y & 1) == 0 && (mask >> y & 1) == 1) { // y 没有访问过且在 mask 中
                int ml = dfs(y) + 1;
                diameter = Math.max(diameter, maxLen + ml);
                maxLen = Math.max(maxLen, ml);
            }
        return maxLen;
    }
}

java 解法, 执行用时: 25 ms, 内存消耗: 41.8 MB, 提交时间: 2023-03-12 09:40:17

class Solution {
    private List<Integer>[] g;
    private boolean[] inSet, vis;
    private int[] ans;
    private int n, diameter;

    public int[] countSubgraphsForEachDiameter(int n, int[][] edges) {
        this.n = n;
        g = new ArrayList[n];
        Arrays.setAll(g, e -> new ArrayList<>());
        for (var e : edges) {
            int x = e[0] - 1, y = e[1] - 1; // 编号改为从 0 开始
            g[x].add(y);
            g[y].add(x); // 建树
        }

        ans = new int[n - 1];
        inSet = new boolean[n];
        f(0);
        return ans;
    }

    private void f(int i) {
        if (i == n) {
            for (int v = 0; v < n; ++v)
                if (inSet[v]) {
                    vis = new boolean[n];
                    diameter = 0;
                    dfs(v);
                    break;
                }
            if (diameter > 0 && Arrays.equals(vis, inSet))
                ++ans[diameter - 1];
            return;
        }

        // 不选城市 i
        f(i + 1);

        // 选城市 i
        inSet[i] = true;
        f(i + 1);
        inSet[i] = false; // 恢复现场
    }

    // 求树的直径
    private int dfs(int x) {
        vis[x] = true;
        int maxLen = 0;
        for (int y : g[x])
            if (!vis[y] && inSet[y]) {
                int ml = dfs(y) + 1;
                diameter = Math.max(diameter, maxLen + ml);
                maxLen = Math.max(maxLen, ml);
            }
        return maxLen;
    }
}

golang 解法, 执行用时: 12 ms, 内存消耗: 1.9 MB, 提交时间: 2023-03-12 09:39:59

func countSubgraphsForEachDiameter(n int, edges [][]int) []int {
    // 建树
    g := make([][]int, n)
    for _, e := range edges {
        x, y := e[0]-1, e[1]-1 // 编号改为从 0 开始
        g[x] = append(g[x], y)
        g[y] = append(g[y], x)
    }

    // 求树的直径
    var inSet, vis [15]bool
    var diameter int
    var dfs func(int) int
    dfs = func(x int) (maxLen int) {
        vis[x] = true
        for _, y := range g[x] {
            if !vis[y] && inSet[y] {
                ml := dfs(y) + 1
                diameter = max(diameter, maxLen+ml)
                maxLen = max(maxLen, ml)
            }
        }
        return
    }

    ans := make([]int, n-1)
    var f func(int)
    f = func(i int) {
        if i == n {
            for v, b := range inSet {
                if b {
                    vis, diameter = [15]bool{}, 0
                    dfs(v)
                    break
                }
            }
            if diameter > 0 && vis == inSet {
                ans[diameter-1]++
            }
            return
        }

        // 不选城市 i
        f(i + 1)

        // 选城市 i
        inSet[i] = true
        f(i + 1)
        inSet[i] = false // 恢复现场
    }
    f(0)
    return ans
}

func max(a, b int) int { if a < b { return b }; return a }

python3 解法, 执行用时: 652 ms, 内存消耗: 15.2 MB, 提交时间: 2023-03-12 09:39:38

'''
枚举子集 + 树的直径
'''
class Solution:
    def countSubgraphsForEachDiameter(self, n: int, edges: List[List[int]]) -> List[int]:
        # 建树
        g = [[] for _ in range(n)]
        for x, y in edges:
            g[x - 1].append(y - 1)
            g[y - 1].append(x - 1)  # 编号改为从 0 开始

        ans = [0] * (n - 1)
        in_set = [False] * n
        def f(i: int) -> None:
            if i == n:
                vis = [False] * n
                diameter = 0
                for v, b in enumerate(in_set):
                    if not b: continue
                    # 求树的直径
                    def dfs(x: int) -> int:
                        nonlocal diameter
                        vis[x] = True
                        max_len = 0
                        for y in g[x]:
                            if not vis[y] and in_set[y]:
                                ml = dfs(y) + 1
                                diameter = max(diameter, max_len + ml)
                                max_len = max(max_len, ml)
                        return max_len
                    dfs(v)
                    break
                if diameter and vis == in_set:
                    ans[diameter - 1] += 1
                return
            
            # 不选城市 i
            f(i + 1)

            # 选城市  i
            in_set[i] = True
            f(i + 1)
            in_set[i] = False  # 恢复现场
        f(0)
        return ans

上一题