列表

详情


NC236185. 智乃的树分治(模板)

描述

给定一颗大小为N的无根树,节点编号从1N,定义树上两点间的距离dis(u,v)为从uv的唯一最短路径上边的数目。
特别的,我们认为一个节点距离它自身的距离为0,即
定义无根树上的点集, 现在智乃给定d的值为一个常数。她想要知道对于时,集合的尺寸各是多少。

输入描述

第一行是两个整数表示节点数以及距离参数
接下来输入n-1行,每行两个正整数表示树的一条边。

输出描述

输出一行n个整数,分别表示,整数之间用一个空格隔开,行末没有多余空格。

示例1

输入:

7 1
1 2
2 3
2 4
2 5
5 6
5 7

输出:

2 5 2 2 4 2 2

示例2

输入:

7 0
1 2
2 3
2 4
2 5
5 6
5 7

输出:

1 1 1 1 1 1 1

原站题解

上次编辑到这里,代码来自缓存 点击恢复默认模板

C++(g++ 7.5.0) 解法, 执行用时: 254ms, 内存消耗: 16464K, 提交时间: 2022-09-19 11:05:18

#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <vector>
#include <queue>

using namespace std ;

using ll = long long ;
using pii = pair<ll,ll> ;

const int N = 2e5 + 100 ,M = 2 * N ;

int n,m ;
int h[N],e[M],ne[M],idx ;
int dep[N],sz[N],tp[N],tp2[N],hh1,hh2 ;
int cha[N],ans[N] ;
bool st[N] ;

void add(int a,int b){
    e[idx] = b,ne[idx] = h[a],h[a] = idx ++ ;
}

int getsz(int u,int fa){
    int tot = 1 ;
    for(int i = h[u] ; ~ i ; i = ne[i]){
        int j = e[i] ;
        if(j == fa || st[j]) continue ;
        tot += getsz(j,u) ;
    }
    return tot ;
}

int getgravity(int u,int fa,int sum){
    int gr = -1 ;
    sz[u] = 1 ;
    int mx = 0 ;
    for(int i = h[u] ; ~ i ; i = ne[i]){
        int j = e[i] ;
        if(j == fa || st[j]) continue ;
        int nx = getgravity(j,u,sum) ;
        if(nx != -1) gr = nx ;
        sz[u] += sz[j] ;
        mx = max(mx,sz[j]) ;
    }
    mx = max(mx,sum - sz[u]) ;
    if(mx * 2 <= sum) gr = u ;
    return gr ;
}

void dfs(int u,int fa,int depth){
    tp[++hh1] = tp2[++hh2] = u ;
    dep[u] = depth ;
    for(int i = h[u] ; ~ i ; i = ne[i]){
        int j = e[i] ;
        if(j == fa || st[j]) continue ;
        dfs(j,u,depth+1) ;
    }
}

void calc(int tp[],int len,int fg){
    sort(tp+1,tp+len+1,[](int a,int b){
        return dep[a] < dep[b] ;
    }) ;
    for(int i = 1,r = len ; i <= len ; i ++){
        while(r >= 1 && dep[tp[i]] + dep[tp[r]] > m) r -- ;
        if(r < i) break ;
        ans[tp[i]] += (r - i + 1) * fg ;
        cha[i+1] += fg,cha[r+1] -= fg ;
    }
    for(int i = 1 ; i <= len ; i ++){
        cha[i] += cha[i-1] ;
        ans[tp[i]] += cha[i] ;
    }

    for(int i = 1 ; i <= len + 1 ; i ++) cha[i] = 0 ;
}

void divtree(int u,int fa){
    int tot = getsz(u,fa) ;
    int gr = getgravity(u,fa,tot) ;
    st[gr] = 1 ;

    hh1 = 0 ;
    tp[++hh1] = gr ;
    dep[gr] = 0 ;
    for(int i = h[gr] ; ~ i ; i = ne[i]){
        int j = e[i] ;
        if(j == fa || st[j]) continue ;
        hh2 = 0 ;
        dfs(j,-1,1) ;
        calc(tp2,hh2,-1) ;
    }
    calc(tp,hh1,1) ;

    for(int i = h[gr] ; ~ i ; i = ne[i]){
        int j = e[i] ;
        if(j == fa || st[j]) continue ;
        divtree(j,-1) ;
    }
}

int main(){
    scanf("%d%d",&n,&m) ;
    memset(h,-1,sizeof h) ;
    for(int i = 1 ; i <= n - 1 ; i ++){
        int a,b ;
        scanf("%d%d",&a,&b) ;
        add(a,b),add(b,a) ;
    }
    divtree(1,-1) ;
    for(int i = 1;  i <= n ; i ++) printf("%d ",ans[i]) ;
    return 0 ;
}

C++ 解法, 执行用时: 185ms, 内存消耗: 9680K, 提交时间: 2022-05-13 09:03:39

#include<bits/stdc++.h>
#define N 100005
using namespace std ;
int n,d ;
struct Edge 
{
	int nxt,to ;
}e[N<<1] ;
struct node
{
	int dis,id ;
}q[N] ;
int head[N],tot=0 ;
int siz[N],mxt[N] ;
int rt,sum,num ;
int ans[N] ;
bool vis[N] ;
bool cmp(const node &A,const node &B)
{
	return A.dis<B.dis ;
}
void add(int from,int to)
{
	e[++tot].to=to ; e[tot].nxt=head[from] ;
	head[from]=tot ;
}
void Getroot(int x,int f)
{
	siz[x]=1 ; mxt[x]=0 ;
	for(int i=head[x];i;i=e[i].nxt)
	{
		int y=e[i].to ;
		if(vis[y]||y==f) continue ;
		Getroot(y,x) ;
		siz[x]+=siz[y] ;
		mxt[x]=max(mxt[x],siz[y]) ;
	}
	mxt[x]=max(mxt[x],sum-siz[x]) ;
	if(rt==-1||mxt[x]<mxt[rt]) rt=x ;
}
void Getdis(int x,int f,int dis)
{
	q[++num]={dis,x} ;
	for(int i=head[x];i;i=e[i].nxt)
	{
		int y=e[i].to ;
		if(vis[y]||y==f) continue ;
		Getdis(y,x,dis+1) ;
	}
}
void calc(int x,int dis,int op)
{
	num=0 ;
	Getdis(x,-1,dis) ;
	sort(q+1,q+num+1,cmp) ;
	for(int l=1,r=num;l<=num;++l)
	{
		while(r&&q[l].dis+q[r].dis>d) r-- ;
		if(l<=r) ans[q[l].id]+=op*(r-1) ;
		else ans[q[l].id]+=op*r ;
	}
}
void solve(int x)
{
	calc(x,0,1) ; vis[x]=1 ;
	for(int i=head[x];i;i=e[i].nxt)
	{
		int y=e[i].to ;
		if(vis[y]) continue ;
		calc(y,1,-1) ;
		rt=-1 ; sum=siz[y] ;
		Getroot(y,-1) ; solve(rt) ;
	}
}
int main()
{
	scanf("%d%d",&n,&d) ;
	for(int i=1,u,v;i<n;++i)
	{
		scanf("%d%d",&u,&v) ;
		add(u,v) ; add(v,u) ;
	}
	sum=n ; rt=-1 ;
	Getroot(1,-1) ; solve(rt) ;
	for(int i=1;i<=n;++i) printf("%d ",ans[i]+1) ;
	return 0 ;
}

上一题