NC20358. [SDOI2013]方程
描述
输入描述
输入含有多组数据,第一行两个正整数T,p。T表示这个测试点内的数据组数,p的含义见题目描述。对于每组数据,第一行四个非负整数n,n1,n2,m。第二行n1+n2个正整数,表示A1..n1+n2。请注意,如果n1+n2等于0,那么这一行会成为一个空行。
输出描述
共T行,每行一个正整数表示取模后的答案。
示例1
输入:
3 10007 3 1 1 6 3 3 3 0 0 5 3 1 1 3 3 3
输出:
3 6 0
说明:
【样例说明】C++14(g++5.4) 解法, 执行用时: 92ms, 内存消耗: 4456K, 提交时间: 2020-02-19 10:39:28
#include <bits/stdc++.h> #include <ext/pb_ds/assoc_container.hpp> #include <ext/pb_ds/tree_policy.hpp> #define iinf 0x3f3f3f3f #define linf (1ll<<60) #define eps 1e-8 #define maxn 100010 #define cl(x) memset(x,0,sizeof(x)) #define rep(i,a,b) for(i=a;i<=b;i++) #define em(x) emplace(x) #define emb(x) emplace_back(x) #define emf(x) emplace_front(x) #define fi first #define se second #define de(x) cerr<<#x<<" = "<<x<<endl using namespace std; using namespace __gnu_pbds; typedef long long ll; typedef pair<int,int> pii; typedef pair<ll,ll> pll; ll read(ll x=0) { ll c, f(1); for(c=getchar();!isdigit(c);c=getchar())if(c=='-')f=-f; for(;isdigit(c);c=getchar())x=x*10+c-0x30; return f*x; } ll mod; struct EasyMath { ll prime[maxn], phi[maxn], mu[maxn]; bool mark[maxn]; ll fastpow(ll a, ll b, ll c) { ll t(a%c), ans(1ll); for(;b;b>>=1,t=t*t%c)if(b&1)ans=ans*t%c; return ans; } void exgcd(ll a, ll b, ll &x, ll &y) { if(!b){x=1,y=0;return;} ll xx, yy; exgcd(b,a%b,xx,yy); x=yy, y=xx-a/b*yy; } ll inv(ll x, ll p) //p是素数 {return fastpow(x%p,p-2,p);} ll inv2(ll a, ll p) { ll x, y; exgcd(a,p,x,y); return (x+p)%p; } void shai(ll N) { ll i, j; for(i=2;i<=N;i++)mark[i]=false; *prime=0; phi[1]=mu[1]=1; for(i=2;i<=N;i++) { if(!mark[i])prime[++*prime]=i, mu[i]=-1, phi[i]=i-1; for(j=1;j<=*prime and i*prime[j]<=N;j++) { mark[i*prime[j]]=true; if(i%prime[j]==0) { phi[i*prime[j]]=phi[i]*prime[j]; break; } mu[i*prime[j]]=-mu[i]; phi[i*prime[j]]=phi[i]*(prime[j]-1); } } } ll CRT(vector<ll> a, vector<ll> m) //要求模数两两互质 { ll M=1, ans=0, n=a.size(), i; for(i=0;i<n;i++)M*=m[i]; for(i=0;i<n;i++)(ans+=a[i]*(M/m[i])%M*inv2(M/m[i],m[i]))%=M; return ans; } }em; struct CombinatorialNumber_mod { ll p[20], q[20], fact[20][maxn], tot, t[20]; void init(ll P) { tot=0; ll i, j; for(i=2;i*i<=P;i++) if(P%i==0) { p[++tot]=i; q[tot]=0; while(P%i==0)q[tot]++, P/=i; } if(P>1)p[++tot]=P, q[tot]=1; rep(i,1,tot) { fact[i][0]=1; t[i]=1; rep(j,1,q[i])t[i]*=p[i]; rep(j,1,maxn-1) { if(j%p[i]==0)fact[i][j]=fact[i][j-1]; else fact[i][j]=fact[i][j-1]*j%t[i]; } } } pll fact_mod(ll n, ll id) { if(n<p[id])return pll(0,fact[id][n]); pll ans = pll( n/p[id], em.fastpow(fact[id][t[id]],n/t[id],t[id]) ); (ans.second*=fact[id][n%t[id]])%t[id]; auto nex = fact_mod(n/p[id],id); ans.first+=nex.first; (ans.second*=nex.second)%=t[id]; return ans; } ll exlucas(ll n, ll m, ll id) { ll cnt; auto a=fact_mod(n,id), b=fact_mod(m,id), c=fact_mod(n-m,id); a.first-=b.first+c.first; (a.second*=em.inv2(b.second*c.second%t[id],t[id]))%t[id]; return em.fastpow(p[id],a.first,t[id])*a.second%t[id]; } ll calc(ll n, ll m) { if(m>n or m<0 or n<0)return 0; vector<ll> a, v(t+1,t+tot+1); ll i; rep(i,1,tot)a.emb(exlucas(n,m,i)); return em.CRT(a,v); } }Cmod; ll ans, n1, n2, n, M, a[20]; void dfs(ll pos, ll f, ll sm) { if(pos>n1) { ( ans += f*Cmod.calc(M-sm-1,n-1) )%=mod; return; } dfs(pos+1,f,sm); dfs(pos+1,-f,sm+a[pos]); } int main() { ll T, i, s; T=read(), mod=read(); Cmod.init(mod); while(T--) { n=read(), n1=read(), n2=read(), M=read(); rep(i,1,n1)a[i]=read(); s=0; rep(i,n1+1,n1+n2)a[i]=read(), s+=a[i]-1; ans=0; dfs(1,1,s); printf("%lld\n",(ans+mod)%mod); } return 0; }
C++11(clang++ 3.9) 解法, 执行用时: 289ms, 内存消耗: 584K, 提交时间: 2020-09-19 12:21:24
#include<bits/stdc++.h> #define Min(x,y) ((x)<(y)?(x):(y)) #define Max(x,y) ((x)>(y)?(x):(y)) using namespace std; typedef long long ll; int T,n1,n2; ll n,m,p,a[20],f[1000005]; template <class T> void read(T &x) { char c; int sign=1; while((c=getchar())>'9'||c<'0') if(c=='-') sign=-1; x=c-48; while((c=getchar())>='0'&&c<='9') x=(x<<1)+(x<<3)+c-48; x*=sign; } ll quickpow(ll a,ll b,ll mod) { ll ret=1; while(b) { if(b&1) ret=ret*a%mod; a=a*a%mod; b>>=1; } return ret; } ll exgcd(ll a,ll b,ll &x,ll &y) { if(!b) { x=1; y=0; return a; } ll d=exgcd(b,a%b,x,y); ll t=x; x=y; y=t-a/b*y; return d; } ll inv(ll a,ll p) { ll x,y; exgcd(a,p,x,y); return (x%p+p)%p; } ll fac(ll n,ll pi,ll pk) { if(!n) return 1LL; if(n<pi) return f[n]; return quickpow(f[pk-1],n/pk,pk)*f[n%pk]%pk*fac(n/pi,pi,pk)%pk; } ll C(ll n,ll m,ll pi,ll pk) { if(n<m) return 0; f[0]=1; for(int i=1;i<=pk;i++) if(i%pi!=0) f[i]=f[i-1]*i%pk; else f[i]=f[i-1]; ll jn=fac(n,pi,pk),jm=fac(m,pi,pk),jnm=fac(n-m,pi,pk); int k=0; for(ll i=n;i;i/=pi) k+=i/pi; for(ll i=m;i;i/=pi) k-=i/pi; for(ll i=n-m;i;i/=pi) k-=i/pi; return jn * inv(jm,pk)%pk * inv(jnm,pk)%pk * quickpow(pi,k,pk)%pk; } ll crt(ll a,ll pk) { ll x=p/pk; return a*x%p*inv(x,pk)%p;//关于pk的逆元 } ll solve(ll n,ll m,ll pi,ll pk) { ll ret=0; for(int i=0,t=(1<<n1);i<t;++i) { int opt=1; ll nown=n; for(int j=0;j<n1;++j) if(i>>j&1) opt=-opt,nown-=a[j+1]; ret=(ret+opt*C(nown,m,pi,pk))%pk; } return ret; } ll exlucas(ll n,ll m,ll P) { if(n<m) return 0; ll ret=0; for(ll i=2;i*i<=P;++i) { if(P%i==0) { ll pk=1; while(P%i==0) { pk*=i; P/=i; } ret=(ret+crt(solve(n,m,i,pk),pk))%p; } } if(P!=1) ret=(ret+crt(solve(n,m,P,P),P))%p; return (ret%p+p)%p; } int main() { read(T);read(p); while(T--) { read(n);read(n1);read(n2);read(m); for(int i=1;i<=n1+n2;++i) read(a[i]); for(int i=n1+1;i<=n1+n2;++i) m-=(a[i]-1); printf("%lld\n",exlucas(m-1,n-1,p)); } return 0; }