NC53556. 字符串
描述
3.这个前缀-1的个数>q
求这样的串有几个,对998244353取模
输入描述
一行五个整数n,m,p,q,k。
输出描述
输出一个整数表示答案。
示例1
输入:
2 1 0 1 2
输出:
2
C++14(g++5.4) 解法, 执行用时: 646ms, 内存消耗: 41056K, 提交时间: 2019-12-14 19:43:30
#include<bits/stdc++.h> using namespace std; typedef long long ll; const int mod = 7 * 17 << 23 | 1, G = 3; const int N = 1 << 20 | 7; int fac[N], ifac[N]; inline int C(int n, int m) { if(n < m || m < 0) return 0; return (ll) fac[n] * ifac[m] % mod * ifac[n - m] % mod; } inline int power_mod(int a, int b) { int res = 1; for(; b; b >>= 1, a = (ll) a * a % mod) if(b & 1) res = (ll) res * a % mod; return res; } inline int add(int x, int y) { return (x += y) >= mod ? x - mod : x;} int rev[N], roots[N], nbase; void NTT(int *a, int n, bool idft = 0) { int zeros = __builtin_ctz(n); if(zeros > nbase) { for(int i = 0; i < n; i++) rev[i] = (rev[i>>1]>>1)|((i&1)<<zeros-1); nbase = zeros; } int shift = nbase - zeros; for(int i = 0; i < n; i++) if(i < (rev[i]>>shift)) swap(a[i], a[rev[i]>>shift]); for(int i = 1; i < n; i <<= 1) { for(int j = 0; j < n; j += i * 2) { for(int k = 0; k < i; k++) { int z = (ll) a[i + j + k] * roots[i + k] % mod; a[i + j + k] = add(a[j + k], mod - z); a[j + k] = add(a[j + k], z); } } } if(idft) { reverse(a + 1, a + n); int invn = power_mod(n, mod - 2); for(int i = 0; i < n; i++) a[i] = (ll) a[i] * invn % mod; } } void multiply(int *a, int *b, int *c, int l1, int l2) { static int ta[N], tb[N]; int need = l1 + l2 - 1, sz = 1 << (32 - __builtin_clz(need - 1)); memcpy(ta, a, l1 * 4), memcpy(tb, b, l2 * 4); memset(ta + l1, 0, (sz - l1) * 4), memset(tb + l2, 0, (sz - l2) * 4); NTT(ta, sz), NTT(tb, sz); for(int i = 0; i < sz; i++) ta[i] = (ll) ta[i] * tb[i] % mod; NTT(ta, sz, 1); memcpy(c, ta, need * 4); } int n, m, p, q, k; int f[N], g[N>>1], ig[N>>1]; //void work(int l, int r) { // if(l > r) return; // if(r - l < 200) { // for(int i = l; i <= r; i++) { // for(int j = l; j < i; j++) { // f[i] = (f[i] - (ll) f[j] * g[i - j]) % mod; // if(f[i] < 0) f[i] += mod; // } // } // return; // } // int mid = (l + r) >> 1; // work(l, mid); // multiply(f + l, g + 1, t, mid - l + 1, r - l); // for(int i = mid + 1; i <= r; i++) // f[i] = add(f[i], mod - t[i - l - 1]); // work(mid + 1, r); //} void Polyinv(int *a, int *b, int n) { static int ta[N], tb[N]; if(n == 1) return void(b[0] = power_mod(a[0], mod - 2)); Polyinv(a, b, (n+1)>>1); int sza = n, szb = (n+1)>>1; int sz = 1 << (32 - __builtin_clz(sza + szb - 2)); memcpy(ta, a, sza * 4), memcpy(tb, b, szb * 4); memset(ta + sza, 0, (sz - sza) * 4), memset(tb + szb, 0, (sz - szb) * 4); NTT(ta, sz), NTT(tb, sz); for(int i = 0; i < sz; i++) { tb[i] = tb[i] * (2ll - (ll) ta[i] * tb[i] % mod) % mod; if(tb[i] < 0) tb[i] += mod; } NTT(tb, sz, 1); memcpy(b, tb, n * 4); } int main() { #ifdef local freopen("in.txt", "r", stdin); #endif ios::sync_with_stdio(false); cin.tie(0), cout.tie(0); fac[0] = ifac[0] = 1; for(int i = 1; i < N; i++) fac[i] = (ll) fac[i - 1] * i % mod; ifac[N - 1] = power_mod(fac[N - 1], mod - 2); for(int i = N - 2; i; i--) ifac[i] = ifac[i + 1] * (i + 1ll) % mod; roots[1] = 1; for(int i = 1; i < 20; i++) { int wn = power_mod(G, (mod - 1) >> i + 1); for(int j = 1 << i - 1; j < (1 << i); j++) { roots[j << 1] = roots[j]; roots[j<<1|1] = (ll) roots[j] * wn % mod; } } cin >> n >> m >> p >> q >> k; int res = C(n + m, n); if(k + p > n) return cout << res << '\n', 0; q = min(q, n - k); for(int i = 0; i <= q - p; i++) { f[i] = C(k + p * 2 + i * 2, p + i); g[i] = C(i * 2, i); } Polyinv(g, ig, q - p + 1); multiply(f, ig, f, q - p + 1, q - p + 1); for(int i = 0; i <= q - p; i++) { res = (res - (ll) f[i] * C(n + m - (k + p * 2 + i * 2), m - (p + i))) % mod; } if(res < 0) res += mod; cout << res << '\n'; return 0; }
C++11(clang++ 3.9) 解法, 执行用时: 927ms, 内存消耗: 76004K, 提交时间: 2019-12-13 22:00:44
#include<bits/stdc++.h> using namespace std; #define int long long const int N=2500005,M=998244353; int x1[N],fac[N],inv[N],x3[N],x2[N],n,m,p,q,k,x4[N],x[N],y[N],num; int ksm(int x,int y){ if (!y)return 1; int z=ksm(x,y/2); z*=z;z%=M; if (y&1)z*=x; return z%M; } void change(int *y,int len){ for (int i=1,j=len/2;i<len-1;i++){ if (i<j)swap(y[i],y[j]); int k=len/2; while (j>=k){ j-=k; k/=2; } j+=k; } } void NTT(int *y,int len,int opt){ change(y,len); for (int h=2;h<=len;h<<=1){ int wn=ksm(3,(M-1)/h); if (opt==-1)wn=ksm(wn,M-2); for (int j=0;j<len;j+=h){ int w=1; for (int k=j;k<j+h/2;k++){ int u=y[k],v=w*y[k+h/2]%M; y[k]=(u+v)%M; y[k+h/2]=(u-v+M)%M; w=(w*wn)%M; } } } if (opt==1)return; int tmp=ksm(len,M-2); for (int i=0;i<len;i++)(y[i]*=tmp)%=M; } void INV(int *x1,int *x2,int n){ if (n==1){ x2[0]=ksm(x1[0],M-2); return; } int len=1; while (len<2*n)len*=2; INV(x1,x2,(n+1)>>1); for (int i=0;i<n;i++)x3[i]=x1[i]; for (int i=n;i<len;i++)x3[i]=0; NTT(x3,len,1);NTT(x2,len,1); for (int i=0;i<len;i++)x2[i]=x2[i]*(2-x2[i]*x3[i]%M+M)%M; NTT(x2,len,-1); for (int i=n;i<len;i++)x2[i]=0; } int C(int x,int y){ if (x<y)return 0; return fac[x]*inv[y]%M*inv[x-y]%M; } signed main(){ scanf("%lld%lld%lld%lld%lld",&n,&m,&p,&q,&k); fac[0]=inv[0]=1; for (int i=1;i<N;i++)fac[i]=fac[i-1]*i%M; inv[N-1]=ksm(fac[N-1],M-2); for (int i=N-2;i>=0;i--)inv[i]=inv[i+1]*(i+1)%M; for (int i=q;i>=p;i--){ int j=i+k; if (j<=n){ x[num]=i; y[num]=j; x4[num]=C(n+m-i-j,m-i); num++; } } if (!num){ printf("%lld\n",C(n+m,m)); return 0; } for (int i=0;i<num;i++)x1[i]=C(2*i,i); int len=1; while (len<num*2)len*=2; INV(x1,x2,num); NTT(x2,len,1);NTT(x4,len,1); for (int i=0;i<len;i++)x4[i]=(x4[i]*x2[i])%M; NTT(x4,len,-1); int ans=C(n+m,n); for (int i=0;i<num;i++)(ans+=M-x4[i]*C(x[i]+y[i],x[i])%M)%=M; printf("%lld\n",ans); }