列表

详情


NC247437. 不可道

描述

n个随机变量,每个随机变量从之间的整数均匀随机。求最小的,至少出现了k次的数字的期望(如果这个数字不存在,则视为0)

例:若,则是一种可能的情况,其最小的至少出现了k次的数字为3;也是一种可能的情况,其不存在至少出现k次的数字,故视为0

答案模998244353输出

输入描述

一行三个正整数n,m,k

输出描述

一行一个整数表示答案

示例1

输入:

3 3 2

输出:

221832080

示例2

输入:

5 3 3

输出:

369720132

原站题解

上次编辑到这里,代码来自缓存 点击恢复默认模板

C++(g++ 7.5.0) 解法, 执行用时: 1036ms, 内存消耗: 21516K, 提交时间: 2022-12-24 16:24:53

#define _USE_MATH_DEFINES
#include <bits/stdc++.h>

#define ff first
#define ss second

#define typet typename T
#define typeu typename U
#define types typename... Ts
#define tempt template <typet>
#define tempu template <typeu>
#define temps template <types>
#define tandu template <typet, typeu>

#ifdef LOCAL
#include "debug.h"
#else
#define debug(...) do { } while (false)
#endif // LOCAL

using i64 = int64_t;
using u32 = uint32_t;
using u64 = uint64_t;
using pii = std::pair<int, int>;
using vi = std::vector<int>;
using vl = std::vector<i64>;
using vs = std::vector<std::string>;
using vvi = std::vector<vi>;
using vp = std::vector<pii>;
tempt using heap = std::priority_queue<T, std::vector<T>, std::greater<T>>;

#define lowbit(x) ((x) & -(x))
#define all(x)    std::begin(x), std::end(x)
#define null_func [] (...) { }

tandu bool Min(T& x, const U& y) { return x > y ? x = y, true : false; }
tandu bool Max(T& x, const U& y) { return x < y ? x = y, true : false; }

constexpr int mod = 998244353;
constexpr int add(int x, int y) { return x + y < mod ? x + y : x + y - mod; }
constexpr int sub(int x, int y) { return x < y ? mod + x - y : x - y; }
constexpr int mul(i64 x, int y) { return x * y % mod; }
constexpr void Add(int& x, int y) { x = add(x, y); }
constexpr void Sub(int& x, int y) { x = sub(x, y); }
constexpr void Mul(int& x, int y) { x = mul(x, y); }
constexpr int pow(int x, int y, int z = 1) {
  for (; y; y /= 2) {
    if (y & 1) Mul(z, x);
    Mul(x, x);
  }
  return z;
}
temps constexpr int add(Ts... x) { int y = 0; (... , Add(y, x)); return y; }
temps constexpr int mul(Ts... x) { int y = 1; (... , Mul(y, x)); return y; }

namespace power_serial {

constexpr int N = 1 << 19;

int ceil2(int n) { return 1 << (31 - __builtin_clz(n) + (n != lowbit(n))); }

int rev[N], revn;
void butterfly(auto a, int n) {
  if (revn != n) {
    revn = n;
    for (int i = 0; i != n; ++i) {
      rev[i] = rev[i / 2] / 2;
      if (i & 1) rev[i] |= n / 2;
    }
  }
  for (int i = 0; i != n; ++i) if (i < rev[i]) std::swap(a[i], a[rev[i]]);
}

} // namespace power_serial

constexpr int maxn = power_serial::N;

int inv[maxn], fac[maxn], fiv[maxn];

namespace power_serial::ntt_space {

template <int P = mod, int R = 3> struct ntt_t {
  int wn[N];
  ntt_t() {
    int w = pow(R, (P - 1) / N); wn[N / 2] = 1;
    for (int i = N / 2 + 1; i != N; ++i) wn[i] = mul(wn[i - 1], w);
    for (int i = N / 2 - 1; i; --i) wn[i] = wn[i * 2];
  }
  void operator()(int* a, int n, bool op = true) const {
    if (!op) std::reverse(a + 1, a + n);
    butterfly(a, n);
    for (int i = 1, t; i != n; i *= 2) {
      const int *w = wn + i;
      for (int *b = a; b != a + n; b += i) for (int j = 0; j != i; ++j, ++b) {
        b[i] = sub(*b, t = mul(b[i], w[j])); Add(*b, t);
      }
    }
    if (op) return;
    for (int i = 0, inv = P - (P - 1) / n; i != n; ++i) Mul(a[i], inv);
  }
};

ntt_t ntt;
void cycle_conv_xxx(int* f, const int* g, int n) {
  ntt(f, n);
  for (int i = 0; i != n; ++i) Mul(f[i], g[i]);
  ntt(f, n, false);
}
void cycle_conv(int* f, int* g, int n) { ntt(g, n); cycle_conv_xxx(f, g, n); } // f * g % (x^n - 1)

vi operator*(const vi& f, const vi& g) {
  int n = f.size(), m = g.size();
  if (!n or !m) return {};
  vi h(n + m - 1);
  if (n < 20 or m < 20) {
    for (int i = 0; i != n; ++i) for (int j = 0; j != m; ++j) Add(h[i + j], mul(f[i], g[j]));
  } else {
    int s = ceil2(n + m - 1);
    int *a = new int[s];
    int *b = new int[s];
    std::copy(all(f), a);
    std::copy(all(g), b);
    std::fill(a + n, a + s, 0);
    std::fill(b + m, b + s, 0);
    cycle_conv(a, b, s);
    std::copy_n(a, n + m - 1, h.begin());
    delete[] a;
    delete[] b;
  }
  return h;
}

void ntt_inv(const int* f, int* g, int n) {
  int m = std::min(n, 16);
  std::fill_n(g, m, 0);
  g[0] = pow(f[0], mod - 2);
  for (int i = 1; i != m; ++i) {
    for (int j = 0; j != i; ++j) Sub(g[i], mul(f[i - j], g[j]));
    Mul(g[i], g[0]);
  }
  int *a = new int[n];
  int *b = new int[n];
  for (; m != n; m *= 2) {
    std::copy_n(f, m * 2, a);
    std::copy_n(g, m, b);
    std::fill_n(b + m, m, 0);
    cycle_conv(a, b, m * 2);
    std::fill_n(a, m, 0);
    cycle_conv_xxx(a, b, m * 2);
    for (int i = m; i != m * 2; ++i) g[i] = sub(0, a[i]);
  }
  delete[] a;
  delete[] b;
} // 10nlogn

void ntt_log(const int* f, int* g, int n) {
  if (n < 32) {
    for (int i = 1; i != n; ++i) {
      for (int j = 1; j != i; ++j) Add(g[i], mul(mul(j, g[j]), f[i - j]));
      g[i] = sub(f[i], mul(inv[i], g[i]));
    }
  } else {
    int *a = new int[n * 2];
    int *b = new int[n * 2];
    for (int i = 1; i != n; ++i) a[i - 1] = mul(i, f[i]);
    a[n - 1] = 0;
    std::fill_n(a + n, n, 0);
    ntt_inv(f, b, n);
    std::fill_n(b + n, n, 0);
    cycle_conv(a, b, n * 2);
    g[0] = 0;
    for (int i = 1; i != n; ++i) g[i] = mul(inv[i], a[i - 1]);
    delete[] a;
    delete[] b;
  }
} // 16nlogn

void ntt_exp(const int* f, int* g, int n) {
  int m = std::min(n, 32);
  std::fill_n(g, m, 0);
  g[0] = 1;
  for (int i = 1; i != m; ++i) {
    for (int j = 1; j <= i; ++j) g[i] = add(g[i], mul(mul(j, f[j]), g[i - j]));
    Mul(g[i], inv[i]);
  }
  if (m == n) return;
  int *ptr = new int[n * 5];
  int *a = ptr, *b = a + n, *c = b + n, *h = c + n, *t = h + n;
  ntt_inv(g, h, m);
  for (int i = 1; i != n; ++i) t[i - 1] = mul(i, f[i]);
  for (; ; m *= 2) {
    for (int i = 1; i != m; ++i) a[i - 1] = mul(i, g[i]);
    a[m - 1] = 0;
    std::copy_n(h, m, c);
    cycle_conv(a, c, m);
    int x = sub(t[m - 1], a[m - 1]);
    for (int i = 0; i != m - 1; ++i) Sub(a[i], add(t[i], t[i + m]));
    std::copy_n(g, m, b);
    cycle_conv_xxx(b, c, m);
    Sub(b[0], 1);
    std::fill_n(b + m, m, 0);
    std::copy_n(t, m, c);
    c[m - 1] = 0;
    std::fill_n(c + m, m, 0);
    cycle_conv(b, c, m * 2);
    for (int i = 1; i != m; ++i) c[i] = mul(inv[i + m], sub(b[i - 1], a[i - 1]));
    c[0] = mul(inv[m], x);
    std::fill_n(c + m, m, 0);
    std::copy_n(g, m, b);
    std::fill_n(b + m, m, 0);
    cycle_conv(c, b, m * 2);
    std::copy_n(c, m, g + m);
    if (m * 2 == n) break;
    std::copy_n(g, m * 2, a);
    std::copy_n(h, m, b);
    std::fill_n(b + m, m, 0);
    cycle_conv(a, b, m * 2);
    std::fill_n(a, m, 0);
    cycle_conv_xxx(a, b, m * 2);
    for (int i = m; i != m * 2; ++i) h[i] = sub(0, a[i]);
  }
  delete[] ptr;
} // 22nlogn

vi ntt_poly(auto func, vi f) {
  int n = f.size(), m = ceil2(n);
  f.resize(m);
  vi g(m);
  func(f.data(), g.data(), m);
  g.resize(n);
  return g;
}

} // namespace power_serial::ntt_space

// vs split(const std::string& s, const std::string& w = "\\s+") {
//   std::regex reg(w);
//   return vs(std::sregex_token_iterator(all(s), reg, -1), std::sregex_token_iterator());
// }

void initialize() {
  std::cin.tie(nullptr) -> sync_with_stdio(false);
  std::cout << std::fixed << std::setprecision(10);
  fac[0] = fiv[0] = 1;
  for (int i = 1; i != maxn; ++i) {
    inv[i] = mul(inv[mod % i], mod - mod / i) + (i == 1);
    fac[i] = mul(fac[i - 1], i);
    fiv[i] = mul(fiv[i - 1], inv[i]);
  }
}

int cas;

void solution() {
  using namespace power_serial::ntt_space;
  int n, m, k;
  std::cin >> n >> m >> k;
  vi A(n + k + 1), B(n + k + 1);
  for (int i = 0; i != k; ++i) {
    A[i] = fiv[i];
  }
  for (int i = 0; i <= n; ++i) {
    B[i] = fiv[i];
  }
  vi Am = ntt_poly(ntt_log, A);
  for (int i = 1; i <= n + k; ++i) {
    Mul(Am[i], m);
  }
  Am = ntt_poly(ntt_exp, Am);
  vi Bm = ntt_poly(ntt_log, B);
  for (int i = 1; i <= n + k; ++i) {
    Mul(Bm[i], m);
  }
  Bm = ntt_poly(ntt_exp, Bm);
  vi f(n + 1), g(n + 1);
  for (int i = k; i <= n + k; ++i) {
    f[i - k] = sub(Bm[i], Am[i]);
    g[i - k] = sub(B[i], A[i]);
  }
  B.resize(n + 1);
  f = f * B;
  f.resize(n + 1);
  g = ntt_poly(ntt_inv, g);
  f = f * g;
  int ans = sub(f[n], mul(m, Am[n]));
  std::cout << pow(m, mod - 1 - n, mul(ans, fac[n])) << '\n';
}

int main() {
  initialize();
  int T = 1;
  // std::cin >> T;
  for (cas = 1; cas <= T; ++cas) solution();
  return 0;
}

C++(clang++ 11.0.1) 解法, 执行用时: 1536ms, 内存消耗: 47744K, 提交时间: 2022-12-23 23:22:41

#include<algorithm>
#include<cstring>
#include<cctype>
#include<cstdio>
#define rep(i,x,y) for(int i=x; i<=y; ++i)
#define repd(i,x,y) for(int i=x; i>=y; --i)

using namespace std;
typedef long long LL;
const int N=2000005,mod=998244353;
int n,m,k,len,bin[N];
LL I[N],flv[N];
LL a[N],b[N],c[N],G[N],A0[N],A[N],B[N],inv[N];

int getint()
{
	char ch;
	while(!isdigit(ch=getchar()));
	int x=ch-48;
	while(isdigit(ch=getchar())) x=x*10+ch-48;
	return x;
}

LL getmi(LL a,LL x)
{
	LL rt=1;
	while(x)
	{
		if(x&1) rt=rt*a%mod;
		a=a*a%mod,x>>=1;
	}
	return rt;
}

void FFT(LL a[],int len,int tp)
{
	rep(i,0,len-1) bin[i]=bin[i>>1]>>1|((i&1)*(len>>1));
	rep(i,0,len-1) if(i<bin[i]) swap(a[i],a[bin[i]]);
	for(int i=1; i<len; i<<=1)
	{
		LL wn=getmi(3,(mod-1)/(i<<1));
		if(tp==-1) wn=getmi(wn,mod-2);
		for(int j=0; j<len; j+=i<<1)
		{
			LL w=1,x,y;
			rep(k,0,i-1)
			{
				x=a[j+k],y=a[i+j+k]*w%mod,w=w*wn%mod;
				a[j+k]=(x+y)%mod,a[i+j+k]=(x-y+mod)%mod;
			}
		}
	}
	if(tp==-1)
	{
		LL x=getmi(len,mod-2);
		rep(i,0,len-1) a[i]=a[i]*x%mod;
	}
}

void get_inv(LL a[],LL b[],int n)
{
	if(n==1)
	{
		b[0]=getmi(a[0],mod-2);
		return;
	}
	get_inv(a,b,n>>1);
	rep(i,0,n-1) G[i]=a[i];
	rep(i,n,2*n-1) G[i]=0;
	FFT(G,n<<1,1),FFT(b,n<<1,1);
	rep(i,0,2*n-1) b[i]=b[i]*(2-G[i]*b[i]%mod+mod)%mod;
	FFT(b,n<<1,-1);
	rep(i,n,2*n-1) b[i]=0;
}

void get_ln(LL a[],LL b[],int n)
{
	rep(i,0,2*n-1) A[i]=B[i]=0;
	get_inv(a,A,n);
	rep(i,0,n-2) B[i]=a[i+1]*(i+1)%mod;
	FFT(A,n<<1,1),FFT(B,n<<1,1);
	rep(i,0,2*n-1) A[i]=A[i]*B[i]%mod;
	FFT(A,n<<1,-1),b[0]=0;
	rep(i,1,n-1) b[i]=A[i-1]*I[i]%mod;
}

void get_exp(LL a[],LL b[],int n)
{
	if(n==1)
	{
		b[0]=1;
		return;
	}
	get_exp(a,b,n>>1);
	rep(i,0,2*n-1) A0[i]=0;
	get_ln(b,A0,n);
	rep(i,0,n-1) A0[i]=(a[i]+mod-A0[i])%mod;
	++A0[0],FFT(A0,n<<1,1),FFT(b,n<<1,1);
	rep(i,0,2*n-1) b[i]=b[i]*A0[i]%mod;
	FFT(b,n<<1,-1);
	rep(i,n,2*n-1) b[i]=0;
}

LL f[N],g[N],em[N],inve[N],fem[N],invb[N];
LL fe[N],_fem[N],ans;

int main()
{
	scanf("%d%d%d",&n,&m,&k);
	flv[0]=1;
	rep(i,1,n) flv[i]=flv[i-1]*i%mod;
	inv[n]=getmi(flv[n],mod-2);
	repd(i,n,1) inv[i-1]=inv[i]*i%mod;
	rep(i,k,n) g[i]=inv[i];
	rep(i,0,k-1) f[i]=inv[i];
	LL pw=1;
	rep(i,0,n)
	{
		inve[i]=em[i]=inv[i];
		if(i&1) inve[i]*=-1;
		em[i]=em[i]*pw%mod;
		pw=pw*(m-1)%mod;
	}
	for(len=1; len<=n+n; len<<=1);
	rep(i,0,len-1) I[i]=getmi(i,mod-2);
	FFT(f,len,1),FFT(inve,len,1);
	FFT(g,len,1);
	rep(i,0,len-1) fe[i]=f[i]*inve[i]%mod;
	FFT(fe,len,-1);
	rep(i,n+1,len-1) fe[i]=0;
	rep(i,0,n) fem[i]=fe[i];
	get_ln(fem,_fem,len);
	rep(i,0,len-1) _fem[i]=_fem[i]*m%mod,fem[i]=0;
	get_exp(_fem,fem,len);
	rep(i,0,n) a[i]=fem[i]*m%mod;
	rep(i,0,n) b[i]=fe[i+k],c[i]=fem[i+k];
	get_inv(b,invb,len);
	rep(i,n+1,len-1) invb[i]=0;
	FFT(invb,len,1),FFT(c,len,1);
	rep(i,0,len-1) c[i]=invb[i]*c[i]%mod;
	FFT(c,len,-1);
	rep(i,n+1,len-1) c[i]=0;
	rep(i,0,n) a[i]=a[i]-c[i];
	rep(i,0,n) a[i]=a[i+k];
	FFT(a,len,1);
	rep(i,0,len-1) a[i]=a[i]*invb[i]%mod;
	FFT(a,len,-1);
	rep(i,n+1,len-1) a[i]=0;
	FFT(a,len,1);
	rep(i,0,len-1) a[i]=a[i]*g[i]%mod;
	FFT(a,len,-1);
	rep(i,n+1,len-1) a[i]=0;
	FFT(a,len,1),FFT(em,len,1);
	rep(i,0,len-1) a[i]=a[i]*em[i]%mod;
	FFT(a,len,-1);
	rep(i,n+1,len-1) a[i]=0;
	rep(i,0,n) a[i]=a[i]*flv[i]%mod;
	ans=a[n]*getmi(getmi(m,mod-2),n)%mod;
	printf("%lld\n",(ans+mod)%mod);
	return 0;
}

上一题