列表

详情


NC217977. 牛客推荐系统开发之标签重复度

描述

牛客讨论区的每一篇帖子都会有一些标签,比如说“编程语言”、“秋招”、“招聘”等。有些标签是其他标签的子标签,比如“秋招”是“招聘”的子标签。于是这种父子标签的关系可以用一棵树来表示,即某个标签一定是另一个标签的子标签或者这个标签是这棵树的树根节点,每一个标签都有一个权值w_i
定义两个的标签的重复度(这两个标签可以是相同的)为这在这棵树上这两个标签组成的简单路径中的最大权值乘以最小权值。
定义一棵标签树的标签重复度为这棵树上每两个标签之间的重复度之和。
你作为优秀的牛客算法工程师,需要去求某一棵有个标签的标签树的重复度,求出答案对998244353取模后的结果即可。

输入描述

第1行为一个整数,表示总共有个标签;
第2行为n个整数,即,表示第个标签的权值;
接下来n-1行,每行两个整数,表示标签和标签互为父子标签关系。

输出描述

一个整数,为所求答案模998244353后的结果。

示例1

输入:

5
9 12 18 10 15
5 3
3 1
3 2
2 4

输出:

2704

示例2

输入:

1
1

输出:

1

原站题解

上次编辑到这里,代码来自缓存 点击恢复默认模板

C++ 解法, 执行用时: 578ms, 内存消耗: 16296K, 提交时间: 2021-06-14 00:37:10

#include<bits/stdc++.h>
#define eb emplace_back
#define pii pair<int,int>
#define fi first
#define se second
using namespace std;
const int mod=998244353;
const int M=1e5+9;
int n,m,rt,ans,mx;
int w[M],c[M],siz[M],bi[M],su[M];
vector<int>g[M];
vector<pii>p;
bool vis[M];
void frt(int u,int fa,int sz){
	if(fa==0)mx=1e9;
	siz[u]=1;
	int ma=0;
	for(auto v:g[u]){
		if(v!=fa&&!vis[v]){
			frt(v,u,sz);
			siz[u]+=siz[v];
			ma=max(ma,siz[v]);
		}
	}
	ma=max(ma,sz-siz[u]);
	if(mx>ma)mx=ma,rt=u;
}
void solve(int u,int fa,int a,int b){
	a=min(a,w[u]);
	b=max(b,w[u]);
	p.eb(b,a);
	for(auto v:g[u]){
		if(v!=fa&&!vis[v]){
			solve(v,u,a,b);
		}
	}
}
void add(int&x,int y){x+=y;x>=mod?x-=mod:0;x<0?x+=mod:0;}
void change(int x,int v){
	for(int y=x;x<=n;x+=x&-x)bi[x]+=v,su[x]=(1ll*su[x]+v*c[y]+mod)%mod;
}
int sum(int x,int *h,int rex=0){
	for(;x;x-=x&-x)add(rex,h[x]);
	return rex;
}
void dfs(int u,int sz){
	frt(u,0,sz);
	u=rt;
	vis[u]=1;
	vector<pii>q;
	q.eb(w[u],w[u]);
	for(auto v:g[u]){
		if(!vis[v]){
			p.clear();
			solve(v,u,w[u],w[u]);
			sort(p.begin(),p.end());
			for(auto o:p){
				int a=o.fi,b=o.se;
				change(b,1);
				add(ans,-1ll*sum(b,su)*c[a]%mod);
				add(ans,-1ll*(sum(n,bi)-sum(b,bi)+mod)%mod*c[a]%mod*c[b]%mod);
				q.eb(o);
			}
			for(auto o:p){
				int b=o.se;
				change(b,-1);
			}
		}
	}
	sort(q.begin(),q.end());
	for(auto o:q){
		int a=o.fi,b=o.se;
		change(b,1);
		add(ans,1ll*sum(b,su)*c[a]%mod);
		add(ans,1ll*(sum(n,bi)-sum(b,bi))*c[a]%mod*c[b]%mod);
	}
	for(auto o:q){
		int b=o.se;
		change(b,-1);
	}
	
	for(auto v:g[u]){
		if(!vis[v])dfs(v,siz[v]);
	}
}
int main(){
	scanf("%d",&n);
	for(int i=1;i<=n;++i)scanf("%d",&w[i]),c[i]=w[i];
	sort(c+1,c+n+1);
	m=unique(c+1,c+n+1)-c-1;
	for(int i=1;i<=n;++i)w[i]=lower_bound(c+1,c+m+1,w[i])-c;
	for(int i=1,u,v;i<n;++i){
		scanf("%d%d",&u,&v);
		g[u].eb(v);
		g[v].eb(u);
	}
	dfs(1,n);
	printf("%d\n",(ans+mod)%mod);
	return 0;
} 
/*
2 
1 2
1 2
*/

上一题