列表

详情


NC253640. 小红的好子序列(hard)

描述

请注意,本题和easy版本唯一的区别是数据范围不同!

小红定义一个数组是“好数组”,当且仅当存在某元素的出现次数不小于数组大小的一半。例如,[1,2,1,3,1,1]、[2,2,3,3]是好数组,但[1,2,1,5,6]则不是好数组。

现在小红拿到了一个数组,她想知道,这个数组有多少个非空子序列是好数组?答案对10^9+7取模。

子序列的定义:数组中不放回的取出若干个元素组成的新数组。

输入描述

第一行输入一个正整数n,代表数组的大小。
第二行输入n个正整数a_i,代表小红拿到的数组。
1\leq n \leq 10^5
1\leq a_i \leq 10^9

输出描述

合法的非空子序列数量。答案对10^9+7取模。

示例1

输入:

3
1 2 3

输出:

6

说明:

除了[1,2,3]这个子序列不合法以外,其他都是合法的。

原站题解

上次编辑到这里,代码来自缓存 点击恢复默认模板

C++(g++ 7.5.0) 解法, 执行用时: 21ms, 内存消耗: 2732K, 提交时间: 2023-07-07 19:44:24

#include <cstdio>
#include <map>
#define int long long
using namespace std;
typedef long long ll;
int read(){
	char c=getchar();int x=0;
	while(c<48||c>57) c=getchar();
	do x=(x<<1)+(x<<3)+(c^48),c=getchar();
	while(c>=48&&c<=57);
	return x;
}
const int N=1000003,P=1000000007;
int n,a[N];
map<int,int> mp;
int qp(int a,int b=P-2){
	int res=1;
	while(b){
		if(b&1) res=(ll)res*a%P;
		a=(ll)a*a%P;b>>=1;
	}
	return res;
}
int fac[N],fiv[N];
int C(int a,int b){
	return (ll)fiv[b]*fiv[a-b]%P*fac[a]%P;
}
void inc(int &x,int v){
	if((x+=v)>=P) x-=P;
}
int s[N];
signed main(){
	n=read();
	for(int i=1;i<=n;++i) ++mp[read()];
	fac[0]=1;
	for(int i=1;i<=n;++i) fac[i]=(ll)fac[i-1]*i%P;
	fiv[n]=qp(fac[n]);
	for(int i=n;i;--i) fiv[i-1]=(ll)fiv[i]*i%P;
	int res=0;
	for(auto [x,cnt]:mp){
		for(int i=1;i<=cnt;++i) inc(s[i],C(cnt,i));
	}
	int tmp=0;
	for(auto [x,cnt]:mp){
		int cur=1;
		for(int i=1;i<=cnt;++i){
			inc(cur,C(n-cnt,i));
			inc(res,(ll)C(cnt,i)*cur%P);
			inc(tmp,(ll)(s[i]-C(cnt,i)+P)*C(cnt,i)%P);
		}
	}
	tmp=(ll)tmp*((P+1)>>1)%P;
	res-=tmp;
	if(res<0) res+=P;
	printf("%d\n",res);
	return 0;
}

C++(clang++ 11.0.1) 解法, 执行用时: 44ms, 内存消耗: 8176K, 提交时间: 2023-07-22 10:32:54

#include<bits/stdc++.h>
#define int long long
#define ffor(i,a,b) for(int i=(a);i<=(b);i++)
#define roff(i,a,b) for(int i=(a);i>=(b);i--)
using namespace std;
const int MAXN=1e5+10,MOD=1e9+7;
int n,ans,a[MAXN],frac[MAXN],inv[MAXN];
map<int,int> mp; vector<int> pos[MAXN];
int qpow(int base,int p) {
	int ans=1;
	while(p) {
		if(p&1) ans=ans*base%MOD;
		base=base*base%MOD,p>>=1;
	}
	return ans;
}
int C(int u,int d) {if(u>d) return 0;return frac[d]*inv[u]%MOD*inv[d-u]%MOD;}
signed main() {
	ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
	cin>>n; ffor(i,1,n) cin>>a[i],mp[a[i]]++;
	frac[0]=1; ffor(i,1,n) frac[i]=frac[i-1]*i%MOD;
	inv[n]=qpow(frac[n],MOD-2); roff(i,n-1,0) inv[i]=inv[i+1]*(i+1)%MOD;
	for(auto pr:mp) {
		int k=pr.second,tot=1;
		ffor(j,1,k) tot=(tot+C(j,n-k))%MOD,ans=(ans+C(j,k)*tot)%MOD;
		ffor(j,1,k) pos[j].push_back(C(j,k));
	}
	ffor(i,1,n) {
		int _2=(MOD+1)/2,tot=0,sqtot=0;
		for(auto v:pos[i]) tot=(tot+v)%MOD,sqtot=(sqtot+v*v)%MOD;
		tot=tot*tot%MOD,tot=(tot-sqtot)%MOD,tot=(tot*_2)%MOD;
		ans=(ans-tot)%MOD;
	}
	ans=(ans+MOD)%MOD;
	cout<<ans;
	return 0;
}

pypy3 解法, 执行用时: 460ms, 内存消耗: 37784K, 提交时间: 2023-07-07 20:56:05

import sys
input = lambda:sys.stdin.readline().strip()
M = lambda:map(int,input().split())
inf = float('inf')
from collections import Counter
mod = 10**9+7
n = int(input())
arr = [0]+list(M())
a = [1,1]+[0]*(n-1)
b = [1,1]+[0]*(n-1)
for i in range(2,n+1):
    a[i] = a[i - 1] * i%mod
    b[i] = pow(a[i],mod-2,mod)
f = lambda n,m:a[n]*b[m]*b[n - m]%mod
ans = 0;
k = Counter(arr)
for xx in k.items():
    x,y = xx[1],n-xx[1]
    res = f(y,0)
    for i in range(1,x+1):
        if y >= i:res = (res+f(y,i))%mod
        ans = (ans+f(x,i)*res)%mod
        yy = f(x, i)
s = Counter(k.values())
lin = [0]*(n+1)
for i in s:
    for j in range(1,s[i]+1):
        for v in range(1,i+1):
            ans = (ans-lin[v]*f(i,v))%mod
            lin[v] = (lin[v]+f(i,v))%mod
print(ans)


上一题