NC232289. Bandit Blues
描述
输入描述
第一行包含三个整数。
输出描述
输出一个整数表示答案。
示例1
输入:
1 1 1
输出:
1
说明:
唯一可能的排列是示例2
输入:
2 1 1
输出:
0
示例3
输入:
2 2 1
输出:
1
说明:
只有两个大小为示例4
输入:
5 2 2
输出:
22
C++(g++ 7.5.0) 解法, 执行用时: 1021ms, 内存消耗: 12672K, 提交时间: 2022-10-13 14:09:13
#include<bits/stdc++.h> using namespace std; using ll = long long; constexpr ll P = 998244353; constexpr int g = 3; const int maxn = 1<<20; int qpow(int a, int k, int p = P) { int res = 1; while(k) { if(k&1) res = 1ll * res * a % p; a = 1ll * a * a % p, k >>= 1; } return res; } namespace Polynomial{ using poly = vector<int>; int rev[maxn], inv[maxn]; int norm(int x){ return 1<<(32 - __builtin_clz(x-1)); } poly& dot(poly &A, poly &B){ for(int i = 0 ; i < A.size() ; i ++) A[i] = 1ll * A[i] * B[i] % P; return A; } poly& operator += (poly& A, poly B){ A.resize(max(A.size(),B.size())); for(int i = 0 ; i < A.size() ; i ++) if(i<B.size()) A[i] = (A[i]+B[i])%P; return A; } poly operator + (poly A, poly B){ return A += B; } poly& operator *= (poly &A, int b){ for(auto &v: A) v = 1ll * v * b % P; return A; } poly operator * (poly A, int b){ return A *= b; } poly operator * (int b, poly A){ return A *= b; } void change(poly &A, int len){ for(int i = 0 ; i < len ; i ++) { rev[i] = rev[i>>1]>>1; if(i&1) rev[i] |= len>>1; } for(int i = 0 ; i < len ; i ++) if(i < rev[i]) swap(A[i], A[rev[i]]); } void ntt(poly &A, int len, int inv){ change(A,len); for(int h = 2 ; h <= len ; h <<= 1) { int wn = qpow(g, (P-1)/h); for(int j = 0 ; j < len ; j += h) { int wi = 1; for(int k = j ; k < j + h/2 ; k ++) { int u = A[k], v = 1ll * wi * A[k+h/2] % P; A[k] = (u+v>=P?u+v-P:u+v), A[k+h/2] = (u-v<0?u-v+P:u-v); wi = 1ll * wi * wn % P; } } } if(inv==-1) { reverse(A.begin()+1,A.end()); int inv_len = qpow(len,P-2); for(auto &v: A) v = 1ll * v * inv_len % P; } } poly operator * (poly A, poly B){ int n = A.size() + B.size() - 1, len = norm(n); if(A.size()<=8 || B.size()<=8) { poly C(n,0); for(int i = 0 ; i < A.size() ; i ++) for(int j = 0 ; j < B.size() ; j ++) C[i+j] = (C[i+j] + 1ll * A[i] * B[j]) % P; return C; } A.resize(len), B.resize(len); ntt(A,len,1), ntt(B,len,1), dot(A,B), ntt(A,len,-1); return A.resize(n),A; } // 求逆,长度为2^k poly inv2k(poly A) { int n = A.size(), m = n/2; if(n==1) return {qpow(A[0],P-2)}; poly B = inv2k(poly(A.begin(),A.begin()+m)), C = B; B.resize(n), ntt(A,n,1), ntt(B,n,1), dot(A,B), ntt(A,n,-1); for(int i = 0 ; i < n ; i ++) A[i] = (i<m?0:P-A[i]); ntt(A,n,1), dot(A,B), ntt(A,n,-1); return move(C.begin(),C.end(),A.begin()),A; } // 求逆 poly Inv(poly A){ int n = A.size(); A.resize(norm(n),0); A = inv2k(A); return A.resize(n),A; } // 求导 poly& derivative(poly &A){ for(int i = 1 ; i < A.size() ; i ++) A[i-1] = 1ll * A[i] * i % P; return A.pop_back(), A; } // 积分 poly& integral(poly &A){ A.push_back(0); for(int i = A.size()-1 ; i > 0 ; i --) A[i] = 1ll * A[i-1] * inv[i] % P; return A[0] = 0, A; } // ln(f) poly Ln(poly A){ int n = A.size(); A = derivative(A) * Inv(A); return A.resize(n-1), integral(A); } // exp(f) poly Exp(poly A){ int n = A.size(), len = norm(n); poly B = {1}, C; A.resize(len); for(int h = 2 ; h <= len ; h <<= 1) { B.resize(h), C = Ln(B), C.resize(h); for(int i = 0 ; i < h ; i ++) C[i] = A[i] - C[i] + (A[i]<C[i]?P:0); C[0]++, B = B * C; } return B.resize(n),B; } poly Pow(poly &A, int k){ return Exp(Ln(A)*k); } } using namespace Polynomial; ll fac[maxn], ifac[maxn]; ll C(int n, int m){ return fac[n] * ifac[m] % P * ifac[n-m] % P; } poly s1(int n, int k){ poly A(n); for(int i = 0 ; i < n ; i ++) A[i] = inv[i+1]; poly f = Pow(A,k); poly ans(n+1,0); for(int i = k ; i <= n ; i ++){ ans[i] = 1ll * f[i-k] * ifac[k] % P * fac[i] % P; } return ans; } int main() { ios::sync_with_stdio(false); cin.tie(0); int n,a,b; cin >> n >> a >> b; if(a==0 || b==0) return cout << "0\n",0; fac[0] = 1; for(int i = 1 ; i <= n ; i ++) fac[i] = fac[i-1] * i % P; ifac[n] = qpow(fac[n], P-2); for(int i = n ; i >= 1 ; i --) ifac[i-1] = ifac[i] * i % P; for(int i = 1 ; i <= n ; i ++) inv[i] = 1ll * fac[i-1] * ifac[i] % P; auto f1 = s1(n,a-1), f2 = s1(n,b-1); ll ans = 0; for(int l = a-1 ; l <= n-b ; l ++){ ans = (ans + C(n-1, l) * f1[l] % P * f2[n-l-1]) % P; } cout << ans << endl; return 0; }
C++ 解法, 执行用时: 500ms, 内存消耗: 26356K, 提交时间: 2022-01-15 11:45:05
#include<bits/stdc++.h> #define mod 998244353 using namespace std; int n,a,b,ans; int len,L,f[400005],g[400005],R[400005],lim; vector<int> F[400005]; int fac[200005],ifac[200005]; int read(){ int x=0,w=0;char ch=getchar(); while(!isdigit(ch)) w|=(ch=='-'),ch=getchar(); while(isdigit(ch)) x=(x<<3)+(x<<1)+(ch^48),ch=getchar(); return w?-x:x; } void print(int x){ if(x>=10) print(x/10); putchar(x%10+'0'); } int ksm(int x,int y){ int res=1; while(y){ if(y&1) res=1ll*res*x%mod; x=1ll*x*x%mod,y/=2; } return res; } void NTT(int *x,int on){ for(int i=0;i<len;i++) if(i<R[i]) swap(x[i],x[R[i]]); for(int i=2,wn;i<=len;i*=2){ wn=ksm(3,(mod-1)/i); if(on==-1) wn=ksm(wn,mod-2); for(int j=0,w;j<len;j+=i){ w=1; for(int k=0;k<i/2;k++){ int A=x[j+k],B=1ll*w*x[j+k+i/2]%mod; x[j+k]=(A+B)%mod; x[j+k+i/2]=(A+mod-B)%mod; w=1ll*w*wn%mod; } } } } void solve(int x,int l,int r){ if(l==r){ F[x].push_back(1),F[x].push_back(l-2); return; } int mid=(l+r)/2; solve(2*x,l,mid); solve(2*x+1,mid+1,r); int p=F[2*x].size()+F[2*x+1].size()-2; len=1,L=0; while(p>=len) len*=2,L++; for(int i=0;i<len;i++) R[i]=(R[i>>1]>>1)|((i&1)<<(L-1)); for(int i=0;i<F[2*x].size();i++) f[i]=F[2*x][i]; F[2*x].clear(); for(int i=0;i<F[2*x+1].size();i++) g[i]=F[2*x+1][i]; F[2*x+1].clear(); NTT(f,1),NTT(g,1); for(int i=0;i<len;i++) f[i]=1ll*f[i]*g[i]%mod; NTT(f,-1); lim=ksm(len,mod-2); for(int i=0;i<len;i++) f[i]=1ll*f[i]*lim%mod; for(int i=0;i<=p;i++) F[x].push_back(f[i]); for(int i=0;i<len;i++) f[i]=g[i]=0; } int C(int x,int y){ int res=1ll*fac[x]*ifac[y]%mod*ifac[x-y]%mod; return res; } int main(){ fac[0]=1; for(int i=1;i<=200000;i++) fac[i]=1ll*i*fac[i-1]%mod; ifac[200000]=ksm(fac[200000],mod-2); for(int i=199999;i>=0;i--) ifac[i]=1ll*(i+1)*ifac[i+1]%mod; n=read(),a=read(),b=read(); if(a<1||b<1) puts("0"); else{ if(n==1){ if(a==1&&b==1) puts("1"); else puts("0"); } else{ solve(1,2,n); if(a-1+b-1<=n-1) ans=1ll*C(a+b-2,a-1)*F[1][n-a-b+1]%mod; else ans=0; print(ans),puts(""); } } return 0; }