列表

详情


NC17061. 多彩的树

描述

有一棵树包含 N 个节点,节点编号从 1 到 N。节点总共有 K 种颜色,颜色编号从 1 到 K。第 i 个节点的颜色为 Ai
Fi 表示恰好包含 i 种颜色的路径数量。请计算:

输入描述

第一行输入两个正整数 N 和 K,N 表示节点个数,K 表示颜色种类数量。
第二行输入 N 个正整数,A1, A2, A3, ... ..., AN,Ai 表示第 i 个节点的颜色。
接下来 N - 1 行,第 i 行输入两个正整数 Ui 和 Vi,表示节点 Ui 和节点 Vi 之间存在一条无向边,数据保证这 N-1 条边连通了 N 个节点。
1 ≤ N ≤ 50000.
1 ≤ K ≤ 10.
1 ≤ Ai ≤ K.

输出描述

输出一个整数表示答案。

示例1

输入:

5 3
1 2 1 2 3
4 2
1 3
2 1
2 5

输出:

4600065

原站题解

上次编辑到这里,代码来自缓存 点击恢复默认模板

C++(g++ 7.5.0) 解法, 执行用时: 3138ms, 内存消耗: 4592K, 提交时间: 2023-03-30 23:17:37

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=50005,mo=1e9+7;
int clr[N];
vector<int>link[N];
ll f[1<<10];
int cnt[N];
int fa[N];
ll Pow(ll x,ll y){
    ll ret=1;
    while(y){
        if (y&1) ret=ret*x%mo;
        y/=2;
        x=x*x%mo;
    }
    return ret;
}
void dfs(int u,int _fa,int s){
    fa[u]=_fa;
    if ((1<<clr[u]-1)&s) cnt[u]=1;
    else cnt[u]=0;
    for (int v:link[u]){
        if (v==_fa) continue;
        dfs(v,u,s);
        if (cnt[u]) cnt[u]+=cnt[v];
    }
}
int main(){
    int n,m; scanf("%d%d",&n,&m);
    for (int i=1;i<=n;i++) scanf("%d",&clr[i]);
    for (int i=1;i<n;i++){
        int x,y; scanf("%d%d",&x,&y);
        link[x].emplace_back(y);
        link[y].emplace_back(x);
    }
    for (int s=1;s<(1<<10);s++){
        dfs(1,0,s);
        for (int i=1;i<=n;i++)
            if (cnt[i]&&cnt[fa[i]]==0) 
                f[s]+=cnt[i]+1ll*cnt[i]*(cnt[i]-1)/2;
    }
    for (int s=1;s<(1<<10);s++)
        for (int sub=s&(s-1);sub;sub=(sub-1)&s) 
            f[s]-=f[sub];
    ll ans=0;
    for (int s=1;s<(1<<10);s++)
        ans=(ans+f[s]*Pow(131,__builtin_popcount(s))%mo)%mo;
    printf("%lld\n",ans);
}

C++14(g++5.4) 解法, 执行用时: 4021ms, 内存消耗: 4552K, 提交时间: 2018-07-09 20:17:16

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=50005,mo=1e9+7;

int clr[N];
vector<int>link[N];
ll f[1<<10];
int cnt[N];
int fa[N];

ll Pow(ll x,ll y){
	ll ret=1;
	while(y){
		if (y&1) ret=ret*x%mo;
		y/=2;
		x=x*x%mo;
	}
	return ret;
}
void dfs(int u,int _fa,int s){
	fa[u]=_fa;
	if ((1<<clr[u]-1)&s) cnt[u]=1;
	else cnt[u]=0;
	for (int v:link[u]){
		if (v==_fa) continue;
		dfs(v,u,s);
		if (cnt[u]) cnt[u]+=cnt[v];
	}
}
int main(){
	int n,m; scanf("%d%d",&n,&m);
	for (int i=1;i<=n;i++) scanf("%d",&clr[i]);
	for (int i=1;i<n;i++){
		int x,y; scanf("%d%d",&x,&y);
		link[x].emplace_back(y);
		link[y].emplace_back(x);
	}
	for (int s=1;s<(1<<10);s++){
		dfs(1,0,s);
		for (int i=1;i<=n;i++)
			if (cnt[i]&&cnt[fa[i]]==0) 
				f[s]+=cnt[i]+1ll*cnt[i]*(cnt[i]-1)/2;
	}
	for (int s=1;s<(1<<10);s++)
		for (int sub=s&(s-1);sub;sub=(sub-1)&s) 
			f[s]-=f[sub];
	ll ans=0;
	for (int s=1;s<(1<<10);s++)
		ans=(ans+f[s]*Pow(131,__builtin_popcount(s))%mo)%mo;
	printf("%lld\n",ans);
}

C++11(clang++ 3.9) 解法, 执行用时: 2658ms, 内存消耗: 2552K, 提交时间: 2020-05-07 16:02:13

#include<bits/stdc++.h>
#define ll long long
using namespace std;
const ll p=1e9+7;
int n,k,mx,tot,cnt[500010],e[500010],hd[500010],nt[500010],sz[500010],a[500010];
ll f[500010],g[500010];
void build(int x,int y){
	tot++;
	e[tot]=y;
	nt[tot]=hd[x];hd[x]=tot;
}
void dfs(int x,int fa,int s){
	int i;
	sz[x]=0;
	for(i=hd[x];i;i=nt[i]){
		if(e[i]==fa)continue;
		dfs(e[i],x,s);
		sz[x]+=sz[e[i]];
	}
	if(s&(1<<a[x])){
		sz[x]++;
		if(x==1||!(s&(1<<a[fa])))f[s]+=1ll*sz[x]*(sz[x]+1)/2;
	}
	 else sz[x]=0;
}
int main(){
	int i,x,y,s,t;
	ll now,ans=0;
	scanf("%d%d",&n,&k);
	for(i=1;i<=n;i++)scanf("%d",&a[i]);
	for(i=1;i<n;i++){
		scanf("%d%d",&x,&y);
		build(x,y);build(y,x);
	}
	for(i=1;i<=n;i++)a[i]--;
	mx=1<<k;mx--;
	for(s=1;s<=mx;s++){
		cnt[s]=cnt[s>>1]+(s&1);
		dfs(1,0,s);
		for(t=0;t<s;t++)if((s&t)==t)f[s]-=f[t];
		g[cnt[s]]+=f[s];
	}
	now=1;
	for(i=1;i<=n;i++){
		now=now*131%p;
	  ans=(ans+g[i]*now)%p;
  }
     printf("%lld",ans);
}

上一题