列表

详情


100074. 子数组不同元素数目的平方和 II

给你一个下标从 0 开始的整数数组 nums 。

定义 nums 一个子数组的 不同计数 值如下:

请你返回 nums 中所有子数组的 不同计数 的 平方 和。

由于答案可能会很大,请你将它对 109 + 7 取余 后返回。

子数组指的是一个数组里面一段连续 非空 的元素序列。

 

示例 1:

输入:nums = [1,2,1]
输出:15
解释:六个子数组分别为:
[1]: 1 个互不相同的元素。
[2]: 1 个互不相同的元素。
[1]: 1 个互不相同的元素。
[1,2]: 2 个互不相同的元素。
[2,1]: 2 个互不相同的元素。
[1,2,1]: 2 个互不相同的元素。
所有不同计数的平方和为 12 + 12 + 12 + 22 + 22 + 22 = 15 。

示例 2:

输入:nums = [2,2]
输出:3
解释:三个子数组分别为:
[2]: 1 个互不相同的元素。
[2]: 1 个互不相同的元素。
[2,2]: 1 个互不相同的元素。
所有不同计数的平方和为 12 + 12 + 12 = 3 。

 

提示:

原站题解

去查看

上次编辑到这里,代码来自缓存 点击恢复默认模板
class Solution { public: int sumCounts(vector<int>& nums) { } };

cpp 解法, 执行用时: 572 ms, 内存消耗: 77.3 MB, 提交时间: 2023-10-29 10:44:12

class Solution {
public:
    int sumCounts(vector<int>& nums) {
        int n = nums.size();

        const int MOD = 1e9 + 7;
		// sum1:区间和,sum2:区间平方和
        long long sum1[n * 4 + 10], sum2[n * 4 + 10];
		// 因为是区间修改,所以要懒标记下推
        int lazy[n * 4 + 10];
        memset(sum1, 0, sizeof(sum1)); memset(sum2, 0, sizeof(sum2));
        memset(lazy, 0, sizeof(lazy));

		// 根据公式维护区间加 K
        auto add = [&](int id, int l, int r, int K) {
            int len = r - l + 1;
            sum2[id] = (sum2[id] + 2LL * K * sum1[id] + 1LL * K * K % MOD * len) % MOD;
            sum1[id] = (sum1[id] + 1LL * K * len) % MOD;
        };

		// 懒标记下推
        auto down = [&](int id, int l, int r) {
            int nxt = id << 1, mid = (l + r) >> 1;
            lazy[nxt] += lazy[id]; add(nxt, l, mid, lazy[id]);
            lazy[nxt | 1] += lazy[id]; add(nxt | 1, mid + 1, r, lazy[id]);
            lazy[id] = 0;
        };

		// 区间加 1
        function<void(int, int, int, int, int)> modify = [&](int id, int l, int r, int ql, int qr) {
            if (ql <= l && r <= qr) {
                add(id, l, r, 1);
                lazy[id]++;
            } else {
                if (lazy[id]) down(id, l, r);
                int nxt = id << 1, mid = (l + r) >> 1;
                if (ql <= mid) modify(nxt, l, mid, ql, qr);
                if (qr > mid) modify(nxt | 1, mid + 1, r, ql, qr);
                sum1[id] = (sum1[nxt] + sum1[nxt | 1]) % MOD;
                sum2[id] = (sum2[nxt] + sum2[nxt | 1]) % MOD;
            }
        };

        long long ans = 0;
		// last[x] 表示元素 x 最近出现在哪个下标
        unordered_map<int, int> last;
        for (int i = 1; i <= n; i++) {
            int &old = last[nums[i - 1]];
            modify(1, 1, n, old + 1, i);
            old = i;
			// 答案就是 [1, i] 这段区间的 sum2 之和
            ans = (ans + sum2[1]) % MOD;
        }
        return ans;
    }
};

golang 解法, 执行用时: 124 ms, 内存消耗: 27.7 MB, 提交时间: 2023-10-29 10:43:47

type lazySeg []struct{ sum, todo int }

func (t lazySeg) do(o, l, r, add int) {
	t[o].sum += add * (r - l + 1)
	t[o].todo += add
}

// o=1  [l,r] 1<=l<=r<=n
// 把 [L,R] 加一,同时返回加一之前的区间和
func (t lazySeg) add1(o, l, r, L, R int) (res int) {
	if L <= l && r <= R {
		res = t[o].sum
		t.do(o, l, r, 1)
		return
	}
	m := (l + r) >> 1
	if add := t[o].todo; add != 0 {
		t.do(o<<1, l, m, add)
		t.do(o<<1|1, m+1, r, add)
		t[o].todo = 0
	}
	if L <= m {
		res = t.add1(o<<1, l, m, L, R)
	}
	if m < R {
		res += t.add1(o<<1|1, m+1, r, L, R)
	}
	t[o].sum = t[o<<1].sum + t[o<<1|1].sum
	return
}

func sumCounts(nums []int) (ans int) {
	last := map[int]int{}
	n := len(nums)
	t := make(lazySeg, n*4)
	s := 0
	for i, x := range nums {
		i++
		j := last[x]
		s += t.add1(1, 1, n, j+1, i)*2 + i - j
		ans = (ans + s) % 1_000_000_007
		last[x] = i
	}
	return
}

cpp 解法, 执行用时: 288 ms, 内存消耗: 91.3 MB, 提交时间: 2023-10-29 10:43:24

class Solution {
    vector<long long> sum;
    vector<int> todo;

    void do_(int o, int l, int r, int add) {
        sum[o] += (long long) add * (r - l + 1);
        todo[o] += add;
    }

    // o=1  [l,r] 1<=l<=r<=n
    // 把 [L,R] 加一,同时返回加一之前的区间和
    long long query_and_add1(int o, int l, int r, int L, int R) {
        if (L <= l && r <= R) {
            long long res = sum[o];
            do_(o, l, r, 1);
            return res;
        }

        int m = (l + r) / 2;
        int add = todo[o];
        if (add != 0) {
            do_(o * 2, l, m, add);
            do_(o * 2 + 1, m + 1, r, add);
            todo[o] = 0;
        }

        long long res = 0;
        if (L <= m) res += query_and_add1(o * 2, l, m, L, R);
        if (m < R)  res += query_and_add1(o * 2 + 1, m + 1, r, L, R);
        sum[o] = sum[o * 2] + sum[o * 2 + 1];
        return res;
    }

public:
    int sumCounts(vector<int> &nums) {
        int n = nums.size();
        sum.resize(n * 4);
        todo.resize(n * 4);

        long long ans = 0, s = 0;
        unordered_map<int, int> last;
        for (int i = 1; i <= n; i++) {
            int x = nums[i - 1];
            int j = last.count(x) ? last[x] : 0;
            s += query_and_add1(1, 1, n, j + 1, i) * 2 + i - j;
            ans = (ans + s) % 1'000'000'007;
            last[x] = i;
        }
        return ans;
    }
};

java 解法, 执行用时: 108 ms, 内存消耗: 65.3 MB, 提交时间: 2023-10-29 10:43:02

class Solution {
    private long[] sum;
    private int[] todo;

    public int sumCounts(int[] nums) {
        int n = nums.length;
        sum = new long[n * 4];
        todo = new int[n * 4];

        long ans = 0, s = 0;
        var last = new HashMap<Integer, Integer>();
        for (int i = 1; i <= n; i++) {
            int x = nums[i - 1];
            int j = last.getOrDefault(x, 0);
            s += queryAndAdd1(1, 1, n, j + 1, i) * 2 + i - j;
            ans = (ans + s) % 1_000_000_007;
            last.put(x, i);
        }
        return (int) ans;
    }

    private void do_(int o, int l, int r, int add) {
        sum[o] += (long) add * (r - l + 1);
        todo[o] += add;
    }

    // o=1  [l,r] 1<=l<=r<=n
    // 把 [L,R] 加一,同时返回加一之前的区间和
    private long queryAndAdd1(int o, int l, int r, int L, int R) {
        if (L <= l && r <= R) {
            long res = sum[o];
            do_(o, l, r, 1);
            return res;
        }

        int m = (l + r) / 2;
        int add = todo[o];
        if (add != 0) {
            do_(o * 2, l, m, add);
            do_(o * 2 + 1, m + 1, r, add);
            todo[o] = 0;
        }

        long res = 0;
        if (L <= m) res += queryAndAdd1(o * 2, l, m, L, R);
        if (m < R)  res += queryAndAdd1(o * 2 + 1, m + 1, r, L, R);
        sum[o] = sum[o * 2] + sum[o * 2 + 1];
        return res;
    }
}

python3 解法, 执行用时: 4348 ms, 内存消耗: 70.5 MB, 提交时间: 2023-10-29 10:42:51

class Solution:
    def sumCounts(self, nums: List[int]) -> int:
        n = len(nums)
        sum = [0] * (n * 4)
        todo = [0] * (n * 4)

        def do(o: int, l: int, r: int, add: int) -> None:
            sum[o] += add * (r - l + 1)
            todo[o] += add

        # o=1  [l,r] 1<=l<=r<=n
        # 把 [L,R] 加一,同时返回加一之前的区间和
        def query_and_add1(o: int, l: int, r: int, L: int, R: int) -> int:
            if L <= l and r <= R:
                res = sum[o]
                do(o, l, r, 1)
                return res

            m = (l + r) // 2
            add = todo[o]
            if add:
                do(o * 2, l, m, add)
                do(o * 2 + 1, m + 1, r, add)
                todo[o] = 0

            res = 0
            if L <= m: res += query_and_add1(o * 2, l, m, L, R)
            if m < R:  res += query_and_add1(o * 2 + 1, m + 1, r, L, R)
            sum[o] = sum[o * 2] + sum[o * 2 + 1]
            return res

        ans = s = 0
        last = {}
        for i, x in enumerate(nums, 1):
            j = last.get(x, 0)
            s += query_and_add1(1, 1, n, j + 1, i) * 2 + i - j
            ans += s
            last[x] = i
        return ans % 1_000_000_007

上一题