列表

详情


NC231478. The Child and Binary Tree

描述

我们的小朋友很喜欢计算机科学,而且尤其喜欢二叉树。 考虑一个含有 n 个互异正整数的序列 。如果一棵带点权的有根二叉树满足其所有顶点的权值都在集合 中,我们的小朋友就会将其称作神犇的。
并且他认为,一棵带点权的树的权值,是其所有顶点权值的总和。
给出一个整数 m,你能对于任意的 计算出权值为 s 的神犇二叉树的个数吗?请参照样例以更好的理解什么样的两棵二叉树会被视为不同的。 我们只需要知道答案关于 998244353 取模后的值。

输入描述

输入第一行有 2 个整数 n,m 。 
第二行有 n 个用空格隔开的互异的整数

输出描述

输出 m 行,每行有一个整数。第 i 行应当含有权值恰为 i 的神犇二叉树的总数。请输出答案关于 998244353 取模的结果。

示例1

输入:

2 3
1 2

输出:

1
3
9

说明:

有9个权值恰好为3的神犇二叉树:

示例2

输入:

3 10
9 4 3

输出:

0
0
1
1
0
2
4
2
6
15

示例3

输入:

5 10
13 10 6 4 15

输出:

0
0
0
1
0
1
0
2
0
5

原站题解

上次编辑到这里,代码来自缓存 点击恢复默认模板

C++(g++ 7.5.0) 解法, 执行用时: 626ms, 内存消耗: 11652K, 提交时间: 2022-09-17 21:22:19

#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
const int N = 1 << 19, mod = 998244353, INF = 1e5;
ll qsm(ll a, ll b) {
    ll s = 1;
    while (b) {
        if (b & 1) s = s * a % mod;
        a = a * a % mod;
        b >>= 1;
    }
    return s;
}
void change(ll y[], int len) {
    int k;
    for (int i = 1, j = len / 2; i < len - 1; i++) {
        if (i < j) std::swap(y[i], y[j]);
        k = len / 2;
        while (j >= k) {
            j = j - k;
            k = k / 2;
        }
        if (j < k) j += k;
    }
}
// on == 1 时是 DFT,on == -1 时是 IDFT
void NTT(ll y[], int len, int on) {
    change(y, len);
    for (int h = 2; h <= len; h <<= 1) {
        ll wn = qsm(3, (mod - 1) / h);
        if (on == -1) wn = qsm(wn, mod - 2);
        for (int j = 0; j < len; j += h) {
            ll w = 1;
            for (int k = j; k < j + h / 2; k++) {
                ll u = y[k];
                ll t = w * y[k + h / 2] % mod;
                y[k] = (u + t) % mod;
                y[k + h / 2] = (u - t % mod + mod) % mod;
                w = w * wn % mod;
            }
        }
    }
    if (on == -1) {
        ll inv_n = qsm(len, mod - 2);
        for (int i = 0; i < len; i++) {
            y[i] = y[i] * inv_n % mod;
        }
    }
}
void polyinv(ll f[], const ll h[], int n) {
    static ll d[N];
    f[0] = qsm(h[0], mod - 2), f[1] = 0;
    for (int w = 2; w / 2 < n; w *= 2) {
        memcpy(d, h, w * 8);
        for (int i = w; i < 2 * w; i++) d[i] = f[i] = 0;
        NTT(f, 2 * w, 1), NTT(d, 2 * w, 1);
        for (int i = 0; i < 2 * w; i++) {
            f[i] = f[i] * (2ll - f[i] * d[i] % mod + mod) % mod;
            f[i] = (f[i] % mod + mod) % mod;
        }
        NTT(f, 2 * w, -1);
        for (int i = w; i < 2 * w; i++) f[i] = 0;
    }
}
void polysqrt(ll f[], const ll h[], int n) {
    static ll t[N], inv[N];
    f[0] = 1;
    inv[0] = inv[1] = f[1] = 0;
    ll inv2 = qsm(2, mod - 2);
    for (int w = 2; w / 2 < n; w *= 2) {
        memcpy(t, h, w * 8);
        polyinv(inv, f, w);
        for (int i = w; i < 2 * w; i++) inv[i] = t[i] = 0;
        NTT(f, 2 * w, 1), NTT(t, 2 * w, 1);
        NTT(inv, 2 * w, 1);
        for (int i = 0; i < 2 * w; i++) {
            f[i] = (f[i] + t[i] * inv[i] % mod) * inv2 % mod;
        }
        NTT(f, 2 * w, -1);
        for (int i = w; i < 2 * w; i++) f[i] = 0;
    }
}
ll f[N], g[N];
int main() {
    int n, m;
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; i++) {
        int x;
        scanf("%d", &x);
        f[x] = mod - 4;
    }
//    int lim = 1;
//    while (lim <= m) lim <<= 1;
    f[0]++;
    polysqrt(g, f, m);
    g[0]++;
    polyinv(f, g, m);
    for (int i = 1; i <= m; i ++) {
        printf("%lld\n", (f[i] * 2ll % mod + mod) % mod);
    }
    return 0;
}

C++(clang++ 11.0.1) 解法, 执行用时: 574ms, 内存消耗: 16556K, 提交时间: 2023-08-10 12:35:32

#include<bits/stdc++.h>
#define int long long
#define For(i,a,b) for(i=a;i<=b;i++)
using namespace std;
const int N=4e5+10,mod=998244353,G=3,inv_G=332748118,inv_2=499122177;
int a[N],b[N];
int A[N],B[N],C[N],D[N];
int g[N];
int rev[N];
int qz(int x,int y){
	int res=1;
	for(;y;y>>=1){
		if(y&1) res=res*x%mod;
		x=x*x%mod;
	}
	return res;
}
void ntt(int a[],int sign,int tot){
	int i,j,mid;
	For(i,0,tot-1) if(i<rev[i]) swap(a[i],a[rev[i]]);
	for(mid=1;mid<tot;mid<<=1){
		auto g1=qz(G,(mod-1)/(mid<<1));
		if(sign==-1) g1=qz(inv_G,(mod-1)/(mid<<1));
		for(i=0;i<tot;i+=(mid<<1)){
			auto gk=1; 
			for(j=0;j<mid;j++,gk=gk*g1%mod){
				auto x=a[i+j],y=gk*a[i+j+mid]%mod;
				a[i+j]=(x+y)%mod,a[i+j+mid]=(x-y+mod)%mod;
			}
		}
	}
	int inv=qz(tot,mod-2);
	if(sign==-1) For(i,0,tot-1) a[i]=a[i]*inv%mod;
}
void inv(int a[],int b[],int n){
	b[0]=qz(a[0],mod-2);
	int len,tot,i;
	for(len=1;len<(n<<1);len<<=1){
		tot=len<<1;
		For(i,0,len-1) A[i]=a[i],B[i]=b[i];
		For(i,0,tot-1) rev[i]=(rev[i>>1]>>1)|((i&1)?len:0);
		ntt(A,1,tot),ntt(B,1,tot);
		For(i,0,tot-1) b[i]=((2ll-1ll*A[i]*B[i]%mod)*B[i]%mod+mod)%mod;
		ntt(b,-1,tot);
		For(i,len,tot-1) b[i]=0; 
	}
	For(i,0,len-1) A[i]=B[i]=0;
	For(i,n,len-1) b[i]=0;
}
void sqrt(int a[],int b[],int n){
	b[0]=1;
	int len,tot,i;
	for(len=1;len<(n<<1);len<<=1){
		tot=len<<1;
		For(i,0,len-1) C[i]=a[i];
		inv(b,D,len);
		For(i,0,tot-1) rev[i]=(rev[i>>1]>>1)|((i&1)?len:0);
		ntt(C,1,tot),ntt(D,1,tot);
		For(i,0,tot-1) C[i]=C[i]*D[i]%mod;
		ntt(C,-1,tot);
		For(i,0,len-1) b[i]=(b[i]+C[i])%mod*inv_2%mod;
		For(i,len,tot-1) b[i]=0;
	}
	For(i,0,len-1) C[i]=D[i]=0;
	For(i,n,len-1) b[i]=0;
}
signed main(){
	int i,n,m,x;
	cin>>n>>m;
	For(i,1,n){
		cin>>x;
		a[x]++;
	}
	For(i,1,1e5) a[i]=(-4*a[i]+mod)%mod;
	a[0]=(a[0]+1)%mod;
	sqrt(a,b,1e5+1);
	b[0]=(b[0]+1)%mod;
	inv(b,g,1e5+1);
	For(i,0,1e5) g[i]=g[i]*2%mod;
	For(i,1,m) cout<<g[i]<<'\n';
	return 0;
}

上一题