327. 区间和的个数
给你一个整数数组 nums
以及两个整数 lower
和 upper
。求数组中,值位于范围 [lower, upper]
(包含 lower
和 upper
)之内的 区间和的个数 。
区间和 S(i, j)
表示在 nums
中,位置从 i
到 j
的元素之和,包含 i
和 j
(i
≤ j
)。
示例 1:
输入:nums = [-2,5,-1], lower = -2, upper = 2 输出:3 解释:存在三个区间:[0,0]、[2,2] 和 [0,2] ,对应的区间和分别是:-2 、-1 、2 。
示例 2:
输入:nums = [0], lower = 0, upper = 0 输出:1
提示:
1 <= nums.length <= 105
-231 <= nums[i] <= 231 - 1
-105 <= lower <= upper <= 105
原站题解
python3 解法, 执行用时: 7820 ms, 内存消耗: 75.7 MB, 提交时间: 2023-10-08 23:35:45
class SegTree: #线段树,数组实现 def __init__(self, n): self.n = n self.treesum = [0 for _ in range(4 * self.n)] def update(self, ID, diff): self._update(0, 0, self.n - 1, ID, diff) def query(self, ql, qr): return self._query(0, 0, self.n - 1, ql, qr) def _update(self, root, l, r, ID, diff): if l == r == ID: self.treesum[root] += diff return left = 2 * root + 1 right = 2 *root + 2 mid = l + r >> 1 if ID <= mid: self._update(left, l, mid, ID, diff) else: self._update(right, mid + 1, r, ID, diff) self.treesum[root] = self.treesum[left] + self.treesum[right] def _query(self, root, l, r, ql, qr): if l == ql and r == qr: return self.treesum[root] left = 2 * root + 1 right = 2 * root + 2 mid = l + r >> 1 if qr <= mid: return self._query(left, l, mid, ql, qr) elif mid + 1 <= ql: return self._query(right, mid + 1, r, ql, qr) else: return self._query(left, l, mid, ql, mid) + self._query(right, mid + 1, r, mid + 1, qr) class Solution: def countRangeSum(self, nums: List[int], lower: int, upper: int) -> int: n1 = len(nums) presum = [0 for _ in range(n1 + 1)] for i in range(n1): presum[i + 1] = presum[i] + nums[i] #------------------列举所有数字 去重 排序 离散化 all_num = [] for p in presum: all_num += [p, p - lower, p - upper] all_num = list(set(all_num)) all_num.sort() n2 = len(all_num) val_id = dict() for i, val in enumerate(all_num): val_id[val] = i res = 0 ST = SegTree(n2) for p in presum: L = val_id[p - upper] R = val_id[p - lower] res += ST.query(L, R) ST.update(val_id[p], 1) return res
python3 解法, 执行用时: 2808 ms, 内存消耗: 31.6 MB, 提交时间: 2023-10-08 23:35:18
class Solution: def countRangeSum(self, nums: List[int], lower: int, upper: int) -> int: self.lower = lower self.upper = upper self.res = 0 n = len(nums) prefixsum = [0 for _ in range(n + 1)] for i in range(n): prefixsum[i + 1] = prefixsum[i] + nums[i] self.mergesort(prefixsum, 0, len(prefixsum) - 1) return self.res def mergesort(self, nums: List[int], L: int, R: int) -> None: if L < R: mid = L + (R - L) // 2 self.mergesort(nums, L, mid) self.mergesort(nums, mid + 1, R) self.merge(nums, L, mid, R) def merge(self, nums: List[int], L: int, mid: int, R: int) -> None: i, j = L, mid + 1 tmp = [] while i <= mid and j <= R: if nums[i] <= nums[j]: tmp.append(nums[i]) i += 1 else: tmp.append(nums[j]) j += 1 ######################################### ## 套用标准的归并排序 本题需要单独计算的地方 ii, jj, kk = L, mid + 1, mid + 1 while ii <= mid: while jj <= R and nums[jj] - nums[ii] < self.lower: jj += 1 while kk <= R and nums[kk] - nums[ii] <= self.upper: kk += 1 self.res += (kk - jj) ii += 1 ######################################### if i <= mid: tmp += nums[i: mid + 1] if j <= R: tmp += nums[j: R + 1] for i in range(len(tmp)): nums[L + i] = tmp[i]
python3 解法, 执行用时: 3488 ms, 内存消耗: 66.9 MB, 提交时间: 2023-10-08 23:35:05
class BitTree: #树状数组 动态前缀和 def __init__(self, n): self.tree = [0 for x in range(n + 1)] self.n = n #---- 最右侧1的权重 def lowbit(self, i: int) -> int: return i & (-i) #----某个位置,加上k def update(self, i: int, k: int) -> None: while i <= self.n: self.tree[i] += k i += self.lowbit(i) #----前缀和(实指) def presum(self, i: int) -> int: res = 0 while i > 0: res += self.tree[i] i -= self.lowbit(i) return res class Solution: def countRangeSum(self, nums: List[int], lower: int, upper: int) -> int: n = len(nums) #---- 前缀和 实指 presum = [0 for _ in range(n + 1)] for i in range(n): #虚指 presum[i+1] = presum[i] + nums[i] #------------------ 以presum 为对象 离散化 + 树状数组----------------------# #------ 所有的点 all_num = [] for x in presum: all_num += [x, x - lower, x - upper] #------ 离散化 #all_num = list(set(all_num)) #离散化,要去重 都行 all_num.sort() #排序 val_id = dict() for i, x in enumerate(all_num): val_id[x] = i #------ 树状数组 BIT = BitTree(len(all_num)) res = 0 for i, x in enumerate(presum): #遍历,往前探 idL = val_id[x - upper] idR = val_id[x - lower] res += ( BIT.presum(idR + 1) - BIT.presum(idL + 1 - 1) ) ID = val_id[x] BIT.update(ID + 1, 1) return res
python3 解法, 执行用时: 6136 ms, 内存消耗: 31.9 MB, 提交时间: 2023-10-08 23:34:52
class Solution: def countRangeSum(self, nums: List[int], lower: int, upper: int) -> int: n1 = len(nums) presum = [0 for _ in range(n1 + 1)] for i in range(n1): presum[i + 1] = presum[i] + nums[i] res = 0 preWindow = [] for i, p in enumerate(presum): L = bisect_left(preWindow, p - upper) R = bisect_right(preWindow, p - lower) res += (R - L) bisect.insort(preWindow, p) return res # sortedcontainers def countRangeSum2(self, nums: List[int], lower: int, upper: int) -> int: from sortedcontainers import SortedList as SL n1 = len(nums) presum = [0 for _ in range(n1 + 1)] for i in range(n1): presum[i + 1] = presum[i] + nums[i] res = 0 preWindow = SL() for i, p in enumerate(presum): L = preWindow.bisect_left(p - upper) R = preWindow.bisect_right(p - lower) res += (R - L) preWindow.add(p) return res
golang 解法, 执行用时: 228 ms, 内存消耗: 11.2 MB, 提交时间: 2023-10-08 23:33:52
import "math/rand" // 默认导入的 rand 不是这个库,需要显式指明 type node struct { ch [2]*node priority int key int dupCnt int sz int } func (o *node) cmp(b int) int { switch { case b < o.key: return 0 case b > o.key: return 1 default: return -1 } } func (o *node) size() int { if o != nil { return o.sz } return 0 } func (o *node) maintain() { o.sz = o.dupCnt + o.ch[0].size() + o.ch[1].size() } func (o *node) rotate(d int) *node { x := o.ch[d^1] o.ch[d^1] = x.ch[d] x.ch[d] = o o.maintain() x.maintain() return x } type treap struct { root *node } func (t *treap) _insert(o *node, key int) *node { if o == nil { return &node{priority: rand.Int(), key: key, dupCnt: 1, sz: 1} } if d := o.cmp(key); d >= 0 { o.ch[d] = t._insert(o.ch[d], key) if o.ch[d].priority > o.priority { o = o.rotate(d ^ 1) } } else { o.dupCnt++ } o.maintain() return o } func (t *treap) insert(key int) { t.root = t._insert(t.root, key) } // equal=false: 小于 key 的元素个数 // equal=true: 小于或等于 key 的元素个数 func (t *treap) rank(key int, equal bool) (cnt int) { for o := t.root; o != nil; { switch c := o.cmp(key); { case c == 0: o = o.ch[0] case c > 0: cnt += o.dupCnt + o.ch[0].size() o = o.ch[1] default: cnt += o.ch[0].size() if equal { cnt += o.dupCnt } return } } return } func countRangeSum(nums []int, lower, upper int) (cnt int) { preSum := make([]int, len(nums)+1) for i, v := range nums { preSum[i+1] = preSum[i] + v } t := &treap{} for _, sum := range preSum { left, right := sum-upper, sum-lower cnt += t.rank(right, true) - t.rank(left, false) t.insert(sum) } return }
java 解法, 执行用时: 159 ms, 内存消耗: 59.4 MB, 提交时间: 2023-10-08 23:33:30
class Solution { public int countRangeSum(int[] nums, int lower, int upper) { long sum = 0; long[] preSum = new long[nums.length + 1]; for (int i = 0; i < nums.length; ++i) { sum += nums[i]; preSum[i + 1] = sum; } BalancedTree treap = new BalancedTree(); int ret = 0; for (long x : preSum) { long numLeft = treap.lowerBound(x - upper); int rankLeft = (numLeft == Long.MAX_VALUE ? (int) (treap.getSize() + 1) : treap.rank(numLeft)[0]); long numRight = treap.upperBound(x - lower); int rankRight = (numRight == Long.MAX_VALUE ? (int) treap.getSize() : treap.rank(numRight)[0] - 1); ret += rankRight - rankLeft + 1; treap.insert(x); } return ret; } } class BalancedTree { private class BalancedNode { long val; long seed; int count; int size; BalancedNode left; BalancedNode right; BalancedNode(long val, long seed) { this.val = val; this.seed = seed; this.count = 1; this.size = 1; this.left = null; this.right = null; } BalancedNode leftRotate() { int prevSize = size; int currSize = (left != null ? left.size : 0) + (right.left != null ? right.left.size : 0) + count; BalancedNode root = right; right = root.left; root.left = this; root.size = prevSize; size = currSize; return root; } BalancedNode rightRotate() { int prevSize = size; int currSize = (right != null ? right.size : 0) + (left.right != null ? left.right.size : 0) + count; BalancedNode root = left; left = root.right; root.right = this; root.size = prevSize; size = currSize; return root; } } private BalancedNode root; private int size; private Random rand; public BalancedTree() { this.root = null; this.size = 0; this.rand = new Random(); } public long getSize() { return size; } public void insert(long x) { ++size; root = insert(root, x); } public long lowerBound(long x) { BalancedNode node = root; long ans = Long.MAX_VALUE; while (node != null) { if (x == node.val) { return x; } if (x < node.val) { ans = node.val; node = node.left; } else { node = node.right; } } return ans; } public long upperBound(long x) { BalancedNode node = root; long ans = Long.MAX_VALUE; while (node != null) { if (x < node.val) { ans = node.val; node = node.left; } else { node = node.right; } } return ans; } public int[] rank(long x) { BalancedNode node = root; int ans = 0; while (node != null) { if (x < node.val) { node = node.left; } else { ans += (node.left != null ? node.left.size : 0) + node.count; if (x == node.val) { return new int[]{ans - node.count + 1, ans}; } node = node.right; } } return new int[]{Integer.MIN_VALUE, Integer.MAX_VALUE}; } private BalancedNode insert(BalancedNode node, long x) { if (node == null) { return new BalancedNode(x, rand.nextInt()); } ++node.size; if (x < node.val) { node.left = insert(node.left, x); if (node.left.seed > node.seed) { node = node.rightRotate(); } } else if (x > node.val) { node.right = insert(node.right, x); if (node.right.seed > node.seed) { node = node.leftRotate(); } } else { ++node.count; } return node; } }
cpp 解法, 执行用时: 488 ms, 内存消耗: 142.9 MB, 提交时间: 2023-10-08 23:33:16
class BalancedTree { private: struct BalancedNode { long long val; long long seed; int count; int size; BalancedNode* left; BalancedNode* right; BalancedNode(long long _val, long long _seed): val(_val), seed(_seed), count(1), size(1), left(nullptr), right(nullptr) {} BalancedNode* left_rotate() { int prev_size = size; int curr_size = (left ? left->size : 0) + (right->left ? right->left->size : 0) + count; BalancedNode* root = right; right = root->left; root->left = this; root->size = prev_size; size = curr_size; return root; } BalancedNode* right_rotate() { int prev_size = size; int curr_size = (right ? right->size : 0) + (left->right ? left->right->size : 0) + count; BalancedNode* root = left; left = root->right; root->right = this; root->size = prev_size; size = curr_size; return root; } }; private: BalancedNode* root; int size; mt19937 gen; uniform_int_distribution<long long> dis; private: BalancedNode* insert(BalancedNode* node, long long x) { if (!node) { return new BalancedNode(x, dis(gen)); } ++node->size; if (x < node->val) { node->left = insert(node->left, x); if (node->left->seed > node->seed) { node = node->right_rotate(); } } else if (x > node->val) { node->right = insert(node->right, x); if (node->right->seed > node->seed) { node = node->left_rotate(); } } else { ++node->count; } return node; } public: BalancedTree(): root(nullptr), size(0), gen(random_device{}()), dis(LLONG_MIN, LLONG_MAX) {} long long get_size() const { return size; } void insert(long long x) { ++size; root = insert(root, x); } long long lower_bound(long long x) const { BalancedNode* node = root; long long ans = LLONG_MAX; while (node) { if (x == node->val) { return x; } if (x < node->val) { ans = node->val; node = node->left; } else { node = node->right; } } return ans; } long long upper_bound(long long x) const { BalancedNode* node = root; long long ans = LLONG_MAX; while (node) { if (x < node->val) { ans = node->val; node = node->left; } else { node = node->right; } } return ans; } pair<int, int> rank(long long x) const { BalancedNode* node = root; int ans = 0; while (node) { if (x < node->val) { node = node->left; } else { ans += (node->left ? node->left->size : 0) + node->count; if (x == node->val) { return {ans - node->count + 1, ans}; } node = node->right; } } return {INT_MIN, INT_MAX}; } }; class Solution { public: int countRangeSum(vector<int>& nums, int lower, int upper) { long long sum = 0; vector<long long> preSum = {0}; for (int v: nums) { sum += v; preSum.push_back(sum); } BalancedTree* treap = new BalancedTree(); int ret = 0; for (long long x: preSum) { long long numLeft = treap->lower_bound(x - upper); int rankLeft = (numLeft == LLONG_MAX ? treap->get_size() + 1 : treap->rank(numLeft).first); long long numRight = treap->upper_bound(x - lower); int rankRight = (numRight == LLONG_MAX ? treap->get_size() : treap->rank(numRight).first - 1); ret += (rankRight - rankLeft + 1); treap->insert(x); } return ret; } };
cpp 解法, 执行用时: 1224 ms, 内存消耗: 302.6 MB, 提交时间: 2023-10-08 23:33:06
class BIT { private: vector<int> tree; int n; public: BIT(int _n): n(_n), tree(_n + 1) {} static constexpr int lowbit(int x) { return x & (-x); } void update(int x, int d) { while (x <= n) { tree[x] += d; x += lowbit(x); } } int query(int x) const { int ans = 0; while (x) { ans += tree[x]; x -= lowbit(x); } return ans; } }; class Solution { public: int countRangeSum(vector<int>& nums, int lower, int upper) { long long sum = 0; vector<long long> preSum = {0}; for (int v: nums) { sum += v; preSum.push_back(sum); } set<long long> allNumbers; for (long long x: preSum) { allNumbers.insert(x); allNumbers.insert(x - lower); allNumbers.insert(x - upper); } // 利用哈希表进行离散化 unordered_map<long long, int> values; int idx = 0; for (long long x: allNumbers) { values[x] = idx; idx++; } int ret = 0; BIT bit(values.size()); for (int i = 0; i < preSum.size(); i++) { int left = values[preSum[i] - upper], right = values[preSum[i] - lower]; ret += bit.query(right + 1) - bit.query(left); bit.update(values[preSum[i]] + 1, 1); } return ret; } };
java 解法, 执行用时: 637 ms, 内存消耗: 107.1 MB, 提交时间: 2023-10-08 23:32:51
class Solution { public int countRangeSum(int[] nums, int lower, int upper) { long sum = 0; long[] preSum = new long[nums.length + 1]; for (int i = 0; i < nums.length; ++i) { sum += nums[i]; preSum[i + 1] = sum; } Set<Long> allNumbers = new TreeSet<Long>(); for (long x : preSum) { allNumbers.add(x); allNumbers.add(x - lower); allNumbers.add(x - upper); } // 利用哈希表进行离散化 Map<Long, Integer> values = new HashMap<Long, Integer>(); int idx = 0; for (long x: allNumbers) { values.put(x, idx); idx++; } int ret = 0; BIT bit = new BIT(values.size()); for (int i = 0; i < preSum.length; i++) { int left = values.get(preSum[i] - upper), right = values.get(preSum[i] - lower); ret += bit.query(right + 1) - bit.query(left); bit.update(values.get(preSum[i]) + 1, 1); } return ret; } } class BIT { int[] tree; int n; public BIT(int n) { this.n = n; this.tree = new int[n + 1]; } public static int lowbit(int x) { return x & (-x); } public void update(int x, int d) { while (x <= n) { tree[x] += d; x += lowbit(x); } } public int query(int x) { int ans = 0; while (x != 0) { ans += tree[x]; x -= lowbit(x); } return ans; } }
golang 解法, 执行用时: 400 ms, 内存消耗: 38.3 MB, 提交时间: 2023-10-08 23:32:32
type fenwick struct { tree []int } func (f fenwick) inc(i int) { for ; i < len(f.tree); i += i & -i { f.tree[i]++ } } func (f fenwick) sum(i int) (res int) { for ; i > 0; i &= i - 1 { res += f.tree[i] } return } func (f fenwick) query(l, r int) (res int) { return f.sum(r) - f.sum(l-1) } func countRangeSum(nums []int, lower, upper int) (cnt int) { n := len(nums) // 计算前缀和 preSum,以及后面统计时会用到的所有数字 allNums allNums := make([]int, 1, 3*n+1) preSum := make([]int, n+1) for i, v := range nums { preSum[i+1] = preSum[i] + v allNums = append(allNums, preSum[i+1], preSum[i+1]-lower, preSum[i+1]-upper) } // 将 allNums 离散化 sort.Ints(allNums) k := 1 kth := map[int]int{allNums[0]: k} for i := 1; i <= 3*n; i++ { if allNums[i] != allNums[i-1] { k++ kth[allNums[i]] = k } } // 遍历 preSum,利用树状数组计算每个前缀和对应的合法区间数 t := fenwick{make([]int, k+1)} t.inc(kth[0]) for _, sum := range preSum[1:] { left, right := kth[sum-upper], kth[sum-lower] cnt += t.query(left, right) t.inc(kth[sum]) } return }
golang 解法, 执行用时: 1004 ms, 内存消耗: 278.1 MB, 提交时间: 2023-10-08 23:32:20
type node struct { l, r, val int lo, ro *node } func (o *node) insert(val int) { o.val++ if o.l == o.r { return } m := (o.l + o.r) >> 1 if val <= m { if o.lo == nil { o.lo = &node{l: o.l, r: m} } o.lo.insert(val) } else { if o.ro == nil { o.ro = &node{l: m + 1, r: o.r} } o.ro.insert(val) } } func (o *node) query(l, r int) (res int) { if o == nil || l > o.r || r < o.l { return } if l <= o.l && o.r <= r { return o.val } return o.lo.query(l, r) + o.ro.query(l, r) } func countRangeSum(nums []int, lower, upper int) (cnt int) { preSum := make([]int, len(nums)+1) for i, v := range nums { preSum[i+1] = preSum[i] + v } lbound, rbound := math.MaxInt64, -math.MaxInt64 for _, sum := range preSum { lbound = min(lbound, sum, sum-lower, sum-upper) rbound = max(rbound, sum, sum-lower, sum-upper) } root := &node{l: lbound, r: rbound} for _, sum := range preSum { left, right := sum-upper, sum-lower cnt += root.query(left, right) root.insert(sum) } return } func min(a ...int) int { res := a[0] for _, v := range a[1:] { if v < res { res = v } } return res } func max(a ...int) int { res := a[0] for _, v := range a[1:] { if v > res { res = v } } return res }
java 解法, 执行用时: 346 ms, 内存消耗: 229.8 MB, 提交时间: 2023-10-08 23:32:06
class Solution { public int countRangeSum(int[] nums, int lower, int upper) { long sum = 0; long[] preSum = new long[nums.length + 1]; for (int i = 0; i < nums.length; ++i) { sum += nums[i]; preSum[i + 1] = sum; } long lbound = Long.MAX_VALUE, rbound = Long.MIN_VALUE; for (long x : preSum) { lbound = Math.min(Math.min(lbound, x), Math.min(x - lower, x - upper)); rbound = Math.max(Math.max(rbound, x), Math.max(x - lower, x - upper)); } SegNode root = new SegNode(lbound, rbound); int ret = 0; for (long x : preSum) { ret += count(root, x - upper, x - lower); insert(root, x); } return ret; } public int count(SegNode root, long left, long right) { if (root == null) { return 0; } if (left > root.hi || right < root.lo) { return 0; } if (left <= root.lo && root.hi <= right) { return root.add; } return count(root.lchild, left, right) + count(root.rchild, left, right); } public void insert(SegNode root, long val) { root.add++; if (root.lo == root.hi) { return; } long mid = (root.lo + root.hi) >> 1; if (val <= mid) { if (root.lchild == null) { root.lchild = new SegNode(root.lo, mid); } insert(root.lchild, val); } else { if (root.rchild == null) { root.rchild = new SegNode(mid + 1, root.hi); } insert(root.rchild, val); } } } class SegNode { long lo, hi; int add; SegNode lchild, rchild; public SegNode(long left, long right) { lo = left; hi = right; add = 0; lchild = null; rchild = null; } }
cpp 解法, 执行用时: 1924 ms, 内存消耗: 764.3 MB, 提交时间: 2023-10-08 23:31:53
struct SegNode { long long lo, hi; int add; SegNode* lchild, *rchild; SegNode(long long left, long long right): lo(left), hi(right), add(0), lchild(nullptr), rchild(nullptr) {} }; class Solution { public: void insert(SegNode* root, long long val) { root->add++; if (root->lo == root->hi) { return; } long long mid = (root->lo + root->hi) >> 1; if (val <= mid) { if (!root->lchild) { root->lchild = new SegNode(root->lo, mid); } insert(root->lchild, val); } else { if (!root->rchild) { root->rchild = new SegNode(mid + 1, root->hi); } insert(root->rchild, val); } } int count(SegNode* root, long long left, long long right) const { if (!root) { return 0; } if (left > root->hi || right < root->lo) { return 0; } if (left <= root->lo && root->hi <= right) { return root->add; } return count(root->lchild, left, right) + count(root->rchild, left, right); } int countRangeSum(vector<int>& nums, int lower, int upper) { long long sum = 0; vector<long long> preSum = {0}; for(int v: nums) { sum += v; preSum.push_back(sum); } long long lbound = LLONG_MAX, rbound = LLONG_MIN; for (long long x: preSum) { lbound = min({lbound, x, x - lower, x - upper}); rbound = max({rbound, x, x - lower, x - upper}); } SegNode* root = new SegNode(lbound, rbound); int ret = 0; for (long long x: preSum) { ret += count(root, x - upper, x - lower); insert(root, x); } return ret; } };
java 解法, 执行用时: 930 ms, 内存消耗: 149 MB, 提交时间: 2023-10-08 23:31:24
class Solution { public int countRangeSum(int[] nums, int lower, int upper) { long sum = 0; long[] preSum = new long[nums.length + 1]; for (int i = 0; i < nums.length; ++i) { sum += nums[i]; preSum[i + 1] = sum; } Set<Long> allNumbers = new TreeSet<Long>(); for (long x : preSum) { allNumbers.add(x); allNumbers.add(x - lower); allNumbers.add(x - upper); } // 利用哈希表进行离散化 Map<Long, Integer> values = new HashMap<Long, Integer>(); int idx = 0; for (long x : allNumbers) { values.put(x, idx); idx++; } SegNode root = build(0, values.size() - 1); int ret = 0; for (long x : preSum) { int left = values.get(x - upper), right = values.get(x - lower); ret += count(root, left, right); insert(root, values.get(x)); } return ret; } public SegNode build(int left, int right) { SegNode node = new SegNode(left, right); if (left == right) { return node; } int mid = (left + right) / 2; node.lchild = build(left, mid); node.rchild = build(mid + 1, right); return node; } public int count(SegNode root, int left, int right) { if (left > root.hi || right < root.lo) { return 0; } if (left <= root.lo && root.hi <= right) { return root.add; } return count(root.lchild, left, right) + count(root.rchild, left, right); } public void insert(SegNode root, int val) { root.add++; if (root.lo == root.hi) { return; } int mid = (root.lo + root.hi) / 2; if (val <= mid) { insert(root.lchild, val); } else { insert(root.rchild, val); } } } class SegNode { int lo, hi, add; SegNode lchild, rchild; public SegNode(int left, int right) { lo = left; hi = right; add = 0; lchild = null; rchild = null; } }
golang 解法, 执行用时: 592 ms, 内存消耗: 80.5 MB, 提交时间: 2023-10-08 23:31:12
type segTree []struct { l, r, val int } func (t segTree) build(o, l, r int) { t[o].l, t[o].r = l, r if l == r { return } m := (l + r) >> 1 t.build(o<<1, l, m) t.build(o<<1|1, m+1, r) } func (t segTree) inc(o, i int) { if t[o].l == t[o].r { t[o].val++ return } if i <= (t[o].l+t[o].r)>>1 { t.inc(o<<1, i) } else { t.inc(o<<1|1, i) } t[o].val = t[o<<1].val + t[o<<1|1].val } func (t segTree) query(o, l, r int) (res int) { if l <= t[o].l && t[o].r <= r { return t[o].val } m := (t[o].l + t[o].r) >> 1 if r <= m { return t.query(o<<1, l, r) } if l > m { return t.query(o<<1|1, l, r) } return t.query(o<<1, l, r) + t.query(o<<1|1, l, r) } func countRangeSum(nums []int, lower, upper int) (cnt int) { n := len(nums) // 计算前缀和 preSum,以及后面统计时会用到的所有数字 allNums allNums := make([]int, 1, 3*n+1) preSum := make([]int, n+1) for i, v := range nums { preSum[i+1] = preSum[i] + v allNums = append(allNums, preSum[i+1], preSum[i+1]-lower, preSum[i+1]-upper) } // 将 allNums 离散化 sort.Ints(allNums) k := 1 kth := map[int]int{allNums[0]: k} for i := 1; i <= 3*n; i++ { if allNums[i] != allNums[i-1] { k++ kth[allNums[i]] = k } } // 遍历 preSum,利用线段树计算每个前缀和对应的合法区间数 t := make(segTree, 4*k) t.build(1, 1, k) t.inc(1, kth[0]) for _, sum := range preSum[1:] { left, right := kth[sum-upper], kth[sum-lower] cnt += t.query(1, left, right) t.inc(1, kth[sum]) } return }
golang 解法, 执行用时: 160 ms, 内存消耗: 10.2 MB, 提交时间: 2023-10-08 23:31:00
func countRangeSum(nums []int, lower, upper int) int { var mergeCount func([]int) int mergeCount = func(arr []int) int { n := len(arr) if n <= 1 { return 0 } n1 := append([]int(nil), arr[:n/2]...) n2 := append([]int(nil), arr[n/2:]...) cnt := mergeCount(n1) + mergeCount(n2) // 递归完毕后,n1 和 n2 均为有序 // 统计下标对的数量 l, r := 0, 0 for _, v := range n1 { for l < len(n2) && n2[l]-v < lower { l++ } for r < len(n2) && n2[r]-v <= upper { r++ } cnt += r - l } // n1 和 n2 归并填入 arr p1, p2 := 0, 0 for i := range arr { if p1 < len(n1) && (p2 == len(n2) || n1[p1] <= n2[p2]) { arr[i] = n1[p1] p1++ } else { arr[i] = n2[p2] p2++ } } return cnt } prefixSum := make([]int, len(nums)+1) for i, v := range nums { prefixSum[i+1] = prefixSum[i] + v } return mergeCount(prefixSum) }
javascript 解法, 执行用时: 208 ms, 内存消耗: 64.2 MB, 提交时间: 2023-10-08 23:30:47
/** * @param {number[]} nums * @param {number} lower * @param {number} upper * @return {number} */ const countRangeSumRecursive = (sum, lower, upper, left, right) => { if (left === right) { return 0; } else { const mid = Math.floor((left + right) / 2); const n1 = countRangeSumRecursive(sum, lower, upper, left, mid); const n2 = countRangeSumRecursive(sum, lower, upper, mid + 1, right); let ret = n1 + n2; // 首先统计下标对的数量 let i = left; let l = mid + 1; let r = mid + 1; while (i <= mid) { while (l <= right && sum[l] - sum[i] < lower) l++; while (r <= right && sum[r] - sum[i] <= upper) r++; ret += (r - l); i++; } // 随后合并两个排序数组 const sorted = new Array(right - left + 1); let p1 = left, p2 = mid + 1; let p = 0; while (p1 <= mid || p2 <= right) { if (p1 > mid) { sorted[p++] = sum[p2++]; } else if (p2 > right) { sorted[p++] = sum[p1++]; } else { if (sum[p1] < sum[p2]) { sorted[p++] = sum[p1++]; } else { sorted[p++] = sum[p2++]; } } } for (let i = 0; i < sorted.length; i++) { sum[left + i] = sorted[i]; } return ret; } } var countRangeSum = function(nums, lower, upper) { let s = 0; const sum = [0]; for(const v of nums) { s += v; sum.push(s); } return countRangeSumRecursive(sum, lower, upper, 0, sum.length - 1); };
java 解法, 执行用时: 60 ms, 内存消耗: 55.4 MB, 提交时间: 2023-10-08 23:30:33
class Solution { public int countRangeSum(int[] nums, int lower, int upper) { long s = 0; long[] sum = new long[nums.length + 1]; for (int i = 0; i < nums.length; ++i) { s += nums[i]; sum[i + 1] = s; } return countRangeSumRecursive(sum, lower, upper, 0, sum.length - 1); } public int countRangeSumRecursive(long[] sum, int lower, int upper, int left, int right) { if (left == right) { return 0; } else { int mid = (left + right) / 2; int n1 = countRangeSumRecursive(sum, lower, upper, left, mid); int n2 = countRangeSumRecursive(sum, lower, upper, mid + 1, right); int ret = n1 + n2; // 首先统计下标对的数量 int i = left; int l = mid + 1; int r = mid + 1; while (i <= mid) { while (l <= right && sum[l] - sum[i] < lower) { l++; } while (r <= right && sum[r] - sum[i] <= upper) { r++; } ret += r - l; i++; } // 随后合并两个排序数组 long[] sorted = new long[right - left + 1]; int p1 = left, p2 = mid + 1; int p = 0; while (p1 <= mid || p2 <= right) { if (p1 > mid) { sorted[p++] = sum[p2++]; } else if (p2 > right) { sorted[p++] = sum[p1++]; } else { if (sum[p1] < sum[p2]) { sorted[p++] = sum[p1++]; } else { sorted[p++] = sum[p2++]; } } } for (int j = 0; j < sorted.length; j++) { sum[left + j] = sorted[j]; } return ret; } } }
cpp 解法, 执行用时: 580 ms, 内存消耗: 208.4 MB, 提交时间: 2023-10-08 23:30:20
class Solution { public: int countRangeSumRecursive(vector<long>& sum, int lower, int upper, int left, int right) { if (left == right) { return 0; } else { int mid = (left + right) / 2; int n1 = countRangeSumRecursive(sum, lower, upper, left, mid); int n2 = countRangeSumRecursive(sum, lower, upper, mid + 1, right); int ret = n1 + n2; // 首先统计下标对的数量 int i = left; int l = mid + 1; int r = mid + 1; while (i <= mid) { while (l <= right && sum[l] - sum[i] < lower) l++; while (r <= right && sum[r] - sum[i] <= upper) r++; ret += (r - l); i++; } // 随后合并两个排序数组 vector<long> sorted(right - left + 1); int p1 = left, p2 = mid + 1; int p = 0; while (p1 <= mid || p2 <= right) { if (p1 > mid) { sorted[p++] = sum[p2++]; } else if (p2 > right) { sorted[p++] = sum[p1++]; } else { if (sum[p1] < sum[p2]) { sorted[p++] = sum[p1++]; } else { sorted[p++] = sum[p2++]; } } } for (int i = 0; i < sorted.size(); i++) { sum[left + i] = sorted[i]; } return ret; } } int countRangeSum(vector<int>& nums, int lower, int upper) { long s = 0; vector<long> sum{0}; for(auto& v: nums) { s += v; sum.push_back(s); } return countRangeSumRecursive(sum, lower, upper, 0, sum.size() - 1); } };