列表

详情


NC249268. 点分治

描述

给定一棵n个点的树,请找到树上k条两两不同的路径,使得它们们的交尽量长。

路径的定义是一组有序的、两两不同的点u,x_{1},x_{2} \dots x_{m},v,其中(u,x_{1}),(x_{1},x_{2}) \dots (x_{m},v)之间均有边直接相连。两条路径不同当且仅当它们包含的点集不同(即(u,v),(v,u)算同一条路径)。路径可以只包含一个点u

路径的长度被定义为路径所包含的点的个数。不难看出,任意多条树上路径的交还是一条路径。

空路径的长度定义为0

输入描述

第一行数据组数T(1\le T\le 10^4),对于每组数据:
第一行两个正整数n,k(1\le n \le 2\cdot 10^5 , 1\le k \le \frac{n(n-1)}{2})
接下来n-1行,每行两个正整数u,v (1\le u,v\le n),代表树上的一条边。
保证所有的n之和不超过2\cdot 10^5

输出描述

对于每组数据,一行一个整数表示答案

示例1

输入:

3
4 2
1 2
1 3
1 4
5 3
1 2
2 3
2 4
3 5
3 1
1 2
2 3

输出:

2
3
3

原站题解

上次编辑到这里,代码来自缓存 点击恢复默认模板

C++(g++ 7.5.0) 解法, 执行用时: 906ms, 内存消耗: 34336K, 提交时间: 2023-03-18 20:53:37

#include <iostream>
#include <map>
#include <algorithm>
#include <set>
#include <cstring>
#include <queue>
#define N 550000
using namespace std;
#define int long long
int t;
int n,cntf;
int dd[N],son[N],sz[N],dpf[N],*f[N];
int panf[N],k;
int edt=1,head[N],to[N],nx[N];
void adde(int u,int v){
	to[++edt]=v,nx[edt]=head[u],head[u]=edt;
}
void dfs1(int u,int fath){
	sz[u]=1;
	panf[u]=1LL*n*n;
	dd[u]=0;
	for(int ed=head[u];ed;ed=nx[ed]){
		int v=to[ed];
		if(v==fath) continue;
		dfs1(v,u);
		sz[u]+=sz[v];
		panf[u]-=1LL*sz[v]*sz[v];
		if(dd[u]<dd[v]+1) son[u]=v,dd[u]=dd[v]+1;
	}
	panf[u]-=1LL*(n-sz[u])*(n-sz[u]);
}
void dfs2(int u,int fath){
	f[u]=dpf+(++cntf);
	if(son[u])dfs2(son[u],u);
	for(int ed=head[u];ed;ed=nx[ed]){
		int v=to[ed];
		if(v==fath||v==son[u]) continue;
		dfs2(v,u);
	}
}
int len;
bool flag;
inline void pan(int Q){
	if(Q>=k) flag=1;
}
void dp(int u,int fath){
	if(son[u]) dp(son[u],u);
	f[u][0]=sz[u];
	if(dd[u]>=len) pan(1LL*(n-sz[son[u]])*f[u][len]);
	for(int ed=head[u];ed;ed=nx[ed]){
		int v=to[ed];
		if(v==fath||v==son[u]) continue;
		dp(v,u);
		for(int i=0;i<=dd[v];i++){
			if(i+1==len) pan(1LL*f[v][i]*(n-sz[v]));
			if(i+1<len&&i+1+dd[u]>=len) pan(1LL*f[v][i]*f[u][len-1-i]);
		} 
		for(int i=0;i<dd[v];i++){
			f[u][i+1]=max(f[u][i+1],f[v][i]);
		}
	}
}
bool check(int mid){
	flag=0;
	len=mid;
	for(int i=1;i<=n;i++) dpf[i]=0;
	dp(1,0);
	return flag;
}
signed main(){
	ios::sync_with_stdio(0);
	cin.tie(0);
	int t;
	cin>>t;
	while(t--){
		
		cin>>n>>k;
		edt=1;
		for(int i=0;i<=2*n+10;i++){
			head[i]=son[i]=0;
		}
		for(int i=2;i<=n;i++){
			int u,v;
			cin>>u>>v;
			adde(u,v);adde(v,u);
		}
		dfs1(1,0);
		dfs2(1,0);
		bool flag0=0;
		for(int i=1;i<=n;i++){
			if((panf[i]+1)/2>=k) flag0=1;
		}
		//cout<<"??"<<endl;
		if(flag0){
			int l=1,r=n,mid,res=0;
			while(l<=r){
				mid=(l+r)>>1;
				if(check(mid)) res=mid,l=mid+1;
				else r=mid-1;
			}
			cout<<res+1<<endl;
		}	
		else{
			cout<<"0"<<endl;
		}
	}
	return 0;
}

C++(clang++ 11.0.1) 解法, 执行用时: 149ms, 内存消耗: 22284K, 提交时间: 2023-04-18 10:13:25

#include <bits/stdc++.h>
#define N 200005
#define pb push_back
using namespace std;
typedef long long ll;

int n, ans; ll K; 
int dep[N], son[N], siz[N];
int *dp[N], F[N], tt;
vector<int> E[N];

void dfs1(int x, int fa) {
	siz[x] = 1; ll tot = 1;
	for (auto y : E[x]) if (y != fa) {
		dfs1(y, x); tot += 1ll*siz[x]*siz[y]; siz[x] += siz[y];
		if (dep[y]+1 > dep[x]) dep[x] = dep[y]+1, son[x] = y; 
	}
    tot += 1ll*siz[x]*(n-siz[x]);
	if (tot >= K) ans = 1;
}
void dfs2(int x, int fa) {
	dp[x] = F+(++tt);
	if (son[x]) dfs2(son[x], x);
	for (auto y : E[x]) if (y != fa && y != son[x]) {
		dfs2(y, x);
		for (int i = 0; i <= dep[y]; i++) {
			ll _ = (K-1)/dp[y][i]+1;
			if (_ > n) break;
			int k = upper_bound(dp[x]+1,dp[x]+dep[x]+1,(int)_,greater<int>())-dp[x]-1;
			if (k >= 1) ans = max(ans, i+k+2);
		}
		for (int i = 0; i <= dep[y]; i++) {
			dp[x][i+1] = max(dp[x][i+1], dp[y][i]);
		}
	}
	if (x != 1) {
		dp[x][0] = max(dp[x][0], siz[x]);
		ll _ = (K-1)/(n-siz[x])+1;
		if (_ <= n) {
			int k = upper_bound(dp[x],dp[x]+dep[x]+1,(int)_,greater<int>())-dp[x]-1;
			if (k >= 0) ans = max(ans, k+2);
		}
	} 
}
void CLR() {
	for (int i = 1; i <= n; i++) {
		F[i] = dep[i] = son[i] = siz[i] = 0; 
		E[i].clear();
	}
	ans = tt = 0;
}

void solve() {
	scanf("%d %lld", &n, &K);
	CLR();
	for (int i = 1, u, v; i < n; i++) {
		scanf("%d %d", &u, &v);
		E[u].pb(v); E[v].pb(u);
	}
	dfs1(1, 0); dfs2(1, 0);
	printf("%d\n", ans);
}

int main() {
	int ttt; scanf("%d", &ttt);
	while (ttt--) solve();
	return 0;
}

上一题