列表

详情


NC205436. 买糖果

描述

家糖果店,第家店卖种糖果。

现在来了个人,每个人都有一些自己喜欢的糖果店,且对于每个糖果店的非空集合,都存在一个人喜欢的恰好是这些糖果店。

现在对于这个人中的每一个,他都会选择一家自己喜欢的糖果店,购买其中的一种糖果。

求出总共有多少种购买糖果的方案。两种方案不同,代表存在一个人,他挑选的糖果店不同,或购买的糖果种类不同。

答案对取模。

输入描述

第一行一个整数

接下来一行个整数

输出描述

输出一个整数,表示答案。

示例1

输入:

2
2 3

输出:

30

原站题解

上次编辑到这里,代码来自缓存 点击恢复默认模板

C++11(clang++ 3.9) 解法, 执行用时: 2177ms, 内存消耗: 24456K, 提交时间: 2020-06-19 20:21:37

#include<bits/stdc++.h>
#define debug(x) cerr<<#x<<" = "<<x
#define sp <<"  "
#define el <<endl
#define fgx cerr<<"-----------------------------------"<<endl
#define LL long long
#define DB double
using namespace std;
inline LL read(){
	LL nm=0; bool fh=1; char cw=getchar();
	for(;!isdigit(cw);cw=getchar()) fh^=(cw=='-');
	for(;isdigit(cw);cw=getchar()) nm=nm*10+(cw-'0');
	return fh?nm:-nm;
}
#define M 1100010
#define mod 998244353
namespace  CALC{
	inline int add(int x,int y){return (x+y>=mod)?(x+y-mod):(x+y);}
	inline int mns(int x,int y){return (x-y<0)?(x-y+mod):(x-y);}
	inline int mul(LL x,LL y){return x*y%mod;}
	inline void upd(int &x,int y){x=(x+y>=mod)?(x+y-mod):(x+y);}
	inline void dec(int &x,int y){x=(x-y<0)?(x-y+mod):(x-y);}
	inline int qpow(int x,int sq){LL res=1;for(;sq;sq>>=1,x=mul(x,x))if(sq&1)res=mul(res,x);return res;}
}using namespace CALC;
namespace POLY{
	int lg[M],g[40],v[40],od[23][M],iv[40],vv[M];
	void init(int N){
		N=min(N,M-1); int len=2,nw=1; 
		for(;len<=N;len<<=1,nw++){
			lg[len]=nw;
			for(int i=1;i<len;i++) od[nw][i]=(od[nw][i>>1]>>1)|((i&1)<<(nw-1));
		} len>>=1;
		for(int i=1;i<23;i++)v[i]=qpow(g[i]=qpow(3,(mod-1)/(1<<i)),mod-2),iv[i]=qpow(1<<i,mod-2);
		for(int i=1;i<=len;i++) vv[i]=qpow(i,mod-2); vv[0]=0; 
	}
	inline void NTT(int *x,int len,int kd){
		int bas=lg[len];
		for(int i=1;i<len;i++) if(i<od[bas][i]) swap(x[i],x[od[bas][i]]);
		for(int tt=1,tp=1;tt<len;tp++,tt<<=1){
			const int wn=(kd>0)?g[tp]:v[tp];
			for(int st=0;st<len;st+=(tt<<1)){
				for(int now=1,pos=st;pos<st+tt;pos++,now=mul(now,wn)){
					int t1=x[pos],t2=mul(now,x[pos+tt]);
					x[pos]=add(t1,t2),x[pos+tt]=mns(t1,t2);
				}
			}
		} if(kd>0) return;
		for(int i=0;i<len;i++) x[i]=mul(x[i],iv[bas]);
	}
	inline void cpy(int *_dt,int *_ss,int len){memcpy(_dt,_ss,sizeof(int)*len);}
	inline void tms(int *a,int lena,int *b,int lenb,int *res){
		static int A[M],B[M],G[M]; 
		if(min(lena,lenb)<=10){
			memset(G,0,sizeof(int)*(lena+lenb+1));
			for(int i=0;i<=lena;i++) for(int j=0;j<=lenb;j++)
				upd(G[i+j],mul(a[i],b[j]));
		}
		else{
			int len=1; while(len<=lena+lenb) len<<=1;
			memset(A,0,sizeof(int)*len),memset(B,0,sizeof(int)*len);
			cpy(A,a,lena+1),cpy(B,b,lenb+1),NTT(A,len,1),NTT(B,len,1);
			for(int i=0;i<len;i++) G[i]=mul(A[i],B[i]); NTT(G,len,-1);
		} cpy(res,G,lena+lenb+1);
	}
	void get_inv(int *F,int *G,int len){
		static int A[M],B[M];
		if(len==1){G[0]=qpow(F[0],mod-2);return;} get_inv(F,G,len>>1);
		cpy(A,F,len),cpy(B,G,len),len<<=1,NTT(A,len,1),NTT(B,len,1);
		for(int i=0;i<len;i++) G[i]=mns(add(B[i],B[i]),mul(mul(B[i],B[i]),A[i]));
		for(int i=0;i<len;i++) A[i]=B[i]=0; NTT(G,len,-1),len>>=1;
		for(int i=len;i<len+len;i++) G[i]=0;
	}
	void get_mod(int *F,int *G,int n,int m,int *Q,int *R){
		if(m>n){cpy(R,F,n+1),Q[0]=0;return;}
		static int C[M],W[M]; int len=1; while(len<n) len<<=1;
		reverse(F,F+n+1),reverse(G,G+m+1),cpy(C,G,m+1),get_inv(C,W,len);
		cpy(C,F,n+1),tms(C,n,W,n-m,W);
		for(int i=0;i<=n-m;i++) Q[i]=W[i]; reverse(Q,Q+n-m+1);
		for(int i=0;i<=n-m;i++) W[i]=Q[i];
		reverse(F,F+n+1),reverse(G,G+m+1),tms(W,n-m,G,m,C);
		for(int i=0;i<m;i++) R[i]=mns(F[i],C[i]);
		memset(C,0,sizeof(int)*len),memset(W,0,sizeof(int)*len);
	}
}
using POLY::NTT;
using POLY::tms;
namespace EVAL{
	int *p[M<<2],*st,q[M*20],t[M],g1[M],g2[M],w[M];
	void calc(int x,int l,int r,int *A){
		p[x]=st,st=st+r-l+3; if(l==r){p[x][1]=1,p[x][0]=mns(0,A[l]);return;}
		int mid=((l+r)>>1); calc(x<<1,l,mid,A),calc(x<<1|1,mid+1,r,A);
		tms(p[x<<1],mid-l+1,p[x<<1|1],r-mid,p[x]);
	}
	void solve(int x,int *g,int l,int r,int *res){
		if(l==r){res[l]=g[0];return;} int mid=((l+r)>>1);
		POLY::get_mod(g,p[x<<1],r-l,mid-l+1,w,g1);
		POLY::get_mod(g,p[x<<1|1],r-l,r-mid,w,g2);
		POLY::cpy(g,g1,mid-l+1),POLY::cpy(g+mid-l+1,g2,r-mid);
		solve(x<<1,g,l,mid,res),solve(x<<1|1,g+mid-l+1,mid+1,r,res);
	}
	inline void getans(int *F,int n,int *A,int m,int *res){
		st=q,calc(1,1,m,A);
		POLY::get_mod(F,p[1],n,m,w,t),solve(1,t,1,m,res);
	}
}
int n,m,A[M],X[M],a[60],F[M],Y[M];
inline void pre(int *f,int l,int r){
	if(l==r){f[0]=A[l],f[1]=1;return;} int mid=((l+r)>>1);
	pre(f,l,mid),pre(f+mid-l+3,mid+1,r);
	tms(f,mid-l+1,f+mid-l+3,r-mid,f);
}
int main(){
	POLY::init(M);
	n=read(),m=n,m>>=1;
	for(int i=0;i<n;i++) a[i]=read();
	for(int sta=0;sta<(1<<m);sta++){
		for(int k=0;k<m;k++) if((sta>>k)&1) A[sta+1]+=a[k];
		for(int k=0;k<m;k++) if((sta>>k)&1) X[sta]+=a[m+k];
	}
	pre(F,1,(1<<m)); int ans=1;
	for(int i=2;i<=(1<<m);i++) ans=mul(ans,A[i]);
	EVAL::getans(F,(1<<m),X,(1<<m)-1,Y);
	for(int i=1;i<(1<<m);i++) ans=mul(ans,Y[i]);
	printf("%d\n",ans);
	return 0;
}

C++14(g++5.4) 解法, 执行用时: 315ms, 内存消耗: 24124K, 提交时间: 2020-09-23 23:10:56

#include <bits/stdc++.h>
using namespace std;
const int mod = 998244353, G = 3;
const int N = (1<<18) + 5;
int up, w[N], rev[N], inv[N];
int fpw(int a, int b)
{
	int ans = 1;
	while(b)
	{
		if(b&1) ans = 1ll*ans*a%mod;
		a = 1ll*a*a%mod;
		b >>= 1;
	}
	return ans;
}
namespace poly
{
	void init(int n)
	{
		inv[0] = inv[1] = 1;
		for(int i=2; i<=n; i++) inv[i] = 1ll*(mod-mod/i)*inv[mod%i]%mod;
		up = 1; int l = 0;
		while(up<=n) up <<= 1, l++;
		for(int i=0; i<up; i++) rev[i] = (rev[i>>1]>>1)|((i&1)<<(l-1));
		int wn = fpw(G, mod>>l); w[up>>1] = 1;
		for(int i=(up>>1)+1; i<up; i++) w[i] = 1ll*w[i-1]*wn%mod;
		for(int i=(up>>1)-1; i>=1; i--) w[i] = w[i<<1];
	}
	void clear(int *a, int n) { memset(a, 0, n<<2); }
	int getlen(int n) { return 1<<(32-__builtin_clz(n)); }
	inline void mul(int *a, int n, int x, int *b) { while(n--) *b++ = 1ll**a++*x%mod; }
	inline void dot(int *a, int *b, int n, int *c) { while(n--) *c++ = 1ll**a++**b++%mod; }
	void DFT(int *a, int l)
	{
		static unsigned long long tmp[N];
		int u = __builtin_ctz(up/l), t;
		for(int i=0; i<l; i++) tmp[i] = a[rev[i]>>u];
		for(int i=1; i^l; i<<=1)
			for(int j=0, d=i<<1; j^l; j+=d)
				for(int k=0; k<i; k++)
					t = tmp[i|j|k]*w[i|k]%mod, tmp[i|j|k] = tmp[j|k]+mod-t, tmp[j|k] += t;
		for(int i=0; i<l; i++) a[i] = tmp[i]%mod;
	}
	void IDFT(int *a, int l)
	{
		reverse(a+1, a+l); DFT(a, l);
		mul(a, l, mod-mod/l, a);
	}
	inline void conv(int *a, int *b, int l) { DFT(a, l); DFT(b, l); dot(a, b, l, a); IDFT(a, l); }
    void Inv(const int *a, int *b, int n)
	{
		static int c[N], l;
		if(n==0) { b[0] = fpw(a[0], mod-2); return; }
		Inv(a, b, n>>1); l = getlen(n<<1);
		for(int i=0; i<=n; i++) c[i] = a[i];
		for(int i=n+1; i<l; i++) c[i] = 0;
		DFT(c, l); DFT(b, l);
		for(int i=0; i<l; i++) b[i] = (2ll-1ll*c[i]*b[i]%mod+mod)%mod*b[i]%mod;
		IDFT(b, l);
		for(int i=n+1; i<l; i++) b[i] = 0;
	}
    int *f[N], *g[N], buf[N<<5], *np(buf);
    void mul(int *a, int n, int *b, int m, int *c, int deg, int st)
    {
        static int A[N], B[N], l;
        l = getlen(deg), copy(a, a+n+1, A), copy(b, b+m+1, B);
        conv(A, B, l); copy(A+st, A+deg+1, c);
        clear(A, l), clear(B, l);
    }
    void eval_init(int p, int l, int r, int *a)
    {
        g[p] = np, np += r-l+2, f[p] = np, np += r-l+2;
        if(l==r) { g[p][0] = (mod-a[l])%mod, g[p][1] = 1; return; } 
        int lc = p<<1, rc = lc|1, mid = (l+r)>>1, up1 = mid-l+1, up2 = r-mid;
        eval_init(lc, l, mid, a); eval_init(rc, mid+1, r, a);
        mul(g[lc], up1, g[rc], up2, g[p], up1+up2, 0);
    }
    void eval_work(int p, int l, int r, int *a)
    {
        if(l==r) { a[l] = f[p][0]; return; }
        int lc = p<<1, rc = lc|1, mid = (l+r)>>1, up1 = mid-l+1, up2 = r-mid;
        mul(f[p], r-l, g[rc], up2, f[lc], r-l, up2);
        eval_work(lc, l, mid, a);
        mul(f[p], r-l, g[lc], up1, f[rc], r-l, up1);
        eval_work(rc, mid+1, r, a);
    }
    void eval(int *a, int n, int *b, int m, int *c)
    {
        static int invg[N], q[N];
        eval_init(1, 1, m, b);
        reverse(g[1], g[1]+m+1);
        Inv(g[1], invg, m);
        reverse(invg, invg+m+1);
        mul(a, n, invg, m, q, n+m, 0);
        copy(q+n+1, q+n+m+1, f[1]); 
        eval_work(1, 1, m, c);
        for(int i=1; i<=m; i++) c[i] = (1ll*c[i]*b[i]%mod+a[0])%mod;
    }
} 
int n, a[35];
int c[N], d[N], v[N];
int *f[N], pool[N<<5], *ptr(pool);
void solve(int p, int l, int r)
{
    f[p] = ptr, ptr += r-l+2;
    if(l==r) { f[p][0] = d[l], f[p][1] = 1; return; }
    int lc = p<<1, rc = lc|1, mid = (l+r)>>1;
    solve(lc, l, mid); solve(rc, mid+1, r);
    poly::mul(f[lc], mid-l+1, f[rc], r-mid, f[p], r-l+1, 0);
}
int main()
{
    ios_base::sync_with_stdio(false); cin.tie(nullptr);
    cin >> n;
    for(int i=0; i<n; i++) cin >> a[i];
    int L = n/2, R = n - L, ans = 1;
    poly::init((1<<(R+1))-1);
    for(int i=1; i<(1<<L); i++)
    {
        int cur = 0;
        for(int j=0; j<L; j++)
            if((i>>j)&1) cur += a[j];
        ans = 1ll*ans*cur%mod;
        c[i] = cur;
    }
    for(int i=1; i<(1<<R); i++)
    {
        int cur = 0;
        for(int j=0; j<R; j++)
            if((i>>j)&1) cur += a[L+j];
        ans = 1ll*ans*cur%mod;
        d[i] = cur;
    }
    solve(1, 1, (1<<R)-1);
   // for(int i=0; i<(1<<R); i++) cout << f[1][i] << ' '; cout << '\n';
    poly::eval(f[1], (1<<R)-1, c, (1<<R)-1, v);
    for(int i=1; i<(1<<L); i++) ans = 1ll*ans*v[i]%mod;
    cout << ans << '\n';
    return 0;
}

上一题