NC247437. 不可道
描述
输入描述
一行三个正整数
输出描述
一行一个整数表示答案
示例1
输入:
3 3 2
输出:
221832080
示例2
输入:
5 3 3
输出:
369720132
C++(g++ 7.5.0) 解法, 执行用时: 1036ms, 内存消耗: 21516K, 提交时间: 2022-12-24 16:24:53
#define _USE_MATH_DEFINES #include <bits/stdc++.h> #define ff first #define ss second #define typet typename T #define typeu typename U #define types typename... Ts #define tempt template <typet> #define tempu template <typeu> #define temps template <types> #define tandu template <typet, typeu> #ifdef LOCAL #include "debug.h" #else #define debug(...) do { } while (false) #endif // LOCAL using i64 = int64_t; using u32 = uint32_t; using u64 = uint64_t; using pii = std::pair<int, int>; using vi = std::vector<int>; using vl = std::vector<i64>; using vs = std::vector<std::string>; using vvi = std::vector<vi>; using vp = std::vector<pii>; tempt using heap = std::priority_queue<T, std::vector<T>, std::greater<T>>; #define lowbit(x) ((x) & -(x)) #define all(x) std::begin(x), std::end(x) #define null_func [] (...) { } tandu bool Min(T& x, const U& y) { return x > y ? x = y, true : false; } tandu bool Max(T& x, const U& y) { return x < y ? x = y, true : false; } constexpr int mod = 998244353; constexpr int add(int x, int y) { return x + y < mod ? x + y : x + y - mod; } constexpr int sub(int x, int y) { return x < y ? mod + x - y : x - y; } constexpr int mul(i64 x, int y) { return x * y % mod; } constexpr void Add(int& x, int y) { x = add(x, y); } constexpr void Sub(int& x, int y) { x = sub(x, y); } constexpr void Mul(int& x, int y) { x = mul(x, y); } constexpr int pow(int x, int y, int z = 1) { for (; y; y /= 2) { if (y & 1) Mul(z, x); Mul(x, x); } return z; } temps constexpr int add(Ts... x) { int y = 0; (... , Add(y, x)); return y; } temps constexpr int mul(Ts... x) { int y = 1; (... , Mul(y, x)); return y; } namespace power_serial { constexpr int N = 1 << 19; int ceil2(int n) { return 1 << (31 - __builtin_clz(n) + (n != lowbit(n))); } int rev[N], revn; void butterfly(auto a, int n) { if (revn != n) { revn = n; for (int i = 0; i != n; ++i) { rev[i] = rev[i / 2] / 2; if (i & 1) rev[i] |= n / 2; } } for (int i = 0; i != n; ++i) if (i < rev[i]) std::swap(a[i], a[rev[i]]); } } // namespace power_serial constexpr int maxn = power_serial::N; int inv[maxn], fac[maxn], fiv[maxn]; namespace power_serial::ntt_space { template <int P = mod, int R = 3> struct ntt_t { int wn[N]; ntt_t() { int w = pow(R, (P - 1) / N); wn[N / 2] = 1; for (int i = N / 2 + 1; i != N; ++i) wn[i] = mul(wn[i - 1], w); for (int i = N / 2 - 1; i; --i) wn[i] = wn[i * 2]; } void operator()(int* a, int n, bool op = true) const { if (!op) std::reverse(a + 1, a + n); butterfly(a, n); for (int i = 1, t; i != n; i *= 2) { const int *w = wn + i; for (int *b = a; b != a + n; b += i) for (int j = 0; j != i; ++j, ++b) { b[i] = sub(*b, t = mul(b[i], w[j])); Add(*b, t); } } if (op) return; for (int i = 0, inv = P - (P - 1) / n; i != n; ++i) Mul(a[i], inv); } }; ntt_t ntt; void cycle_conv_xxx(int* f, const int* g, int n) { ntt(f, n); for (int i = 0; i != n; ++i) Mul(f[i], g[i]); ntt(f, n, false); } void cycle_conv(int* f, int* g, int n) { ntt(g, n); cycle_conv_xxx(f, g, n); } // f * g % (x^n - 1) vi operator*(const vi& f, const vi& g) { int n = f.size(), m = g.size(); if (!n or !m) return {}; vi h(n + m - 1); if (n < 20 or m < 20) { for (int i = 0; i != n; ++i) for (int j = 0; j != m; ++j) Add(h[i + j], mul(f[i], g[j])); } else { int s = ceil2(n + m - 1); int *a = new int[s]; int *b = new int[s]; std::copy(all(f), a); std::copy(all(g), b); std::fill(a + n, a + s, 0); std::fill(b + m, b + s, 0); cycle_conv(a, b, s); std::copy_n(a, n + m - 1, h.begin()); delete[] a; delete[] b; } return h; } void ntt_inv(const int* f, int* g, int n) { int m = std::min(n, 16); std::fill_n(g, m, 0); g[0] = pow(f[0], mod - 2); for (int i = 1; i != m; ++i) { for (int j = 0; j != i; ++j) Sub(g[i], mul(f[i - j], g[j])); Mul(g[i], g[0]); } int *a = new int[n]; int *b = new int[n]; for (; m != n; m *= 2) { std::copy_n(f, m * 2, a); std::copy_n(g, m, b); std::fill_n(b + m, m, 0); cycle_conv(a, b, m * 2); std::fill_n(a, m, 0); cycle_conv_xxx(a, b, m * 2); for (int i = m; i != m * 2; ++i) g[i] = sub(0, a[i]); } delete[] a; delete[] b; } // 10nlogn void ntt_log(const int* f, int* g, int n) { if (n < 32) { for (int i = 1; i != n; ++i) { for (int j = 1; j != i; ++j) Add(g[i], mul(mul(j, g[j]), f[i - j])); g[i] = sub(f[i], mul(inv[i], g[i])); } } else { int *a = new int[n * 2]; int *b = new int[n * 2]; for (int i = 1; i != n; ++i) a[i - 1] = mul(i, f[i]); a[n - 1] = 0; std::fill_n(a + n, n, 0); ntt_inv(f, b, n); std::fill_n(b + n, n, 0); cycle_conv(a, b, n * 2); g[0] = 0; for (int i = 1; i != n; ++i) g[i] = mul(inv[i], a[i - 1]); delete[] a; delete[] b; } } // 16nlogn void ntt_exp(const int* f, int* g, int n) { int m = std::min(n, 32); std::fill_n(g, m, 0); g[0] = 1; for (int i = 1; i != m; ++i) { for (int j = 1; j <= i; ++j) g[i] = add(g[i], mul(mul(j, f[j]), g[i - j])); Mul(g[i], inv[i]); } if (m == n) return; int *ptr = new int[n * 5]; int *a = ptr, *b = a + n, *c = b + n, *h = c + n, *t = h + n; ntt_inv(g, h, m); for (int i = 1; i != n; ++i) t[i - 1] = mul(i, f[i]); for (; ; m *= 2) { for (int i = 1; i != m; ++i) a[i - 1] = mul(i, g[i]); a[m - 1] = 0; std::copy_n(h, m, c); cycle_conv(a, c, m); int x = sub(t[m - 1], a[m - 1]); for (int i = 0; i != m - 1; ++i) Sub(a[i], add(t[i], t[i + m])); std::copy_n(g, m, b); cycle_conv_xxx(b, c, m); Sub(b[0], 1); std::fill_n(b + m, m, 0); std::copy_n(t, m, c); c[m - 1] = 0; std::fill_n(c + m, m, 0); cycle_conv(b, c, m * 2); for (int i = 1; i != m; ++i) c[i] = mul(inv[i + m], sub(b[i - 1], a[i - 1])); c[0] = mul(inv[m], x); std::fill_n(c + m, m, 0); std::copy_n(g, m, b); std::fill_n(b + m, m, 0); cycle_conv(c, b, m * 2); std::copy_n(c, m, g + m); if (m * 2 == n) break; std::copy_n(g, m * 2, a); std::copy_n(h, m, b); std::fill_n(b + m, m, 0); cycle_conv(a, b, m * 2); std::fill_n(a, m, 0); cycle_conv_xxx(a, b, m * 2); for (int i = m; i != m * 2; ++i) h[i] = sub(0, a[i]); } delete[] ptr; } // 22nlogn vi ntt_poly(auto func, vi f) { int n = f.size(), m = ceil2(n); f.resize(m); vi g(m); func(f.data(), g.data(), m); g.resize(n); return g; } } // namespace power_serial::ntt_space // vs split(const std::string& s, const std::string& w = "\\s+") { // std::regex reg(w); // return vs(std::sregex_token_iterator(all(s), reg, -1), std::sregex_token_iterator()); // } void initialize() { std::cin.tie(nullptr) -> sync_with_stdio(false); std::cout << std::fixed << std::setprecision(10); fac[0] = fiv[0] = 1; for (int i = 1; i != maxn; ++i) { inv[i] = mul(inv[mod % i], mod - mod / i) + (i == 1); fac[i] = mul(fac[i - 1], i); fiv[i] = mul(fiv[i - 1], inv[i]); } } int cas; void solution() { using namespace power_serial::ntt_space; int n, m, k; std::cin >> n >> m >> k; vi A(n + k + 1), B(n + k + 1); for (int i = 0; i != k; ++i) { A[i] = fiv[i]; } for (int i = 0; i <= n; ++i) { B[i] = fiv[i]; } vi Am = ntt_poly(ntt_log, A); for (int i = 1; i <= n + k; ++i) { Mul(Am[i], m); } Am = ntt_poly(ntt_exp, Am); vi Bm = ntt_poly(ntt_log, B); for (int i = 1; i <= n + k; ++i) { Mul(Bm[i], m); } Bm = ntt_poly(ntt_exp, Bm); vi f(n + 1), g(n + 1); for (int i = k; i <= n + k; ++i) { f[i - k] = sub(Bm[i], Am[i]); g[i - k] = sub(B[i], A[i]); } B.resize(n + 1); f = f * B; f.resize(n + 1); g = ntt_poly(ntt_inv, g); f = f * g; int ans = sub(f[n], mul(m, Am[n])); std::cout << pow(m, mod - 1 - n, mul(ans, fac[n])) << '\n'; } int main() { initialize(); int T = 1; // std::cin >> T; for (cas = 1; cas <= T; ++cas) solution(); return 0; }
C++(clang++ 11.0.1) 解法, 执行用时: 1536ms, 内存消耗: 47744K, 提交时间: 2022-12-23 23:22:41
#include<algorithm> #include<cstring> #include<cctype> #include<cstdio> #define rep(i,x,y) for(int i=x; i<=y; ++i) #define repd(i,x,y) for(int i=x; i>=y; --i) using namespace std; typedef long long LL; const int N=2000005,mod=998244353; int n,m,k,len,bin[N]; LL I[N],flv[N]; LL a[N],b[N],c[N],G[N],A0[N],A[N],B[N],inv[N]; int getint() { char ch; while(!isdigit(ch=getchar())); int x=ch-48; while(isdigit(ch=getchar())) x=x*10+ch-48; return x; } LL getmi(LL a,LL x) { LL rt=1; while(x) { if(x&1) rt=rt*a%mod; a=a*a%mod,x>>=1; } return rt; } void FFT(LL a[],int len,int tp) { rep(i,0,len-1) bin[i]=bin[i>>1]>>1|((i&1)*(len>>1)); rep(i,0,len-1) if(i<bin[i]) swap(a[i],a[bin[i]]); for(int i=1; i<len; i<<=1) { LL wn=getmi(3,(mod-1)/(i<<1)); if(tp==-1) wn=getmi(wn,mod-2); for(int j=0; j<len; j+=i<<1) { LL w=1,x,y; rep(k,0,i-1) { x=a[j+k],y=a[i+j+k]*w%mod,w=w*wn%mod; a[j+k]=(x+y)%mod,a[i+j+k]=(x-y+mod)%mod; } } } if(tp==-1) { LL x=getmi(len,mod-2); rep(i,0,len-1) a[i]=a[i]*x%mod; } } void get_inv(LL a[],LL b[],int n) { if(n==1) { b[0]=getmi(a[0],mod-2); return; } get_inv(a,b,n>>1); rep(i,0,n-1) G[i]=a[i]; rep(i,n,2*n-1) G[i]=0; FFT(G,n<<1,1),FFT(b,n<<1,1); rep(i,0,2*n-1) b[i]=b[i]*(2-G[i]*b[i]%mod+mod)%mod; FFT(b,n<<1,-1); rep(i,n,2*n-1) b[i]=0; } void get_ln(LL a[],LL b[],int n) { rep(i,0,2*n-1) A[i]=B[i]=0; get_inv(a,A,n); rep(i,0,n-2) B[i]=a[i+1]*(i+1)%mod; FFT(A,n<<1,1),FFT(B,n<<1,1); rep(i,0,2*n-1) A[i]=A[i]*B[i]%mod; FFT(A,n<<1,-1),b[0]=0; rep(i,1,n-1) b[i]=A[i-1]*I[i]%mod; } void get_exp(LL a[],LL b[],int n) { if(n==1) { b[0]=1; return; } get_exp(a,b,n>>1); rep(i,0,2*n-1) A0[i]=0; get_ln(b,A0,n); rep(i,0,n-1) A0[i]=(a[i]+mod-A0[i])%mod; ++A0[0],FFT(A0,n<<1,1),FFT(b,n<<1,1); rep(i,0,2*n-1) b[i]=b[i]*A0[i]%mod; FFT(b,n<<1,-1); rep(i,n,2*n-1) b[i]=0; } LL f[N],g[N],em[N],inve[N],fem[N],invb[N]; LL fe[N],_fem[N],ans; int main() { scanf("%d%d%d",&n,&m,&k); flv[0]=1; rep(i,1,n) flv[i]=flv[i-1]*i%mod; inv[n]=getmi(flv[n],mod-2); repd(i,n,1) inv[i-1]=inv[i]*i%mod; rep(i,k,n) g[i]=inv[i]; rep(i,0,k-1) f[i]=inv[i]; LL pw=1; rep(i,0,n) { inve[i]=em[i]=inv[i]; if(i&1) inve[i]*=-1; em[i]=em[i]*pw%mod; pw=pw*(m-1)%mod; } for(len=1; len<=n+n; len<<=1); rep(i,0,len-1) I[i]=getmi(i,mod-2); FFT(f,len,1),FFT(inve,len,1); FFT(g,len,1); rep(i,0,len-1) fe[i]=f[i]*inve[i]%mod; FFT(fe,len,-1); rep(i,n+1,len-1) fe[i]=0; rep(i,0,n) fem[i]=fe[i]; get_ln(fem,_fem,len); rep(i,0,len-1) _fem[i]=_fem[i]*m%mod,fem[i]=0; get_exp(_fem,fem,len); rep(i,0,n) a[i]=fem[i]*m%mod; rep(i,0,n) b[i]=fe[i+k],c[i]=fem[i+k]; get_inv(b,invb,len); rep(i,n+1,len-1) invb[i]=0; FFT(invb,len,1),FFT(c,len,1); rep(i,0,len-1) c[i]=invb[i]*c[i]%mod; FFT(c,len,-1); rep(i,n+1,len-1) c[i]=0; rep(i,0,n) a[i]=a[i]-c[i]; rep(i,0,n) a[i]=a[i+k]; FFT(a,len,1); rep(i,0,len-1) a[i]=a[i]*invb[i]%mod; FFT(a,len,-1); rep(i,n+1,len-1) a[i]=0; FFT(a,len,1); rep(i,0,len-1) a[i]=a[i]*g[i]%mod; FFT(a,len,-1); rep(i,n+1,len-1) a[i]=0; FFT(a,len,1),FFT(em,len,1); rep(i,0,len-1) a[i]=a[i]*em[i]%mod; FFT(a,len,-1); rep(i,n+1,len-1) a[i]=0; rep(i,0,n) a[i]=a[i]*flv[i]%mod; ans=a[n]*getmi(getmi(m,mod-2),n)%mod; printf("%lld\n",(ans+mod)%mod); return 0; }