NC24906. 神经网络
描述
输入描述
第一行一个整数,表示Nowing・钟的神经结点个数。
第二行有n个整数,其中第i个整数表示第i个神经结点的神经权重。
接下来是n-1行,每行两个整数,表示有一条神经纤维连接编号为u,v的两个神经结点。数据保证给出的神经结点以及神经纤维构成一棵树。
输出描述
输出一行一个整数,表示产生的总神经脉冲量期望值在模998244353意义下的结果。
也就是说,如果实际期望值为,那么ans应满足。
示例1
输入:
4 16 13 8 9 1 2 3 1 3 4
输出:
665496476
说明:
样例的中实际的总神经脉冲量期望值为,但因为,故应输出665496476。C++14(g++5.4) 解法, 执行用时: 1526ms, 内存消耗: 31444K, 提交时间: 2019-11-05 14:10:48
#include<bits/stdc++.h> using namespace std; typedef long long ll; const int N = 1 << 18 | 7; const int P = 7 * 17 << 23 | 1, G = 3; namespace NTT { int w[19][2]; int power_mod(int a, int b) { int ret = 1; for(a %= P; b; b>>=1,a=1ll*a*a%P) if(b&1) ret = 1ll*ret*a%P; return ret; } void ntt_init() { for(int i = 1; i < 19; i++) { w[i][0] = power_mod(G, P-1>>i); w[i][1] = power_mod(w[i][0], P-2); } } void ntt(int *y, int len, int on) { static int r[N], nl, ww, wn, u, v; int i, j, k, l = __builtin_ctz(len) - 1; if(nl != len) { for(i = 0, nl = len; i < len; i++) r[i] = (r[i>>1]>>1)|(i&1)<<l; } for(i = 0; i < len; i++) if(i < r[i]) swap(y[i], y[r[i]]); for(i = 1, l = 1; i < len; i <<= 1, l++) for(j = 0, wn = w[l][on]; j < len; j+=i<<1) for(k = j, ww = 1; k < j + i; k++, ww = 1ll*ww*wn%P) u = y[k], v = 1ll*y[k+i]*ww%P, y[k] = (u + v) % P, y[k+i] = (u - v + P) % P; if(on) { int invl = power_mod(len, P - 2); for(i = 0; i < len; i++) y[i] = 1ll*y[i]*invl%P; } } } int ans, w[N], inv[N]; void conv(vector<int>&a, vector<int>&b) { using namespace NTT; static int x[N], y[N], i; int l1=a.size(), l2=b.size(), l=l1+l2-1; while(l&(l-1)) l+=l&-l; for(i = 0; i < l1; i++) x[i] = a[i]; for(i = l1; i < l; i++) x[i] = 0; for(i = 0; i < l2; i++) y[i] = b[i]; for(i = l2; i < l; i++) y[i] = 0; ntt(x, l, 0), ntt(y, l, 0); for(i = 0; i < l; i++) x[i] = 1ll*x[i]*y[i]%P; ntt(x, l, 1); for(i = 1; i < l1+l2-1; i++) (ans += 2ll*x[i]*inv[i]%P) %= P; } vector<int> cnt[N], val[N], E[N]; int vis[N], sz[N], id[N]; int getroot(int u, int fa, int maxs) { if(sz[u] * 2 < maxs) return fa; for(auto &v : E[u]) { if(v == fa || vis[v]) continue; int t = getroot(v, u, maxs); if(sz[v] * 2 >= maxs) return t; } return u; } void getpoly(int x, int fa, int dep, int noww, vector<int>&c,vector<int>&v) { noww = (noww + w[x]) % P; if(dep < c.size()) c[dep]++, v[dep] = (v[dep]+noww)%P; else c.push_back(1), v.push_back(noww); for(int &e : E[x]) { if(e==fa||vis[e]) continue; getpoly(e, x, dep+1, noww, c, v); } } void solve(int root) { cnt[0].resize(1), val[0].resize(1); cnt[0][0] = 1, val[0][0] = w[root]; int cc = 1; for(int &e : E[root]) { if(vis[e]) continue; cnt[cc].resize(1), val[cc].resize(1); cnt[cc][0] = val[cc][0] = 0; getpoly(e, root, 1, 0, cnt[cc], val[cc]); id[cc] = cc, cc++; } sort(id, id + cc, [](const int &x, const int &y) { return cnt[x].size() < cnt[y].size(); }); for(int i = 1; i < cc; i++) { conv(cnt[id[i]], val[id[i-1]]); conv(cnt[id[i-1]], val[id[i]]); for(int j = 1; j < int(cnt[id[i]].size()); j++) (val[id[i]][j] += 1ll*cnt[id[i]][j]*w[root]%P) %= P; for(int j = 0; j < int(cnt[id[i-1]].size()); j++) { cnt[id[i]][j] += cnt[id[i-1]][j]; (val[id[i]][j] += val[id[i-1]][j])%=P; } } } void prepare(int u, int fa) { sz[u] = 1; for(auto &e : E[u]) { if(e==fa||vis[e]) continue; prepare(e, u); sz[u] += sz[e]; } } void work(int root) { prepare(root, 0); if(sz[root] <= 1) return; int u = getroot(root, 0, sz[root]); vis[u] = 1; solve(u); for(int &e : E[u]) { if(vis[e]) continue; work(e); } } int main() { NTT::ntt_init(); inv[1] = 1; for(int i = 2; i < N; i++) inv[i] = (ll)inv[P%i]*(P-P/i)%P; int n; scanf("%d", &n); for(int i = 1; i <= n; i++) scanf("%d", w + i); for(int i = 1, u, v; i < n; i++) { scanf("%d%d", &u, &v); E[u].push_back(v); E[v].push_back(u); } work(1); printf("%d\n", ans); return 0; }
C++11(clang++ 3.9) 解法, 执行用时: 1089ms, 内存消耗: 33008K, 提交时间: 2019-04-23 12:45:45
#include<bits/stdc++.h> using namespace std; typedef long long ll; const int N = 1 << 18 | 7; const int P = 7 * 17 << 23 | 1, G = 3; namespace NTT { int w[19][2]; int power_mod(int a, int b) { int ret = 1; for(a %= P; b; b>>=1,a=1ll*a*a%P) if(b&1) ret = 1ll*ret*a%P; return ret; } void ntt_init() { for(int i = 1; i < 19; i++) { w[i][0] = power_mod(G, P-1>>i); w[i][1] = power_mod(w[i][0], P-2); } } void ntt(int *y, int len, int on) { static int r[N], nl, ww, wn, u, v; int i, j, k, l = __builtin_ctz(len) - 1; if(nl != len) { for(i = 0, nl = len; i < len; i++) r[i] = (r[i>>1]>>1)|(i&1)<<l; } for(i = 0; i < len; i++) if(i < r[i]) swap(y[i], y[r[i]]); for(i = 1, l = 1; i < len; i <<= 1, l++) for(j = 0, wn = w[l][on]; j < len; j+=i<<1) for(k = j, ww = 1; k < j + i; k++, ww = 1ll*ww*wn%P) u = y[k], v = 1ll*y[k+i]*ww%P, y[k] = (u + v) % P, y[k+i] = (u - v + P) % P; if(on) { int invl = power_mod(len, P - 2); for(i = 0; i < len; i++) y[i] = 1ll*y[i]*invl%P; } } } int ans, w[N], inv[N]; void conv(vector<int>&a, vector<int>&b) { using namespace NTT; static int x[N], y[N], i; int l1=a.size(), l2=b.size(), l=l1+l2-1; while(l&(l-1)) l+=l&-l; for(i = 0; i < l1; i++) x[i] = a[i]; for(i = l1; i < l; i++) x[i] = 0; for(i = 0; i < l2; i++) y[i] = b[i]; for(i = l2; i < l; i++) y[i] = 0; ntt(x, l, 0), ntt(y, l, 0); for(i = 0; i < l; i++) x[i] = 1ll*x[i]*y[i]%P; ntt(x, l, 1); for(i = 1; i < l1+l2-1; i++) (ans += 2ll*x[i]*inv[i]%P) %= P; } vector<int> cnt[N], val[N], E[N]; int vis[N], sz[N], id[N], maxs, rt, rtv; void getroot(int x, int fa) { sz[x] = 1; int t = 0; for(int &e : E[x]) { if(vis[e] || e==fa) continue; getroot(e, x); t = max(t, sz[e]); sz[x] += sz[e]; } t = max(t, maxs - t); if(t < rtv) rtv = t, rt = x; } void getpoly(int x, int fa, int dep, int noww, vector<int>&c,vector<int>&v) { noww = (noww + w[x]) % P; if(dep < c.size()) c[dep]++, v[dep] = (v[dep]+noww)%P; else c.push_back(1), v.push_back(noww); sz[x] = 1; for(int &e : E[x]) { if(e==fa||vis[e]) continue; getpoly(e, x, dep+1, noww, c, v); sz[x] += sz[e]; } } void solve(int root) { cnt[0].resize(1), val[0].resize(1); cnt[0][0] = 1, val[0][0] = w[root]; int cc = 1; for(int &e : E[root]) { if(vis[e]) continue; cnt[cc].resize(1), val[cc].resize(1); cnt[cc][0] = val[cc][0] = 0; getpoly(e, root, 1, 0, cnt[cc], val[cc]); id[cc] = cc, cc++; } sort(id, id + cc, [](const int &x, const int &y){ return cnt[x].size() < cnt[y].size(); }); for(int i = 1; i < cc; i++) { conv(cnt[id[i]], val[id[i-1]]); conv(cnt[id[i-1]], val[id[i]]); for(int j = 1; j < int(cnt[id[i]].size()); j++) (val[id[i]][j] += 1ll*cnt[id[i]][j]*w[root]%P) %= P; for(int j = 0; j < int(cnt[id[i-1]].size()); j++) { cnt[id[i]][j] += cnt[id[i-1]][j]; (val[id[i]][j] += val[id[i-1]][j])%=P; } } } void work(int root) { if(sz[root] <= 1) return; maxs = sz[root], rtv = INT_MAX; getroot(root, 0); vis[rt] = 1; solve(rt); for(int &e : E[rt]) { if(vis[e]) continue; work(e); } } int main() { NTT::ntt_init(); inv[1] = 1; for(int i = 2; i < N; i++) inv[i] = (ll)inv[P%i]*(P-P/i)%P; int n; scanf("%d", &n); for(int i = 1; i <= n; i++) scanf("%d", w + i); for(int i = 1, u, v; i < n; i++) { scanf("%d%d", &u, &v); E[u].push_back(v); E[v].push_back(u); } sz[1] = n; work(1); printf("%d\n", ans); return 0; }