NC17061. 多彩的树
描述
输入描述
第一行输入两个正整数 N 和 K,N 表示节点个数,K 表示颜色种类数量。
第二行输入 N 个正整数,A1, A2, A3, ... ..., AN,Ai 表示第 i 个节点的颜色。接下来 N - 1 行,第 i 行输入两个正整数 Ui 和 Vi,表示节点 Ui 和节点 Vi 之间存在一条无向边,数据保证这 N-1 条边连通了 N 个节点。1 ≤ N ≤ 50000.
1 ≤ K ≤ 10.
1 ≤ Ai ≤ K.
输出描述
输出一个整数表示答案。
示例1
输入:
5 3 1 2 1 2 3 4 2 1 3 2 1 2 5
输出:
4600065
C++(g++ 7.5.0) 解法, 执行用时: 3138ms, 内存消耗: 4592K, 提交时间: 2023-03-30 23:17:37
#include<bits/stdc++.h> using namespace std; typedef long long ll; const int N=50005,mo=1e9+7; int clr[N]; vector<int>link[N]; ll f[1<<10]; int cnt[N]; int fa[N]; ll Pow(ll x,ll y){ ll ret=1; while(y){ if (y&1) ret=ret*x%mo; y/=2; x=x*x%mo; } return ret; } void dfs(int u,int _fa,int s){ fa[u]=_fa; if ((1<<clr[u]-1)&s) cnt[u]=1; else cnt[u]=0; for (int v:link[u]){ if (v==_fa) continue; dfs(v,u,s); if (cnt[u]) cnt[u]+=cnt[v]; } } int main(){ int n,m; scanf("%d%d",&n,&m); for (int i=1;i<=n;i++) scanf("%d",&clr[i]); for (int i=1;i<n;i++){ int x,y; scanf("%d%d",&x,&y); link[x].emplace_back(y); link[y].emplace_back(x); } for (int s=1;s<(1<<10);s++){ dfs(1,0,s); for (int i=1;i<=n;i++) if (cnt[i]&&cnt[fa[i]]==0) f[s]+=cnt[i]+1ll*cnt[i]*(cnt[i]-1)/2; } for (int s=1;s<(1<<10);s++) for (int sub=s&(s-1);sub;sub=(sub-1)&s) f[s]-=f[sub]; ll ans=0; for (int s=1;s<(1<<10);s++) ans=(ans+f[s]*Pow(131,__builtin_popcount(s))%mo)%mo; printf("%lld\n",ans); }
C++14(g++5.4) 解法, 执行用时: 4021ms, 内存消耗: 4552K, 提交时间: 2018-07-09 20:17:16
#include<bits/stdc++.h> using namespace std; typedef long long ll; const int N=50005,mo=1e9+7; int clr[N]; vector<int>link[N]; ll f[1<<10]; int cnt[N]; int fa[N]; ll Pow(ll x,ll y){ ll ret=1; while(y){ if (y&1) ret=ret*x%mo; y/=2; x=x*x%mo; } return ret; } void dfs(int u,int _fa,int s){ fa[u]=_fa; if ((1<<clr[u]-1)&s) cnt[u]=1; else cnt[u]=0; for (int v:link[u]){ if (v==_fa) continue; dfs(v,u,s); if (cnt[u]) cnt[u]+=cnt[v]; } } int main(){ int n,m; scanf("%d%d",&n,&m); for (int i=1;i<=n;i++) scanf("%d",&clr[i]); for (int i=1;i<n;i++){ int x,y; scanf("%d%d",&x,&y); link[x].emplace_back(y); link[y].emplace_back(x); } for (int s=1;s<(1<<10);s++){ dfs(1,0,s); for (int i=1;i<=n;i++) if (cnt[i]&&cnt[fa[i]]==0) f[s]+=cnt[i]+1ll*cnt[i]*(cnt[i]-1)/2; } for (int s=1;s<(1<<10);s++) for (int sub=s&(s-1);sub;sub=(sub-1)&s) f[s]-=f[sub]; ll ans=0; for (int s=1;s<(1<<10);s++) ans=(ans+f[s]*Pow(131,__builtin_popcount(s))%mo)%mo; printf("%lld\n",ans); }
C++11(clang++ 3.9) 解法, 执行用时: 2658ms, 内存消耗: 2552K, 提交时间: 2020-05-07 16:02:13
#include<bits/stdc++.h> #define ll long long using namespace std; const ll p=1e9+7; int n,k,mx,tot,cnt[500010],e[500010],hd[500010],nt[500010],sz[500010],a[500010]; ll f[500010],g[500010]; void build(int x,int y){ tot++; e[tot]=y; nt[tot]=hd[x];hd[x]=tot; } void dfs(int x,int fa,int s){ int i; sz[x]=0; for(i=hd[x];i;i=nt[i]){ if(e[i]==fa)continue; dfs(e[i],x,s); sz[x]+=sz[e[i]]; } if(s&(1<<a[x])){ sz[x]++; if(x==1||!(s&(1<<a[fa])))f[s]+=1ll*sz[x]*(sz[x]+1)/2; } else sz[x]=0; } int main(){ int i,x,y,s,t; ll now,ans=0; scanf("%d%d",&n,&k); for(i=1;i<=n;i++)scanf("%d",&a[i]); for(i=1;i<n;i++){ scanf("%d%d",&x,&y); build(x,y);build(y,x); } for(i=1;i<=n;i++)a[i]--; mx=1<<k;mx--; for(s=1;s<=mx;s++){ cnt[s]=cnt[s>>1]+(s&1); dfs(1,0,s); for(t=0;t<s;t++)if((s&t)==t)f[s]-=f[t]; g[cnt[s]]+=f[s]; } now=1; for(i=1;i<=n;i++){ now=now*131%p; ans=(ans+g[i]*now)%p; } printf("%lld",ans); }