列表

详情


NC15053. 求子树

描述

给一个n个点的树,第i个点的值是vi,初始根是1

m个操作,每次操作:

1.将树根换为x

2.给出两个点xy,求所有点对(a,b)的个数满足ax子树中,by子树中,va==vb

输入描述

第一行两个数表示n,m

第二行n个数,表示每个点的点权vi

之后n-1行,每行两个数x,y表示一条边

之后m行,每行为:

1 x表示把根换成x点

2 x y表示查询x点的子树与y点的子树

输出描述

对于每个询问,输出一行一个数表示答案

示例1

输入:

5 5
1 2 3 4 5
1 2
1 3
3 4
3 5
2 4 5
2 1 5
2 3 5
1 5
2 4 5

输出:

0
1
1
1

原站题解

上次编辑到这里,代码来自缓存 点击恢复默认模板

C++14(g++5.4) 解法, 执行用时: 692ms, 内存消耗: 78540K, 提交时间: 2019-02-28 16:44:34

#include <bits/stdc++.h>
#define rep(i, a, b) for (int i = (a); i <= (b); ++i)
#define dep(i, a, b) for (int i = (a); i >= (b); --i)
#define mp make_pair
#define ft first
#define sc second
#define pb push_back
#define pii pair<int, int>
using namespace std;

int qreadInt() {
	int x = 0; char ch = getchar();
	while (!isdigit(ch)) ch = getchar();
	while (isdigit(ch)) (x *= 10) += (ch - '0'), ch = getchar();
	return x;
}
 
typedef long long ll;
const int N = 100010, M = 500010, BUF = 12000000, OUT = 7000000;
char Buf[BUF], *buf = Buf, Out[OUT], *ou = Out;
int Outn[30], Outcnt;
int n, m, cb, i, j, op, x, y, a[N], b[N], g[N], v[N << 1], nxt[N << 1], ed, root = 1;
int size[N], son[N], f[N], top[N], st[N], en[N], dfn;
int cq, ce, lim, pos[N], ap[N], cnt[N];
ll dp[N], ans[M], now;

struct P {
    int l, r;
} p[N], *pool[N];

struct E {
    int l, r, p;
    E() {}
    E(int _l, int _r, int _p) { l = _l, r = _r, p = _p; }
} e[M << 2], w[M << 2];

inline void read(int& a) {
    for (a = 0; *buf < 48; buf++)
        ;
    while (*buf > 47) a = a * 10 + *buf++ - 48;
}

inline void write(ll x) {
    if (!x)
        *ou++ = 48;
    else {
        for (Outcnt = 0; x; x /= 10) Outn[++Outcnt] = x % 10 + 48;
        while (Outcnt) *ou++ = Outn[Outcnt--];
    }
}

inline void add(int x, int y) {
    v[++ed] = y;
    nxt[ed] = g[x];
    g[x] = ed;
}

void dfs(int x) {
    size[x] = 1, dp[x] = ap[a[x]];
    for (int i = g[x]; i; i = nxt[i])
        if (v[i] != f[x]) {
            f[v[i]] = x;
            dfs(v[i]), size[x] += size[v[i]], dp[x] += dp[v[i]];
            if (size[v[i]] > size[son[x]])
                son[x] = v[i];
        }
}

void dfs2(int x, int y) {
    pool[st[x] = ++dfn] = p + a[x];
    top[x] = y;
    if (son[x])
        dfs2(son[x], y);
    for (int i = g[x]; i; i = nxt[i])
        if (v[i] != son[x] && v[i] != f[x])
            dfs2(v[i], v[i]);
    en[x] = dfn;
}

inline int lca2(int x, int y) {
    int t;
    while (top[x] != top[y]) y = f[t = top[y]];
    return x == y ? t : son[x];
}

inline ll one(int x) {
    if (x == root)
        return dp[1];
    if (st[x] > st[root] || en[x] < en[root])
        return dp[x];
    return dp[1] - dp[lca2(x, root)];
}

inline void addquery(int x, int y) {
    if (y == root)
        swap(x, y);
    if (x == root) {
        ans[cq] = one(y);
        return;
    }
    int p = 1;
    if (st[x] <= st[root] && en[root] <= en[x])
        ans[cq] = one(y), x = lca2(x, root), p = -1;
    if (st[y] <= st[root] && en[root] <= en[y])
        ans[cq] += dp[x] * p, y = lca2(y, root), p = -p;
    e[++ce] = E(en[x], en[y], cq * p);
    if (st[x] > 1)
        e[++ce] = E(st[x] - 1, en[y], -cq * p);
    if (st[y] > 1)
        e[++ce] = E(en[x], st[y] - 1, -cq * p);
    if (st[x] > 1 && st[y] > 1)
        e[++ce] = E(st[x] - 1, st[y] - 1, cq * p);
}

inline void addl(P* x) { x->l++, now += x->r; }
inline void dell(P* x) { x->l--, now -= x->r; }
inline void addr(P* x) { x->r++, now += x->l; }
inline void delr(P* x) { x->r--, now -= x->l; }

int main() {
    fread(Buf, 1, BUF, stdin);
    read(n), read(m);
    for (i = 1; i <= n; i++) read(a[i]), b[i] = a[i];
    sort(b + 1, b + n + 1);
    for (i = 1; i <= n; i++)
        if (b[i] > b[i - 1])
            b[++cb] = b[i];
    for (i = 1; i <= n; i++) ap[a[i] = lower_bound(b + 1, b + cb + 1, a[i]) - b]++;
    for (i = 1; i < n; i++) read(x), read(y), add(x, y), add(y, x);
    dfs(1), dfs2(1, 1);
    while (m--) {
        read(op);
        if (op == 1)
            read(root);
        else
            read(x), read(y), cq++, addquery(x, y);
    }
    for (lim = 1; lim * lim < ce; lim++)
        ;
    lim = n / lim * 3.3;
    if (!lim)
        lim = 1;
    for (i = 1; i <= n; i++) pos[i] = (i - 1) / lim + 1;
    for (i = 1; i <= ce; i++) cnt[e[i].r]++;
    for (i = 1; i <= n; i++) cnt[i] += cnt[i - 1];
    for (i = 1; i <= ce; i++) w[cnt[e[i].r]--] = e[i];
    for (i = 1; i <= n; i++) cnt[i] = 0;
    for (i = 1; i <= ce; i++) cnt[pos[w[i].l]]++;
    for (i = 1; i <= n; i++) cnt[i] += cnt[i - 1];
    for (i = ce; i; i--) e[cnt[pos[w[i].l]]--] = w[i];
    for (x = 0, i = 1; i <= ce; x ^= 1, i = j) {
        for (j = i; j <= ce && pos[e[i].l] == pos[e[j].l]; j++)
            ;
        if (x)
            reverse(e + i, e + j);
    }
    P*(*l) = pool, *(*r) = pool;
    for (i = 1; i <= ce; i++) {
        P*(*L) = pool + e[i].l, *(*R) = pool + e[i].r;
        while (l < L) addl(*(++l));
        while (l > L) dell(*(l--));
        while (r < R) addr(*(++r));
        while (r > R) delr(*(r--));
        if (e[i].p > 0)
            ans[e[i].p] += now;
        else
            ans[-e[i].p] -= now;
    }
    for (i = 1; i <= cq; i++) write(ans[i]), *ou++ = '\n';
    fwrite(Out, 1, ou - Out, stdout);
    return 0;
}

C++11(clang++ 3.9) 解法, 执行用时: 664ms, 内存消耗: 87336K, 提交时间: 2018-01-27 16:18:02

#include<cstdio>
#include<algorithm>
using namespace std;
typedef long long ll;
const int N=100010,M=500010,BUF=12000000,OUT=7000000;
char Buf[BUF],*buf=Buf,Out[OUT],*ou=Out;int Outn[30],Outcnt;
int n,m,cb,i,j,op,x,y,a[N],b[N],g[N],v[N<<1],nxt[N<<1],ed,root=1;
int size[N],son[N],f[N],top[N],st[N],en[N],dfn;
int cq,ce,lim,pos[N],ap[N],cnt[N];ll dp[N],ans[M],now;
struct P{int l,r;}p[N],*pool[N];
struct E{int l,r,p;E(){}E(int _l,int _r,int _p){l=_l,r=_r,p=_p;}}e[M<<2],w[M<<2];
inline void read(int&a){for(a=0;*buf<48;buf++);while(*buf>47)a=a*10+*buf++-48;}
inline void write(ll x){
  if(!x)*ou++=48;
  else{
    for(Outcnt=0;x;x/=10)Outn[++Outcnt]=x%10+48;
    while(Outcnt)*ou++=Outn[Outcnt--];
  }
}
inline void add(int x,int y){v[++ed]=y;nxt[ed]=g[x];g[x]=ed;}
void dfs(int x){
  size[x]=1,dp[x]=ap[a[x]];
  for(int i=g[x];i;i=nxt[i])if(v[i]!=f[x]){
    f[v[i]]=x;
    dfs(v[i]),size[x]+=size[v[i]],dp[x]+=dp[v[i]];
    if(size[v[i]]>size[son[x]])son[x]=v[i];
  }
}
void dfs2(int x,int y){
  pool[st[x]=++dfn]=p+a[x];top[x]=y;
  if(son[x])dfs2(son[x],y);
  for(int i=g[x];i;i=nxt[i])if(v[i]!=son[x]&&v[i]!=f[x])dfs2(v[i],v[i]);
  en[x]=dfn;
}
inline int lca2(int x,int y){
  int t;
  while(top[x]!=top[y])y=f[t=top[y]];
  return x==y?t:son[x];
}
inline ll one(int x){
  if(x==root)return dp[1];
  if(st[x]>st[root]||en[x]<en[root])return dp[x];
  return dp[1]-dp[lca2(x,root)];
}
inline void addquery(int x,int y){
  if(y==root)swap(x,y);
  if(x==root){ans[cq]=one(y);return;}
  int p=1;
  if(st[x]<=st[root]&&en[root]<=en[x])ans[cq]=one(y),x=lca2(x,root),p=-1;
  if(st[y]<=st[root]&&en[root]<=en[y])ans[cq]+=dp[x]*p,y=lca2(y,root),p=-p;
  e[++ce]=E(en[x],en[y],cq*p);
  if(st[x]>1)e[++ce]=E(st[x]-1,en[y],-cq*p);
  if(st[y]>1)e[++ce]=E(en[x],st[y]-1,-cq*p);
  if(st[x]>1&&st[y]>1)e[++ce]=E(st[x]-1,st[y]-1,cq*p);
}
inline void addl(P*x){x->l++,now+=x->r;}
inline void dell(P*x){x->l--,now-=x->r;}
inline void addr(P*x){x->r++,now+=x->l;}
inline void delr(P*x){x->r--,now-=x->l;}
int main(){
  fread(Buf,1,BUF,stdin);read(n),read(m);
  for(i=1;i<=n;i++)read(a[i]),b[i]=a[i];
  sort(b+1,b+n+1);
  for(i=1;i<=n;i++)if(b[i]>b[i-1])b[++cb]=b[i];
  for(i=1;i<=n;i++)ap[a[i]=lower_bound(b+1,b+cb+1,a[i])-b]++;
  for(i=1;i<n;i++)read(x),read(y),add(x,y),add(y,x);
  dfs(1),dfs2(1,1);
  while(m--){
    read(op);
    if(op==1)read(root);else read(x),read(y),cq++,addquery(x,y);
  }
  for(lim=1;lim*lim<ce;lim++);
  lim=n/lim*3.3;
  if(!lim)lim=1;
  for(i=1;i<=n;i++)pos[i]=(i-1)/lim+1;
  for(i=1;i<=ce;i++)cnt[e[i].r]++;
  for(i=1;i<=n;i++)cnt[i]+=cnt[i-1];
  for(i=1;i<=ce;i++)w[cnt[e[i].r]--]=e[i];
  for(i=1;i<=n;i++)cnt[i]=0;
  for(i=1;i<=ce;i++)cnt[pos[w[i].l]]++;
  for(i=1;i<=n;i++)cnt[i]+=cnt[i-1];
  for(i=ce;i;i--)e[cnt[pos[w[i].l]]--]=w[i];
  for(x=0,i=1;i<=ce;x^=1,i=j){
    for(j=i;j<=ce&&pos[e[i].l]==pos[e[j].l];j++);
    if(x)reverse(e+i,e+j);
  }
  P*(*l)=pool,*(*r)=pool;
  for(i=1;i<=ce;i++){
    P*(*L)=pool+e[i].l,*(*R)=pool+e[i].r;
    while(l<L)addl(*(++l));
    while(l>L)dell(*(l--));
    while(r<R)addr(*(++r));
    while(r>R)delr(*(r--));
    if(e[i].p>0)ans[e[i].p]+=now;else ans[-e[i].p]-=now;
  }
  for(i=1;i<=cq;i++)write(ans[i]),*ou++='\n';
  fwrite(Out,1,ou-Out,stdout);
  return 0;
}

上一题