NC21604. 出题人的数组
描述
输入描述
第一行输入两个正整数n,m,分别表示数组A,B的长度
第二行输入n个正整数,表示数组A
第二行输入m个正整数,表示数组B
输出描述
一个正整数,表示cost
示例1
输入:
3 3 1 3 5 2 6 4
输出:
75
C++14(g++5.4) 解法, 执行用时: 95ms, 内存消耗: 1964K, 提交时间: 2019-03-30 11:03:39
#include<bits/stdc++.h> using namespace std; typedef long long ll; const int maxn=1e5+5; ll a[maxn],b[maxn],n,m; struct node { ll s; ll num; } ta[maxn],tb[maxn]; bool pd(node a,node b) { //a比b的平均值大返回true return a.s*b.num > a.num*b.s; } int main() { //std::ios::sync_with_stdio(0); cin>>n>>m; int nn=0,mm=0; for(int i=1; i<=n; ++i) { cin>>a[i]; ta[++nn]=node {a[i],1}; while(nn>1) { if(pd(ta[nn],ta[nn-1])) { ta[nn-1].num+=ta[nn].num,ta[nn-1].s+=ta[nn].s; nn--; } else break; } } for(int i=1; i<=m; ++i) { cin>>b[i]; tb[++mm]=node {b[i],1}; while(mm>1) { if(pd(tb[mm],tb[mm-1])) { tb[mm-1].num+=tb[mm].num,tb[mm-1].s+=tb[mm].s; mm--; } else break; } } ll i=1,j=1,cnt=1,xa=1,xb=1,ans=0; while(i<=nn && j<=mm) { if(pd(ta[i],tb[j])) { //a的均值大 for(int k=1; k<=ta[i].num; ++k) { ans+=a[xa++]*(cnt++); } i++; } else { for(int k=1; k<=tb[j].num; ++k) { ans+=b[xb++]*(cnt++); } j++; } } while(i<=nn) { for(int k=1; k<=ta[i].num; ++k) { ans+=a[xa++]*(cnt++); } i++; } while(j<=mm) { for(int k=1; k<=tb[j].num; ++k) { ans+=b[xb++]*(cnt++); } j++; } cout<<ans<<endl; return 0; }
C++11(clang++ 3.9) 解法, 执行用时: 31ms, 内存消耗: 2060K, 提交时间: 2020-08-03 20:36:11
#include<bits/stdc++.h> using namespace std; int L1=0,L2=0,L3=0,a[100005],b[100005],c[200005],c1[100005],c2[100005]; long long ans=0,w1[100005],w2[100005]; void get(int *a,int *b,long long *w,int n,int &L) { for(int i=1;i<=n;i++) { b[++L]=1,w[L]=a[i]; while(L>1&&w[L-1]*b[L]<w[L]*b[L-1])w[L-1]+=w[L],b[L-1]+=b[L],L--; } } void add(int *a,int l,int r) { for(int i=l;i<=r;i++)c[++L3]=a[i]; } int main() { int i,j,n,m,s1=0,s2=0; scanf("%d%d",&n,&m); for(i=1;i<=n;i++)scanf("%d",&a[i]); for(i=1;i<=m;i++)scanf("%d",&b[i]); get(a,c1,w1,n,L1),get(b,c2,w2,m,L2); for(i=j=1;i<=L1;i++) { while(j<=L2&&w1[i]*c2[j]<w2[j]*c1[i])add(b,s2+1,s2+c2[j]),s2+=c2[j++]; add(a,s1+1,s1+c1[i]),s1+=c1[i]; } while(j<=L2)add(b,s2+1,s2+c2[j]),s2+=c2[j++]; for(i=1;i<=L3;i++)ans+=(long long)i*c[i]; printf("%lld\n",ans); return 0; }