列表

详情


NC15206. 博弈论与概率统计

描述

Alice 和 Bob 在玩一个双人游戏。每一轮中,Alice 有 p 的概率胜利,1-p 的概率失败,不会出现平局。
双方初始时各有 0 分,当一个人胜利的时候,他会获得一分,失败则扣掉一分。遗憾的是,博弈论世界的人目前是无法理解负数的,因此,如果某个人输掉一轮比赛的时候他只有 0 分,那么他就不会被扣分(对方会照常加一分)。游戏一共要进行 N+M 轮,Alice 想请你帮她算算在游戏结束时她的得分的数学期望。
“这算啥,我小 L 分分钟搞定!”。比小 L 更熟练的你当然也是随手就算出来了,但就在你打算告诉 Alice 答案之前,博弈论世界之神——temporaryDO 出现了,他给大家带来了一个重要信息:这 N+M 轮游戏中, Alice 恰好赢了 N 轮!
熟知条件概率那套理论的你立刻注意到,你需要修改自己的计算方法来得到正确的答案了。
为了避免精度问题,请将结果对 109+7 取模。即,我们的数据保证答案是一个有理数p/q,且有 109+7|q,你只需要找到一个整数 x∈ [0, 109+7) 使得 qx≡p (mod 109+7) 即可。

输入描述

输入的第一行包含两个正整数 T, P',其中 T 表示数据组数,P'/1000表示 p ,即 Alice 在每轮游戏中的获胜概率。
接下来 T 行,每行两个非负整数 N,M,表示一组数据。

输出描述

输出 T 行,每行一个整数,表示对应数据的答案。

示例1

输入:

3 500
1 1
2 3
4 4

输出:

500000004
200000002
728571435

说明:

每一轮游戏 Alice 均有 1/2 的概率胜利。
- 对于第一组数据,Alice 的胜利可能在第一轮或第二轮,并且概率相等。若她在第一轮胜利,则最终得分为 0,否则她的得分为 1。故期望为1/2,验证发现 2 x 500000004≡1 (mod 109+7)。
- 对于第二组数据,所求期望为3/5。
- 对于第三组数据,所求期望为93/70。

原站题解

上次编辑到这里,代码来自缓存 点击恢复默认模板

C++14(g++5.4) 解法, 执行用时: 1602ms, 内存消耗: 22360K, 提交时间: 2020-02-09 19:08:40

#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#define iinf 0x3f3f3f3f
#define linf (1ll<<60)
#define eps 1e-8
#define maxn 300010
#define cl(x) memset(x,0,sizeof(x))
#define rep(i,a,b) for(i=a;i<=b;i++)
#define em(x) emplace(x)
#define emb(x) emplace_back(x)
#define emf(x) emplace_front(x)
#define fi first
#define se second
#define de(x) cerr<<#x<<" = "<<x<<endl
using namespace std;
using namespace __gnu_pbds;
typedef long long ll;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
ll read(ll x=0)
{
    ll c, f(1);
    for(c=getchar();!isdigit(c);c=getchar())if(c=='-')f=-f;
    for(;isdigit(c);c=getchar())x=x*10+c-0x30;
    return f*x;
}
#define mod 1000000007ll
ll fact[maxn], _fact[maxn], id[maxn], l[maxn], r[maxn], n[maxn], m[maxn], ans[maxn];
ll fastpow(ll a, ll b)
{
    ll t(a%mod), ans(1ll);
    for(;b;b>>=1,t=t*t%mod)if(b&1)ans=ans*t%mod;
    return ans;
}
ll C(ll n, ll m)
{
    if(n<0 or m<0 or m>n)return 0;
    return fact[n]*_fact[m]%mod*_fact[n-m]%mod;
}
int main()
{
    ll T=read(), p=read(), i, L, R, now, _2=500000004;
    rep(i,1,T)n[i]=read(), m[i]=read(), id[i]=i;
    rep(i,1,T)l[i]=min(m[i],n[i])-1, r[i]=n[i]+m[i];
    fact[0]=_fact[0]=1;
    rep(i,1,maxn-1)fact[i]=fact[i-1]*i%mod, _fact[i]=fastpow(fact[i],mod-2);
    ll S=sqrt(2e5);
    sort(id+1,id+T+1,[S](ll a, ll b){return l[a]/S == l[b]/S ? r[a]<r[b] : l[a]/S < l[b]/S; });
    L=R=0, now=1;
    // rep(i,1,T)
    // {
    //     ll ans=0, j;
    //     // rep(j,0,m[i]-1)(ans+=C(n[i]+m[i],j))%=mod;
    //     // rep(j,max(0ll,m[i]-n[i]),m[i]-1)(ans+=C(n[i]+m[i],m[i]-j-1))%=mod;
    //     rep(j,0,min(m[i],n[i])-1)(ans+=C(n[i]+m[i],j))%=mod;
    //     // printf("ans=%lld\n",ans);
    //     // rep(j,0,m[i]-1)(printf("C=%lld\n",C(n[i]+m[i],m[i]+j+1)));
    //     // printf("ans=%lld C=%lld\n",ans,C(n[i]+m[i],n[i]));
    //     ans=ans*fastpow( C(n[i]+m[i],n[i]), mod-2 )%mod;
    //     (ans+=max(0ll,m[i]-n[i]))%=mod;
    //     ans=(ans+n[i]-m[i]+mod)%mod;
    //     printf("%lld\n",ans);
    // }
    // return 0;
    rep(i,1,T)
    {
        for(;L>l[id[i]];L--)
        {
            now = (now-C(R,L))%mod;
        }
        for(;R<r[id[i]];R++)
        {
            now = (now*2-C(R,L))%mod;
        }
        for(;L<l[id[i]];L++)
        {
            now = (now+C(R,L+1))%mod;
        }
        for(;R>r[id[i]];R--)
        {
            now = (now+C(R-1,L))*_2%mod;
        }
        ans[id[i]] = ( n[id[i]]-m[id[i]]+max(0ll,m[id[i]]-n[id[i]]) + now*fastpow(C(n[id[i]]+m[id[i]],n[id[i]]),mod-2) )%mod;
    }
    rep(i,1,T)printf("%lld\n",(ans[i]+mod)%mod);
    return 0;
}

C++ 解法, 执行用时: 830ms, 内存消耗: 15480K, 提交时间: 2021-08-09 18:42:14

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<vector>
#define SF scanf
#define PF printf
#define MAXN 250010
#define MOD 1000000007
using namespace std;
typedef long long ll;
ll fac[MAXN],inv[MAXN],blo[MAXN];
int n,m,t;
ll ans[MAXN],ans2[MAXN];
struct node{
	int x,y;
	int	id;
	bool operator <(const node &a) const {
		if(blo[x]!=blo[a.x])
			return blo[x]<blo[a.x];
		return y<a.y;
	}
}que[MAXN];
ll ans1;
ll C(int x,int y){
	return fac[x]*inv[y]%MOD*inv[x-y]%MOD;
}
const ll inv2=(MOD+1ll)>>1ll;
void change(int &x,int &y,int adx,int ady){
	if(adx==1){
		ans1=(ans1+C(y,x+1))%MOD;
		x++;
	}
	if(ady==1){
		ans1=((ans1*2ll-C(y,x))%MOD+MOD)%MOD;	
		y++;
	}
	if(adx==-1){
		ans1=(ans1-C(y,x)+MOD)%MOD;
		x--;
	}
	if(ady==-1){
		ans1=(ans1+C(y-1,x))%MOD*inv2%MOD;
		y--;
	}
}
void prepare(){
	fac[0]=1;
	for(int i=1;i<=250000;i++) fac[i]=fac[i-1]*i%MOD;	
	inv[0]=inv[1]=1;
	for(int i=2;i<=250000;i++) inv[i]=inv[MOD%i]*(MOD-MOD/i)%MOD;
	for(int i=1;i<=250000;i++) inv[i]=inv[i-1]*inv[i]%MOD;
	
	int siz=700;
	for(int i=1;i<=250000;i++)
		blo[i]=i/siz+1;
}
ll fsp(ll x,int y){
	ll res=1;
	while(y){
		if(y&1)
			res=res*x%MOD;
		x=x*x%MOD;
		y>>=1;
	}
	return res;
}
int main(){
	prepare();
	int p;
	SF("%d%d",&t,&p);
	for(int i=1;i<=t;i++){
		SF("%d%d",&n,&m);
		que[i].x=min(n-1,m-1);
		ans2[i]=fsp(C(n+m,n),MOD-2);
		que[i].y=n+m;
		if(n>m)
			ans[i]+=(n-m)*C(n+m,n)%MOD;
		que[i].id=i;
	}
	sort(que+1,que+1+t);
	int l=0,r=0,now=0;
	ans1=1;
	while(++now<=t){
		while(r<que[now].y)
			change(l,r,0,1);
		while(l<que[now].x)
			change(l,r,1,0);
		while(l>que[now].x)
			change(l,r,-1,0);
		while(r>que[now].y)
			change(l,r,0,-1);
		(ans[que[now].id]+=ans1)%=MOD;
	}
	for(int i=1;i<=t;i++)
		PF("%lld\n",ans[i]*ans2[i]%MOD);
}

C++(g++ 7.5.0) 解法, 执行用时: 1083ms, 内存消耗: 22372K, 提交时间: 2022-09-02 19:46:54

#include<bits/stdc++.h>
#define int long long
using namespace std;
const int MOD=1e9+7;
const int N=3e5;
const int M=5e5+10;
int K;//块的大小 
int inv[M],res[M];
int inv2;
struct Node{
	int n,m;
	int l,r;
	int ans,pos;
}Q[N];
bool cmp(Node x,Node y){
	if(x.l/K==y.l/K)return x.r<y.r;
	else return x.l<y.l;
}
bool cmp1(Node x,Node y){
	return x.pos<y.pos;
}
int fast(int a,int b){
	int ans=1;
	while(b){
		if(b&1)ans=ans*a%MOD;
		b>>=1;
		a=a*a%MOD;
	}
	return ans;
}
int C(int x,int y){
	if(y>x)return 0;
	if(x==y||y==0)return 1;
	if(y==1)return x;
	return res[x]*inv[y]%MOD*inv[x-y]%MOD;
}
void solve(){
	int n,m,t,p;
	scanf("%lld%lld",&t,&p); 
	for(int i=1;i<=t;i++){
		scanf("%lld%lld",&Q[i].n,&Q[i].m);
		Q[i].pos=i;
		Q[i].l=Q[i].n+Q[i].m;
		Q[i].r=min(Q[i].n,Q[i].m)-1;
	}
	K=500;
	sort(Q+1,Q+t+1,cmp);
//	int ql=Q[1].l,qr=Q[1].r,sum=0;
//	for(int i=0;i<=Q[1].r;i++){
//		sum=(sum+C(Q[1].l,i))%MOD;
//	}
//    Q[1].ans=sum;
	int ql=1,qr=0,sum=1;
	for(int i=1;i<=t;i++){
		int l=Q[i].l,r=Q[i].r;
		while(qr<r){
			qr++;
			sum=(sum+C(ql,qr))%MOD;
		}
		while(qr>r){
			sum=(sum-C(ql,qr)+MOD)%MOD;
			qr--;
		}
		while(ql<l){
			sum=(2*sum%MOD-C(ql,qr)+MOD)%MOD;
			ql++;
		}
		while(ql>l){
			ql--;
			sum=(sum+C(ql,qr))%MOD*inv2%MOD;
		}
		Q[i].ans=sum;
	}
	sort(Q+1,Q+1+t,cmp1);
	for(int i=1;i<=t;i++){
		if(Q[i].n>Q[i].m){
			Q[i].ans+=(Q[i].n-Q[i].m)*C(Q[i].l,Q[i].n)%MOD;
			Q[i].ans%=MOD;
		}
		Q[i].ans=Q[i].ans*fast(C(Q[i].l,Q[i].n),MOD-2)%MOD;
		printf("%lld\n",Q[i].ans);
	}
}
signed main(){
	res[0]=1;
	for(int i=1;i<M;i++){
		res[i]=res[i-1]*i%MOD;
		inv[i]=fast(res[i],MOD-2);
	}
	inv2=fast(2,MOD-2);
	solve();
	return 0;
}

上一题