NC205436. 买糖果
描述
有家糖果店,第家店卖种糖果。
现在来了个人,每个人都有一些自己喜欢的糖果店,且对于每个糖果店的非空集合,都存在一个人喜欢的恰好是这些糖果店。
现在对于这个人中的每一个,他都会选择一家自己喜欢的糖果店,购买其中的一种糖果。
求出总共有多少种购买糖果的方案。两种方案不同,代表存在一个人,他挑选的糖果店不同,或购买的糖果种类不同。
答案对取模。
输入描述
第一行一个整数。
接下来一行个整数。
输出描述
输出一个整数,表示答案。
示例1
输入:
2 2 3
输出:
30
C++11(clang++ 3.9) 解法, 执行用时: 2177ms, 内存消耗: 24456K, 提交时间: 2020-06-19 20:21:37
#include<bits/stdc++.h> #define debug(x) cerr<<#x<<" = "<<x #define sp <<" " #define el <<endl #define fgx cerr<<"-----------------------------------"<<endl #define LL long long #define DB double using namespace std; inline LL read(){ LL nm=0; bool fh=1; char cw=getchar(); for(;!isdigit(cw);cw=getchar()) fh^=(cw=='-'); for(;isdigit(cw);cw=getchar()) nm=nm*10+(cw-'0'); return fh?nm:-nm; } #define M 1100010 #define mod 998244353 namespace CALC{ inline int add(int x,int y){return (x+y>=mod)?(x+y-mod):(x+y);} inline int mns(int x,int y){return (x-y<0)?(x-y+mod):(x-y);} inline int mul(LL x,LL y){return x*y%mod;} inline void upd(int &x,int y){x=(x+y>=mod)?(x+y-mod):(x+y);} inline void dec(int &x,int y){x=(x-y<0)?(x-y+mod):(x-y);} inline int qpow(int x,int sq){LL res=1;for(;sq;sq>>=1,x=mul(x,x))if(sq&1)res=mul(res,x);return res;} }using namespace CALC; namespace POLY{ int lg[M],g[40],v[40],od[23][M],iv[40],vv[M]; void init(int N){ N=min(N,M-1); int len=2,nw=1; for(;len<=N;len<<=1,nw++){ lg[len]=nw; for(int i=1;i<len;i++) od[nw][i]=(od[nw][i>>1]>>1)|((i&1)<<(nw-1)); } len>>=1; for(int i=1;i<23;i++)v[i]=qpow(g[i]=qpow(3,(mod-1)/(1<<i)),mod-2),iv[i]=qpow(1<<i,mod-2); for(int i=1;i<=len;i++) vv[i]=qpow(i,mod-2); vv[0]=0; } inline void NTT(int *x,int len,int kd){ int bas=lg[len]; for(int i=1;i<len;i++) if(i<od[bas][i]) swap(x[i],x[od[bas][i]]); for(int tt=1,tp=1;tt<len;tp++,tt<<=1){ const int wn=(kd>0)?g[tp]:v[tp]; for(int st=0;st<len;st+=(tt<<1)){ for(int now=1,pos=st;pos<st+tt;pos++,now=mul(now,wn)){ int t1=x[pos],t2=mul(now,x[pos+tt]); x[pos]=add(t1,t2),x[pos+tt]=mns(t1,t2); } } } if(kd>0) return; for(int i=0;i<len;i++) x[i]=mul(x[i],iv[bas]); } inline void cpy(int *_dt,int *_ss,int len){memcpy(_dt,_ss,sizeof(int)*len);} inline void tms(int *a,int lena,int *b,int lenb,int *res){ static int A[M],B[M],G[M]; if(min(lena,lenb)<=10){ memset(G,0,sizeof(int)*(lena+lenb+1)); for(int i=0;i<=lena;i++) for(int j=0;j<=lenb;j++) upd(G[i+j],mul(a[i],b[j])); } else{ int len=1; while(len<=lena+lenb) len<<=1; memset(A,0,sizeof(int)*len),memset(B,0,sizeof(int)*len); cpy(A,a,lena+1),cpy(B,b,lenb+1),NTT(A,len,1),NTT(B,len,1); for(int i=0;i<len;i++) G[i]=mul(A[i],B[i]); NTT(G,len,-1); } cpy(res,G,lena+lenb+1); } void get_inv(int *F,int *G,int len){ static int A[M],B[M]; if(len==1){G[0]=qpow(F[0],mod-2);return;} get_inv(F,G,len>>1); cpy(A,F,len),cpy(B,G,len),len<<=1,NTT(A,len,1),NTT(B,len,1); for(int i=0;i<len;i++) G[i]=mns(add(B[i],B[i]),mul(mul(B[i],B[i]),A[i])); for(int i=0;i<len;i++) A[i]=B[i]=0; NTT(G,len,-1),len>>=1; for(int i=len;i<len+len;i++) G[i]=0; } void get_mod(int *F,int *G,int n,int m,int *Q,int *R){ if(m>n){cpy(R,F,n+1),Q[0]=0;return;} static int C[M],W[M]; int len=1; while(len<n) len<<=1; reverse(F,F+n+1),reverse(G,G+m+1),cpy(C,G,m+1),get_inv(C,W,len); cpy(C,F,n+1),tms(C,n,W,n-m,W); for(int i=0;i<=n-m;i++) Q[i]=W[i]; reverse(Q,Q+n-m+1); for(int i=0;i<=n-m;i++) W[i]=Q[i]; reverse(F,F+n+1),reverse(G,G+m+1),tms(W,n-m,G,m,C); for(int i=0;i<m;i++) R[i]=mns(F[i],C[i]); memset(C,0,sizeof(int)*len),memset(W,0,sizeof(int)*len); } } using POLY::NTT; using POLY::tms; namespace EVAL{ int *p[M<<2],*st,q[M*20],t[M],g1[M],g2[M],w[M]; void calc(int x,int l,int r,int *A){ p[x]=st,st=st+r-l+3; if(l==r){p[x][1]=1,p[x][0]=mns(0,A[l]);return;} int mid=((l+r)>>1); calc(x<<1,l,mid,A),calc(x<<1|1,mid+1,r,A); tms(p[x<<1],mid-l+1,p[x<<1|1],r-mid,p[x]); } void solve(int x,int *g,int l,int r,int *res){ if(l==r){res[l]=g[0];return;} int mid=((l+r)>>1); POLY::get_mod(g,p[x<<1],r-l,mid-l+1,w,g1); POLY::get_mod(g,p[x<<1|1],r-l,r-mid,w,g2); POLY::cpy(g,g1,mid-l+1),POLY::cpy(g+mid-l+1,g2,r-mid); solve(x<<1,g,l,mid,res),solve(x<<1|1,g+mid-l+1,mid+1,r,res); } inline void getans(int *F,int n,int *A,int m,int *res){ st=q,calc(1,1,m,A); POLY::get_mod(F,p[1],n,m,w,t),solve(1,t,1,m,res); } } int n,m,A[M],X[M],a[60],F[M],Y[M]; inline void pre(int *f,int l,int r){ if(l==r){f[0]=A[l],f[1]=1;return;} int mid=((l+r)>>1); pre(f,l,mid),pre(f+mid-l+3,mid+1,r); tms(f,mid-l+1,f+mid-l+3,r-mid,f); } int main(){ POLY::init(M); n=read(),m=n,m>>=1; for(int i=0;i<n;i++) a[i]=read(); for(int sta=0;sta<(1<<m);sta++){ for(int k=0;k<m;k++) if((sta>>k)&1) A[sta+1]+=a[k]; for(int k=0;k<m;k++) if((sta>>k)&1) X[sta]+=a[m+k]; } pre(F,1,(1<<m)); int ans=1; for(int i=2;i<=(1<<m);i++) ans=mul(ans,A[i]); EVAL::getans(F,(1<<m),X,(1<<m)-1,Y); for(int i=1;i<(1<<m);i++) ans=mul(ans,Y[i]); printf("%d\n",ans); return 0; }
C++14(g++5.4) 解法, 执行用时: 315ms, 内存消耗: 24124K, 提交时间: 2020-09-23 23:10:56
#include <bits/stdc++.h> using namespace std; const int mod = 998244353, G = 3; const int N = (1<<18) + 5; int up, w[N], rev[N], inv[N]; int fpw(int a, int b) { int ans = 1; while(b) { if(b&1) ans = 1ll*ans*a%mod; a = 1ll*a*a%mod; b >>= 1; } return ans; } namespace poly { void init(int n) { inv[0] = inv[1] = 1; for(int i=2; i<=n; i++) inv[i] = 1ll*(mod-mod/i)*inv[mod%i]%mod; up = 1; int l = 0; while(up<=n) up <<= 1, l++; for(int i=0; i<up; i++) rev[i] = (rev[i>>1]>>1)|((i&1)<<(l-1)); int wn = fpw(G, mod>>l); w[up>>1] = 1; for(int i=(up>>1)+1; i<up; i++) w[i] = 1ll*w[i-1]*wn%mod; for(int i=(up>>1)-1; i>=1; i--) w[i] = w[i<<1]; } void clear(int *a, int n) { memset(a, 0, n<<2); } int getlen(int n) { return 1<<(32-__builtin_clz(n)); } inline void mul(int *a, int n, int x, int *b) { while(n--) *b++ = 1ll**a++*x%mod; } inline void dot(int *a, int *b, int n, int *c) { while(n--) *c++ = 1ll**a++**b++%mod; } void DFT(int *a, int l) { static unsigned long long tmp[N]; int u = __builtin_ctz(up/l), t; for(int i=0; i<l; i++) tmp[i] = a[rev[i]>>u]; for(int i=1; i^l; i<<=1) for(int j=0, d=i<<1; j^l; j+=d) for(int k=0; k<i; k++) t = tmp[i|j|k]*w[i|k]%mod, tmp[i|j|k] = tmp[j|k]+mod-t, tmp[j|k] += t; for(int i=0; i<l; i++) a[i] = tmp[i]%mod; } void IDFT(int *a, int l) { reverse(a+1, a+l); DFT(a, l); mul(a, l, mod-mod/l, a); } inline void conv(int *a, int *b, int l) { DFT(a, l); DFT(b, l); dot(a, b, l, a); IDFT(a, l); } void Inv(const int *a, int *b, int n) { static int c[N], l; if(n==0) { b[0] = fpw(a[0], mod-2); return; } Inv(a, b, n>>1); l = getlen(n<<1); for(int i=0; i<=n; i++) c[i] = a[i]; for(int i=n+1; i<l; i++) c[i] = 0; DFT(c, l); DFT(b, l); for(int i=0; i<l; i++) b[i] = (2ll-1ll*c[i]*b[i]%mod+mod)%mod*b[i]%mod; IDFT(b, l); for(int i=n+1; i<l; i++) b[i] = 0; } int *f[N], *g[N], buf[N<<5], *np(buf); void mul(int *a, int n, int *b, int m, int *c, int deg, int st) { static int A[N], B[N], l; l = getlen(deg), copy(a, a+n+1, A), copy(b, b+m+1, B); conv(A, B, l); copy(A+st, A+deg+1, c); clear(A, l), clear(B, l); } void eval_init(int p, int l, int r, int *a) { g[p] = np, np += r-l+2, f[p] = np, np += r-l+2; if(l==r) { g[p][0] = (mod-a[l])%mod, g[p][1] = 1; return; } int lc = p<<1, rc = lc|1, mid = (l+r)>>1, up1 = mid-l+1, up2 = r-mid; eval_init(lc, l, mid, a); eval_init(rc, mid+1, r, a); mul(g[lc], up1, g[rc], up2, g[p], up1+up2, 0); } void eval_work(int p, int l, int r, int *a) { if(l==r) { a[l] = f[p][0]; return; } int lc = p<<1, rc = lc|1, mid = (l+r)>>1, up1 = mid-l+1, up2 = r-mid; mul(f[p], r-l, g[rc], up2, f[lc], r-l, up2); eval_work(lc, l, mid, a); mul(f[p], r-l, g[lc], up1, f[rc], r-l, up1); eval_work(rc, mid+1, r, a); } void eval(int *a, int n, int *b, int m, int *c) { static int invg[N], q[N]; eval_init(1, 1, m, b); reverse(g[1], g[1]+m+1); Inv(g[1], invg, m); reverse(invg, invg+m+1); mul(a, n, invg, m, q, n+m, 0); copy(q+n+1, q+n+m+1, f[1]); eval_work(1, 1, m, c); for(int i=1; i<=m; i++) c[i] = (1ll*c[i]*b[i]%mod+a[0])%mod; } } int n, a[35]; int c[N], d[N], v[N]; int *f[N], pool[N<<5], *ptr(pool); void solve(int p, int l, int r) { f[p] = ptr, ptr += r-l+2; if(l==r) { f[p][0] = d[l], f[p][1] = 1; return; } int lc = p<<1, rc = lc|1, mid = (l+r)>>1; solve(lc, l, mid); solve(rc, mid+1, r); poly::mul(f[lc], mid-l+1, f[rc], r-mid, f[p], r-l+1, 0); } int main() { ios_base::sync_with_stdio(false); cin.tie(nullptr); cin >> n; for(int i=0; i<n; i++) cin >> a[i]; int L = n/2, R = n - L, ans = 1; poly::init((1<<(R+1))-1); for(int i=1; i<(1<<L); i++) { int cur = 0; for(int j=0; j<L; j++) if((i>>j)&1) cur += a[j]; ans = 1ll*ans*cur%mod; c[i] = cur; } for(int i=1; i<(1<<R); i++) { int cur = 0; for(int j=0; j<R; j++) if((i>>j)&1) cur += a[L+j]; ans = 1ll*ans*cur%mod; d[i] = cur; } solve(1, 1, (1<<R)-1); // for(int i=0; i<(1<<R); i++) cout << f[1][i] << ' '; cout << '\n'; poly::eval(f[1], (1<<R)-1, c, (1<<R)-1, v); for(int i=1; i<(1<<L); i++) ans = 1ll*ans*v[i]%mod; cout << ans << '\n'; return 0; }