列表

详情


NC226598. 势能线段树模板题二

描述

智乃酱最近在学习势能线段树
对于势能线段树,假设线段树的节点数为,操作数目为
若势能可被某种操作重置或者增加,则必须考虑最坏情况下能够提供操作总数*|单次操作重置势能的节点总数|*|节点势能上限|的时间复杂度。
则势能线段树则总时间复杂度为
一般来讲,在使用lazy_tag的情况下, |线段树单次操作影响到的节点数目|就是,本题中节点势能上限近似是个(其实是6)所以总复杂度是级别的。
使用势能线段树时要定义势能、势能初始值(势能最大值)、0势能点。

本题中有两种可以定义势能与0势能的方法,你可以都尝试一下。
  1. 定义区间开根次数cnt为势能,势能初始值为势能上限=6,定义0势能点为cnt=0,但是此方法重置势能比较困难,强烈不推荐使用。
  2. 定义区间最大值max与区间最小值min的差值diff为势能,势能初始值为差值diff,定义0势能点为diff=0。
--------------------------------------------------------------------------------------------------------------------------
给你一个长度大小为的正整数数组,进行次操作,操作有下列两种。
  1. 给定区间对区间中所有数字开根号向下取整,即
  2. 给定区间,对区间中每个数字加上一个正整数
  3. 查询给定区间的元素和,即求

输入描述

第一行输入两个正整数
接下来一行输入个正整数
接下来行,每行首先输入一个正整数
时,表示操作一,然后继续输入两个正整数表示对区间开根号向下取整。
时,表示操作二,然后继续输入三个正整数表示给区间加上一个正整数
时,表示操作二,然后继续输入两个正整数表示求区间的区间和。

输出描述

对于每一个当,输出区间和。

示例1

输入:

10 3
10 10 10 10 10 10 10 10 10 10
3 1 10
1 1 5
3 1 10

输出:

100
65

示例2

输入:

10 3
10 10 10 10 10 10 10 10 10 10
1 1 5
2 1 10 5
3 1 10

输出:

115

原站题解

上次编辑到这里,代码来自缓存 点击恢复默认模板

C++(g++ 7.5.0) 解法, 执行用时: 797ms, 内存消耗: 22428K, 提交时间: 2023-04-09 10:33:49

#include <iostream>
#include <cmath>

using namespace std;

using i64 = long long;

const int N = 200010;

struct Node {
    int l, r;
    i64 mx, mn;
    i64 add;
    i64 sum;
}tr[N << 2];
int n, m;
int w[N];
int a, b, c;

void pushup(int u) {
    tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
    tr[u].mx = max(tr[u << 1].mx, tr[u << 1 | 1].mx);
    tr[u].mn = min(tr[u << 1].mn, tr[u << 1 | 1].mn);
}

void pushadd(int u, i64 v) {
    tr[u].add += v;
    tr[u].sum += v * (tr[u].r - tr[u].l + 1);
    tr[u].mx += v, tr[u].mn += v;
}

void pushdown(int u) {
    if (tr[u].add) {
        pushadd(u << 1, tr[u].add);
        pushadd(u << 1 | 1, tr[u].add);
        tr[u].add = 0;
    }
}

void build(int u, int l, int r) {
    tr[u] = {l, r};
    if (l == r) return void(tr[u].mx = tr[u].mn = tr[u].sum = w[l]);
    int mid = l + r >> 1;
    build(u << 1, l, mid);
    build(u << 1 | 1, mid + 1, r);
    pushup(u);
}

void modify(int u, int l, int r) {
    if (tr[u].l >= l && tr[u].r <= r) {
        int delta1 = tr[u].mx - (i64)sqrt(tr[u].mx);
        int delta2 = tr[u].mn - (i64)sqrt(tr[u].mn);
        if (delta1 == delta2) return pushadd(u, -delta1);
    }
    pushdown(u);
    int mid = tr[u].l + tr[u].r >> 1;
    if (l <= mid) modify(u << 1, l, r);
    if (r > mid) modify(u << 1 | 1, l, r);
    pushup(u);
}

i64 query(int u, int l, int r) {
    if (tr[u].l >= l && tr[u].r <= r) return tr[u].sum;
    int mid = tr[u].l + tr[u].r >> 1;
    pushdown(u);
    i64 res = 0;
    if (l <= mid) res += query(u << 1, l, r);
    if (r > mid) res += query(u << 1 | 1, l, r);
    return res;
}

void add(int u, int l, int r, int v) {
    if (tr[u].l >= l && tr[u].r <= r) return pushadd(u, v);
    int mid = tr[u].l + tr[u].r >> 1;
    pushdown(u);
    if (l <= mid) add(u << 1, l, r, v);
    if (r > mid) add(u << 1 | 1, l, r, v);
    pushup(u);
}

int main() {
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; i++) scanf("%d", w + i);
    build(1, 1, n);
    while (m--) {
        int op, l, r, v;
        scanf("%d%d%d", &op, &l, &r);
        if (op == 1) modify(1, l, r);
        if (op == 2) scanf("%d", &v), add(1, l, r, v);
        if (op == 3) printf("%lld\n", query(1, l, r));
    }
    return 0;
}

C++(clang++ 11.0.1) 解法, 执行用时: 735ms, 内存消耗: 27304K, 提交时间: 2022-09-29 20:40:54

#include<iostream>
#include<cmath>
#define _for(i,L,R) for(int i=L;i<=R;++i)
#define int long long
using namespace std;

const int N=2e5+5;
int n,m,a[N];
class Node{
public:
	int l,r,mn,mx,s;
	int add;
	#define ls u<<1
	#define rs u<<1|1
	#define mid (tr[u].l+tr[u].r>>1)
}tr[N<<2];

void pushup(int u)
{
	tr[u].s=tr[ls].s+tr[rs].s;
	tr[u].mn=min(tr[ls].mn,tr[rs].mn);
	tr[u].mx=max(tr[ls].mx,tr[rs].mx);
}

void build(int u,int l,int r)
{
	if(l==r) return tr[u]={l,r,a[r],a[r],a[r]},void();
	tr[u]={l,r};
	build(ls,l,mid); build(rs,mid+1,r);
	pushup(u); 
}

void change(int u,int x)
{
	tr[u].s+=(tr[u].r-tr[u].l+1)*x;
	tr[u].mx+=x;
	tr[u].mn+=x;
	tr[u].add+=x;
}

void pushdown(int u)
{
	if(tr[u].add!=0){
		change(ls,tr[u].add);
		change(rs,tr[u].add);
		tr[u].add=0;
	} 
}

void uqdate1(int u,int l,int r)
{
	if(l<=tr[u].l and tr[u].r<=r){
		int d1=tr[u].mx-(int)sqrt(tr[u].mx);
		int d2=tr[u].mn-(int)sqrt(tr[u].mn);
		if(d1==d2) return change(u,-d1),void();
	}
	pushdown(u);
	if(l<=mid) uqdate1(ls,l,r);
	if(r>mid) uqdate1(rs,l,r);
	pushup(u);
}

void uqdate2(int u,int l,int r,int x)
{
	if(l<=tr[u].l and tr[u].r<=r) return change(u,x),void();
	pushdown(u);
	if(l<=mid) uqdate2(ls,l,r,x);
	if(r>mid) uqdate2(rs,l,r,x);
	pushup(u);
}

int query(int u,int l,int r)
{
	if(l<=tr[u].l and tr[u].r<=r) return tr[u].s;
	pushdown(u);
	int res=0;
	if(l<=mid) res+=query(ls,l,r);
	if(r>mid) res+=query(rs,l,r);
	return res; 
}

signed main()
{
	scanf("%lld%lld",&n,&m);
	_for(i,1,n) scanf("%lld",a+i);
	build(1,1,n);
	while(m--){
		int op,l,r;
		scanf("%lld%lld%lld",&op,&l,&r);
		if(op==1) uqdate1(1,l,r);
		else if(op==2){
			int x;
			scanf("%lld",&x);
			uqdate2(1,l,r,x);
		}
		else printf("%lld\n",query(1,l,r));
	}
    return 0;
}

上一题