列表

详情


NC20373. [SDOI2015]序列统计

描述

小C有一个集合S,里面的元素都是小于M的非负整数。他用程序编写了一个数列生成器,可以生成一个长度为N的数列,数列中的每个数都属于集合S。 
小C用这个生成器生成了许多这样的数列。但是小C有一个问题需要你的帮助:
给定整数x,求所有可以生成出的,且满足数列中所有数的乘积mod M的值等于x的不同的数列的有多少个。
小C认为,两个数列{Ai}和{Bi}不同,当且仅当至少存在一个整数i,满足Ai≠Bi
另外,小C认为这个问题的答案可能很大,因此他只需要你帮助他求出答案mod 1004535809的值就可以了。

输入描述

一行,四个整数,N、M、x、|S|,其中|S|为集合S中元素个数。
第二行,|S|个整数,表示集合S中的所有元素。

输出描述

一行,一个整数,表示你求出的种类数mod 1004535809的值。

示例1

输入:

4 3 1 2
1 2

输出:

8

原站题解

上次编辑到这里,代码来自缓存 点击恢复默认模板

C++11(clang++ 3.9) 解法, 执行用时: 85ms, 内存消耗: 1408K, 提交时间: 2020-10-09 23:35:44

#include<cstdio>
#include<cstring>
#include<iostream>
#define N 100005
#define mod 1004535809
using namespace std;
int a[N],p[N],b[N],f[N],rev[N],w[N],winv[N],g[N],tmp[N];
int qp(int a,int x,int p){
	int i=1;
	for(;x;x/=2,a=1ll*a*a%p)if(x&1)i=1ll*i*a%p;
	return i;
}
void NTT(int n,int f[],int sgn){
	int i,j,k,tmp;
	for(i=0;i<n;i++)if(i<rev[i])swap(f[i],f[rev[i]]);
	if(sgn<0)swap(w,winv);
	for(k=2;k<=n;k<<=1)
		for(i=0;i<n;i+=k)
			for(j=0;j<(k>>1);j++){
				tmp=1ll*w[n/k*j]*f[i+j+(k>>1)]%mod;
				f[i+j+(k>>1)]=(f[i+j]-tmp+mod)%mod;
				f[i+j]=(f[i+j]+tmp)%mod;
			}
	if(sgn<0)swap(w,winv);
}
void solve(int n,int m){
	int M=m,i,j,minv;
	for(j=0;(1<<j)<2*m;j++);
	for(m=1<<j,i=w[0]=winv[0]=1;i<m;i++){
		rev[i]=(rev[i>>1]>>1)|((i&1)<<j-1);
		winv[i]=qp(3,mod-1-1ll*(mod-1)/m*i%(mod-1),mod);
		w[i]=qp(3,1ll*(mod-1)/m*i%(mod-1),mod);
	}minv=qp(m,mod-2,mod);
	for(g[0]=1;n;n/=2){
		if(n&1){
			NTT(m,g,1);NTT(m,f,1);
			for(i=0;i<m;i++)g[i]=1ll*g[i]*f[i]%mod;
			NTT(m,f,-1);NTT(m,g,-1);
			for(i=0;i<m;i++){
				g[i]=1ll*g[i]*minv%mod;
				f[i]=1ll*f[i]*minv%mod;
			}
			for(i=0;i<M;i++)g[i]=(g[i]+g[i+M])%mod;
			for(i=M;i<m;i++)g[i]=0;
		}
		NTT(m,f,1);
		for(i=0;i<m;i++)f[i]=1ll*f[i]*f[i]%mod;
		NTT(m,f,-1);
		for(i=0;i<m;i++)f[i]=1ll*f[i]*minv%mod;
		for(i=0;i<M;i++)f[i]=(f[i]+f[i+M])%mod;
		for(i=M;i<m;i++)f[i]=0;
	}
}
int main(){
	int n,m,x,s,mg,i,j;
	//freopen("input.in","r",stdin);
	scanf("%d%d%d%d",&n,&m,&x,&s);
	for(i=2,j=m-1;i<=j/i;i++)if(!(j%i)){
		p[++p[0]]=i;
		while(!(j%i))j/=i;
	}
	if(j>1)p[++p[0]]=j;
	for(i=2;i<m;i++){
		for(j=1;j<=p[0]&&qp(i,(m-1)/p[j],m)!=1;j++);
		if(j>p[0])break;
	}mg=i;
	for(i=1,j=mg;i<m-1;i++,j=1ll*mg*j%m)a[j]=i;
	for(i=1,x=a[x];i<=s;i++){
		scanf("%d",&j);
		if(j)f[a[j]]=1;
	}
	solve(n,--m);
	printf("%d",g[x]);
	return 0;
}

上一题