列表

详情


2836. 在传球游戏中最大化函数值

给你一个长度为 n 下标从 0 开始的整数数组 receiver 和一个整数 k 。

总共有 n 名玩家,玩家 编号 互不相同,且为 [0, n - 1] 中的整数。这些玩家玩一个传球游戏,receiver[i] 表示编号为 i 的玩家会传球给编号为 receiver[i] 的玩家。玩家可以传球给自己,也就是说 receiver[i] 可能等于 i 。

你需要从 n 名玩家中选择一名玩家作为游戏开始时唯一手中有球的玩家,球会被传 恰好 k 次。

如果选择编号为 x 的玩家作为开始玩家,定义函数 f(x) 表示从编号为 x 的玩家开始,k 次传球内所有接触过球玩家的编号之  ,如果有玩家多次触球,则 累加多次 。换句话说, f(x) = x + receiver[x] + receiver[receiver[x]] + ... + receiver(k)[x] 。

你的任务时选择开始玩家 x ,目的是 最大化 f(x) 。

请你返回函数的 最大值 。

注意:receiver 可能含有重复元素。

 

示例 1:

传递次数 传球者编号 接球者编号 x + 所有接球者编号
      2
1 2 1 3
2 1 0 3
3 0 2 5
4 2 1 6

 

输入:receiver = [2,0,1], k = 4
输出:6
解释:上表展示了从编号为 x = 2 开始的游戏过程。
从表中可知,f(2) 等于 6 。
6 是能得到最大的函数值。
所以输出为 6 。

示例 2:

传递次数 传球者编号 接球者编号 x + 所有接球者编号
      4
1 4 3 7
2 3 2 9
3 2 1 10

 

输入:receiver = [1,1,1,2,3], k = 3
输出:10
解释:上表展示了从编号为 x = 4 开始的游戏过程。
从表中可知,f(4) 等于 10 。
10 是能得到最大的函数值。
所以输出为 10 。

 

提示:

原站题解

去查看

上次编辑到这里,代码来自缓存 点击恢复默认模板
class Solution { public: long long getMaxFunctionValue(vector<int>& receiver, long long k) { } };

golang 解法, 执行用时: 420 ms, 内存消耗: 99.2 MB, 提交时间: 2023-08-28 10:24:11

func getMaxFunctionValue(receiver []int, K int64) int64 {
	type pair struct{ pa, sum int }
	n, m := len(receiver), bits.Len(uint(K))
	pa := make([][]pair, n)
	for i, p := range receiver {
		pa[i] = make([]pair, m)
		pa[i][0] = pair{p, p}
	}
	for i := 0; i+1 < m; i++ {
		for x := range pa {
			p := pa[x][i]
			pp := pa[p.pa][i]
			pa[x][i+1] = pair{pp.pa, p.sum + pp.sum} // 合并节点值之和
		}
	}

	ans := 0
	for i := 0; i < n; i++ {
		x := i
		sum := i // 节点值之和,初始为节点 i
		for k := uint(K); k > 0; k &= k - 1 {
			p := pa[x][bits.TrailingZeros(k)]
			sum += p.sum
			x = p.pa
		}
		ans = max(ans, sum)
	}
	return int64(ans)
}

func max(a, b int) int { if b > a { return b }; return a }

cpp 解法, 执行用时: 604 ms, 内存消耗: 234.6 MB, 提交时间: 2023-08-28 10:23:58

class Solution {
public:
    long long getMaxFunctionValue(vector<int> &receiver, long long K) {
        int n = receiver.size();
        int m = 64 - __builtin_clzll(K); // K 的二进制长度
        vector<vector<pair<int, long long>>> pa(m, vector<pair<int, long long>>(n));
        for (int i = 0; i < n; i++)
            pa[0][i] = {receiver[i], receiver[i]};
        for (int i = 0; i < m - 1; i++) {
            for (int x = 0; x < n; x++) {
                auto [p, s] = pa[i][x];
                auto [pp, ss] = pa[i][p];
                pa[i + 1][x] = {pp, s + ss}; // 合并节点值之和
            }
        }

        long long ans = 0;
        for (int i = 0; i < n; i++) {
            long long sum = i;
            int x = i;
            for (long long k = K; k; k &= k - 1) {
                auto [p, s] = pa[__builtin_ctzll(k)][x];
                sum += s;
                x = p;
            }
            ans = max(ans, sum);
        }
        return ans;
    }
};

java 解法, 执行用时: 55 ms, 内存消耗: 143.4 MB, 提交时间: 2023-08-28 10:23:45

class Solution {
    public long getMaxFunctionValue(List<Integer> receiver, long K) {
        int n = receiver.size();
        int m = 64 - Long.numberOfLeadingZeros(K); // K 的二进制长度
        var pa = new int[m][n];
        var sum = new long[m][n];
        for (int i = 0; i < n; i++) {
            pa[0][i] = receiver.get(i);
            sum[0][i] = receiver.get(i);
        }
        for (int i = 0; i < m - 1; i++) {
            for (int x = 0; x < n; x++) {
                int p = pa[i][x];
                pa[i + 1][x] = pa[i][p];
                sum[i + 1][x] = sum[i][x] + sum[i][p]; // 合并节点值之和
            }
        }

        long ans = 0;
        for (int i = 0; i < n; i++) {
            long s = i;
            int x = i;
            for (long k = K; k > 0; k &= k - 1) {
                int ctz = Long.numberOfTrailingZeros(k);
                s += sum[ctz][x];
                x = pa[ctz][x];
            }
            ans = Math.max(ans, s);
        }
        return ans;
    }
}

python3 解法, 执行用时: 4680 ms, 内存消耗: 393.4 MB, 提交时间: 2023-08-28 10:23:32

class Solution:
    def getMaxFunctionValue(self, receiver: List[int], k: int) -> int:
        n = len(receiver)
        m = k.bit_length() - 1
        pa = [[(p, p)] + [None] * m for p in receiver]
        for i in range(m):
            for x in range(n):
                p, s = pa[x][i]
                pp, ss = pa[p][i]
                pa[x][i + 1] = (pp, s + ss)  # 合并节点值之和

        ans = 0
        for i in range(n):
            x = sum = i
            for j in range(m + 1):
                if (k >> j) & 1:  # k 的二进制从低到高第 j 位是 1
                    x, s = pa[x][j]
                    sum += s
            ans = max(ans, sum)
        return ans

上一题