列表

详情


NC252737. 最便宜的构建

描述

有一天,牛牛获得了一张 n 个结点 m 条边的无向连通图,其中每条边都有边权,第 i 条边的边权记为 w_i

牛妹想考一考牛牛,她选择了 k 个关于结点的集合,第 i 个集合中有 s_i 个结点编号,这 s_i 个结点编号分别记为:q_{i,1},q_{i,2},...,q_{i,s_i}

牛妹让牛牛在图的 m 条边中选择一个子集构成一个子图,使得牛妹选择的 k 个集合都被【满足】,一个集合被【满足】当且仅当集合中结点编号所代表的结点在子图上是连通的。

显然能够使得 k 个集合都被【满足】的边集选择方案可能不止一个,所以牛牛想问你所有可能的边集选择方案中,被选择边中边权最大的那条边最少得是多少?

输入描述

第一行输入两个空格分隔的整数 n\ m

接下来 m 行,第 i 行输入三个空格分隔的整数:u\ v\ w_i,代表图中存在一条连接了 uv 边权为 w_i 的边。

接下来输入一行一个整数代表 k

接下来 k 行,第 i 行输入若干个空格分隔的整数:s_i, q_{i,1},q_{i,2},...,q_{i,s_i},描述了牛妹选择的第 i 个集合。

保证: 
0< n,m \le 10^5
0 < u,v,k,q_{i,j} \le n 
1 < s_i\le n 
0 < w_i \le 10^9 
牛妹选择的 k 个集合中结点个数的和不超过 2\times 10^5,且同一个集合中不会出现重复的结点。 图中无重边,无自环。

输出描述

一行一个整数代表所有符合牛妹要求的边集选择方案中边权最大的那条边最少是多少。

示例1

输入:

5 6
1 5 4
2 5 6
4 5 1
2 4 3
3 5 2
1 3 7
3
2 2 4
2 1 3
3 5 3 1

输出:

4

说明:

选择输入的第 1,4,5 条边即可,其中边权最大为 4。

原站题解

上次编辑到这里,代码来自缓存 点击恢复默认模板

pypy3 解法, 执行用时: 2122ms, 内存消耗: 66732K, 提交时间: 2023-06-09 22:06:54

from sys import stdin
input=lambda:stdin.readline().strip()
n,m=map(int,input().split())
A=[]
MAX=0
for i in range(m):
    u,v,w=map(int,input().split())
    MAX=max(MAX,w)
    A.append((w,u,v))
A.sort(key=lambda x:x[0])
k=int(input())
L=[]
for i in range(k):
    temp=list(map(int,input().split()))
    L.append(temp[1:])
def find(x,node):
    if x==node[x]:
        return x
    node[x]=find(node[x],node)
    return node[x]
def check(mid):
    node=[i for i in range(n+1)]
    for w,u,v in A:
        if w>mid:
            break
        u=find(u,node)
        v=find(v,node)
        if u==v:continue
        if u>v:u,v=v,u
        node[v]=u
    for num in L:
        for i in range(1,len(num)):
            if find(num[i],node)!=find(num[i-1],node):
                return False
    return True
l=0
r=MAX
while l<r:
    mid=l+r>>1
    if check(mid):
        r=mid
    else:
        l=mid+1
print(l)

C++(clang++ 11.0.1) 解法, 执行用时: 339ms, 内存消耗: 27276K, 提交时间: 2023-06-18 19:03:24

#include<bits/stdc++.h>
using namespace std;

const int N = 1e6+10;

int n,m,k;
int p[N];

vector<int>v[N];

struct node{
	int a,b,w;
}e[N];

int find(int x)
{
	return x==p[x]?x:p[x]=find(p[x]);
}

int check(int mid)
{
	for(int i=1;i<=n;i++) p[i]=i;
	for(int i=1;i<=m;i++)
	{
		if(e[i].w>mid) continue;
		p[find(e[i].a)]=find(e[i].b);
	}
	for(int i=1;i<=k;i++)
	{
		int ans=find(v[i][0]);
		for(int j=0;j<v[i].size();j++)
			if(find(v[i][j])!=ans) return 0;
	}
	return 1;
}



int main()
{
	cin>>n>>m;
	for(int i=1;i<=m;i++)
	{
		int x,y,c;
		cin>>x>>y>>c;
	    e[i]={x,y,c};
	}
	cin>>k;
	for(int i=1;i<=k;i++)
	{
		int op;
		cin>>op;
		for(int j=1;j<=op;j++)
		{
			int x;
			cin>>x;
			v[i].push_back(x);
		}
	}
	int l=1,r=1e9;
	while(l<r)
	{
		int mid=(l+r)/2;
		if(check(mid)) r=mid;
		else l=mid+1;
	}
	cout<<l;
	
	return 0;
}

Python3 解法, 执行用时: 1811ms, 内存消耗: 37168K, 提交时间: 2023-06-25 17:27:59

ps = []
fps = {}
def rread(): 
    return int(input())
def fread(): 
    return [int(_x) for _x in input().split()]
def psth(s):
    if fps[s]!=s:
        fps[s]=psth(fps[s])
    return fps[s]
n,m=fread()
for i in range(1,n+1):
    fps[i]=i
for i in range(0,m):
    x,y,w=fread()
    ps.append((w,x,y))
ps.sort(reverse=True)
k=rread()
for i in range(0,k):
    a=list(map(int,input().split()))
    for j in range(2,a[0]+1):
        while psth(a[j])!=psth(a[1]):
            m=m-1
            w,x,y=ps[m]
            cx=psth(x)
            cy=psth(y)
            fps[cx]=cy
print(w)

上一题