列表

详情


NC207668. JustAPureDataStructureProblem

描述

Mia is learning data structures, this semester. Knowing this, her boyfriend, John, gives her a difficult problem, which is about designing a new data structure.

In this problem, you are going to design a data structure which stores a sequence of elements and supports the following operations (ndenotes the total number of elements currently stored in the data structure):

输入描述

The first line is an integer T (1 ≤ T ≤ 10), indicates that there are T-group data.

For each test case, the first line contains one integerm(1 ≤ m ≤ 300,000), which represents the number of operations.

Each of the followingmlines contains an operation described above.

It is guaranteed thatnwill never exceed 1,000,000 after any operation.

输出描述

For each test case, output the results (one at a line) for all the Single Element Queries (Operation 6).


**You are guaranteed that all the positions are valid. If you're getting runtime error (RE), please go over your solution.**

示例1

输入:

2
13
1 1 2 1
1 2 1 2
3 3 2
2 4 3
4 4 4
1 5 3 5
3 7 0
3 1 2
5 1
6 1
6 2
6 3
6 4
9
1 1 1 1
3 1 0
1 1 2 1
4 2 2
3 2 0
1 1 2 2
5 2
6 1
6 2

输出:

1
2
3
4
2
1

原站题解

上次编辑到这里,代码来自缓存 点击恢复默认模板

C++14(g++5.4) 解法, 执行用时: 1437ms, 内存消耗: 7384K, 提交时间: 2020-08-20 12:01:16

#include <bits/stdc++.h>
#pragma GCC optimize(3, "Ofast")
using namespace std;

inline int read() {
	int s = 0, w = 1;
	char ch = getchar();
	while (ch < '0' || ch>'9') { if (ch == '-')w = -1; ch = getchar(); }
	while (ch >= '0' && ch <= '9') s = s * 10 + ch - '0', ch = getchar();
	return s * w;
}

typedef unsigned long long ull;
typedef long long ll;
typedef pair<ll, int> pii;

#define f(i, l, r) for (int i = l; i <= r; i++)
#define rf(i, r, l) for (int i = r; i >= l; i--)
#define all(x) (x).begin(), (x).end()
#define sf(a) scanf("%d", &a)
#define llsf(a) scanf("%lld", &a)
#define l l
#define r r
#define lch (o << 1)
#define rch (o << 1 | 1)
#define mid ((l + r) >> 1)
#define mem(x, y) memset(x, y, sizeof(x))
#define mod1(x) ((x >= mod) && (x -= mod))
#define mod2(x) ((x < 0) && (x += mod))

const int inf = 2e9 + 7;
const ll INF = INT64_MAX;
double eps = 1e-6;
const int mod = 1e9 + 7;
const int N = 1e6 + 10;
const double pi = acos(-1.0);

mt19937 mt(233333);
int v[N], ls[N], rs[N], num[N], sz[N], fix[N];
#define lc ls[rt]
#define rc rs[rt]
void upd(int rt) {
	sz[rt] = sz[lc] + sz[rc] + num[rt];
}
int cnt, root;
int newnode(int val,int n) {
	++cnt;
	v[cnt] = val;
	num[cnt] = sz[cnt] = n;
	ls[cnt] = rs[cnt] = 0;
	fix[cnt] = mt();
	return cnt;
}
pii kth(int rt, int k) {//第k个值所在的节点
	if (!k)return {0,0};
	if (!rt)return { 0,0 };
	while (rt) {
		int lsz = sz[lc];
		if (k <= lsz)rt = lc;
		else if (lsz + num[rt] >= k) {
			return {rt,lsz};
		}
		else {
			pii tmp = kth(rc, k - (lsz + num[rt]));
			return { tmp.first,tmp.second + lsz + num[rt] };
		}
	}
}
void split(int rt, int k,int &l,int &r) {//分裂左k个点到第一节点 剩下的第二节点
	if (!k) { l = 0, r = rt; return; }
	if (k == sz[rt]) { l = rt, r = 0; return; }
	if (sz[lc] >= k)r = rt, split(lc, k, l, lc);
	else l=rt,split(rc, k - sz[lc] - num[rt], rc, r);
	upd(rt);
}

int merge(int x, int y) {
	if (!x || !y)return x + y;
	if (fix[x]<fix[y]) {
		rs[x] = merge(rs[x], y); upd(x); return x;
	}
	else {
		ls[y] = merge(x, ls[y]); upd(y); 
		 return y;
	}
}
void merge(vector<int> vt) {
	int l = 0;
	for (auto r : vt) {
		pii L = kth(l, sz[l]), R = kth(r, 1);
		if (v[L.first] == v[R.first]) {
			int ll, md, rr, numb = num[L.first] + num[R.first];
			split(l, L.second, ll, md);
			split(r, num[R.first], md, rr);
			num[md] = sz[md] = numb;
			l = merge(merge(ll, md), rr);
		}
		else l = merge(l, r);
	}
	root = l;
}

void print(int rt) {
	if (!rt)return;
	print(lc);
	f(i, 1, num[rt])cout << v[rt] << " ";
	print(rc);
}
int main() {
#ifdef local
	int start = clock();
	freopen("in.txt", "r", stdin);
#endif
	int yyyy = 0;
	int _; sf(_);
	while (_--) {
		int m; sf(m);
		while (m--) {
			yyyy++;
			int op, p, c, x; op = read(); p = read();
			//cout << ls[120] << '\n';
			pii now = kth(root, p);
			int l, md, r,numb= num[now.first];
			split(root, now.second, l, md);
			split(md, numb, md, r);
			int ml=0, mr=md;
			auto divide = [&]() {
				if (now.second + 1 < p) {
					ml = newnode(v[md], p - now.second-1);
					num[mr] -= p - now.second-1;
					sz[mr] -= p - now.second-1;
				}
			}; int yy;
			//cout << yyyy << '\n';
			
			switch (op)
			{
			case 1:
				c = read(); x = read();
				divide();
				yy = newnode(x, c);
				merge({ l,ml,yy,mr,r }); break;
			case 2:
				x = read();
				v[md] = x;
				merge({ l, md, r }); break;
			case 3:
				c = read();
				num[md] = sz[md] = c; if (!c)md = 0;
				merge({ l,md,r }); break;
			case 4:
				x = read();
				divide();
				num[mr]--; sz[mr]--; if (!num[mr])mr = 0;
				yy = newnode(x, 1);
				merge({ l,ml,yy,mr,r }); break;
			case 5:
				num[md]--; sz[md]--; if (!num[md])md = 0;
				merge({ l,md,r }); break;
			case 6:
				printf("%d\n",v[md]); 
				merge({ l,md,r });
				break;
			}
			//print(root); cout << '\n';
		}
		//f(i, 0, cnt)tree[i] = { 0,0,0,0,0,0};
		cnt = 0; root = 0;
	}

#ifdef local
	int end = clock();
	cout << '\n' << end - start << "ms";
#endif* /
}

C++11(clang++ 3.9) 解法, 执行用时: 3221ms, 内存消耗: 25276K, 提交时间: 2020-06-10 16:35:46

#include<bits/stdc++.h>
#define maxn 600050
using namespace std;

int n, m;
int rt;
int sz[maxn], rnd[maxn];
int ls[maxn], rs[maxn];
int cnt[maxn], val[maxn];

int newnode(int w, int c = 1) {
    ++n;
    sz[n] = cnt[n] = c;
    val[n] = w;
    ls[n] = rs[n] = 0;

    //uniform_int_distribution<int> u(1, 1e9);
    //static default_random_engine e(rand());
    rnd[n] = rand();

    return n;
}

void update(int k) {
    sz[k] = cnt[k] + sz[ls[k]] + sz[rs[k]];
}

void spilt(int rt, int k, int& l, int& r) {

    if (!k) return l = 0, r = rt, void(0);
    if (k == sz[rt]) return l = rt, r = 0, void(0);

    if (sz[ls[rt]] >= k)
        spilt(ls[rt], k, l, ls[r = rt]);
    else
        spilt(rs[rt], k - sz[ls[rt]] - cnt[rt], rs[l = rt], r);
    update(rt);
}

int merge(int x, int y) {
    if (!x || !y) return x + y;
    int k;
    if (rnd[x] < rnd[y])
        rs[k = x] = merge(rs[x], y);
    else
        ls[k = y] = merge(x, ls[y]);
    return update(k), k;
}

pair<int, int> find(int k, int rk) {
    if (!k) return make_pair(0, 0);
    if (sz[ls[k]] < rk && rk <= sz[ls[k]] + cnt[k]) return make_pair(k, sz[ls[k]]);
    if (sz[ls[k]] >= rk)
        return find(ls[k], rk);
    else {
        auto p = find(rs[k], rk - sz[ls[k]] - cnt[k]);
        p.second += sz[ls[k]] + cnt[k];
        return p;
    }
}

int merge(vector<int> v) {
    int rt = 0;
    for (int k : v) {
        auto pr = find(rt, sz[rt]), pl = find(k, 1);
        if (val[pr.first] == val[pl.first]) {
            int l, mid, r, c = cnt[pl.first] + cnt[pr.first];
            spilt(rt, pr.second, l, mid);
            spilt(k, cnt[pl.first], mid, r);
            cnt[mid] = sz[mid] = c;
            rt = merge(merge(l, mid), r);
        }
        else rt = merge(rt, k);
    }
    return rt;
}

void travel(int k) {
    if (!k) return;
    travel(ls[k]);
    cout << "(" << cnt[k] << "," << val[k] << ") ";
    travel(rs[k]);
}

#define print(k) cout<<#k<<"("<<sz[k]<<"):"; travel(k); cout<<endl;

void solve() {

    rt = n = 0;

    scanf("%d", &m);
    while (m--) {
        int op, rk, c, x;
        scanf("%d%d", &op, &rk);

        pair<int, int> pr = find(rt, rk);

        c = cnt[pr.first];
        int l, m1, m2 = 0, r;
        spilt(rt, pr.second, l, r);

        assert(0 <= c && c <= sz[r]);
        spilt(r, c, m1, r);

        auto spilt2 = [&]() {
            m2 = m1, m1 = 0;
            if (rk > pr.second + 1) {
                m1 = newnode(val[m2], rk - pr.second - 1);
                cnt[m2] = cnt[m2] - cnt[m1];
                sz[m2] = sz[m2] - sz[m1];
            }
        };


        switch (op) {
        case 1: {
            scanf("%d%d", &c, &x);
            spilt2();
            rt = merge(vector<int>{l, m1, newnode(x, c), m2, r});
            break;
        }
        case 2: {
            scanf("%d", &x);
            val[m1] = x;
            rt = merge(vector<int>{l, m1, r});
            break;
        }
        case 3: {
            scanf("%d", &c);
            cnt[m1] = sz[m1] = c;
            if (!c) m1 = 0;
            rt = merge(vector<int>{l, m1, r});
            break;
        }
        case 4: {
            scanf("%d", &x);
            spilt2();
            --cnt[m2], --sz[m2];
            if (cnt[m2] == 0) m2 = 0;
            rt = merge(vector<int>{l, m1, newnode(x), m2, r});
            break;
        }
        case 5: {
            --cnt[m1], --sz[m1];
            if (cnt[m1] == 0) m1 = 0;
            rt = merge(vector<int>{l, m1, r});
            break;
        }
        case 6: {
            printf("%d\n", val[m1]);
            rt = merge(vector<int>{l, m1, r});
            break;
        }
        }
    }
}

int main() {
    srand(time(0));
    //freopen("in.txt", "r", stdin);
    int T;
    scanf("%d", &T);
    while (T--)
        solve();
    //cerr << clock() << endl;
    return 0;
}

上一题