class Solution {
public:
int maximumPoints(vector<vector<int>>& edges, vector<int>& coins, int k) {
}
};
100108. 收集所有金币可获得的最大积分
节点 0
处现有一棵由 n
个节点组成的无向树,节点编号从 0
到 n - 1
。给你一个长度为 n - 1
的二维 整数 数组 edges
,其中 edges[i] = [ai, bi]
表示在树上的节点 ai
和 bi
之间存在一条边。另给你一个下标从 0 开始、长度为 n
的数组 coins
和一个整数 k
,其中 coins[i]
表示节点 i
处的金币数量。
从根节点开始,你必须收集所有金币。要想收集节点上的金币,必须先收集该节点的祖先节点上的金币。
节点 i
上的金币可以用下述方法之一进行收集:
coins[i] - k
点积分。如果 coins[i] - k
是负数,你将会失去 abs(coins[i] - k)
点积分。floor(coins[i] / 2)
点积分。如果采用这种方法,节点 i
子树中所有节点 j
的金币数 coins[j]
将会减少至 floor(coins[j] / 2)
。返回收集 所有 树节点的金币之后可以获得的最大积分。
示例 1:
输入:edges = [[0,1],[1,2],[2,3]], coins = [10,10,3,3], k = 5 输出:11 解释: 使用第一种方法收集节点 0 上的所有金币。总积分 = 10 - 5 = 5 。 使用第一种方法收集节点 1 上的所有金币。总积分 = 5 + (10 - 5) = 10 。 使用第二种方法收集节点 2 上的所有金币。所以节点 3 上的金币将会变为 floor(3 / 2) = 1 ,总积分 = 10 + floor(3 / 2) = 11 。 使用第二种方法收集节点 3 上的所有金币。总积分 = 11 + floor(1 / 2) = 11. 可以证明收集所有节点上的金币能获得的最大积分是 11 。
示例 2:
输入:edges = [[0,1],[0,2]], coins = [8,4,4], k = 0 输出:16 解释: 使用第一种方法收集所有节点上的金币,因此,总积分 = (8 - 0) + (4 - 0) + (4 - 0) = 16 。
提示:
n == coins.length
2 <= n <= 105
0 <= coins[i] <= 104
edges.length == n - 1
0 <= edges[i][0], edges[i][1] < n
0 <= k <= 104
原站题解
java 解法, 执行用时: 114 ms, 内存消耗: 111.7 MB, 提交时间: 2023-10-30 07:47:33
class Solution { public int maximumPoints(int[][] edges, int[] coins, int k) { int n = coins.length; List<Integer>[] g = new ArrayList[n]; Arrays.setAll(g, e -> new ArrayList<>()); for (int[] e : edges) { int x = e[0], y = e[1]; g[x].add(y); g[y].add(x); } int[][] memo = new int[n][14]; for (int[] m : memo) { Arrays.fill(m, -1); // -1 表示没有计算过 } return dfs(0, 0, -1, memo, g, coins, k); } private int dfs(int i, int j, int fa, int[][] memo, List<Integer>[] g, int[] coins, int k) { if (memo[i][j] != -1) { // 之前计算过 return memo[i][j]; } int res1 = (coins[i] >> j) - k; int res2 = coins[i] >> (j + 1); for (int ch : g[i]) { if (ch == fa) continue; res1 += dfs(ch, j, i, memo, g, coins, k); // 不右移 if (j < 13) { // j+1 >= 14 相当于 res2 += 0,无需递归 res2 += dfs(ch, j + 1, i, memo, g, coins, k); // 右移 } } return memo[i][j] = Math.max(res1, res2); // 记忆化 } }
java 解法, 执行用时: 38 ms, 内存消耗: 129 MB, 提交时间: 2023-10-30 07:47:18
class Solution { public int maximumPoints(int[][] edges, int[] coins, int k) { List<Integer>[] g = new ArrayList[coins.length]; Arrays.setAll(g, e -> new ArrayList<>()); for (int[] e : edges) { int x = e[0], y = e[1]; g[x].add(y); g[y].add(x); } return dfs(0, -1, g, coins, k)[0]; } private int[] dfs(int x, int fa, List<Integer>[] g, int[] coins, int k) { int[] res1 = new int[14]; int[] res2 = new int[14]; for (int y : g[x]) { if (y == fa) continue; int[] r = dfs(y, x, g, coins, k); for (int j = 0; j < r.length; j++) { res1[j] += r[j]; if (j < 13) { res2[j] += r[j + 1]; } } } for (int j = 0; j < res1.length; j++) { res1[j] = Math.max(res1[j] + (coins[x] >> j) - k, res2[j] + (coins[x] >> (j + 1))); } return res1; } }
golang 解法, 执行用时: 236 ms, 内存消耗: 61.8 MB, 提交时间: 2023-10-30 07:47:03
func maximumPoints(edges [][]int, coins []int, k int) int { n := len(coins) g := make([][]int, n) for _, e := range edges { x, y := e[0], e[1] g[x] = append(g[x], y) g[y] = append(g[y], x) } memo := make([][14]int, n) for i := range memo { for j := range memo[i] { memo[i][j] = -1 } } var dfs func(int, int, int) int dfs = func(i, j, fa int) int { p := &memo[i][j] if *p != -1 { return *p } res1 := coins[i]>>j - k res2 := coins[i] >> (j + 1) for _, ch := range g[i] { if ch != fa { res1 += dfs(ch, j, i) // 不右移 if j < 13 { // j+1 >= 14 相当于 res2 += 0 无需递归 res2 += dfs(ch, j+1, i) // 右移 } } } *p = max(res1, res2) return *p } return dfs(0, 0, -1) } func maximumPoints2(edges [][]int, coins []int, k int) int { n := len(coins) g := make([][]int, n) for _, e := range edges { x, y := e[0], e[1] g[x] = append(g[x], y) g[y] = append(g[y], x) } var dfs func(int, int) [14]int dfs = func(x, fa int) (res1 [14]int) { res2 := [14]int{} for _, y := range g[x] { if y != fa { r := dfs(y, x) for j, v := range r { res1[j] += v if j < 13 { res2[j] += r[j+1] } } } } for j := 0; j < 14; j++ { res1[j] = max(res1[j]+coins[x]>>j-k, res2[j]+coins[x]>>(j+1)) } return } return dfs(0, -1)[0] } func max(a, b int) int { if b > a { return b }; return a }
cpp 解法, 执行用时: 408 ms, 内存消耗: 234.3 MB, 提交时间: 2023-10-30 07:46:23
class Solution { public: int maximumPoints(vector<vector<int>> &edges, vector<int> &coins, int k) { vector<vector<int>> g(coins.size()); for (auto &e : edges) { int x = e[0], y = e[1]; g[x].push_back(y); g[y].push_back(x); } function<array<int, 14>(int, int)> dfs = [&](int x, int fa) -> array<int, 14> { array<int, 14> res1{}, res2{}; for (int y : g[x]) { if (y == fa) continue; auto r = dfs(y, x); for (int j = 0; j < 14; j++) { res1[j] += r[j]; if (j < 13) { res2[j] += r[j + 1]; } } } for (int j = 0; j < 14; j++) { res1[j] = max(res1[j] + (coins[x] >> j) - k, res2[j] + (coins[x] >> (j + 1))); } return res1; }; return dfs(0, -1)[0]; } int maximumPoints2(vector<vector<int>> &edges, vector<int> &coins, int k) { int n = coins.size(); vector<vector<int>> g(n); for (auto &e: edges) { int x = e[0], y = e[1]; g[x].push_back(y); g[y].push_back(x); } vector<vector<int>> memo(n, vector<int>(14, -1)); // -1 表示没有计算过 function<int(int, int, int)> dfs = [&](int i, int j, int fa) -> int { auto &res = memo[i][j]; // 注意这里是引用 if (res != -1) { // 之前计算过 return res; } int res1 = (coins[i] >> j) - k; int res2 = coins[i] >> (j + 1); for (int ch: g[i]) { if (ch == fa) continue; res1 += dfs(ch, j, i); // 不右移 if (j < 13) { // j+1 >= 14 相当于 res2 += 0,无需递归 res2 += dfs(ch, j + 1, i); // 右移 } } return res = max(res1, res2); // 记忆化 }; return dfs(0, 0, -1); } };
python3 解法, 执行用时: 2592 ms, 内存消耗: 390.5 MB, 提交时间: 2023-10-30 07:45:50
class Solution: def maximumPoints(self, edges: List[List[int]], coins: List[int], k: int) -> int: g = [[] for _ in coins] for x, y in edges: g[x].append(y) g[y].append(x) @cache def dfs(i: int, j: int, fa: int) -> int: res1 = (coins[i] >> j) - k res2 = coins[i] >> (j + 1) for ch in g[i]: if ch != fa: res1 += dfs(ch, j, i) # 不右移 if j < 13: # j+1 >= 14 相当于 res2 += 0,无需递归 res2 += dfs(ch, j + 1, i) # 右移 return max(res1, res2) return dfs(0, 0, -1) # 自底向上 def maximumPoints2(self, edges: List[List[int]], coins: List[int], k: int) -> int: g = [[] for _ in coins] for x, y in edges: g[x].append(y) g[y].append(x) @cache def dfs(i: int, j: int, fa: int) -> int: res1 = (coins[i] >> j) - k res2 = coins[i] >> (j + 1) for ch in g[i]: if ch != fa: res1 += dfs(ch, j, i) # 不右移 if j < 13: # j+1 >= 14 相当于 res2 += 0,无需递归 res2 += dfs(ch, j + 1, i) # 右移 return max(res1, res2) return dfs(0, 0, -1)