列表

详情


NC212881. 开心消消乐

描述

YanGu有一个 n 个点的树,即由 n-1条边连接起来的联通图,每条边都有颜色 w_i. 定义 path(u,v) 为 u 到 v 的简单路径,不妨设有 m 条边,颜色按 u 到 v 的路径顺序排列分别为 , 且有 p 个连续的颜色段,相邻的颜色段颜色不同,颜色分别为 , 长度分别为 , 显然有 . 对于颜色段定义

其中 d 为颜色, l 为长度.
定义

譬如 u 到 v 按顺序经过的路径颜色为 1,2,2,2,3,4,4,4,4,4,3,3,1, 则此时 , 因为存在连续 3 个 颜色 2 和连续 5个颜色 4 消为0.
YanGu想知道有多少个这样的二元组 (u,v) 满足 .

输入描述

第一行含有一个整数 T 表示测试数据组数. 对于每组测试数据,
第一行含有两个整数 n, k 表示点的个数和阈值,
接下来 n-1行,每行
有三个整数 u,v,w 表示 u 和 v 之间边的颜色为 w

*
*
*
*
*
*
* 至多有50组

输出描述

对于每个测试数据 , 输出一个整数表示答案

示例1

输入:

2
5 4
1 2 3
1 3 5
2 4 5
3 5 4
8 50809177
1 2 700805901
2 3 32145015
3 4 792263333
3 5 538420696
1 6 351870424
2 7 263716407
5 8 818097140

输出:

9
27

原站题解

上次编辑到这里,代码来自缓存 点击恢复默认模板

C++(clang++11) 解法, 执行用时: 1224ms, 内存消耗: 22684K, 提交时间: 2020-10-28 23:57:34

#include<cstdio>
#include<vector>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
const int N = 100050;
template<typename T>inline void read(T&x)
{
	T f = 1,c = 0;char ch = getchar();
	while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
	while(ch>='0'&&ch<='9'){c=c*10+ch-'0';ch=getchar();}
	x = f*c;
}
struct Pair
{
	int x,y;
	Pair(){}
	Pair(int x,int y):x(x),y(y){}
	bool operator < (const Pair&p)const{return x<p.x;}
};
vector<Pair>ve[N];
int T,n,m;
int siz[N],fa[N],mx_siz[N],rt,sum;
bool use[N];
void get_rt(int u,int ff=0)
{
	fa[u] = ff,siz[u] = 1,mx_siz[u] = 0;
	for(int i=0;i<ve[u].size();i++)
	{
		int to = ve[u][i].y;
		if(to==ff||use[to])continue;
		get_rt(to,u);
		siz[u] += siz[to];
		if(siz[to]>mx_siz[u])
			mx_siz[u] = siz[to];
	}
	mx_siz[u] = max(mx_siz[u],sum-siz[u]);
	if(mx_siz[u]<mx_siz[rt])rt=u;
}
int tin[N],tout[N],pla[N],tim,dep[N],col[N],tl;
ll w[N],w1[N],w2[N],ans,tmp[N];
ll sta[N],sta2[N];
ll calc(int d,int l){return l>2?0:1ll*d*l;}
void dfs(int u,int ff,ll s,int d,int l)
{
	tin[u] = ++tim,pla[tim] = u,w[tim] = s+calc(d,l);
	dep[u] = dep[ff]+1;
	if(dep[u]!=l+1)w1[u]=w1[ff];
	else w1[u] = calc(d,l);
	for(int i=0;i<ve[u].size();i++)
	{
		int to = ve[u][i].y;
		if(to==ff||use[to])continue;
		if(!ff)col[to]=ve[u][i].x;
		else col[to]=col[u];
		if(ve[u][i].x!=d)dfs(to,u,s+calc(d,l)-m,ve[u][i].x,1);
		else dfs(to,u,s-m,d,l+1);
	}
	tout[u] = tim;
}
ll calc(int u,int ff,ll s,int d,int l)
{
	dep[u] = (ff!=0);tim = 0;
	dfs(u,ff,s,d,l);
	ll ret = 0;int lw;
	for(int i=1;i<=tim;i++)
		w2[i] = w[i] - w1[pla[i]];
	for(int i=1;i<=tim;i++)
	{
		if(col[pla[i]]==w1[pla[i]])sta[++tl] = w[i],sta2[tl] = w2[i];
		if(i==tim||col[pla[i]]!=col[pla[i+1]])
		{
			sort(sta+1,sta+tl+1);
			int lk = 1,rk = tl+1;//w(1,1)
			for(;lk<=tl;lk++)
			{
				while(rk-1>=1&&sta[lk]+sta[rk-1]>=0)rk--;
				ret += tl+1-max(rk,lk+1);
			}
			sort(sta2+1,sta2+tl+1);
			lk = 1,rk = tl+1;//w2(1,1)
			for(;lk<=tl;lk++)
			{
				while(rk-1>=1&&sta2[lk]+sta2[rk-1]>=0)rk--;
				ret -= tl+1-max(rk,lk+1);
			}
			tl = 0;
		}
	}
	lw = (ff==0);//w(a,b)
	for(int i=lw+1;i<=tim;i++)if(i==tim||col[pla[i]]!=col[pla[i+1]])
	{
		sort(w+lw+1,w+i+1);
		int lk = lw+1,rk = i+1;
		for(;lk<=i;lk++)
		{
			while(rk-1>=lw+1&&w[lk]+w[rk-1]>=0)rk--;
			ret -= i+1-max(rk,lk+1);
		}
		lw = i;
	}
	lw = (ff==0);//w2
	for(int i=lw+1;i<=tim;i++)if(i==tim||col[pla[i]]!=col[pla[i+1]])
	{
		sort(w2+lw+1,w2+i+1);
		int lk = lw+1,rk = i+1;
		for(;lk<=i;lk++)
		{
			while(rk-1>=lw+1&&w2[lk]+w2[rk-1]>=0)rk--;
			ret += i+1-max(rk,lk+1);
		}
		lw = i;
	}
	sort(w+1,w+1+tim);//w
	int lk = 1,rk = tim+1;
	for(;lk<=tim;lk++)
	{
		while(rk-1>=1&&w[lk]+w[rk-1]>=0)rk--;
		ret += tim+1-max(rk,lk+1);
	}
	return ret;
}
void clear()
{
	ans = 0;
	for(int i=1;i<=n;i++)
	{
		use[i] = col[i] = 0;
		ve[i].clear();
	}
}
void work()
{
	ans+=calc(rt,0,0,0,0);
	use[rt] = 1;int u = rt;
	for(int i=0;i<ve[u].size();i++)
	{
		int to = ve[u][i].y;
		if(use[to])continue;
		ans -= calc(to,u,-m,ve[u][i].x,1);
	}
	for(int i=0;i<ve[u].size();i++)
	{
		int to = ve[u][i].y;
		if(use[to])continue;
		sum = siz[to];rt = 0;
		get_rt(to);work();
	}
}
void sol()
{
	clear();
	read(n),read(m);
	for(int u,v,w,i=1;i<n;i++)
		read(u),read(v),read(w),ve[u].push_back(Pair(w,v)),ve[v].push_back(Pair(w,u));
	for(int i=1;i<=n;i++)
		sort(ve[i].begin(),ve[i].end());
	rt=0,sum=n;
	get_rt(1);work();
	printf("%lld\n",ans);
}
int main()
{
	read(T);mx_siz[0] = 0x3f3f3f3f;
	while(T--)sol();
	return 0;
}

上一题