列表

详情


NC20607. [ZJOI2017]线段树

描述

线段树是九条可怜很喜欢的一个数据结构,它拥有着简单的结构、优秀的复杂度与强大的功能,因此可怜曾经花了很长时间研究线段树的一些性质。

最近可怜又开始研究起线段树来了,有所不同的是,她把目光放在了更广义的线段树上:在正常的线段树中,对于区间 [l, r],我们会取 ,然后将这个区间分成 [l, m] 和 [m + 1, r] 两个子区间。在广义的线段树中,m 不要求恰好等于区间的中点,但是 m 还是必须满足 l ≤ m < r 的。不难发现在广义的线段树中,树的深度可以达到 O(n) 级别。
例如下面这棵树,就是一棵广义的线段树:


为了方便,我们按照先序遍历给线段树上所有的节点标号,例如在上图中,[2, 3] 的标号是 5,[4, 4] 的标号是 9,不难发现在 [1, n] 上建立的广义线段树,它共有着 2n − 1 个节点。
考虑把线段树上的定位区间操作(就是打懒标记的时候干的事情)移植到广义线段树上,可以发现在广义的线段树上还是可以用传统的线段树上的方法定位区间的,例如在上图中,蓝色节点和蓝色边就是在定位区间 [2, 4] 时经过的点和边,最终定位到的点是 [2, 3] 和 [4, 4]。
如果你对线段树不熟悉,这儿给出定位区间操作形式化的定义:给出区间 [l, r],找出尽可能少的区间互不相交的线段树节点,使得它们区间的并集恰好是 [l, r]。
定义 S[l,r] 为定位区间 [l, r] 得到的点集,例如在上图中,S[2,4] = {5, 9}。定义线段树上两个点 u, v 的距离 d(u, v) 为线段树上 u 到 v 最短路径上的边数,例如在上图中 d(5, 9) = 3。
现在可怜给了你一棵 [1, n] 上的广义的线段树并给了 m 组询问,每组询问给出三个数 u, l, r (l ≤ r),可怜想要知道

输入描述

第一行输入一个整数 n。
接下来一行包含 n - 1 个空格隔开的整数:按照标号递增的顺序,给出广义线段树上所有**非叶子**节点的划分位置 m。不难发现通过这些信息就能唯一确定一棵 [1, n] 上的广义线段树。
接下来一行输入一个整数 m。
之后 m 行每行输入三个整数 u, l, r (1 ≤ u ≤ 2n − 1, 1 ≤ l ≤ r ≤ n),表示一组询问。

输出描述

对于每组询问,输出一个整数表示答案。

示例1

输入:

10
3 1 2 9 6 4 5 7 8
3
7 6 7
18 4 5
14 5 6

输出:

7
11
3

原站题解

上次编辑到这里,代码来自缓存 点击恢复默认模板

C++11(clang++ 3.9) 解法, 执行用时: 936ms, 内存消耗: 83684K, 提交时间: 2019-03-16 12:12:33

#include<cstdio>  
#include<iostream>  
#include<algorithm>  
#include<cstdlib>  
#include<cstring>
#include<string>
#include<climits>
#include<vector>
#include<cmath>
#include<map>
#include<set>
#define LL long long
 
using namespace std;
 
inline char nc(){
  static char buf[100000],*p1=buf,*p2=buf;
  if (p1==p2) { p2=(p1=buf)+fread(buf,1,100000,stdin); if (p1==p2) return EOF; }
  return *p1++;
}
 
inline void read(int &x){
  char c=nc();int b=1;
  for (;!(c>='0' && c<='9');c=nc()) if (c=='-') b=-1;
  for (x=0;c>='0' && c<='9';x=x*10+c-'0',c=nc()); x*=b;
}
 
inline void read(LL &x){
  char c=nc();LL b=1;
  for (;!(c>='0' && c<='9');c=nc()) if (c=='-') b=-1;
  for (x=0;c>='0' && c<='9';x=x*10+c-'0',c=nc()); x*=b;
}

inline int read(char *s)
{
    char c=nc();int len=0;
    for(;!(c>='A' && c<='Z');c=nc()) if (c==EOF) return 0;
    for(;(c>='A' && c<='Z');s[len++]=c,c=nc());
    s[len++]='\0';
    return len;
}

inline void read(char &x){
  for (x=nc();!(x>='A' && x<='Z');x=nc());
}

int wt,ss[19];
inline void print(int x){

    if (x<0) x=-x,putchar('-');
    if (!x) putchar(48); else {
    for (wt=0;x;ss[++wt]=x%10,x/=10);
    for (;wt;putchar(ss[wt]+48),wt--);}
}
inline void print(LL x){
    if (x<0) x=-x,putchar('-');
    if (!x) putchar(48); else {for (wt=0;x;ss[++wt]=x%10,x/=10);for (;wt;putchar(ss[wt]+48),wt--);}
}

int n,m,s,S,T1,T2,duan[200010],b[200010];
struct data
{
    int l,r,fa,d,c;
    LL rnum,rsum,lnum,lsum;
    vector<int> b;
}a[400010];
int p[400010][25];
struct tepan
{
    int id,x; 
}t1[400010],t2[400010];
 
void dfs(int u)
{
    a[u].c=1;
    for (int i=0;i<a[u].b.size();i++)
    {
        if (!a[a[u].b[i]].d)
        {
            a[a[u].b[i]].d=a[u].d+1;
            p[a[u].b[i]][0]=u;
            dfs(a[u].b[i]);
            a[u].c+=a[a[u].b[i]].c;
        }
    }
}

void pre(int x,int y)
{
    a[x].rsum=a[y].rsum;a[x].lsum=a[y].lsum;
    a[x].lnum=a[y].lnum;a[x].rnum=a[y].rnum;
    if (x==y+1) a[x].rnum++,a[x].rsum+=(LL)a[y].d;
    else a[x].lnum++,a[x].lsum+=(LL)a[y].d;
    if (a[x].l==a[x].r) return ;
    else pre(a[x].b[0],x),pre(a[x].b[1],x);
}
 
void init()
{
    for (int j=1;(1<<j)<=2*n-1;j++)
        for (int i=1;i<=2*n-1;i++)
            if (p[i][j-1]!=-1) p[i][j]=p[p[i][j-1]][j-1];
    for (int i=0;i<a[1].b.size();i++)
        pre(a[1].b[i],1);
}
 
int lca(int x,int y)
{
    if (a[x].d<a[y].d) swap(x,y);
    int i;
    for (i=0;(1<<i)<=a[x].d;i++);i--;
    for (int j=i;j>=0;j--)
        if (a[x].d-(1<<j)>=a[y].d) x=p[x][j];
    if (x==y) return x;
    for (int j=i;j>=0;j--)
        if (p[x][j]!=-1 && p[x][j]!=p[y][j])
            x=p[x][j],y=p[y][j];
    return p[x][0];
}
 
int Find(int x,int y)
{
    y=a[x].d-y;
    int i;
    for (i=0;(1<<i)<=a[x].d;i++);i--;
    for (int j=i;j>=0;j--)
        if (a[x].d-(1<<j)>=y) x=p[x][j];
    return x;
}

void build(int l,int r)
{
    S++;
    if (l==r) {b[l]=S;a[S].l=l,a[S].r=r;return ;}
    s++;a[S].l=l,a[S].r=r;
    int t=S,p=s;
    a[t].b.push_back(S+1);a[S+1].fa=t;
    build(l,duan[p]);
    a[t].b.push_back(S+1);a[S+1].fa=t;
    build(duan[p]+1,r);
}

LL calc(int x,int y,int z,int LCA)
{
    int t;LL res=0;
    if (y!=-1)
    {
        t=lca(x,y);
        if (a[t].d<a[LCA].d) res+=(a[y].rnum-a[a[LCA].b[0]].rnum)*a[t].d;
        else res+=(a[y].rnum-a[t].rnum)*a[t].d+a[t].rsum-a[a[LCA].b[0]].rsum;
        if (a[t].b.size()>0)if (lca(y,a[t].b[1])==t && a[t].d-1>=a[LCA].d && t!=x) res++;
    }
    if (z!=-1)
    { 
        t=lca(x,z);
        if (a[t].d<a[LCA].d) res+=(a[z].lnum-a[a[LCA].b[1]].lnum)*a[t].d;
        else res+=(a[z].lnum-a[t].lnum)*a[t].d+a[t].lsum-a[a[LCA].b[1]].lsum;
        if (a[t].b.size()>0)if (lca(z,a[t].b[0])==t && a[t].d-1>=a[LCA].d && t!=x) res++;
    }
    return res;
}

void pan1(int x)
{
    t1[++T1].id=x,t1[T1].x=a[x].r;
    if (a[x].l==a[x].r) return ;
    pan1(a[x].b[0]);
}

void pan2(int x)
{
    t2[++T2].id=x,t2[T2].x=a[x].l;
    if (a[x].l==a[x].r) return ;
    pan2(a[x].b[1]);
}

void pre_tepan()
{
    T1=0;pan1(1);
    T2=0;pan2(1);
}

int Find1(int x)
{
    int l=1,r=T1,mid,res;
    while (l<=r)
    {
        mid=l+r>>1;
        if (t1[mid].x<=x) r=mid-1,res=mid;else l=mid+1;
    }
    return t1[res].id;
}

int Find2(int x)
{
    int l=1,r=T2,mid,res;
    while (l<=r)
    {
        mid=l+r>>1;
        if (t2[mid].x>=x) r=mid-1,res=mid;else l=mid+1;
    }
    return t2[res].id;
}

int main()
{
    read(n);
    for (int i=1;i<n;i++)
        read(duan[i]);
    s=0;S=0;
    build(1,n);
    a[1].d=1;dfs(1);
    init();
    pre_tepan();
    read(m);
    int x,y,z;
    while (m--)
    {
        read(x);read(y);read(z);
        LL res=0;int t;
        if (y==1 && z==n) print(a[1].d+a[x].d-2*a[1].d),puts("");
        else if (y==1)
        {
            y=Find1(z);
            t=lca(y,b[z+1]);
            res+=a[y].d+a[x].d-2LL*a[lca(x,y)].d;
            res+=a[b[z+1]].lsum-a[a[t].b[1]].lsum;
            res+=(a[b[z+1]].lnum-a[a[t].b[1]].lnum)*(LL)(a[x].d+1);
            res-=2LL*calc(x,-1,b[z+1],t);
            print(res),puts("");
        }
        else if (z==n)
        {
            z=Find2(y);
            t=lca(z,b[y-1]);
            res+=a[z].d+a[x].d-2LL*a[lca(x,z)].d;
            res+=a[b[y-1]].rsum-a[a[t].b[0]].rsum;
            res+=(a[b[y-1]].rnum-a[a[t].b[0]].rnum)*(LL)(a[x].d+1);
            res-=2LL*calc(x,b[y-1],-1,t);
            print(res),puts("");
        }
        else
        { 
            t=lca(b[y-1],b[z+1]);
            res=a[b[y-1]].rsum-a[a[t].b[0]].rsum+a[b[z+1]].lsum-a[a[t].b[1]].lsum;
            res+=(a[b[y-1]].rnum-a[a[t].b[0]].rnum+a[b[z+1]].lnum-a[a[t].b[1]].lnum)*(LL)(a[x].d+1);
            res-=2LL*calc(x,b[y-1],b[z+1],t);
            print(res),puts("");
        }
    }
    return 0;
}

上一题