列表

详情


1622. 奇妙序列

请你实现三个 API appendaddAll 和 multAll 来实现奇妙序列。

请实现 Fancy 类 :

 

示例:

输入:
["Fancy", "append", "addAll", "append", "multAll", "getIndex", "addAll", "append", "multAll", "getIndex", "getIndex", "getIndex"]
[[], [2], [3], [7], [2], [0], [3], [10], [2], [0], [1], [2]]
输出:
[null, null, null, null, null, 10, null, null, null, 26, 34, 20]

解释:
Fancy fancy = new Fancy();
fancy.append(2);   // 奇妙序列:[2]
fancy.addAll(3);   // 奇妙序列:[2+3] -> [5]
fancy.append(7);   // 奇妙序列:[5, 7]
fancy.multAll(2);  // 奇妙序列:[5*2, 7*2] -> [10, 14]
fancy.getIndex(0); // 返回 10
fancy.addAll(3);   // 奇妙序列:[10+3, 14+3] -> [13, 17]
fancy.append(10);  // 奇妙序列:[13, 17, 10]
fancy.multAll(2);  // 奇妙序列:[13*2, 17*2, 10*2] -> [26, 34, 20]
fancy.getIndex(0); // 返回 26
fancy.getIndex(1); // 返回 34
fancy.getIndex(2); // 返回 20

 

提示:

原站题解

去查看

上次编辑到这里,代码来自缓存 点击恢复默认模板
class Fancy { public: Fancy() { } void append(int val) { } void addAll(int inc) { } void multAll(int m) { } int getIndex(int idx) { } }; /** * Your Fancy object will be instantiated and called as such: * Fancy* obj = new Fancy(); * obj->append(val); * obj->addAll(inc); * obj->multAll(m); * int param_4 = obj->getIndex(idx); */

python3 解法, 执行用时: 8544 ms, 内存消耗: 66.5 MB, 提交时间: 2022-07-28 15:35:05

class SegTree:
    def __init__(self, l: int, r: int):
        self.treesum = 0
        self.l = l
        self.r = r
        self.left = None
        self.right = None
        self.add = 0
        self.mul = 1
    
    @property
    def _mid(self):
        return (self.l + self.r) // 2

    @property
    def _left(self):
        if self.left == None:
            self.left = SegTree(self.l, self._mid)
        return self.left

    @property
    def _right(self):
        if self.right == None:
            self.right = SegTree(self._mid + 1, self.r)
        return self.right

    def push_up(self) -> None:
        self.treesum = self._left.treesum + self._right.treesum
        self.treesum %= (10 ** 9 + 7)
    
    def add_update(self, ul: int, ur: int, addval: int) -> None:
        if ul <= self.l and self.r <= ur:
            self.treesum += (self. r - self.l + 1) * addval
            self.add += addval
            self.treesum %= (10 ** 9 + 7)
            return 

        self.lazy_push_down()

        if ul <= self._mid:
            self._left.add_update(ul, ur, addval)
        if self._mid + 1 <= ur:
            self._right.add_update(ul, ur, addval)
        
        self.push_up()

    def mul_update(self, ul: int, ur: int, mulval: int) -> None:
        if ul <= self.l and self.r <= ur:
            self.treesum *= mulval
            self.add *= mulval
            self.mul *= mulval
            self.treesum %= (10 ** 9 + 7)
            self.add %= (10 ** 9 + 7)
            self.mul %= (10 ** 9 + 7)
            return 
        
        self.lazy_push_down()

        if ul <= self._mid:
            self._left.mul_update(ul, ur, mulval)
        if self._mid + 1 <= ur:
            self._right.mul_update(ul, ur, mulval)
        
        self.push_up()

    def query(self, ql: int, qr: int) -> int:
        if qr < self.l or self.r < ql:
            return 0
        if ql <= self.l and self.r <= qr:
            return self.treesum

        self.lazy_push_down()

        range_sum = self._left.query(ql, qr) + self._right.query(ql, qr)
        return range_sum % (10 ** 9 + 7)

    def lazy_push_down(self) -> None:
        if self.add != 0 or self.mul != 1:
            #---- 更新左子和右子的懒数据
            self._left.treesum = self._left.treesum * self.mul + (self._left.r - self._left.l + 1) * self.add
            self._right.treesum = self._right.treesum * self.mul + (self._right.r - self._right.l + 1) * self.add
            self._left.mul *= self.mul
            self._right.mul *= self.mul
            self._left.add = self._left.add * self.mul + self.add
            self._right.add = self._right.add * self.mul + self.add
            
            self._left.treesum %= (10 ** 9 + 7)
            self._right.treesum %= (10 ** 9 + 7)
            self._left.mul %= (10 ** 9 + 7)
            self._right.mul %= (10 ** 9 + 7)
            self._left.add %= (10 ** 9 + 7)
            self._right.add %= (10 ** 9 + 7)

            self.add = 0
            self.mul = 1

class Fancy:

    def __init__(self):
        self.n = -1     #实指 
        self.ST = SegTree(0, 10 ** 5)

    def append(self, val: int) -> None:
        self.n += 1
        self.ST.add_update(self.n, self.n, val)

    def addAll(self, inc: int) -> None:
        if self.n >= 0:         #有个奇葩测试数据,上来就加
            self.ST.add_update(0, self.n, inc)

    def multAll(self, m: int) -> None:
        if self.n >= 0:
            self.ST.mul_update(0, self.n, m)

    def getIndex(self, idx: int) -> int:
        if self.n < idx:
            return -1
        return self.ST.query(idx, idx)


# Your Fancy object will be instantiated and called as such:
# obj = Fancy()
# obj.append(val)
# obj.addAll(inc)
# obj.multAll(m)
# param_4 = obj.getIndex(idx)

cpp 解法, 执行用时: 472 ms, 内存消耗: 184 MB, 提交时间: 2022-07-28 15:34:22

const int64_t MOD=1e9+7;
struct Node{
    int64_t a,b;    //f(x)=ax+b
    Node(int _a=1,int _b=0):a(_a),b(_b){}
    void operator += (const Node &t){
        a=a*t.a%MOD;
        b=(b*t.a+t.b)%MOD;
    }
};
Node A[200000];
class Fancy {
public:
    vector<int64_t> nums;
    Fancy(){
        for(int i=0;i<200000;i++)
            A[i].a=1,A[i].b=0;
    }
    void append(int val) {
        nums.push_back(val);
    }
    void addAll(int inc) {
        if(nums.empty())
            return;
        Node op(1,inc);
        for(int i = 100000-nums.size();i<200000;i+=i&-i)
            A[i]+=op;
    }
    void multAll(int m) {
        if(nums.empty())
            return;
        Node op(m,0);
        for(int i = 100000-nums.size();i<200000;i+=i&-i)
            A[i]+=op;
    }
    int getIndex(int idx) {
        if(idx>=nums.size())
            return -1;
        Node op;
        for(int i = 100000-(idx+1);i>=1;i-=i&-i)
            op += A[i];
        return (op.a*nums[idx]+op.b)%MOD;
    }
};


/**
 * Your Fancy object will be instantiated and called as such:
 * Fancy* obj = new Fancy();
 * obj->append(val);
 * obj->addAll(inc);
 * obj->multAll(m);
 * int param_4 = obj->getIndex(idx);
 */

上一题