列表

详情


NC232289. Bandit Blues

描述

给你三个整数 nab,定义 A 为一个排列中是前缀最大值的数的个数,定义 B 为一个排列中是后缀最大值的数的个数,求长度为 n 的排列中满足 的排列个数。
答案对 998244353 取模。

输入描述

第一行包含三个整数

输出描述

输出一个整数表示答案。

示例1

输入:

1 1 1

输出:

1

说明:

唯一可能的排列是 [1]

示例2

输入:

2 1 1

输出:

0

示例3

输入:

2 2 1

输出:

1

说明:

只有两个大小为 2 的排列是可能的:{[1, 2], [2, 1]}。第一次排列的 ab 的值是 21,第二次排列的这些值是 12

示例4

输入:

5 2 2

输出:

22

原站题解

上次编辑到这里,代码来自缓存 点击恢复默认模板

C++(g++ 7.5.0) 解法, 执行用时: 1021ms, 内存消耗: 12672K, 提交时间: 2022-10-13 14:09:13

#include<bits/stdc++.h>
using namespace std;

using ll = long long;

constexpr ll P = 998244353;
constexpr int g = 3;
const int maxn = 1<<20;
	
int qpow(int a, int k, int p = P)
{
	int res = 1;
	while(k)
	{
		if(k&1) res = 1ll * res * a % p;
		a = 1ll * a * a % p, k >>= 1;
	}
	return res;
}

namespace Polynomial{
	using poly = vector<int>;
	
	int rev[maxn], inv[maxn];
	
	int norm(int x){
		return 1<<(32 - __builtin_clz(x-1));
	}
	
	poly& dot(poly &A, poly &B){
		for(int i = 0 ; i < A.size() ; i ++) A[i] = 1ll * A[i] * B[i] % P;
		return A;
	}
	poly& operator += (poly& A, poly B){
		A.resize(max(A.size(),B.size()));
		for(int i = 0 ; i < A.size() ; i ++)
			if(i<B.size()) A[i] = (A[i]+B[i])%P;
		return A;
	}
	poly operator + (poly A, poly B){
		return A += B;
	}
	poly& operator *= (poly &A, int b){
		for(auto &v: A) v = 1ll * v * b % P;
		return A;
	}
	poly operator * (poly A, int b){
		return A *= b;
	}
	poly operator * (int b, poly A){
		return A *= b;
	}
	
	void change(poly &A, int len){
		for(int i = 0 ; i < len ; i ++)
		{
			rev[i] = rev[i>>1]>>1;
			if(i&1) rev[i] |= len>>1;
		}
		for(int i = 0 ; i < len ; i ++) if(i < rev[i]) swap(A[i], A[rev[i]]);
	}
	
	void ntt(poly &A, int len, int inv){   
		change(A,len);
		
		for(int h = 2 ; h <= len ; h <<= 1)
		{
			int wn = qpow(g, (P-1)/h);
			for(int j = 0 ; j < len ; j += h)
			{
				int wi = 1;
				for(int k = j ; k < j + h/2 ; k ++)
				{
					int u = A[k], v = 1ll * wi * A[k+h/2] % P;
					A[k] = (u+v>=P?u+v-P:u+v), A[k+h/2] = (u-v<0?u-v+P:u-v);
					
					wi = 1ll * wi * wn % P;
				}
			}
		}
		
		if(inv==-1)
		{
			reverse(A.begin()+1,A.end());
			
			int inv_len = qpow(len,P-2);
			for(auto &v: A) v = 1ll * v * inv_len % P;
		}
	}
	
	poly operator * (poly A, poly B){
		int n = A.size() + B.size() - 1, len = norm(n);
		if(A.size()<=8 || B.size()<=8)
		{
			poly C(n,0);
			for(int i = 0 ; i < A.size() ; i ++)
				for(int j = 0 ; j < B.size() ; j ++)
					C[i+j] = (C[i+j] + 1ll * A[i] * B[j]) % P;
			return C;
		}
		
		A.resize(len), B.resize(len);
		ntt(A,len,1), ntt(B,len,1), dot(A,B), ntt(A,len,-1);
		
		return A.resize(n),A;
	}
	// 求逆,长度为2^k
	poly inv2k(poly A)
	{
		int n = A.size(), m = n/2;
		if(n==1) return {qpow(A[0],P-2)};
		
		poly B = inv2k(poly(A.begin(),A.begin()+m)), C = B;
		B.resize(n), ntt(A,n,1), ntt(B,n,1), dot(A,B), ntt(A,n,-1);
		for(int i = 0 ; i < n ; i ++) A[i] = (i<m?0:P-A[i]);
		ntt(A,n,1), dot(A,B), ntt(A,n,-1);
		return move(C.begin(),C.end(),A.begin()),A;
	}
	// 求逆
	poly Inv(poly A){
		int n = A.size();
		A.resize(norm(n),0);
		A = inv2k(A);
		
		return A.resize(n),A;
	}

	// 求导
	poly& derivative(poly &A){
		for(int i = 1 ; i < A.size() ; i ++) A[i-1] =  1ll * A[i] * i % P;
		return A.pop_back(), A;
	}
	
	// 积分
	poly& integral(poly &A){
		A.push_back(0);
		for(int i = A.size()-1 ; i > 0 ; i --) A[i] = 1ll * A[i-1] * inv[i] % P;
		return A[0] = 0, A;
	}
	
    // ln(f)
	poly Ln(poly A){
		int n = A.size();
	    A = derivative(A) * Inv(A);
		return A.resize(n-1), integral(A);
	}
	
	// exp(f)
	poly Exp(poly A){
		int n = A.size(), len = norm(n);
		poly B = {1}, C;
		A.resize(len);
		for(int h = 2 ; h <= len ; h <<= 1)
		{
			B.resize(h), C = Ln(B), C.resize(h);
			for(int i = 0 ; i < h ; i ++) C[i] = A[i] - C[i] + (A[i]<C[i]?P:0);
			C[0]++, B = B * C;
		}
		return B.resize(n),B;
	}
	
	poly Pow(poly &A, int k){
	    return Exp(Ln(A)*k);
	}
}
using namespace Polynomial;

ll fac[maxn], ifac[maxn];

ll C(int n, int m){
	return fac[n] * ifac[m] % P * ifac[n-m] % P;
}

poly s1(int n, int k){
	poly A(n);
    for(int i = 0 ; i < n ; i ++) A[i] = inv[i+1];
    poly f = Pow(A,k);

    poly ans(n+1,0);
    for(int i = k ; i <= n ; i ++){
    	ans[i] = 1ll * f[i-k] * ifac[k] % P * fac[i] % P;
    }
    return ans;
}

int main() {
	ios::sync_with_stdio(false);
	cin.tie(0);

    int n,a,b; cin >> n >> a >> b;
    if(a==0 || b==0) return cout << "0\n",0;

    fac[0] = 1;
    for(int i = 1 ; i <= n ; i ++) fac[i] = fac[i-1] * i % P;
    ifac[n] = qpow(fac[n], P-2);
    for(int i = n ; i >= 1 ; i --) ifac[i-1] = ifac[i] * i % P;
    for(int i = 1 ; i <= n ; i ++) inv[i] = 1ll * fac[i-1] * ifac[i] % P;

    auto f1 = s1(n,a-1), f2 = s1(n,b-1);
    
    ll ans = 0;
    for(int l = a-1 ; l <= n-b ; l ++){
    	ans = (ans + C(n-1, l) * f1[l] % P * f2[n-l-1]) % P;
    }
    cout << ans << endl;

	return 0;
}

C++ 解法, 执行用时: 500ms, 内存消耗: 26356K, 提交时间: 2022-01-15 11:45:05

#include<bits/stdc++.h>
#define mod 998244353
using namespace std;
int n,a,b,ans;
int len,L,f[400005],g[400005],R[400005],lim;
vector<int> F[400005];
int fac[200005],ifac[200005];

int read(){
	int x=0,w=0;char ch=getchar();
	while(!isdigit(ch)) w|=(ch=='-'),ch=getchar();
	while(isdigit(ch)) x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
	return w?-x:x;
}

void print(int x){
	if(x>=10) print(x/10);
	putchar(x%10+'0'); 
}

int ksm(int x,int y){
	int res=1;
	while(y){
		if(y&1) res=1ll*res*x%mod;
		x=1ll*x*x%mod,y/=2; 
	}
	return res;
}

void NTT(int *x,int on){
	for(int i=0;i<len;i++) if(i<R[i]) swap(x[i],x[R[i]]);
	for(int i=2,wn;i<=len;i*=2){
		wn=ksm(3,(mod-1)/i);
		if(on==-1) wn=ksm(wn,mod-2);
		for(int j=0,w;j<len;j+=i){
			w=1;
			for(int k=0;k<i/2;k++){
				int A=x[j+k],B=1ll*w*x[j+k+i/2]%mod;
				x[j+k]=(A+B)%mod;
				x[j+k+i/2]=(A+mod-B)%mod;
				w=1ll*w*wn%mod;
			}
		}
	}
}

void solve(int x,int l,int r){
	if(l==r){
		F[x].push_back(1),F[x].push_back(l-2);
		return;
	}
	int mid=(l+r)/2;
	solve(2*x,l,mid);
	solve(2*x+1,mid+1,r);
	int p=F[2*x].size()+F[2*x+1].size()-2;
	len=1,L=0;
	while(p>=len) len*=2,L++;
	for(int i=0;i<len;i++) R[i]=(R[i>>1]>>1)|((i&1)<<(L-1));
	for(int i=0;i<F[2*x].size();i++) f[i]=F[2*x][i];
	F[2*x].clear();
	for(int i=0;i<F[2*x+1].size();i++) g[i]=F[2*x+1][i];
	F[2*x+1].clear();
	NTT(f,1),NTT(g,1);
	for(int i=0;i<len;i++) f[i]=1ll*f[i]*g[i]%mod;
	NTT(f,-1);
	lim=ksm(len,mod-2);
	for(int i=0;i<len;i++) f[i]=1ll*f[i]*lim%mod;
	for(int i=0;i<=p;i++) F[x].push_back(f[i]);
	for(int i=0;i<len;i++) f[i]=g[i]=0;
}

int C(int x,int y){
	int res=1ll*fac[x]*ifac[y]%mod*ifac[x-y]%mod;
	return res;
}

int main(){
	fac[0]=1;
	for(int i=1;i<=200000;i++) fac[i]=1ll*i*fac[i-1]%mod;
	ifac[200000]=ksm(fac[200000],mod-2);
	for(int i=199999;i>=0;i--) ifac[i]=1ll*(i+1)*ifac[i+1]%mod;
	n=read(),a=read(),b=read();
	if(a<1||b<1) puts("0");
	else{
		if(n==1){
			if(a==1&&b==1) puts("1");
			else puts("0");
		}
		else{
			solve(1,2,n);
			if(a-1+b-1<=n-1) ans=1ll*C(a+b-2,a-1)*F[1][n-a-b+1]%mod;
			else ans=0;
			print(ans),puts("");
		}
	}
	
	return 0;
}

上一题