列表

详情


NC21604. 出题人的数组

描述

出题人有两个数组,A,B,请你把两个数组归并起来使得最小
归并要求原数组的数的顺序在新数组中不改变

输入描述

第一行输入两个正整数n,m,分别表示数组A,B的长度
第二行输入n个正整数,表示数组A
第二行输入m个正整数,表示数组B

输出描述

一个正整数,表示cost

示例1

输入:

3 3
1 3 5
2 6 4

输出:

75

原站题解

上次编辑到这里,代码来自缓存 点击恢复默认模板

C++14(g++5.4) 解法, 执行用时: 95ms, 内存消耗: 1964K, 提交时间: 2019-03-30 11:03:39

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn=1e5+5;
ll a[maxn],b[maxn],n,m;
struct node {
	ll s;
	ll num;
} ta[maxn],tb[maxn];
bool pd(node a,node b) { //a比b的平均值大返回true
	return a.s*b.num > a.num*b.s;
}
int main() {
	//std::ios::sync_with_stdio(0);
	cin>>n>>m;
	int nn=0,mm=0;
	for(int i=1; i<=n; ++i) {
		cin>>a[i];
		ta[++nn]=node {a[i],1};
		while(nn>1) {
			if(pd(ta[nn],ta[nn-1])) {
				ta[nn-1].num+=ta[nn].num,ta[nn-1].s+=ta[nn].s;
				nn--;
			} else	break;
		}
	}
	for(int i=1; i<=m; ++i) {
		cin>>b[i];
		tb[++mm]=node {b[i],1};
		while(mm>1) {
			if(pd(tb[mm],tb[mm-1])) {
				tb[mm-1].num+=tb[mm].num,tb[mm-1].s+=tb[mm].s;
				mm--;
			} else	break;
		}
	}
	ll i=1,j=1,cnt=1,xa=1,xb=1,ans=0;
	while(i<=nn && j<=mm) {
		if(pd(ta[i],tb[j])) { //a的均值大
			for(int k=1; k<=ta[i].num; ++k) {
				ans+=a[xa++]*(cnt++);
			}
			i++;
		} else {
			for(int k=1; k<=tb[j].num; ++k) {
				ans+=b[xb++]*(cnt++);
			}
			j++;
		}
	}
	while(i<=nn) {
		for(int k=1; k<=ta[i].num; ++k) {
			ans+=a[xa++]*(cnt++);
		}
		i++;
	}
	while(j<=mm) {
		for(int k=1; k<=tb[j].num; ++k) {
			ans+=b[xb++]*(cnt++);
		}
		j++;
	}
	cout<<ans<<endl;
	return 0;
}

C++11(clang++ 3.9) 解法, 执行用时: 31ms, 内存消耗: 2060K, 提交时间: 2020-08-03 20:36:11

#include<bits/stdc++.h>
using namespace std;

int L1=0,L2=0,L3=0,a[100005],b[100005],c[200005],c1[100005],c2[100005];
long long ans=0,w1[100005],w2[100005];
void get(int *a,int *b,long long *w,int n,int &L)
{
	for(int i=1;i<=n;i++)
	{
		b[++L]=1,w[L]=a[i];
		while(L>1&&w[L-1]*b[L]<w[L]*b[L-1])w[L-1]+=w[L],b[L-1]+=b[L],L--;
	}
}
void add(int *a,int l,int r)
{
	for(int i=l;i<=r;i++)c[++L3]=a[i];
}
int main()
{
    int i,j,n,m,s1=0,s2=0;
    scanf("%d%d",&n,&m);
    for(i=1;i<=n;i++)scanf("%d",&a[i]);
    for(i=1;i<=m;i++)scanf("%d",&b[i]);
    get(a,c1,w1,n,L1),get(b,c2,w2,m,L2);
    for(i=j=1;i<=L1;i++)
    {
    	while(j<=L2&&w1[i]*c2[j]<w2[j]*c1[i])add(b,s2+1,s2+c2[j]),s2+=c2[j++];
    	add(a,s1+1,s1+c1[i]),s1+=c1[i];
	}
	while(j<=L2)add(b,s2+1,s2+c2[j]),s2+=c2[j++];
	for(i=1;i<=L3;i++)ans+=(long long)i*c[i];
	printf("%lld\n",ans);
    return 0;
}

上一题