列表

详情


327. 区间和的个数

给你一个整数数组 nums 以及两个整数 lowerupper 。求数组中,值位于范围 [lower, upper] (包含 lower 和 upper)之内的 区间和的个数

区间和 S(i, j) 表示在 nums 中,位置从 i 到 j 的元素之和,包含 i 和 j (ij)。

 

示例 1:
输入:nums = [-2,5,-1], lower = -2, upper = 2
输出:3
解释:存在三个区间:[0,0]、[2,2] 和 [0,2] ,对应的区间和分别是:-2 、-1 、2 。

示例 2:

输入:nums = [0], lower = 0, upper = 0
输出:1

 

提示:

相似题目

计算右侧小于当前元素的个数

翻转对

原站题解

去查看

上次编辑到这里,代码来自缓存 点击恢复默认模板
class Solution { public: int countRangeSum(vector<int>& nums, int lower, int upper) { } };

python3 解法, 执行用时: 7820 ms, 内存消耗: 75.7 MB, 提交时间: 2023-10-08 23:35:45

class SegTree:       #线段树,数组实现
    def __init__(self, n):
        self.n = n
        self.treesum = [0 for _ in range(4 * self.n)]
    
    def update(self, ID, diff):
        self._update(0, 0, self.n - 1, ID, diff)
    
    def query(self, ql, qr):
        return self._query(0, 0, self.n - 1, ql, qr)
    
    def _update(self, root, l, r, ID, diff):
        if l == r == ID:
            self.treesum[root] += diff
            return 
        left = 2 * root + 1
        right = 2 *root + 2
        mid = l + r >> 1
        if ID <= mid:
            self._update(left, l, mid, ID, diff)
        else:
            self._update(right, mid + 1, r, ID, diff)
        self.treesum[root] = self.treesum[left] + self.treesum[right]

    def _query(self, root, l, r, ql, qr):
        if l == ql and r == qr:
            return self.treesum[root]
        left = 2 * root + 1
        right = 2 * root + 2
        mid = l + r >> 1
        if qr <= mid:
            return self._query(left, l, mid, ql, qr)
        elif mid + 1 <= ql:
            return self._query(right, mid + 1, r, ql, qr)
        else:
            return self._query(left, l, mid, ql, mid) + self._query(right, mid + 1, r, mid + 1, qr)


class Solution:

    def countRangeSum(self, nums: List[int], lower: int, upper: int) -> int:
        n1 = len(nums)
        presum = [0 for _ in range(n1 + 1)]
        for i in range(n1):
            presum[i + 1] = presum[i] + nums[i]
        #------------------列举所有数字 去重 排序 离散化
        all_num = []
        for p in presum:
            all_num += [p, p - lower, p - upper]
        all_num = list(set(all_num))
        all_num.sort()
        n2 = len(all_num)
        val_id = dict()
        for i, val in enumerate(all_num):
            val_id[val] = i
        
        res = 0
        ST = SegTree(n2)
        for p in presum:
            L = val_id[p - upper]
            R = val_id[p - lower]
            res += ST.query(L, R)
            
            ST.update(val_id[p], 1)

        return res

python3 解法, 执行用时: 2808 ms, 内存消耗: 31.6 MB, 提交时间: 2023-10-08 23:35:18

class Solution:
    def countRangeSum(self, nums: List[int], lower: int, upper: int) -> int:
        self.lower = lower
        self.upper = upper
        self.res = 0
        n = len(nums)
        prefixsum = [0 for _ in range(n + 1)]
        for i in range(n):
            prefixsum[i + 1] = prefixsum[i] + nums[i]

        self.mergesort(prefixsum, 0, len(prefixsum) - 1)
        return self.res

    
    def mergesort(self, nums: List[int], L: int, R: int) -> None:
        if L < R:
            mid = L + (R - L) // 2
            self.mergesort(nums, L, mid)
            self.mergesort(nums, mid + 1, R)
            self.merge(nums, L, mid, R)
        
    def merge(self, nums: List[int], L: int, mid: int, R: int) -> None:
        i, j = L, mid + 1
        tmp = []
        while i <= mid and j <= R:
            if nums[i] <= nums[j]:
                tmp.append(nums[i])
                i += 1
            else:
                tmp.append(nums[j])
                j += 1
        #########################################
        ## 套用标准的归并排序   本题需要单独计算的地方
        ii, jj, kk = L, mid + 1, mid + 1
        while ii <= mid:
            while jj <= R and nums[jj] - nums[ii] < self.lower:
                jj += 1
            while kk <= R and nums[kk] - nums[ii] <= self.upper:
                kk += 1
            self.res += (kk - jj)
            ii += 1
        #########################################
        if i <= mid:
            tmp += nums[i: mid + 1]
        if j <= R:
            tmp += nums[j: R + 1]
        for i in range(len(tmp)):
            nums[L + i] = tmp[i]

python3 解法, 执行用时: 3488 ms, 内存消耗: 66.9 MB, 提交时间: 2023-10-08 23:35:05

class BitTree:          #树状数组 动态前缀和
    def __init__(self, n):
        self.tree = [0 for x in range(n + 1)]
        self.n = n
    #---- 最右侧1的权重
    def lowbit(self, i: int) -> int:
        return i & (-i)
    #----某个位置,加上k
    def update(self, i: int, k: int) -> None:
        while i <= self.n:
            self.tree[i] += k
            i += self.lowbit(i)
    #----前缀和(实指)   
    def presum(self, i: int) -> int:
        res = 0
        while i > 0:
            res += self.tree[i]
            i -= self.lowbit(i)
        return res

class Solution:
    def countRangeSum(self, nums: List[int], lower: int, upper: int) -> int:
        n = len(nums)
        #---- 前缀和 实指
        presum = [0 for _ in range(n + 1)]
        for i in range(n):   #虚指
            presum[i+1] = presum[i] + nums[i]
            #------------------ 以presum 为对象  离散化 + 树状数组----------------------#
        #------ 所有的点
        all_num = []
        for x in presum:
            all_num += [x, x - lower, x - upper]
        
        #------ 离散化
        #all_num = list(set(all_num))    #离散化,要去重 都行
        all_num.sort()                  #排序

        val_id = dict()
        for i, x in enumerate(all_num):
            val_id[x] = i
        #------ 树状数组
        BIT = BitTree(len(all_num))
        res = 0
        for i, x in enumerate(presum): #遍历,往前探
            idL = val_id[x - upper]
            idR = val_id[x - lower]
            res +=  ( BIT.presum(idR + 1) - BIT.presum(idL + 1 - 1) )

            ID = val_id[x]
            BIT.update(ID + 1, 1)
    
        return res

python3 解法, 执行用时: 6136 ms, 内存消耗: 31.9 MB, 提交时间: 2023-10-08 23:34:52

class Solution:
    def countRangeSum(self, nums: List[int], lower: int, upper: int) -> int:
        n1 = len(nums)
        presum = [0 for _ in range(n1 + 1)]
        for i in range(n1):
            presum[i + 1] = presum[i] + nums[i]
            
        res = 0
        preWindow = []
        for i, p in enumerate(presum):
            L = bisect_left(preWindow, p - upper)
            R = bisect_right(preWindow, p - lower)
            res += (R - L)
            bisect.insort(preWindow, p)

        return res
        

    # sortedcontainers
    def countRangeSum2(self, nums: List[int], lower: int, upper: int) -> int:
        from sortedcontainers import SortedList as SL
        n1 = len(nums)
        presum = [0 for _ in range(n1 + 1)]
        for i in range(n1):
            presum[i + 1] = presum[i] + nums[i]

        res = 0
        preWindow = SL()
        for i, p in enumerate(presum):
            L = preWindow.bisect_left(p - upper)
            R = preWindow.bisect_right(p - lower)
            res += (R - L)
            preWindow.add(p)

        return res

golang 解法, 执行用时: 228 ms, 内存消耗: 11.2 MB, 提交时间: 2023-10-08 23:33:52

import "math/rand" // 默认导入的 rand 不是这个库,需要显式指明

type node struct {
    ch       [2]*node
    priority int
    key      int
    dupCnt   int
    sz       int
}

func (o *node) cmp(b int) int {
    switch {
    case b < o.key:
        return 0
    case b > o.key:
        return 1
    default:
        return -1
    }
}

func (o *node) size() int {
    if o != nil {
        return o.sz
    }
    return 0
}

func (o *node) maintain() {
    o.sz = o.dupCnt + o.ch[0].size() + o.ch[1].size()
}

func (o *node) rotate(d int) *node {
    x := o.ch[d^1]
    o.ch[d^1] = x.ch[d]
    x.ch[d] = o
    o.maintain()
    x.maintain()
    return x
}

type treap struct {
    root *node
}

func (t *treap) _insert(o *node, key int) *node {
    if o == nil {
        return &node{priority: rand.Int(), key: key, dupCnt: 1, sz: 1}
    }
    if d := o.cmp(key); d >= 0 {
        o.ch[d] = t._insert(o.ch[d], key)
        if o.ch[d].priority > o.priority {
            o = o.rotate(d ^ 1)
        }
    } else {
        o.dupCnt++
    }
    o.maintain()
    return o
}

func (t *treap) insert(key int) {
    t.root = t._insert(t.root, key)
}

// equal=false: 小于 key 的元素个数
// equal=true: 小于或等于 key 的元素个数
func (t *treap) rank(key int, equal bool) (cnt int) {
    for o := t.root; o != nil; {
        switch c := o.cmp(key); {
        case c == 0:
            o = o.ch[0]
        case c > 0:
            cnt += o.dupCnt + o.ch[0].size()
            o = o.ch[1]
        default:
            cnt += o.ch[0].size()
            if equal {
                cnt += o.dupCnt
            }
            return
        }
    }
    return
}

func countRangeSum(nums []int, lower, upper int) (cnt int) {
    preSum := make([]int, len(nums)+1)
    for i, v := range nums {
        preSum[i+1] = preSum[i] + v
    }

    t := &treap{}
    for _, sum := range preSum {
        left, right := sum-upper, sum-lower
        cnt += t.rank(right, true) - t.rank(left, false)
        t.insert(sum)
    }
    return
}

java 解法, 执行用时: 159 ms, 内存消耗: 59.4 MB, 提交时间: 2023-10-08 23:33:30

class Solution {
    public int countRangeSum(int[] nums, int lower, int upper) {
        long sum = 0;
        long[] preSum = new long[nums.length + 1];
        for (int i = 0; i < nums.length; ++i) {
            sum += nums[i];
            preSum[i + 1] = sum;
        }
        
        BalancedTree treap = new BalancedTree();
        int ret = 0;
        for (long x : preSum) {
            long numLeft = treap.lowerBound(x - upper);
            int rankLeft = (numLeft == Long.MAX_VALUE ? (int) (treap.getSize() + 1) : treap.rank(numLeft)[0]);
            long numRight = treap.upperBound(x - lower);
            int rankRight = (numRight == Long.MAX_VALUE ? (int) treap.getSize() : treap.rank(numRight)[0] - 1);
            ret += rankRight - rankLeft + 1;
            treap.insert(x);
        }
        return ret;
    }
}

class BalancedTree {
    private class BalancedNode {
        long val;
        long seed;
        int count;
        int size;
        BalancedNode left;
        BalancedNode right;

        BalancedNode(long val, long seed) {
            this.val = val;
            this.seed = seed;
            this.count = 1;
            this.size = 1;
            this.left = null;
            this.right = null;
        }

        BalancedNode leftRotate() {
            int prevSize = size;
            int currSize = (left != null ? left.size : 0) + (right.left != null ? right.left.size : 0) + count;
            BalancedNode root = right;
            right = root.left;
            root.left = this;
            root.size = prevSize;
            size = currSize;
            return root;
        }

        BalancedNode rightRotate() {
            int prevSize = size;
            int currSize = (right != null ? right.size : 0) + (left.right != null ? left.right.size : 0) + count;
            BalancedNode root = left;
            left = root.right;
            root.right = this;
            root.size = prevSize;
            size = currSize;
            return root;
        }
    }

    private BalancedNode root;
    private int size;
    private Random rand;

    public BalancedTree() {
        this.root = null;
        this.size = 0;
        this.rand = new Random();
    }

    public long getSize() {
        return size;
    }

    public void insert(long x) {
        ++size;
        root = insert(root, x);
    }

    public long lowerBound(long x) {
        BalancedNode node = root;
        long ans = Long.MAX_VALUE;
        while (node != null) {
            if (x == node.val) {
                return x;
            }
            if (x < node.val) {
                ans = node.val;
                node = node.left;
            } else {
                node = node.right;
            }
        }
        return ans;
    }

    public long upperBound(long x) {
        BalancedNode node = root;
        long ans = Long.MAX_VALUE;
        while (node != null) {
            if (x < node.val) {
                ans = node.val;
                node = node.left;
            } else {
                node = node.right;
            }
        }
        return ans;
    }

    public int[] rank(long x) {
        BalancedNode node = root;
        int ans = 0;
        while (node != null) {
            if (x < node.val) {
                node = node.left;
            } else {
                ans += (node.left != null ? node.left.size : 0) + node.count;
                if (x == node.val) {
                    return new int[]{ans - node.count + 1, ans};
                }
                node = node.right;
            }
        }
        return new int[]{Integer.MIN_VALUE, Integer.MAX_VALUE};
    }

    private BalancedNode insert(BalancedNode node, long x) {
        if (node == null) {
            return new BalancedNode(x, rand.nextInt());
        }
        ++node.size;
        if (x < node.val) {
            node.left = insert(node.left, x);
            if (node.left.seed > node.seed) {
                node = node.rightRotate();
            }
        } else if (x > node.val) {
            node.right = insert(node.right, x);
            if (node.right.seed > node.seed) {
                node = node.leftRotate();
            }
        } else {
            ++node.count;
        }
        return node;
    }
}

cpp 解法, 执行用时: 488 ms, 内存消耗: 142.9 MB, 提交时间: 2023-10-08 23:33:16

class BalancedTree {
private:
    struct BalancedNode {
        long long val;
        long long seed;
        int count;
        int size;
        BalancedNode* left;
        BalancedNode* right;

        BalancedNode(long long _val, long long _seed): val(_val), seed(_seed), count(1), size(1), left(nullptr), right(nullptr) {}

        BalancedNode* left_rotate() {
            int prev_size = size;
            int curr_size = (left ? left->size : 0) + (right->left ? right->left->size : 0) + count;
            BalancedNode* root = right;
            right = root->left;
            root->left = this;
            root->size = prev_size;
            size = curr_size;
            return root;
        }

        BalancedNode* right_rotate() {
            int prev_size = size;
            int curr_size = (right ? right->size : 0) + (left->right ? left->right->size : 0) + count;
            BalancedNode* root = left;
            left = root->right;
            root->right = this;
            root->size = prev_size;
            size = curr_size;
            return root;
        }
    };

private:
    BalancedNode* root;
    int size;
    mt19937 gen;
    uniform_int_distribution<long long> dis;

private:
    BalancedNode* insert(BalancedNode* node, long long x) {
        if (!node) {
            return new BalancedNode(x, dis(gen));
        }
        ++node->size;
        if (x < node->val) {
            node->left = insert(node->left, x);
            if (node->left->seed > node->seed) {
                node = node->right_rotate();
            }
        }
        else if (x > node->val) {
            node->right = insert(node->right, x);
            if (node->right->seed > node->seed) {
                node = node->left_rotate();
            }
        }
        else {
            ++node->count;
        }
        return node;
    }

public:
    BalancedTree(): root(nullptr), size(0), gen(random_device{}()), dis(LLONG_MIN, LLONG_MAX) {}

    long long get_size() const {
        return size;
    }

    void insert(long long x) {
        ++size;
        root = insert(root, x);
    }

    long long lower_bound(long long x) const {
        BalancedNode* node = root;
        long long ans = LLONG_MAX;
        while (node) {
            if (x == node->val) {
                return x;
            }
            if (x < node->val) {
                ans = node->val;
                node = node->left;
            }
            else {
                node = node->right;
            }
        }
        return ans;
    }

    long long upper_bound(long long x) const {
        BalancedNode* node = root;
        long long ans = LLONG_MAX;
        while (node) {
            if (x < node->val) {
                ans = node->val;
                node = node->left;
            }
            else {
                node = node->right;
            }
        }
        return ans;
    }

    pair<int, int> rank(long long x) const {
        BalancedNode* node = root;
        int ans = 0;
        while (node) {
            if (x < node->val) {
                node = node->left;
            }
            else {
                ans += (node->left ? node->left->size : 0) + node->count;
                if (x == node->val) {
                    return {ans - node->count + 1, ans};
                }
                node = node->right;
            }
        }
        return {INT_MIN, INT_MAX};
    }
};

class Solution {
public:
    int countRangeSum(vector<int>& nums, int lower, int upper) {
        long long sum = 0;
        vector<long long> preSum = {0};
        for (int v: nums) {
            sum += v;
            preSum.push_back(sum);
        }
        
        BalancedTree* treap = new BalancedTree();
        int ret = 0;
        for (long long x: preSum) {
            long long numLeft = treap->lower_bound(x - upper);
            int rankLeft = (numLeft == LLONG_MAX ? treap->get_size() + 1 : treap->rank(numLeft).first);
            long long numRight = treap->upper_bound(x - lower);
            int rankRight = (numRight == LLONG_MAX ? treap->get_size() : treap->rank(numRight).first - 1);
            ret += (rankRight - rankLeft + 1);
            treap->insert(x);
        }
        return ret;
    }
};

cpp 解法, 执行用时: 1224 ms, 内存消耗: 302.6 MB, 提交时间: 2023-10-08 23:33:06

class BIT {
private:
    vector<int> tree;
    int n;

public:
    BIT(int _n): n(_n), tree(_n + 1) {}

    static constexpr int lowbit(int x) {
        return x & (-x);
    }

    void update(int x, int d) {
        while (x <= n) {
            tree[x] += d;
            x += lowbit(x);
        }
    }

    int query(int x) const {
        int ans = 0;
        while (x) {
            ans += tree[x];
            x -= lowbit(x);
        }
        return ans;
    }
};

class Solution {
public:
    int countRangeSum(vector<int>& nums, int lower, int upper) {
        long long sum = 0;
        vector<long long> preSum = {0};
        for (int v: nums) {
            sum += v;
            preSum.push_back(sum);
        }
        
        set<long long> allNumbers;
        for (long long x: preSum) {
            allNumbers.insert(x);
            allNumbers.insert(x - lower);
            allNumbers.insert(x - upper);
        }
        // 利用哈希表进行离散化
        unordered_map<long long, int> values;
        int idx = 0;
        for (long long x: allNumbers) {
            values[x] = idx;
            idx++;
        }

        int ret = 0;
        BIT bit(values.size());
        for (int i = 0; i < preSum.size(); i++) {
            int left = values[preSum[i] - upper], right = values[preSum[i] - lower];
            ret += bit.query(right + 1) - bit.query(left);
            bit.update(values[preSum[i]] + 1, 1);
        }
        return ret;
    }
};

java 解法, 执行用时: 637 ms, 内存消耗: 107.1 MB, 提交时间: 2023-10-08 23:32:51

class Solution {
    public int countRangeSum(int[] nums, int lower, int upper) {
        long sum = 0;
        long[] preSum = new long[nums.length + 1];
        for (int i = 0; i < nums.length; ++i) {
            sum += nums[i];
            preSum[i + 1] = sum;
        }
        
        Set<Long> allNumbers = new TreeSet<Long>();
        for (long x : preSum) {
            allNumbers.add(x);
            allNumbers.add(x - lower);
            allNumbers.add(x - upper);
        }
        // 利用哈希表进行离散化
        Map<Long, Integer> values = new HashMap<Long, Integer>();
        int idx = 0;
        for (long x: allNumbers) {
            values.put(x, idx);
            idx++;
        }

        int ret = 0;
        BIT bit = new BIT(values.size());
        for (int i = 0; i < preSum.length; i++) {
            int left = values.get(preSum[i] - upper), right = values.get(preSum[i] - lower);
            ret += bit.query(right + 1) - bit.query(left);
            bit.update(values.get(preSum[i]) + 1, 1);
        }
        return ret;
    }
}

class BIT {
    int[] tree;
    int n;

    public BIT(int n) {
        this.n = n;
        this.tree = new int[n + 1];
    }

    public static int lowbit(int x) {
        return x & (-x);
    }

    public void update(int x, int d) {
        while (x <= n) {
            tree[x] += d;
            x += lowbit(x);
        }
    }

    public int query(int x) {
        int ans = 0;
        while (x != 0) {
            ans += tree[x];
            x -= lowbit(x);
        }
        return ans;
    }
}

golang 解法, 执行用时: 400 ms, 内存消耗: 38.3 MB, 提交时间: 2023-10-08 23:32:32

type fenwick struct {
    tree []int
}

func (f fenwick) inc(i int) {
    for ; i < len(f.tree); i += i & -i {
        f.tree[i]++
    }
}

func (f fenwick) sum(i int) (res int) {
    for ; i > 0; i &= i - 1 {
        res += f.tree[i]
    }
    return
}

func (f fenwick) query(l, r int) (res int) {
    return f.sum(r) - f.sum(l-1)
}

func countRangeSum(nums []int, lower, upper int) (cnt int) {
    n := len(nums)

    // 计算前缀和 preSum,以及后面统计时会用到的所有数字 allNums
    allNums := make([]int, 1, 3*n+1)
    preSum := make([]int, n+1)
    for i, v := range nums {
        preSum[i+1] = preSum[i] + v
        allNums = append(allNums, preSum[i+1], preSum[i+1]-lower, preSum[i+1]-upper)
    }

    // 将 allNums 离散化
    sort.Ints(allNums)
    k := 1
    kth := map[int]int{allNums[0]: k}
    for i := 1; i <= 3*n; i++ {
        if allNums[i] != allNums[i-1] {
            k++
            kth[allNums[i]] = k
        }
    }

    // 遍历 preSum,利用树状数组计算每个前缀和对应的合法区间数
    t := fenwick{make([]int, k+1)}
    t.inc(kth[0])
    for _, sum := range preSum[1:] {
        left, right := kth[sum-upper], kth[sum-lower]
        cnt += t.query(left, right)
        t.inc(kth[sum])
    }
    return
}

golang 解法, 执行用时: 1004 ms, 内存消耗: 278.1 MB, 提交时间: 2023-10-08 23:32:20

type node struct {
    l, r, val int
    lo, ro    *node
}

func (o *node) insert(val int) {
    o.val++
    if o.l == o.r {
        return
    }
    m := (o.l + o.r) >> 1
    if val <= m {
        if o.lo == nil {
            o.lo = &node{l: o.l, r: m}
        }
        o.lo.insert(val)
    } else {
        if o.ro == nil {
            o.ro = &node{l: m + 1, r: o.r}
        }
        o.ro.insert(val)
    }
}

func (o *node) query(l, r int) (res int) {
    if o == nil || l > o.r || r < o.l {
        return
    }
    if l <= o.l && o.r <= r {
        return o.val
    }
    return o.lo.query(l, r) + o.ro.query(l, r)
}

func countRangeSum(nums []int, lower, upper int) (cnt int) {
    preSum := make([]int, len(nums)+1)
    for i, v := range nums {
        preSum[i+1] = preSum[i] + v
    }

    lbound, rbound := math.MaxInt64, -math.MaxInt64
    for _, sum := range preSum {
        lbound = min(lbound, sum, sum-lower, sum-upper)
        rbound = max(rbound, sum, sum-lower, sum-upper)
    }

    root := &node{l: lbound, r: rbound}
    for _, sum := range preSum {
        left, right := sum-upper, sum-lower
        cnt += root.query(left, right)
        root.insert(sum)
    }
    return
}

func min(a ...int) int {
    res := a[0]
    for _, v := range a[1:] {
        if v < res {
            res = v
        }
    }
    return res
}

func max(a ...int) int {
    res := a[0]
    for _, v := range a[1:] {
        if v > res {
            res = v
        }
    }
    return res
}

java 解法, 执行用时: 346 ms, 内存消耗: 229.8 MB, 提交时间: 2023-10-08 23:32:06

class Solution {
    public int countRangeSum(int[] nums, int lower, int upper) {
        long sum = 0;
        long[] preSum = new long[nums.length + 1];
        for (int i = 0; i < nums.length; ++i) {
            sum += nums[i];
            preSum[i + 1] = sum;
        }
        
        long lbound = Long.MAX_VALUE, rbound = Long.MIN_VALUE;
        for (long x : preSum) {
            lbound = Math.min(Math.min(lbound, x), Math.min(x - lower, x - upper));
            rbound = Math.max(Math.max(rbound, x), Math.max(x - lower, x - upper));
        }
        
        SegNode root = new SegNode(lbound, rbound);
        int ret = 0;
        for (long x : preSum) {
            ret += count(root, x - upper, x - lower);
            insert(root, x);
        }
        return ret;
    }

    public int count(SegNode root, long left, long right) {
        if (root == null) {
            return 0;
        }
        if (left > root.hi || right < root.lo) {
            return 0;
        }
        if (left <= root.lo && root.hi <= right) {
            return root.add;
        }
        return count(root.lchild, left, right) + count(root.rchild, left, right);
    }

    public void insert(SegNode root, long val) {
        root.add++;
        if (root.lo == root.hi) {
            return;
        }
        long mid = (root.lo + root.hi) >> 1;
        if (val <= mid) {
            if (root.lchild == null) {
                root.lchild = new SegNode(root.lo, mid);
            }
            insert(root.lchild, val);
        } else {
            if (root.rchild == null) {
                root.rchild = new SegNode(mid + 1, root.hi);
            }
            insert(root.rchild, val);
        }
    }
}

class SegNode {
    long lo, hi;
    int add;
    SegNode lchild, rchild;

    public SegNode(long left, long right) {
        lo = left;
        hi = right;
        add = 0;
        lchild = null;
        rchild = null;
    }
}

cpp 解法, 执行用时: 1924 ms, 内存消耗: 764.3 MB, 提交时间: 2023-10-08 23:31:53

struct SegNode {
    long long lo, hi;
    int add;
    SegNode* lchild, *rchild;
    SegNode(long long left, long long right): lo(left), hi(right), add(0), lchild(nullptr), rchild(nullptr) {}
};

class Solution {
public:
    void insert(SegNode* root, long long val) {
        root->add++;
        if (root->lo == root->hi) {
            return;
        }
        long long mid = (root->lo + root->hi) >> 1;
        if (val <= mid) {
            if (!root->lchild) {
                root->lchild = new SegNode(root->lo, mid);
            }
            insert(root->lchild, val);
        }
        else {
            if (!root->rchild) {
                root->rchild = new SegNode(mid + 1, root->hi);
            }
            insert(root->rchild, val);
        }
    }

    int count(SegNode* root, long long left, long long right) const {
        if (!root) {
            return 0;
        }
        if (left > root->hi || right < root->lo) {
            return 0;
        }
        if (left <= root->lo && root->hi <= right) {
            return root->add;
        }
        return count(root->lchild, left, right) + count(root->rchild, left, right);
    }

    int countRangeSum(vector<int>& nums, int lower, int upper) {
        long long sum = 0;
        vector<long long> preSum = {0};
        for(int v: nums) {
            sum += v;
            preSum.push_back(sum);
        }
        
        long long lbound = LLONG_MAX, rbound = LLONG_MIN;
        for (long long x: preSum) {
            lbound = min({lbound, x, x - lower, x - upper});
            rbound = max({rbound, x, x - lower, x - upper});
        }
        
        SegNode* root = new SegNode(lbound, rbound);
        int ret = 0;
        for (long long x: preSum) {
            ret += count(root, x - upper, x - lower);
            insert(root, x);
        }
        return ret;
    }
};

java 解法, 执行用时: 930 ms, 内存消耗: 149 MB, 提交时间: 2023-10-08 23:31:24

class Solution {
    public int countRangeSum(int[] nums, int lower, int upper) {
        long sum = 0;
        long[] preSum = new long[nums.length + 1];
        for (int i = 0; i < nums.length; ++i) {
            sum += nums[i];
            preSum[i + 1] = sum;
        }
        
        Set<Long> allNumbers = new TreeSet<Long>();
        for (long x : preSum) {
            allNumbers.add(x);
            allNumbers.add(x - lower);
            allNumbers.add(x - upper);
        }
        // 利用哈希表进行离散化
        Map<Long, Integer> values = new HashMap<Long, Integer>();
        int idx = 0;
        for (long x : allNumbers) {
            values.put(x, idx);
            idx++;
        }

        SegNode root = build(0, values.size() - 1);
        int ret = 0;
        for (long x : preSum) {
            int left = values.get(x - upper), right = values.get(x - lower);
            ret += count(root, left, right);
            insert(root, values.get(x));
        }
        return ret;
    }

    public SegNode build(int left, int right) {
        SegNode node = new SegNode(left, right);
        if (left == right) {
            return node;
        }
        int mid = (left + right) / 2;
        node.lchild = build(left, mid);
        node.rchild = build(mid + 1, right);
        return node;
    }

    public int count(SegNode root, int left, int right) {
        if (left > root.hi || right < root.lo) {
            return 0;
        }
        if (left <= root.lo && root.hi <= right) {
            return root.add;
        }
        return count(root.lchild, left, right) + count(root.rchild, left, right);
    }

    public void insert(SegNode root, int val) {
        root.add++;
        if (root.lo == root.hi) {
            return;
        }
        int mid = (root.lo + root.hi) / 2;
        if (val <= mid) {
            insert(root.lchild, val);
        } else {
            insert(root.rchild, val);
        }
    }
}

class SegNode {
    int lo, hi, add;
    SegNode lchild, rchild;

    public SegNode(int left, int right) {
        lo = left;
        hi = right;
        add = 0;
        lchild = null;
        rchild = null;
    }
}

golang 解法, 执行用时: 592 ms, 内存消耗: 80.5 MB, 提交时间: 2023-10-08 23:31:12

type segTree []struct {
    l, r, val int
}

func (t segTree) build(o, l, r int) {
    t[o].l, t[o].r = l, r
    if l == r {
        return
    }
    m := (l + r) >> 1
    t.build(o<<1, l, m)
    t.build(o<<1|1, m+1, r)
}

func (t segTree) inc(o, i int) {
    if t[o].l == t[o].r {
        t[o].val++
        return
    }
    if i <= (t[o].l+t[o].r)>>1 {
        t.inc(o<<1, i)
    } else {
        t.inc(o<<1|1, i)
    }
    t[o].val = t[o<<1].val + t[o<<1|1].val
}

func (t segTree) query(o, l, r int) (res int) {
    if l <= t[o].l && t[o].r <= r {
        return t[o].val
    }
    m := (t[o].l + t[o].r) >> 1
    if r <= m {
        return t.query(o<<1, l, r)
    }
    if l > m {
        return t.query(o<<1|1, l, r)
    }
    return t.query(o<<1, l, r) + t.query(o<<1|1, l, r)
}

func countRangeSum(nums []int, lower, upper int) (cnt int) {
    n := len(nums)

    // 计算前缀和 preSum,以及后面统计时会用到的所有数字 allNums
    allNums := make([]int, 1, 3*n+1)
    preSum := make([]int, n+1)
    for i, v := range nums {
        preSum[i+1] = preSum[i] + v
        allNums = append(allNums, preSum[i+1], preSum[i+1]-lower, preSum[i+1]-upper)
    }

    // 将 allNums 离散化
    sort.Ints(allNums)
    k := 1
    kth := map[int]int{allNums[0]: k}
    for i := 1; i <= 3*n; i++ {
        if allNums[i] != allNums[i-1] {
            k++
            kth[allNums[i]] = k
        }
    }

    // 遍历 preSum,利用线段树计算每个前缀和对应的合法区间数
    t := make(segTree, 4*k)
    t.build(1, 1, k)
    t.inc(1, kth[0])
    for _, sum := range preSum[1:] {
        left, right := kth[sum-upper], kth[sum-lower]
        cnt += t.query(1, left, right)
        t.inc(1, kth[sum])
    }
    return
}

golang 解法, 执行用时: 160 ms, 内存消耗: 10.2 MB, 提交时间: 2023-10-08 23:31:00

func countRangeSum(nums []int, lower, upper int) int {
    var mergeCount func([]int) int
    mergeCount = func(arr []int) int {
        n := len(arr)
        if n <= 1 {
            return 0
        }

        n1 := append([]int(nil), arr[:n/2]...)
        n2 := append([]int(nil), arr[n/2:]...)
        cnt := mergeCount(n1) + mergeCount(n2) // 递归完毕后,n1 和 n2 均为有序

        // 统计下标对的数量
        l, r := 0, 0
        for _, v := range n1 {
            for l < len(n2) && n2[l]-v < lower {
                l++
            }
            for r < len(n2) && n2[r]-v <= upper {
                r++
            }
            cnt += r - l
        }

        // n1 和 n2 归并填入 arr
        p1, p2 := 0, 0
        for i := range arr {
            if p1 < len(n1) && (p2 == len(n2) || n1[p1] <= n2[p2]) {
                arr[i] = n1[p1]
                p1++
            } else {
                arr[i] = n2[p2]
                p2++
            }
        }
        return cnt
    }

    prefixSum := make([]int, len(nums)+1)
    for i, v := range nums {
        prefixSum[i+1] = prefixSum[i] + v
    }
    return mergeCount(prefixSum)
}

javascript 解法, 执行用时: 208 ms, 内存消耗: 64.2 MB, 提交时间: 2023-10-08 23:30:47

/**
 * @param {number[]} nums
 * @param {number} lower
 * @param {number} upper
 * @return {number}
 */
const countRangeSumRecursive = (sum, lower, upper, left, right) => {
    if (left === right) {
        return 0;
    } else {
        const mid = Math.floor((left + right) / 2);
        const n1 = countRangeSumRecursive(sum, lower, upper, left, mid);
        const n2 = countRangeSumRecursive(sum, lower, upper, mid + 1, right);
        let ret = n1 + n2;

        // 首先统计下标对的数量
        let i = left;
        let l = mid + 1;
        let r = mid + 1;
        while (i <= mid) {
            while (l <= right && sum[l] - sum[i] < lower) l++;
            while (r <= right && sum[r] - sum[i] <= upper) r++;
            ret += (r - l);
            i++;
        }

        // 随后合并两个排序数组
        const sorted = new Array(right - left + 1);
        let p1 = left, p2 = mid + 1;
        let p = 0;
        while (p1 <= mid || p2 <= right) {
            if (p1 > mid) {
                sorted[p++] = sum[p2++];
            } else if (p2 > right) {
                sorted[p++] = sum[p1++];
            } else {
                if (sum[p1] < sum[p2]) {
                    sorted[p++] = sum[p1++];
                } else {
                    sorted[p++] = sum[p2++];
                }
            }
        }
        for (let i = 0; i < sorted.length; i++) {
            sum[left + i] = sorted[i];
        }
        return ret;
    }
}
var countRangeSum = function(nums, lower, upper) {
    let s = 0;
    const sum = [0];
    for(const v of nums) {
        s += v;
        sum.push(s);
    }
    return countRangeSumRecursive(sum, lower, upper, 0, sum.length - 1);
};

java 解法, 执行用时: 60 ms, 内存消耗: 55.4 MB, 提交时间: 2023-10-08 23:30:33

class Solution {
    public int countRangeSum(int[] nums, int lower, int upper) {
        long s = 0;
        long[] sum = new long[nums.length + 1];
        for (int i = 0; i < nums.length; ++i) {
            s += nums[i];
            sum[i + 1] = s;
        }
        return countRangeSumRecursive(sum, lower, upper, 0, sum.length - 1);
    }

    public int countRangeSumRecursive(long[] sum, int lower, int upper, int left, int right) {
        if (left == right) {
            return 0;
        } else {
            int mid = (left + right) / 2;
            int n1 = countRangeSumRecursive(sum, lower, upper, left, mid);
            int n2 = countRangeSumRecursive(sum, lower, upper, mid + 1, right);
            int ret = n1 + n2;

            // 首先统计下标对的数量
            int i = left;
            int l = mid + 1;
            int r = mid + 1;
            while (i <= mid) {
                while (l <= right && sum[l] - sum[i] < lower) {
                    l++;
                }
                while (r <= right && sum[r] - sum[i] <= upper) {
                    r++;
                }
                ret += r - l;
                i++;
            }

            // 随后合并两个排序数组
            long[] sorted = new long[right - left + 1];
            int p1 = left, p2 = mid + 1;
            int p = 0;
            while (p1 <= mid || p2 <= right) {
                if (p1 > mid) {
                    sorted[p++] = sum[p2++];
                } else if (p2 > right) {
                    sorted[p++] = sum[p1++];
                } else {
                    if (sum[p1] < sum[p2]) {
                        sorted[p++] = sum[p1++];
                    } else {
                        sorted[p++] = sum[p2++];
                    }
                }
            }
            for (int j = 0; j < sorted.length; j++) {
                sum[left + j] = sorted[j];
            }
            return ret;
        }
    }
}

cpp 解法, 执行用时: 580 ms, 内存消耗: 208.4 MB, 提交时间: 2023-10-08 23:30:20

class Solution {
public:
    int countRangeSumRecursive(vector<long>& sum, int lower, int upper, int left, int right) {
        if (left == right) {
            return 0;
        } else {
            int mid = (left + right) / 2;
            int n1 = countRangeSumRecursive(sum, lower, upper, left, mid);
            int n2 = countRangeSumRecursive(sum, lower, upper, mid + 1, right);
            int ret = n1 + n2;

            // 首先统计下标对的数量
            int i = left;
            int l = mid + 1;
            int r = mid + 1;
            while (i <= mid) {
                while (l <= right && sum[l] - sum[i] < lower) l++;
                while (r <= right && sum[r] - sum[i] <= upper) r++;
                ret += (r - l);
                i++;
            }

            // 随后合并两个排序数组
            vector<long> sorted(right - left + 1);
            int p1 = left, p2 = mid + 1;
            int p = 0;
            while (p1 <= mid || p2 <= right) {
                if (p1 > mid) {
                    sorted[p++] = sum[p2++];
                } else if (p2 > right) {
                    sorted[p++] = sum[p1++];
                } else {
                    if (sum[p1] < sum[p2]) {
                        sorted[p++] = sum[p1++];
                    } else {
                        sorted[p++] = sum[p2++];
                    }
                }
            }
            for (int i = 0; i < sorted.size(); i++) {
                sum[left + i] = sorted[i];
            }
            return ret;
        }
    }

    int countRangeSum(vector<int>& nums, int lower, int upper) {
        long s = 0;
        vector<long> sum{0};
        for(auto& v: nums) {
            s += v;
            sum.push_back(s);
        }
        return countRangeSumRecursive(sum, lower, upper, 0, sum.size() - 1);
    }
};

上一题