列表

详情


NC24906. 神经网络

描述

Nowing・钟喜欢对称的东西,但可惜世界上绝大部分事物都是不对称的,因此Nowing精神崩溃了。众所周知,Sega・黄精通神经网络,因此他打算帮助Nowing修复他的神经。

当Nowing・钟精神正常时,他的神经可以看作是一棵由n个神经元组成的树。这n个神经元编号为,且恰由n-1条神经纤维连通,每个神经元都有一个神经权重a_i。而当Nowing精神崩溃后,这n-1条神经纤维全部断裂,n个神经元变得两两互不相通。Sega・黄的工作就是按某种顺序依次修复这n-1条神经纤维。

每当Sega・黄修复完一条神经纤维,便会有一些新的神经元对相连通,故会产生一些神经脉冲。假设Sega准备修复一条连接u,v神经元的神经纤维,设在修复前u所在连通块的结点集合为G(u),v所在连通块的结点集合为G(v)。
那么在修复这条神经纤维后,产生的神经脉冲值计算公式如下:



其中dist(a,b)指在原本的神经元树上,在不走重复神经元的情况下,由神经元a走到神经元b时经过的所有神经元的神经权重之和。


我们来举个例子。考虑上图,图中每个结点上的数值就是它本身的神经权重,其中结点u的神经权重为8,结点v的神经权重为2。现在要修复图中的虚线边,则修复该边所产生的神经脉冲按如下过程计算:



修复过程中产生的总神经脉冲量,是指在修复这n-1条神经纤维时,所产生的神经脉冲数值之和。对于Sega・黄来说,他每次会从还没有被修复的神经纤维中随机等概率地选择一条进行修复,因此最终产生的总神经脉冲量往往是难以预测的。不过Sega并不在意这些,他在意的是,最终产生的总神经脉冲量的期望值是多少。

不难证明这个期望值一定是个有理数,所以Sega只想知道这个期望值在模998244353意义下的结果,但他觉得这个问题太简单了,于是就扔给你来做。你能解决这个问题吗?

输入描述

第一行一个整数,表示Nowing・钟的神经结点个数。

第二行有n个整数,其中第i个整数表示第i个神经结点的神经权重。

接下来是n-1行,每行两个整数,表示有一条神经纤维连接编号为u,v的两个神经结点。数据保证给出的神经结点以及神经纤维构成一棵树。

输出描述

输出一行一个整数,表示产生的总神经脉冲量期望值在模998244353意义下的结果。
也就是说,如果实际期望值为,那么ans应满足

示例1

输入:

4
16 13 8 9
1 2
3 1
3 4

输出:

665496476

说明:

样例的中实际的总神经脉冲量期望值为\frac{722}{3},但因为665496476 \times 3 \equiv 722 \pmod{998244353},故应输出665496476。

原站题解

上次编辑到这里,代码来自缓存 点击恢复默认模板

C++14(g++5.4) 解法, 执行用时: 1526ms, 内存消耗: 31444K, 提交时间: 2019-11-05 14:10:48

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 1 << 18 | 7;
const int P = 7 * 17 << 23 | 1, G = 3;

namespace NTT {
  int w[19][2];
  int power_mod(int a, int b) {
    int ret = 1;
    for(a %= P; b; b>>=1,a=1ll*a*a%P)
      if(b&1) ret = 1ll*ret*a%P;
    return ret;
  }
  void ntt_init() {
    for(int i = 1; i < 19; i++) {
      w[i][0] = power_mod(G, P-1>>i);
      w[i][1] = power_mod(w[i][0], P-2);
    }
  }
  void ntt(int *y, int len, int on) {
    static int r[N], nl, ww, wn, u, v;
    int i, j, k, l = __builtin_ctz(len) - 1;
    if(nl != len) {
      for(i = 0, nl = len; i < len; i++)
        r[i] = (r[i>>1]>>1)|(i&1)<<l;
    }
    for(i = 0; i < len; i++)
      if(i < r[i]) swap(y[i], y[r[i]]);
    for(i = 1, l = 1; i < len; i <<= 1, l++)
      for(j = 0, wn = w[l][on]; j < len; j+=i<<1)
        for(k = j, ww = 1; k < j + i; k++, ww = 1ll*ww*wn%P)
          u = y[k], v = 1ll*y[k+i]*ww%P,
          y[k] = (u + v) % P, y[k+i] = (u - v + P) % P;
    if(on) {
      int invl = power_mod(len, P - 2);
      for(i = 0; i < len; i++) y[i] = 1ll*y[i]*invl%P;
    }
  }
}

int ans, w[N], inv[N];
void conv(vector<int>&a, vector<int>&b) {
  using namespace NTT;
  static int x[N], y[N], i;
  int l1=a.size(), l2=b.size(), l=l1+l2-1;
  while(l&(l-1)) l+=l&-l;
  for(i = 0; i < l1; i++) x[i] = a[i];
  for(i = l1; i < l; i++) x[i] = 0;
  for(i = 0; i < l2; i++) y[i] = b[i];
  for(i = l2; i < l; i++) y[i] = 0;
  ntt(x, l, 0), ntt(y, l, 0);
  for(i = 0; i < l; i++) x[i] = 1ll*x[i]*y[i]%P;
  ntt(x, l, 1);
  for(i = 1; i < l1+l2-1; i++)
    (ans += 2ll*x[i]*inv[i]%P) %= P;
}

vector<int> cnt[N], val[N], E[N];
int vis[N], sz[N], id[N];
int getroot(int u, int fa, int maxs) {
  if(sz[u] * 2 < maxs) return fa;
  for(auto &v : E[u]) {
    if(v == fa || vis[v]) continue;
    int t = getroot(v, u, maxs);
    if(sz[v] * 2 >= maxs) return t;
  }
  return u;
}
void getpoly(int x, int fa, int dep, int noww,
             vector<int>&c,vector<int>&v) {
  noww = (noww + w[x]) % P;
  if(dep < c.size()) c[dep]++, v[dep] = (v[dep]+noww)%P;
  else c.push_back(1), v.push_back(noww);
  for(int &e : E[x]) {
    if(e==fa||vis[e]) continue;
    getpoly(e, x, dep+1, noww, c, v);
  }
}
void solve(int root) {
  cnt[0].resize(1), val[0].resize(1);
  cnt[0][0] = 1, val[0][0] = w[root];
  int cc = 1;
  for(int &e : E[root]) {
    if(vis[e]) continue;
    cnt[cc].resize(1), val[cc].resize(1);
    cnt[cc][0] = val[cc][0] = 0;
    getpoly(e, root, 1, 0, cnt[cc], val[cc]);
    id[cc] = cc, cc++;
  }
  sort(id, id + cc, [](const int &x, const int &y) {
    return cnt[x].size() < cnt[y].size();
  });
  for(int i = 1; i < cc; i++) {
    conv(cnt[id[i]], val[id[i-1]]);
    conv(cnt[id[i-1]], val[id[i]]);
    for(int j = 1; j < int(cnt[id[i]].size()); j++)
      (val[id[i]][j] += 1ll*cnt[id[i]][j]*w[root]%P) %= P;
    for(int j = 0; j < int(cnt[id[i-1]].size()); j++) {
      cnt[id[i]][j] += cnt[id[i-1]][j];
      (val[id[i]][j] += val[id[i-1]][j])%=P;
    }
  }
}
void prepare(int u, int fa) {
  sz[u] = 1;
  for(auto &e : E[u]) {
    if(e==fa||vis[e]) continue;
    prepare(e, u);
    sz[u] += sz[e];
  }
}
void work(int root) {
  prepare(root, 0);
  if(sz[root] <= 1) return;
  int u = getroot(root, 0, sz[root]);
  vis[u] = 1; solve(u);
  for(int &e : E[u]) {
    if(vis[e]) continue;
    work(e);
  }
}

int main() {
  NTT::ntt_init();
  inv[1] = 1;
  for(int i = 2; i < N; i++) inv[i] = (ll)inv[P%i]*(P-P/i)%P;
  int n; scanf("%d", &n);
  for(int i = 1; i <= n; i++) scanf("%d", w + i);
  for(int i = 1, u, v; i < n; i++) {
    scanf("%d%d", &u, &v);
    E[u].push_back(v);
    E[v].push_back(u);
  }
  work(1);
  printf("%d\n", ans);
  return 0;
}

C++11(clang++ 3.9) 解法, 执行用时: 1089ms, 内存消耗: 33008K, 提交时间: 2019-04-23 12:45:45

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 1 << 18 | 7;
const int P = 7 * 17 << 23 | 1, G = 3;

namespace NTT {
	int w[19][2];
	int power_mod(int a, int b) {
		int ret = 1;
		for(a %= P; b; b>>=1,a=1ll*a*a%P)
			if(b&1) ret = 1ll*ret*a%P;
		return ret;
	}
	void ntt_init() {
		for(int i = 1; i < 19; i++) {
			w[i][0] = power_mod(G, P-1>>i);
			w[i][1] = power_mod(w[i][0], P-2);
		}
	}
	void ntt(int *y, int len, int on) {
		static int r[N], nl, ww, wn, u, v;
		int i, j, k, l = __builtin_ctz(len) - 1;
		if(nl != len) {
			for(i = 0, nl = len; i < len; i++)
				r[i] = (r[i>>1]>>1)|(i&1)<<l;
		}
		for(i = 0; i < len; i++)
			if(i < r[i]) swap(y[i], y[r[i]]);
		for(i = 1, l = 1; i < len; i <<= 1, l++)
			for(j = 0, wn = w[l][on]; j < len; j+=i<<1)
				for(k = j, ww = 1; k < j + i; k++, ww = 1ll*ww*wn%P)
					u = y[k], v = 1ll*y[k+i]*ww%P,
					y[k] = (u + v) % P, y[k+i] = (u - v + P) % P;
		if(on) {
			int invl = power_mod(len, P - 2);
			for(i = 0; i < len; i++) y[i] = 1ll*y[i]*invl%P;
		}
	}
}

int ans, w[N], inv[N];
void conv(vector<int>&a, vector<int>&b) {
	using namespace NTT;
	static int x[N], y[N], i;
	int l1=a.size(), l2=b.size(), l=l1+l2-1;
	while(l&(l-1)) l+=l&-l;
	for(i = 0; i < l1; i++) x[i] = a[i];
	for(i = l1; i < l; i++) x[i] = 0;
	for(i = 0; i < l2; i++) y[i] = b[i];
	for(i = l2; i < l; i++) y[i] = 0;
	ntt(x, l, 0), ntt(y, l, 0);
	for(i = 0; i < l; i++) x[i] = 1ll*x[i]*y[i]%P;
	ntt(x, l, 1);
	for(i = 1; i < l1+l2-1; i++)
		(ans += 2ll*x[i]*inv[i]%P) %= P;
}

vector<int> cnt[N], val[N], E[N];
int vis[N], sz[N], id[N], maxs, rt, rtv;
void getroot(int x, int fa) {
	sz[x] = 1; int t = 0;
	for(int &e : E[x]) {
		if(vis[e] || e==fa) continue;
		getroot(e, x);
		t = max(t, sz[e]);
		sz[x] += sz[e];
	}
	t = max(t, maxs - t);
	if(t < rtv) rtv = t, rt = x;
}
void getpoly(int x, int fa, int dep, int noww,
		vector<int>&c,vector<int>&v) {
	noww = (noww + w[x]) % P;
	if(dep < c.size()) c[dep]++, v[dep] = (v[dep]+noww)%P;
	else c.push_back(1), v.push_back(noww);
	sz[x] = 1;
	for(int &e : E[x]) {
		if(e==fa||vis[e]) continue;
		getpoly(e, x, dep+1, noww, c, v);
		sz[x] += sz[e];
	}
}
void solve(int root) {
	cnt[0].resize(1), val[0].resize(1);
	cnt[0][0] = 1, val[0][0] = w[root];
	int cc = 1;
	for(int &e : E[root]) {
		if(vis[e]) continue;
		cnt[cc].resize(1), val[cc].resize(1);
		cnt[cc][0] = val[cc][0] = 0;
		getpoly(e, root, 1, 0, cnt[cc], val[cc]);
		id[cc] = cc, cc++;
	}
	sort(id, id + cc, [](const int &x, const int &y){
			return cnt[x].size() < cnt[y].size();
		});
	
	for(int i = 1; i < cc; i++) {
		conv(cnt[id[i]], val[id[i-1]]);
		conv(cnt[id[i-1]], val[id[i]]);
		for(int j = 1; j < int(cnt[id[i]].size()); j++)
			(val[id[i]][j] += 1ll*cnt[id[i]][j]*w[root]%P) %= P;
		for(int j = 0; j < int(cnt[id[i-1]].size()); j++) {
			cnt[id[i]][j] += cnt[id[i-1]][j];
			(val[id[i]][j] += val[id[i-1]][j])%=P;
		}
	}
}
void work(int root) {
	if(sz[root] <= 1) return;
	maxs = sz[root], rtv = INT_MAX;
	getroot(root, 0); vis[rt] = 1;
	solve(rt);
	for(int &e : E[rt]) {
		if(vis[e]) continue;
		work(e);
	}
}

int main() {
	NTT::ntt_init();
	inv[1] = 1;
	for(int i = 2; i < N; i++) inv[i] = (ll)inv[P%i]*(P-P/i)%P;
	int n; scanf("%d", &n);
	for(int i = 1; i <= n; i++) scanf("%d", w + i);
	for(int i = 1, u, v; i < n; i++) {
		scanf("%d%d", &u, &v);
		E[u].push_back(v);
		E[v].push_back(u);
	}
	sz[1] = n; work(1);
	printf("%d\n", ans);
	return 0;
}

上一题