列表

详情


NC212231. 神奇的迷宫

描述

有一个神奇的迷宫,一共有  个房间,还有  条通道,长度均为 。每条通道连接两个不同的房间,这个迷宫构成了个树形结构。
Froggy 和 鏡音リン 要去挑战迷宫。刚开始,他们两人依次传送到到迷宫的某两个房间(可能相同)。对于任意一个人,传送到房间  的概率是 p_i 。假设两人被传送到的房间之间的最短距离是 ,那么他们挑战这个迷宫的困难值是 w_L
请你告诉 Froggy,他们挑战这个迷宫的困难值的期望是多少。

输入描述

第一行输入一个正整数 ,表示房间个数。
第二行输入  个整数 。令 ,则 
第三行输入  个整数 
接下来  行,每行输入两个整数 ,表示第  条通道连接房间  和房间 
保证输入的图是一棵树。

输出描述

显然,最后答案一定可以表示成  的形式。
请输出这个分数对  取模后的结果。

示例1

输入:

3
1 2 3
3 2 1
1 2
2 3

输出:

887328316

说明:

距离  的概率为分别 ,所以最后答案是 ,在模  的意义下的值为 

示例2

输入:

6
1 1 4 5 1 4
1 9 1 9 8 10
1 2
2 3
2 4
1 5
1 6

输出:

249561093

示例3

输入:

10
0 76507 29535 6993 123413 149357 6751 0 21623 0
69396374 78945652 40298263 12836349 53787802 13446291 98276886 60256268 46259091 49311395
2 1
3 2
4 2
5 3
6 4
7 1
8 3
9 1
10 8

输出:

905013619

原站题解

上次编辑到这里,代码来自缓存 点击恢复默认模板

C++14(g++5.4) 解法, 执行用时: 705ms, 内存消耗: 23028K, 提交时间: 2020-10-09 21:12:38

#include<bits/stdc++.h>
using namespace std;

#define rep(i, a, n) for(int i=(a); i<(n); ++i)
#define per(i, a, n) for(int i=(a); i>(n); --i)
#define pb emplace_back
#define mp make_pair
#define clr(a, b) memset(a, b, sizeof(a))
#define all(x) (x).begin(),(x).end()
#define lowbit(x) (x & -x)
#define fi first
#define se second
#define lson o<<1
#define rson o<<1|1
#define gmid l[o]+r[o]>>1

using LL = long long;
using ULL = unsigned long long;
using pii = pair<int,int>;
using PLL = pair<LL, LL>;
using UI = unsigned int;

const int mod = 998244353;
const int inf = 0x3f3f3f3f;
const double EPS = 1e-8;
const double PI = acos(-1.0);

const int N = 1e5 + 10;
const int M = N * 3;

int n, sz[N], p[N];
LL w[N], ans;
bool mk[N];
vector<int> V[N], P, G;

int pow_mod(LL x, int p){
	LL s = 1;
	while(p){
		if(p & 1)	s = s * x % mod;
		x = x * x % mod;
		p >>= 1;
	}
	return (int)s;
}

void add(int &a, int b){
	a += b;
	if(a >= mod)	a -= mod;
}

void dfs1(int x, int fa){
	sz[x] = 1;
	for(int j : V[x]){
		if(mk[j] || j == fa)	continue;
		dfs1(j, x);
		sz[x] += sz[j];
	}
}

pii dfs2(int x, int fa, int sum){
	pii ret = mp(inf, 0);
	int mx = sum - sz[x];
	for(int j : V[x]){
		if(mk[j] || j == fa)	continue;
		ret = min(ret, dfs2(j, x, sum));
		mx = max(mx, sz[j]);
	}
	return min(ret, mp(mx, x));
}

void getpath(int x, int fa, int dep){
	if(dep >= G.size()){
		G.pb(p[x]);
	} else {
		add(G[dep], p[x]);
	}
	if(dep >= P.size()){
		P.pb(p[x]);
	} else {
		add(P[dep], p[x]);
	}
	for(int j : V[x]){
		if(mk[j] || j == fa)	continue;
		getpath(j, x, dep + 1);
	}
}

int A[M], r[M], C[M];

void ntt(int *x, int lim, int opt){
	int i, j, k, m, gn, g, tmp;
	for(i=0; i<lim; ++i){
		if(r[i] < i)	swap(x[i], x[r[i]]);
	}
	for(m=2; m<=lim; m<<=1){
		k = m >> 1;
		gn = pow_mod(3, (mod - 1) / m);
		for(i=0; i<lim; i+=m){
			g = 1;
			for(j=0; j<k; ++j, g=1LL*g*gn%mod){
				tmp = 1LL * x[i+j+k] * g % mod;
				x[i+j+k] = (x[i+j] - tmp + mod) % mod;
				x[i+j] = (x[i+j] + tmp) % mod;
			}
		}
	}
	if(opt == -1){
		reverse(x+1, x+lim);
		int rev = pow_mod(lim, mod - 2);
		for(i=0; i<lim; ++i)	x[i] = 1LL * x[i] * rev % mod;
	}
}

LL solve(vector<int> &vec){
	int m = vec.size();
	int lim = 1;
	while(lim < (m << 1))	lim <<= 1;
	rep(i, 0, lim){
		r[i] = (i & 1) * (lim >> 1) + (r[i >> 1] >> 1);
		A[i] = i < m ? vec[i] : 0;
	}
	ntt(A, lim, 1);
	rep(i, 0, lim)	C[i] = 1LL * A[i] * A[i] % mod;
	ntt(C, lim, -1);

	LL ret = 0;
	lim = min(lim, n);
	rep(i, 1, lim){
		ret = ret + w[i] * C[i] % mod;
	}
	return ret % mod;
}

void doit(int x){
	dfs1(x, -1);
	x = dfs2(x, -1, sz[x]).se;

	mk[x] = 1;
	for(int j : V[x]){
		if(!mk[j]){
			doit(j);
		}
	}

	P.clear();
	P.pb(p[x]);

	for(int j : V[x]){
		if(mk[j])	continue;
		G.clear();
		G.pb(0);
		getpath(j, x, 1);

		ans = (ans + mod - solve(G)) % mod;
	}

	ans = (ans + solve(P)) % mod;

	mk[x] = 0;
}

int main(){
	scanf("%d", &n);
	LL sum = 0;
	rep(i, 1, n+1){
		scanf("%d", p+i);
		sum += p[i];
	}
	sum = pow_mod(sum, mod - 2);
	rep(i, 1, n+1){
		p[i] = (int)(sum * p[i] % mod);
	}
	rep(i, 0, n){
		scanf("%lld", w+i);
	}

	int x, y;
	rep(i, 1, n){
		scanf("%d %d", &x, &y);
		V[x].pb(y);
		V[y].pb(x);
	}

	memset(mk, 0, sizeof(mk));

	ans = 0;

	doit(1);

	rep(i, 1, n+1){
		ans = ans + w[0] * p[i] % mod * p[i] % mod;
	}

	printf("%lld\n", ans % mod);
	
	return 0;
}

C++11(clang++ 3.9) 解法, 执行用时: 542ms, 内存消耗: 20984K, 提交时间: 2020-10-16 15:16:46

#include<cstdio>
#include<algorithm>
using namespace std;
const int N=2e5+5;
const int mod=998244353;
int fastpow(int x,int y){
	int res=1;
	while(y){
		if(y&1)res=1ll*res*x%mod;
		x=1ll*x*x%mod;
		y>>=1;
	}
	return res;
}
namespace ntt{
	int rev[N],w[N];
	void init(int len){
		for(int i=0;i<len;++i)
			rev[i]=(rev[i>>1]>>1)|((i&1)*(len>>1));
		w[0]=1;w[1]=fastpow(3,(mod-1)/len);
		for(int i=2;i<len;++i)
			w[i]=1ll*w[i-1]*w[1]%mod;
	}
	void dft(int *p,int len,int op){
		for(int i=0;i<len;++i)
			if(i<rev[i])
				swap(p[i],p[rev[i]]);
		for(int i=1;i<len;i<<=1)
			for(int j=0;j<len;j+=i<<1)
				for(int k=0;k<i;++k){
					int x=p[j+k],y=1ll*w[len/i/2*k]*p[j+k+i]%mod;
					p[j+k]=(x+y)%mod;
					p[j+k+i]=(x-y+mod)%mod;
				}
		if(op==-1){
			reverse(p+1,p+len);
			int inv=fastpow(len,mod-2);
			for(int i=0;i<len;++i)
				p[i]=1ll*p[i]*inv%mod;
		}
	}
}
int n,a[N],val[N],ans[N];
vector<int>E[N];
namespace dc{
	int sz[N],w[N],vis[N],poly[N],sum,root,mxlen;
	void dfs(int u,int f){
		sz[u]=1;w[u]=0;
		for(int v:E[u])
			if(v!=f&&!vis[v]){
				dfs(v,u);
				sz[u]+=sz[v];
				w[u]=max(w[u],sz[v]);
			}
		w[u]=max(w[u],sum-sz[u]);
		if(w[u]<w[root])root=u;
	}
	void dfs2(int u,int f,int d){
		sz[u]=1;
		poly[d]=(poly[d]+a[u])%mod;
		mxlen=max(mxlen,d);
		for(int v:E[u])
			if(v!=f&&!vis[v])
				dfs2(v,u,d+1),sz[u]+=sz[v];
	}
	void calc(int u,int d,int coef){
		mxlen=0;
		dfs2(u,0,d);
		int len=1;
		while(len<=mxlen+mxlen)len<<=1;
		ntt::init(len);ntt::dft(poly,len,1);
		for(int i=0;i<len;++i)
			poly[i]=1ll*poly[i]*poly[i]%mod;
		ntt::dft(poly,len,-1);
		for(int i=0;i<len;++i){
			ans[i]=(ans[i]+1ll*poly[i]*coef)%mod;
			poly[i]=0;
		}
	}
	void solve(int u){
		calc(u,0,1);
		vis[u]=1;
		for(int v:E[u])
			if(!vis[v]){
				calc(v,1,mod-1);
				sum=sz[v];root=0;
				dfs(v,0);solve(root);
			}
	}
}
int main(){
	int sigma=0,tmp=0;
	scanf("%d",&n);
	for(int i=1;i<=n;++i)
		scanf("%d",&a[i]),tmp=(tmp+a[i])%mod;
	for(int i=0;i<n;++i)
		scanf("%d",&val[i]);
	for(int i=1,x,y;i<n;++i){
		scanf("%d%d",&x,&y);
		E[x].push_back(y);
		E[y].push_back(x);
	}
	dc::sum=dc::w[0]=n;
	dc::dfs(1,0);
	dc::solve(dc::root);
	for(int i=0;i<n;++i)
		sigma=(sigma+1ll*ans[i]*val[i])%mod;
	sigma=1ll*sigma*fastpow(tmp,mod-3)%mod;
	printf("%d\n",sigma);
	return 0;
}

上一题