class Solution {
public:
int sumCounts(vector<int>& nums) {
}
};
100074. 子数组不同元素数目的平方和 II
给你一个下标从 0 开始的整数数组 nums
。
定义 nums
一个子数组的 不同计数 值如下:
nums[i..j]
表示 nums
中所有下标在 i
到 j
范围内的元素构成的子数组(满足 0 <= i <= j < nums.length
),那么我们称子数组 nums[i..j]
中不同值的数目为 nums[i..j]
的不同计数。请你返回 nums
中所有子数组的 不同计数 的 平方 和。
由于答案可能会很大,请你将它对 109 + 7
取余 后返回。
子数组指的是一个数组里面一段连续 非空 的元素序列。
示例 1:
输入:nums = [1,2,1] 输出:15 解释:六个子数组分别为: [1]: 1 个互不相同的元素。 [2]: 1 个互不相同的元素。 [1]: 1 个互不相同的元素。 [1,2]: 2 个互不相同的元素。 [2,1]: 2 个互不相同的元素。 [1,2,1]: 2 个互不相同的元素。 所有不同计数的平方和为 12 + 12 + 12 + 22 + 22 + 22 = 15 。
示例 2:
输入:nums = [2,2] 输出:3 解释:三个子数组分别为: [2]: 1 个互不相同的元素。 [2]: 1 个互不相同的元素。 [2,2]: 1 个互不相同的元素。 所有不同计数的平方和为 12 + 12 + 12 = 3 。
提示:
1 <= nums.length <= 105
1 <= nums[i] <= 105
原站题解
cpp 解法, 执行用时: 572 ms, 内存消耗: 77.3 MB, 提交时间: 2023-10-29 10:44:12
class Solution { public: int sumCounts(vector<int>& nums) { int n = nums.size(); const int MOD = 1e9 + 7; // sum1:区间和,sum2:区间平方和 long long sum1[n * 4 + 10], sum2[n * 4 + 10]; // 因为是区间修改,所以要懒标记下推 int lazy[n * 4 + 10]; memset(sum1, 0, sizeof(sum1)); memset(sum2, 0, sizeof(sum2)); memset(lazy, 0, sizeof(lazy)); // 根据公式维护区间加 K auto add = [&](int id, int l, int r, int K) { int len = r - l + 1; sum2[id] = (sum2[id] + 2LL * K * sum1[id] + 1LL * K * K % MOD * len) % MOD; sum1[id] = (sum1[id] + 1LL * K * len) % MOD; }; // 懒标记下推 auto down = [&](int id, int l, int r) { int nxt = id << 1, mid = (l + r) >> 1; lazy[nxt] += lazy[id]; add(nxt, l, mid, lazy[id]); lazy[nxt | 1] += lazy[id]; add(nxt | 1, mid + 1, r, lazy[id]); lazy[id] = 0; }; // 区间加 1 function<void(int, int, int, int, int)> modify = [&](int id, int l, int r, int ql, int qr) { if (ql <= l && r <= qr) { add(id, l, r, 1); lazy[id]++; } else { if (lazy[id]) down(id, l, r); int nxt = id << 1, mid = (l + r) >> 1; if (ql <= mid) modify(nxt, l, mid, ql, qr); if (qr > mid) modify(nxt | 1, mid + 1, r, ql, qr); sum1[id] = (sum1[nxt] + sum1[nxt | 1]) % MOD; sum2[id] = (sum2[nxt] + sum2[nxt | 1]) % MOD; } }; long long ans = 0; // last[x] 表示元素 x 最近出现在哪个下标 unordered_map<int, int> last; for (int i = 1; i <= n; i++) { int &old = last[nums[i - 1]]; modify(1, 1, n, old + 1, i); old = i; // 答案就是 [1, i] 这段区间的 sum2 之和 ans = (ans + sum2[1]) % MOD; } return ans; } };
golang 解法, 执行用时: 124 ms, 内存消耗: 27.7 MB, 提交时间: 2023-10-29 10:43:47
type lazySeg []struct{ sum, todo int } func (t lazySeg) do(o, l, r, add int) { t[o].sum += add * (r - l + 1) t[o].todo += add } // o=1 [l,r] 1<=l<=r<=n // 把 [L,R] 加一,同时返回加一之前的区间和 func (t lazySeg) add1(o, l, r, L, R int) (res int) { if L <= l && r <= R { res = t[o].sum t.do(o, l, r, 1) return } m := (l + r) >> 1 if add := t[o].todo; add != 0 { t.do(o<<1, l, m, add) t.do(o<<1|1, m+1, r, add) t[o].todo = 0 } if L <= m { res = t.add1(o<<1, l, m, L, R) } if m < R { res += t.add1(o<<1|1, m+1, r, L, R) } t[o].sum = t[o<<1].sum + t[o<<1|1].sum return } func sumCounts(nums []int) (ans int) { last := map[int]int{} n := len(nums) t := make(lazySeg, n*4) s := 0 for i, x := range nums { i++ j := last[x] s += t.add1(1, 1, n, j+1, i)*2 + i - j ans = (ans + s) % 1_000_000_007 last[x] = i } return }
cpp 解法, 执行用时: 288 ms, 内存消耗: 91.3 MB, 提交时间: 2023-10-29 10:43:24
class Solution { vector<long long> sum; vector<int> todo; void do_(int o, int l, int r, int add) { sum[o] += (long long) add * (r - l + 1); todo[o] += add; } // o=1 [l,r] 1<=l<=r<=n // 把 [L,R] 加一,同时返回加一之前的区间和 long long query_and_add1(int o, int l, int r, int L, int R) { if (L <= l && r <= R) { long long res = sum[o]; do_(o, l, r, 1); return res; } int m = (l + r) / 2; int add = todo[o]; if (add != 0) { do_(o * 2, l, m, add); do_(o * 2 + 1, m + 1, r, add); todo[o] = 0; } long long res = 0; if (L <= m) res += query_and_add1(o * 2, l, m, L, R); if (m < R) res += query_and_add1(o * 2 + 1, m + 1, r, L, R); sum[o] = sum[o * 2] + sum[o * 2 + 1]; return res; } public: int sumCounts(vector<int> &nums) { int n = nums.size(); sum.resize(n * 4); todo.resize(n * 4); long long ans = 0, s = 0; unordered_map<int, int> last; for (int i = 1; i <= n; i++) { int x = nums[i - 1]; int j = last.count(x) ? last[x] : 0; s += query_and_add1(1, 1, n, j + 1, i) * 2 + i - j; ans = (ans + s) % 1'000'000'007; last[x] = i; } return ans; } };
java 解法, 执行用时: 108 ms, 内存消耗: 65.3 MB, 提交时间: 2023-10-29 10:43:02
class Solution { private long[] sum; private int[] todo; public int sumCounts(int[] nums) { int n = nums.length; sum = new long[n * 4]; todo = new int[n * 4]; long ans = 0, s = 0; var last = new HashMap<Integer, Integer>(); for (int i = 1; i <= n; i++) { int x = nums[i - 1]; int j = last.getOrDefault(x, 0); s += queryAndAdd1(1, 1, n, j + 1, i) * 2 + i - j; ans = (ans + s) % 1_000_000_007; last.put(x, i); } return (int) ans; } private void do_(int o, int l, int r, int add) { sum[o] += (long) add * (r - l + 1); todo[o] += add; } // o=1 [l,r] 1<=l<=r<=n // 把 [L,R] 加一,同时返回加一之前的区间和 private long queryAndAdd1(int o, int l, int r, int L, int R) { if (L <= l && r <= R) { long res = sum[o]; do_(o, l, r, 1); return res; } int m = (l + r) / 2; int add = todo[o]; if (add != 0) { do_(o * 2, l, m, add); do_(o * 2 + 1, m + 1, r, add); todo[o] = 0; } long res = 0; if (L <= m) res += queryAndAdd1(o * 2, l, m, L, R); if (m < R) res += queryAndAdd1(o * 2 + 1, m + 1, r, L, R); sum[o] = sum[o * 2] + sum[o * 2 + 1]; return res; } }
python3 解法, 执行用时: 4348 ms, 内存消耗: 70.5 MB, 提交时间: 2023-10-29 10:42:51
class Solution: def sumCounts(self, nums: List[int]) -> int: n = len(nums) sum = [0] * (n * 4) todo = [0] * (n * 4) def do(o: int, l: int, r: int, add: int) -> None: sum[o] += add * (r - l + 1) todo[o] += add # o=1 [l,r] 1<=l<=r<=n # 把 [L,R] 加一,同时返回加一之前的区间和 def query_and_add1(o: int, l: int, r: int, L: int, R: int) -> int: if L <= l and r <= R: res = sum[o] do(o, l, r, 1) return res m = (l + r) // 2 add = todo[o] if add: do(o * 2, l, m, add) do(o * 2 + 1, m + 1, r, add) todo[o] = 0 res = 0 if L <= m: res += query_and_add1(o * 2, l, m, L, R) if m < R: res += query_and_add1(o * 2 + 1, m + 1, r, L, R) sum[o] = sum[o * 2] + sum[o * 2 + 1] return res ans = s = 0 last = {} for i, x in enumerate(nums, 1): j = last.get(x, 0) s += query_and_add1(1, 1, n, j + 1, i) * 2 + i - j ans += s last[x] = i return ans % 1_000_000_007