NC212231. 神奇的迷宫
描述
输入描述
第一行输入一个正整数,表示房间个数。
第二行输入个整数
。令
,则
。
第三行输入个整数
。
接下来行,每行输入两个整数
,表示第
条通道连接房间
和房间
。
保证输入的图是一棵树。
输出描述
显然,最后答案一定可以表示成的形式。
请输出这个分数对取模后的结果。
示例1
输入:
3 1 2 3 3 2 1 1 2 2 3
输出:
887328316
说明:
距离示例2
输入:
6 1 1 4 5 1 4 1 9 1 9 8 10 1 2 2 3 2 4 1 5 1 6
输出:
249561093
示例3
输入:
10 0 76507 29535 6993 123413 149357 6751 0 21623 0 69396374 78945652 40298263 12836349 53787802 13446291 98276886 60256268 46259091 49311395 2 1 3 2 4 2 5 3 6 4 7 1 8 3 9 1 10 8
输出:
905013619
C++14(g++5.4) 解法, 执行用时: 705ms, 内存消耗: 23028K, 提交时间: 2020-10-09 21:12:38
#include<bits/stdc++.h> using namespace std; #define rep(i, a, n) for(int i=(a); i<(n); ++i) #define per(i, a, n) for(int i=(a); i>(n); --i) #define pb emplace_back #define mp make_pair #define clr(a, b) memset(a, b, sizeof(a)) #define all(x) (x).begin(),(x).end() #define lowbit(x) (x & -x) #define fi first #define se second #define lson o<<1 #define rson o<<1|1 #define gmid l[o]+r[o]>>1 using LL = long long; using ULL = unsigned long long; using pii = pair<int,int>; using PLL = pair<LL, LL>; using UI = unsigned int; const int mod = 998244353; const int inf = 0x3f3f3f3f; const double EPS = 1e-8; const double PI = acos(-1.0); const int N = 1e5 + 10; const int M = N * 3; int n, sz[N], p[N]; LL w[N], ans; bool mk[N]; vector<int> V[N], P, G; int pow_mod(LL x, int p){ LL s = 1; while(p){ if(p & 1) s = s * x % mod; x = x * x % mod; p >>= 1; } return (int)s; } void add(int &a, int b){ a += b; if(a >= mod) a -= mod; } void dfs1(int x, int fa){ sz[x] = 1; for(int j : V[x]){ if(mk[j] || j == fa) continue; dfs1(j, x); sz[x] += sz[j]; } } pii dfs2(int x, int fa, int sum){ pii ret = mp(inf, 0); int mx = sum - sz[x]; for(int j : V[x]){ if(mk[j] || j == fa) continue; ret = min(ret, dfs2(j, x, sum)); mx = max(mx, sz[j]); } return min(ret, mp(mx, x)); } void getpath(int x, int fa, int dep){ if(dep >= G.size()){ G.pb(p[x]); } else { add(G[dep], p[x]); } if(dep >= P.size()){ P.pb(p[x]); } else { add(P[dep], p[x]); } for(int j : V[x]){ if(mk[j] || j == fa) continue; getpath(j, x, dep + 1); } } int A[M], r[M], C[M]; void ntt(int *x, int lim, int opt){ int i, j, k, m, gn, g, tmp; for(i=0; i<lim; ++i){ if(r[i] < i) swap(x[i], x[r[i]]); } for(m=2; m<=lim; m<<=1){ k = m >> 1; gn = pow_mod(3, (mod - 1) / m); for(i=0; i<lim; i+=m){ g = 1; for(j=0; j<k; ++j, g=1LL*g*gn%mod){ tmp = 1LL * x[i+j+k] * g % mod; x[i+j+k] = (x[i+j] - tmp + mod) % mod; x[i+j] = (x[i+j] + tmp) % mod; } } } if(opt == -1){ reverse(x+1, x+lim); int rev = pow_mod(lim, mod - 2); for(i=0; i<lim; ++i) x[i] = 1LL * x[i] * rev % mod; } } LL solve(vector<int> &vec){ int m = vec.size(); int lim = 1; while(lim < (m << 1)) lim <<= 1; rep(i, 0, lim){ r[i] = (i & 1) * (lim >> 1) + (r[i >> 1] >> 1); A[i] = i < m ? vec[i] : 0; } ntt(A, lim, 1); rep(i, 0, lim) C[i] = 1LL * A[i] * A[i] % mod; ntt(C, lim, -1); LL ret = 0; lim = min(lim, n); rep(i, 1, lim){ ret = ret + w[i] * C[i] % mod; } return ret % mod; } void doit(int x){ dfs1(x, -1); x = dfs2(x, -1, sz[x]).se; mk[x] = 1; for(int j : V[x]){ if(!mk[j]){ doit(j); } } P.clear(); P.pb(p[x]); for(int j : V[x]){ if(mk[j]) continue; G.clear(); G.pb(0); getpath(j, x, 1); ans = (ans + mod - solve(G)) % mod; } ans = (ans + solve(P)) % mod; mk[x] = 0; } int main(){ scanf("%d", &n); LL sum = 0; rep(i, 1, n+1){ scanf("%d", p+i); sum += p[i]; } sum = pow_mod(sum, mod - 2); rep(i, 1, n+1){ p[i] = (int)(sum * p[i] % mod); } rep(i, 0, n){ scanf("%lld", w+i); } int x, y; rep(i, 1, n){ scanf("%d %d", &x, &y); V[x].pb(y); V[y].pb(x); } memset(mk, 0, sizeof(mk)); ans = 0; doit(1); rep(i, 1, n+1){ ans = ans + w[0] * p[i] % mod * p[i] % mod; } printf("%lld\n", ans % mod); return 0; }
C++11(clang++ 3.9) 解法, 执行用时: 542ms, 内存消耗: 20984K, 提交时间: 2020-10-16 15:16:46
#include<cstdio> #include<algorithm> using namespace std; const int N=2e5+5; const int mod=998244353; int fastpow(int x,int y){ int res=1; while(y){ if(y&1)res=1ll*res*x%mod; x=1ll*x*x%mod; y>>=1; } return res; } namespace ntt{ int rev[N],w[N]; void init(int len){ for(int i=0;i<len;++i) rev[i]=(rev[i>>1]>>1)|((i&1)*(len>>1)); w[0]=1;w[1]=fastpow(3,(mod-1)/len); for(int i=2;i<len;++i) w[i]=1ll*w[i-1]*w[1]%mod; } void dft(int *p,int len,int op){ for(int i=0;i<len;++i) if(i<rev[i]) swap(p[i],p[rev[i]]); for(int i=1;i<len;i<<=1) for(int j=0;j<len;j+=i<<1) for(int k=0;k<i;++k){ int x=p[j+k],y=1ll*w[len/i/2*k]*p[j+k+i]%mod; p[j+k]=(x+y)%mod; p[j+k+i]=(x-y+mod)%mod; } if(op==-1){ reverse(p+1,p+len); int inv=fastpow(len,mod-2); for(int i=0;i<len;++i) p[i]=1ll*p[i]*inv%mod; } } } int n,a[N],val[N],ans[N]; vector<int>E[N]; namespace dc{ int sz[N],w[N],vis[N],poly[N],sum,root,mxlen; void dfs(int u,int f){ sz[u]=1;w[u]=0; for(int v:E[u]) if(v!=f&&!vis[v]){ dfs(v,u); sz[u]+=sz[v]; w[u]=max(w[u],sz[v]); } w[u]=max(w[u],sum-sz[u]); if(w[u]<w[root])root=u; } void dfs2(int u,int f,int d){ sz[u]=1; poly[d]=(poly[d]+a[u])%mod; mxlen=max(mxlen,d); for(int v:E[u]) if(v!=f&&!vis[v]) dfs2(v,u,d+1),sz[u]+=sz[v]; } void calc(int u,int d,int coef){ mxlen=0; dfs2(u,0,d); int len=1; while(len<=mxlen+mxlen)len<<=1; ntt::init(len);ntt::dft(poly,len,1); for(int i=0;i<len;++i) poly[i]=1ll*poly[i]*poly[i]%mod; ntt::dft(poly,len,-1); for(int i=0;i<len;++i){ ans[i]=(ans[i]+1ll*poly[i]*coef)%mod; poly[i]=0; } } void solve(int u){ calc(u,0,1); vis[u]=1; for(int v:E[u]) if(!vis[v]){ calc(v,1,mod-1); sum=sz[v];root=0; dfs(v,0);solve(root); } } } int main(){ int sigma=0,tmp=0; scanf("%d",&n); for(int i=1;i<=n;++i) scanf("%d",&a[i]),tmp=(tmp+a[i])%mod; for(int i=0;i<n;++i) scanf("%d",&val[i]); for(int i=1,x,y;i<n;++i){ scanf("%d%d",&x,&y); E[x].push_back(y); E[y].push_back(x); } dc::sum=dc::w[0]=n; dc::dfs(1,0); dc::solve(dc::root); for(int i=0;i<n;++i) sigma=(sigma+1ll*ans[i]*val[i])%mod; sigma=1ll*sigma*fastpow(tmp,mod-3)%mod; printf("%d\n",sigma); return 0; }