class Solution {
public:
vector<long long> handleQuery(vector<int>& nums1, vector<int>& nums2, vector<vector<int>>& queries) {
}
};
2569. 更新数组后处理求和查询
给你两个下标从 0 开始的数组 nums1
和 nums2
,和一个二维数组 queries
表示一些操作。总共有 3 种类型的操作:
queries[i] = [1, l, r]
。你需要将 nums1
从下标 l
到下标 r
的所有 0
反转成 1
或将 1
反转成 0
。l
和 r
下标都从 0 开始。queries[i] = [2, p, 0]
。对于 0 <= i < n
中的所有下标,令 nums2[i] = nums2[i] + nums1[i] * p
。queries[i] = [3, 0, 0]
。求 nums2
中所有元素的和。请你返回一个数组,包含所有第三种操作类型的答案。
示例 1:
输入:nums1 = [1,0,1], nums2 = [0,0,0], queries = [[1,1,1],[2,1,0],[3,0,0]] 输出:[3] 解释:第一个操作后 nums1 变为 [1,1,1] 。第二个操作后,nums2 变成 [1,1,1] ,所以第三个操作的答案为 3 。所以返回 [3] 。
示例 2:
输入:nums1 = [1], nums2 = [5], queries = [[2,0,0],[3,0,0]] 输出:[5] 解释:第一个操作后,nums2 保持不变为 [5] ,所以第二个操作的答案是 5 。所以返回 [5] 。
提示:
1 <= nums1.length,nums2.length <= 105
nums1.length = nums2.length
1 <= queries.length <= 105
queries[i].length = 3
0 <= l <= r <= nums1.length - 1
0 <= p <= 106
0 <= nums1[i] <= 1
0 <= nums2[i] <= 109
原站题解
python3 解法, 执行用时: 456 ms, 内存消耗: 44 MB, 提交时间: 2023-02-26 11:14:25
class Solution: def handleQuery(self, nums1: List[int], nums2: List[int], queries: List[List[int]]) -> List[int]: n = len(nums1) s = sum(nums2) x = int(''.join(map(str,nums1[::-1])),2) ans = [] for op,l,r in queries: if op == 1: y = (1<<(r-l+1))-1 y <<= l x ^= y elif op == 2: s += l*x.bit_count() else: ans.append(s) return ans
java 解法, 执行用时: 308 ms, 内存消耗: 69.3 MB, 提交时间: 2023-02-26 11:14:03
class Solution { public long[] handleQuery(int[] nums1, int[] nums2, int[][] queries) { int n = nums1.length; long sum = 0; BitSet bitSet = new BitSet(n); for (int i = 0; i < n; ++i) { sum += nums2[i]; if (nums1[i] == 1) bitSet.set(i); } List<Long> ans = new ArrayList<>(); for (int[] q : queries) { if (q[0] == 1) { // 反转 [l, r] bitSet.flip(q[1], q[2] + 1); } else if (q[0] == 2) { // 1的总个数 * p sum += (long) bitSet.cardinality() * q[1]; } else { ans.add(sum); } } return ans.stream().mapToLong(Long::longValue).toArray(); } }
golang 解法, 执行用时: 212 ms, 内存消耗: 32.6 MB, 提交时间: 2023-02-26 11:13:23
type seg []struct { l, r, cnt1 int flip bool } func (t seg) maintain(o int) { t[o].cnt1 = t[o<<1].cnt1 + t[o<<1|1].cnt1 } func (t seg) build(a []int, o, l, r int) { t[o].l, t[o].r = l, r if l == r { t[o].cnt1 = a[l-1] return } m := (l + r) >> 1 t.build(a, o<<1, l, m) t.build(a, o<<1|1, m+1, r) t.maintain(o) } func (t seg) do(O int) { o := &t[O] o.cnt1 = o.r - o.l + 1 - o.cnt1 o.flip = !o.flip } func (t seg) spread(o int) { if t[o].flip { t.do(o << 1) t.do(o<<1 | 1) t[o].flip = false } } func (t seg) update(o, l, r int) { if l <= t[o].l && t[o].r <= r { t.do(o) return } t.spread(o) m := (t[o].l + t[o].r) >> 1 if l <= m { t.update(o<<1, l, r) } if m < r { t.update(o<<1|1, l, r) } t.maintain(o) } func handleQuery(nums1, nums2 []int, queries [][]int) (ans []int64) { sum := 0 for _, x := range nums2 { sum += x } t := make(seg, len(nums1)*4) t.build(nums1, 1, 1, len(nums1)) for _, q := range queries { if q[0] == 1 { t.update(1, q[1]+1, q[2]+1) } else if q[0] == 2 { sum += q[1] * t[1].cnt1 } else { ans = append(ans, int64(sum)) } } return }
cpp 解法, 执行用时: 256 ms, 内存消耗: 114.6 MB, 提交时间: 2023-02-26 11:13:06
class Solution { static constexpr int MX = 4e5 + 1; int cnt1[MX]; bool flip[MX]; void maintain(int o) { cnt1[o] = cnt1[o * 2] + cnt1[o * 2 + 1]; } void do_(int o, int l, int r) { cnt1[o] = r - l + 1 - cnt1[o]; flip[o] = !flip[o]; } // 初始化线段树 o,l,r=1,1,n void build(vector<int> &a, int o, int l, int r) { if (l == r) { cnt1[o] = a[l - 1]; return; } int m = (l + r) / 2; build(a, o * 2, l, m); build(a, o * 2 + 1, m + 1, r); maintain(o); } // 反转区间 [L,R] o,l,r=1,1,n void update(int o, int l, int r, int L, int R) { if (L <= l && r <= R) { do_(o, l, r); return; } int m = (l + r) / 2; if (flip[o]) { do_(o * 2, l, m); do_(o * 2 + 1, m + 1, r); flip[o] = false; } if (m >= L) update(o * 2, l, m, L, R); if (m < R) update(o * 2 + 1, m + 1, r, L, R); maintain(o); } public: vector<long long> handleQuery(vector<int> &nums1, vector<int> &nums2, vector<vector<int>> &queries) { int n = nums1.size(); build(nums1, 1, 1, n); vector<long long> ans; long long sum = accumulate(nums2.begin(), nums2.end(), 0LL); for (auto &q : queries) { if (q[0] == 1) update(1, 1, n, q[1] + 1, q[2] + 1); else if (q[0] == 2) sum += 1LL * q[1] * cnt1[1]; else ans.push_back(sum); } return ans; } };
java 解法, 执行用时: 29 ms, 内存消耗: 92.4 MB, 提交时间: 2023-02-26 11:12:47
class Solution { public long[] handleQuery(int[] nums1, int[] nums2, int[][] queries) { int n = nums1.length, m = 0, i = 0; cnt1 = new int[n * 4]; flip = new boolean[n * 4]; build(nums1, 1, 1, n); var sum = 0L; for (var x : nums2) sum += x; for (var q : queries) if (q[0] == 3) ++m; var ans = new long[m]; for (var q : queries) { if (q[0] == 1) update(1, 1, n, q[1] + 1, q[2] + 1); else if (q[0] == 2) sum += (long) q[1] * cnt1[1]; else ans[i++] = sum; } return ans; } private int[] cnt1; private boolean[] flip; private void maintain(int o) { cnt1[o] = cnt1[o * 2] + cnt1[o * 2 + 1]; } private void do_(int o, int l, int r) { cnt1[o] = r - l + 1 - cnt1[o]; flip[o] = !flip[o]; } // 初始化线段树 o,l,r=1,1,n private void build(int[] a, int o, int l, int r) { if (l == r) { cnt1[o] = a[l - 1]; return; } int m = (l + r) / 2; build(a, o * 2, l, m); build(a, o * 2 + 1, m + 1, r); maintain(o); } // 反转区间 [L,R] o,l,r=1,1,n private void update(int o, int l, int r, int L, int R) { if (L <= l && r <= R) { do_(o, l, r); return; } int m = (l + r) / 2; if (flip[o]) { do_(o * 2, l, m); do_(o * 2 + 1, m + 1, r); flip[o] = false; } if (m >= L) update(o * 2, l, m, L, R); if (m < R) update(o * 2 + 1, m + 1, r, L, R); maintain(o); } }
python3 解法, 执行用时: 940 ms, 内存消耗: 47.4 MB, 提交时间: 2023-02-26 11:12:27
''' 线段树 由于操作2和操作3更新和统计的是所有nums2[i]的值,那么我们其实只需要维护nums1中1的个数。 用线段树维护区间内1的个数cnt1,以及区间反转标记fip。 ''' class Solution: def handleQuery(self, nums1: List[int], nums2: List[int], queries: List[List[int]]) -> List[int]: n = len(nums1) cnt1 = [0] * (4 * n) flip = [False] * (4 * n) def maintain(o: int) -> None: cnt1[o] = cnt1[o * 2] + cnt1[o * 2 + 1] def do(o: int, l: int, r: int) -> None: cnt1[o] = r - l + 1 - cnt1[o] flip[o] = not flip[o] # 初始化线段树 o,l,r=1,1,n def build(o: int, l: int, r: int) -> None: if l == r: cnt1[o] = nums1[l - 1] return m = (l + r) // 2 build(o * 2, l, m) build(o * 2 + 1, m + 1, r) maintain(o) # 反转区间 [L,R] o,l,r=1,1,n def update(o: int, l: int, r: int, L: int, R: int) -> None: if L <= l and r <= R: do(o, l, r) return m = (l + r) // 2 if flip[o]: do(o * 2, l, m) do(o * 2 + 1, m + 1, r) flip[o] = False if m >= L: update(o * 2, l, m, L, R) if m < R: update(o * 2 + 1, m + 1, r, L, R) maintain(o) build(1, 1, n) ans, s = [], sum(nums2) for op, l, r in queries: if op == 1: update(1, 1, n, l + 1, r + 1) elif op == 2: s += l * cnt1[1] else: ans.append(s) return ans