列表

详情


NC212178. brz的树

描述

蒟蒻 最近造了一棵可爱的树,有一天他不小心将颜料桶打翻了,树上的每个节点都染上了颜色。

蒟蒻 有点不知所措,但是神仙 看到之后,顺手就造了个题然后 1s 就切飞了。现在他把这个题给了蒟蒻,并声称:这个水题都不会做你还能自称是这棵树的主人吗?

蒟蒻 才不会认输,大胆地接下了这个问题:树是一棵以 1 为根的数,树上每个节点有一个颜色,节点 i 的颜色为 c_i,神仙 会进行 m 次询问,每次给出 x,y,问有多少种颜色只在 x 和 y 的子树内出现过。

蒟蒻 不想认输,但是挣扎了一会发现并不会做,你能帮助蒟蒻 吗?(当然啦,你不帮这题就没分了:)

输入描述

第一行两个整数 n,m。

第二行 n 个整数,第 i 个整数 c_i 表示节点 i 的颜色。

接下来 n-1 行每行两个整数 x,y,表示节点 x 和 y 之间有一条边。

接下来 m 行每行两个整数 x,y 表示一次询问。




输出描述

输出 m 行,每行一个整数表示答案。

示例1

输入:

5 2
2 3 1 2 1
1 2
1 3
2 4
2 5
2 3
1 4

输出:

2
3

说明:

对于第一组询问,颜色 1 分别在 3,5 号节点上出现,5 在 2 的子树内,所以颜色 1 只在 2 和 3 的子树内出现过,颜色 3 只在 2 号节点上出现过,所以颜色 3 也只在 2 和 3 的子树内出现过,于是答案为 1,3 两种颜色。

对于第二组询问,1 的子树就是整棵树,所以包含了所有三种颜色,答案为 3。

原站题解

上次编辑到这里,代码来自缓存 点击恢复默认模板

C++(clang++11) 解法, 执行用时: 195ms, 内存消耗: 50732K, 提交时间: 2020-11-10 00:52:35

#include<bits/stdc++.h>
using namespace std;
const int M=1e5+9;
int n,m,num=0,id=0;
int c[M],val[M],tid[M],head[M],dep[M],low[M],tru[M<<1],f[M<<1][23];
struct P{int to,ne;}e[M<<1];
struct A{int l,r,id;};
vector<int>g[M],d[M<<1];
vector<A>h1[M<<1],h2[M<<1];
int pre[M],suf[M],ans[M],X[M],Y[M],a[M<<1];
void add(int i){for(;i<=id;i+=i&-i)a[i]++;}
int sum(int i,int rex=0){for(;i;i-=i&-i)rex+=a[i];return rex;}
void dfs(int u,int fa){
	f[tid[u]=++id][0]=u;dep[u]=dep[fa]+1;
	for(int i=head[u];i;i=e[i].ne){
		int v=e[i].to;
		if(v!=fa){dfs(v,u);f[++id][0]=u;}
	}
	low[u]=id;
}
int Min(int l,int r){return dep[l]<dep[r]?l:r;}
int lca(int x,int y){
	if(tid[x]>tid[y])swap(x,y);
	int k=log2(tid[y]-tid[x]+1);
	return Min(f[tid[x]][k],f[tid[y]-(1<<k)+1][k]);
}
void qdfs(int u,int fa){
	for(int i=head[u];i;i=e[i].ne){
		int v=e[i].to;
		if(v!=fa){qdfs(v,u);val[u]+=val[v];}
	}
}
bool cmp(int x,int y){return tid[x]<tid[y];}
int main(){
	scanf("%d%d",&n,&m);
	for(int i=1;i<=n;++i){
		scanf("%d",&c[i]);
		g[c[i]].push_back(i);
	}
	for(int i=1,x,y;i<n;++i){
		scanf("%d%d",&x,&y);
		e[++num]=P{y,head[x]};head[x]=num;
		e[++num]=P{x,head[y]};head[y]=num;
	}
	dfs(1,0);
	for(int j=1;j<=23;++j){
		for(int i=1;i<=id-(1<<j)+1;++i){
			f[i][j]=Min(f[i][j-1],f[i+(1<<(j-1))][j-1]);
		}
	}
	for(int i=1,s,x,y;i<=n;++i){
		if(!g[i].size())continue;
		sort(g[i].begin(),g[i].end(),cmp);
		s=g[i].size();
		pre[0]=g[i][0];
		suf[s-1]=g[i][s-1];
		for(int j=1;j<s;++j)pre[j]=lca(pre[j-1],g[i][j]);
		for(int j=s-2;j>=0;--j)suf[j]=lca(suf[j+1],g[i][j]);
		for(int j=0;j<s-1;++j){
			x=pre[j],y=suf[j+1];
			if(tid[x]>tid[y])swap(x,y);
			if(tid[y]<=low[x])continue;
			d[tid[x]].push_back(tid[y]);
		}
		val[suf[0]]++;
	}
	for(int i=1,x,y,z;i<=m;++i){
		scanf("%d%d",&x,&y);
		X[i]=x,Y[i]=y;
		if(tid[x]>tid[y])swap(x,y);
		if(tid[y]<=low[x])continue;
		h1[tid[x]].push_back(A{tid[y],low[y],i});
		h2[low[x]].push_back(A{tid[y],low[y],i});
	}
	for(int i=1;i<=id;++i){
		for(int j=0,s=h1[i].size();j<s;++j){
			int l=h1[i][j].l,r=h1[i][j].r,di=h1[i][j].id;
			ans[di]-=sum(r)-sum(l-1);
		}
		for(int j=0,s=d[i].size();j<s;++j)add(d[i][j]);
		for(int j=0,s=h2[i].size();j<s;++j){
			int l=h2[i][j].l,r=h2[i][j].r,di=h2[i][j].id;
			ans[di]+=sum(r)-sum(l-1);
		}
	}
	qdfs(1,0);
	for(int i=1,x,y;i<=m;++i){
		x=X[i],y=Y[i];
		if(tid[x]>tid[y])swap(x,y);
		ans[i]+=val[x];
		if(tid[y]>low[x])ans[i]+=val[y];
		printf("%d\n",ans[i]);
	}
	return 0;
}

上一题