NC212178. brz的树
描述
输入描述
第一行两个整数 n,m。
第二行 n 个整数,第 i 个整数表示节点 i 的颜色。
接下来 n-1 行每行两个整数 x,y,表示节点 x 和 y 之间有一条边。接下来 m 行每行两个整数 x,y 表示一次询问。
输出描述
输出 m 行,每行一个整数表示答案。
示例1
输入:
5 2 2 3 1 2 1 1 2 1 3 2 4 2 5 2 3 1 4
输出:
2 3
说明:
对于第一组询问,颜色 1 分别在 3,5 号节点上出现,5 在 2 的子树内,所以颜色 1 只在 2 和 3 的子树内出现过,颜色 3 只在 2 号节点上出现过,所以颜色 3 也只在 2 和 3 的子树内出现过,于是答案为 1,3 两种颜色。C++(clang++11) 解法, 执行用时: 195ms, 内存消耗: 50732K, 提交时间: 2020-11-10 00:52:35
#include<bits/stdc++.h> using namespace std; const int M=1e5+9; int n,m,num=0,id=0; int c[M],val[M],tid[M],head[M],dep[M],low[M],tru[M<<1],f[M<<1][23]; struct P{int to,ne;}e[M<<1]; struct A{int l,r,id;}; vector<int>g[M],d[M<<1]; vector<A>h1[M<<1],h2[M<<1]; int pre[M],suf[M],ans[M],X[M],Y[M],a[M<<1]; void add(int i){for(;i<=id;i+=i&-i)a[i]++;} int sum(int i,int rex=0){for(;i;i-=i&-i)rex+=a[i];return rex;} void dfs(int u,int fa){ f[tid[u]=++id][0]=u;dep[u]=dep[fa]+1; for(int i=head[u];i;i=e[i].ne){ int v=e[i].to; if(v!=fa){dfs(v,u);f[++id][0]=u;} } low[u]=id; } int Min(int l,int r){return dep[l]<dep[r]?l:r;} int lca(int x,int y){ if(tid[x]>tid[y])swap(x,y); int k=log2(tid[y]-tid[x]+1); return Min(f[tid[x]][k],f[tid[y]-(1<<k)+1][k]); } void qdfs(int u,int fa){ for(int i=head[u];i;i=e[i].ne){ int v=e[i].to; if(v!=fa){qdfs(v,u);val[u]+=val[v];} } } bool cmp(int x,int y){return tid[x]<tid[y];} int main(){ scanf("%d%d",&n,&m); for(int i=1;i<=n;++i){ scanf("%d",&c[i]); g[c[i]].push_back(i); } for(int i=1,x,y;i<n;++i){ scanf("%d%d",&x,&y); e[++num]=P{y,head[x]};head[x]=num; e[++num]=P{x,head[y]};head[y]=num; } dfs(1,0); for(int j=1;j<=23;++j){ for(int i=1;i<=id-(1<<j)+1;++i){ f[i][j]=Min(f[i][j-1],f[i+(1<<(j-1))][j-1]); } } for(int i=1,s,x,y;i<=n;++i){ if(!g[i].size())continue; sort(g[i].begin(),g[i].end(),cmp); s=g[i].size(); pre[0]=g[i][0]; suf[s-1]=g[i][s-1]; for(int j=1;j<s;++j)pre[j]=lca(pre[j-1],g[i][j]); for(int j=s-2;j>=0;--j)suf[j]=lca(suf[j+1],g[i][j]); for(int j=0;j<s-1;++j){ x=pre[j],y=suf[j+1]; if(tid[x]>tid[y])swap(x,y); if(tid[y]<=low[x])continue; d[tid[x]].push_back(tid[y]); } val[suf[0]]++; } for(int i=1,x,y,z;i<=m;++i){ scanf("%d%d",&x,&y); X[i]=x,Y[i]=y; if(tid[x]>tid[y])swap(x,y); if(tid[y]<=low[x])continue; h1[tid[x]].push_back(A{tid[y],low[y],i}); h2[low[x]].push_back(A{tid[y],low[y],i}); } for(int i=1;i<=id;++i){ for(int j=0,s=h1[i].size();j<s;++j){ int l=h1[i][j].l,r=h1[i][j].r,di=h1[i][j].id; ans[di]-=sum(r)-sum(l-1); } for(int j=0,s=d[i].size();j<s;++j)add(d[i][j]); for(int j=0,s=h2[i].size();j<s;++j){ int l=h2[i][j].l,r=h2[i][j].r,di=h2[i][j].id; ans[di]+=sum(r)-sum(l-1); } } qdfs(1,0); for(int i=1,x,y;i<=m;++i){ x=X[i],y=Y[i]; if(tid[x]>tid[y])swap(x,y); ans[i]+=val[x]; if(tid[y]>low[x])ans[i]+=val[y]; printf("%d\n",ans[i]); } return 0; }