列表

详情


LCP 05. 发 LeetCoin

力扣决定给一个刷题团队发LeetCoin作为奖励。同时,为了监控给大家发了多少LeetCoin,力扣有时候也会进行查询。

 

该刷题团队的管理模式可以用一棵树表示:

  1. 团队只有一个负责人,编号为1。除了该负责人外,每个人有且仅有一个领导(负责人没有领导);
  2. 不存在循环管理的情况,如A管理B,B管理C,C管理A。

 

力扣想进行的操作有以下三种:

  1. 给团队的一个成员(也可以是负责人)发一定数量的LeetCoin
  2. 给团队的一个成员(也可以是负责人),以及他/她管理的所有人(即他/她的下属、他/她下属的下属,……),发一定数量的LeetCoin
  3. 查询某一个成员(也可以是负责人),以及他/她管理的所有人被发到的LeetCoin之和。

 

输入:

  1. N表示团队成员的个数(编号为1~N,负责人为1);
  2. leadership是大小为(N - 1) * 2的二维数组,其中每个元素[a, b]代表ba的下属;
  3. operations是一个长度为Q的二维数组,代表以时间排序的操作,格式如下:
    1. operations[i][0] = 1: 代表第一种操作,operations[i][1]代表成员的编号,operations[i][2]代表LeetCoin的数量;
    2. operations[i][0] = 2: 代表第二种操作,operations[i][1]代表成员的编号,operations[i][2]代表LeetCoin的数量;
    3. operations[i][0] = 3: 代表第三种操作,operations[i][1]代表成员的编号;

输出:

返回一个数组,数组里是每次查询的返回值(发LeetCoin的操作不需要任何返回值)。由于发的LeetCoin很多,请把每次查询的结果模1e9+7 (1000000007)

 

示例 1:

输入:N = 6, leadership = [[1, 2], [1, 6], [2, 3], [2, 5], [1, 4]], operations = [[1, 1, 500], [2, 2, 50], [3, 1], [2, 6, 15], [3, 1]]
输出:[650, 665]
解释:团队的管理关系见下图。
第一次查询时,每个成员得到的LeetCoin的数量分别为(按编号顺序):500, 50, 50, 0, 50, 0;
第二次查询时,每个成员得到的LeetCoin的数量分别为(按编号顺序):500, 50, 50, 0, 50, 15.

 

限制:

  1. 1 <= N <= 50000
  2. 1 <= Q <= 50000
  3. operations[i][0] != 3 时,1 <= operations[i][2] <= 5000

原站题解

去查看

上次编辑到这里,代码来自缓存 点击恢复默认模板
class Solution { public: vector<int> bonus(int n, vector<vector<int>>& leadership, vector<vector<int>>& operations) { } };

golang 解法, 执行用时: 320 ms, 内存消耗: 29.9 MB, 提交时间: 2023-08-09 23:00:41

const MOD = 1000000007

type BitTree struct {
	ar []int
	n  int
}

func NewBitTree(n int) *BitTree {
	return &BitTree{make([]int, n+1), n}
}

func (bt *BitTree) Add(p int, v int) {
	p++
	for p <= bt.n {
		bt.ar[p] += v
		p += p & -p
	}
}

func (bt *BitTree) Sum(p int) int {
	o := 0
	p++
	for p > 0 {
		o += bt.ar[p]
		p -= p & -p
	}
	return o
}

func bonus(n int, leadership [][]int, operations [][]int) []int {
	subs := make([][]int, n)
	for _, l := range leadership {
		a := l[0] - 1
		subs[a] = append(subs[a], l[1]-1)
	}

	arIn := make([]int, n)
	arOut := make([]int, n)
	i := 0
	var dfs func(int)
	dfs = func(a int) {
		arIn[a], i = i, i+1
		for _, b := range subs[a] {
			dfs(b)
		}
		arOut[a] = i - 1
	}
	dfs(0)

	bt1, bt2 := NewBitTree(n), NewBitTree(n)
	add := func(l, r, v int) {
		bt1.Add(l, v)
		bt1.Add(r+1, -v)
		bt2.Add(l, v*(l-1))
		bt2.Add(r+1, -v*r)
	}
	sum := func(l, r int) int {
		l--
		o := ((bt1.Sum(r)*r - bt2.Sum(r)) - (bt1.Sum(l)*l - bt2.Sum(l))) % MOD
		if o < 0 {
			o += MOD
		}
		return o
	}

	o := make([]int, 0, 32)
	for _, op := range operations {
		a := op[1] - 1
		switch op[0] {
		case 1:
			add(arIn[a], arIn[a], op[2])
		case 2:
			add(arIn[a], arOut[a], op[2])
		case 3:
			o = append(o, sum(arIn[a], arOut[a]))
		}
	}
	return o
}

python3 解法, 执行用时: 1412 ms, 内存消耗: 95.8 MB, 提交时间: 2023-08-09 22:57:44

M = int(1e9 + 7)

class BIT:
    def __init__(self, n):
        self.n = n + 5
        self.sum = [0 for _ in range(n + 10)]
        self.ntimessum = [0 for _ in range(n + 10)]
    
    def lowbit(self, x):
        return x & (-x)

    # 在 pos 位置加上 k
    def update(self, pos, k):
        x = pos
        while pos <= self.n:
            self.sum[pos] += k
            self.sum[pos] %= M
            self.ntimessum[pos] += k * (x - 1)
            self.ntimessum[pos] %= M
            pos += self.lowbit(pos)
    
    # 区间更新 + 单点查询
    def askis(self, pos):
        if not pos:
            return 0
        ret = 0
        while pos:
            ret += self.sum[pos]
            ret %= M
            pos -= lowbit(pos)
        return ret
    
    # 单点更新 + 区间查询
    def asksi(self, l, r):
        if l > r:
            return 0
        return askis(r) - askis(l - 1)
    
    # 单点更新 + 单点查询
    def askss(self, pos):
        return askis(pos) - askis(pos - 1)
    
    # 区间更新 + 区间查询
    def askii(self, pos):
        if not pos:
            return 0
        ret = 0
        x = pos
        while pos:
            ret += x * self.sum[pos] - self.ntimessum[pos]
            ret %= M
            pos -= self.lowbit(pos)
        return ret

class Solution:
    def bonus(self, n: int, leadership: List[List[int]], operations: List[List[int]]) -> List[int]:
        
        # 邻接表
        g = [[] for _ in range(n + 1)]
        begin = [0 for _ in range(n + 1)]
        end = [0 for _ in range(n + 1)]
        id = 1

        for l in leadership:
            g[l[0]].append(l[1])
        
        # 深搜
        def dfs(cur):
            nonlocal id
            begin[cur] = id
            for child in g[cur]:
                dfs(child)
            end[cur] = id
            id += 1
        dfs(1)
        
        # 树状数组
        b = BIT(n)
        ret = []
        for q in operations:
            if q[0] == 1:
                b.update(end[q[1]], q[2])
                b.update(end[q[1]] + 1, -q[2])
            elif q[0] == 2:
                b.update(begin[q[1]], q[2])
                b.update(end[q[1]] + 1, -q[2])
            else:
                ans = b.askii(end[q[1]]) - b.askii(begin[q[1]] - 1)
                ret.append((ans % M + M) % M)

        return ret

python3 解法, 执行用时: 948 ms, 内存消耗: 53.1 MB, 提交时间: 2023-08-09 22:56:06

'''
使用DFS序遍历子树,并保存相应的DFS序,DFS序可以保证任意子树的所有节点在一段连续空间上,
并保存节点编号到数组索引的映射
计算每一棵子树的节点数(等价于前一步骤中所有子树在数组上的连续空间长度)
后续就是标准的树状数组区间更新及区间查询写法(也可以使用线段树,但树状数组的常数更小),
单点更新可以替换为长度为1个区间更新
'''
mod = 10 ** 9 + 7

def lb(x):
    return x & (-x)

class Solution:
    def bonus(self, n: int, leadership: List[List[int]], operations: List[List[int]]) -> List[int]:
        g1 = defaultdict(list)
        g2 = [0] * (n + 1)
        cnt = [0] + [1] * n

        for a, b in leadership:
            g1[a].append(b)
            g2[b] = a
        
        ts, stk = [0], [1]
        while stk:
            a = stk.pop()
            ts.append(a)
            stk.extend(g1[a][::-1])

        d = {j: i for i, j in enumerate(ts)}

        for b in ts[1:][::-1]:
            cnt[g2[b]] += cnt[b]
 
        maxn = n
        while maxn & (maxn - 1):
            maxn += lb(maxn)
        
        f1 = [0] * (maxn + 1)
        f2 = [0] * (maxn + 1)

        def add(i, x):
            xx = i * x
            while i <= maxn:
                f1[i] += x
                f2[i] += xx
                i += lb(i)
        
        def query(i):
            t1, t2, p = 0, 0, i
            while i:
                t1 += f1[i]
                t2 += f2[i]
                i -= lb(i)
            return ((p + 1) * t1 - t2) % mod

        ans = []
        for op in operations:
            if op[0] == 3:
                p = op[1]
                ans.append((query(d[p] + cnt[p] - 1) - query(d[p] - 1)) % mod)
            else:
                p, c, l = op[1], op[2], 1 if op[0] == 1 else cnt[op[1]]
                add(d[p], c)
                add(d[p] + l, -c)

        return ans

上一题