列表

详情


NC248620. 有向树

描述

小松鼠喜欢在树上玩,玩着玩着他想出了一个问题!

有一棵 n 个点的树,每个点有点权 a_i
给出 n-1 条可以改变方向的有向边,显然一共有 种状态。
在某种状态下,若 u 能到达 v,则此状态的权值增加
种状态的权值之和,答案对 998244353 取模。

输入描述

输入共  行。
第一行一个数表示 ,接下来 n-1 行每行两个数表示树上的边结构,最后一行一共 n 个数表示每个点的点权

输出描述

输出共一行一个数表示  种状态的权值之和取模后的结果。

示例1

输入:

3
1 2
2 3
3 1 2

输出:

14

说明:

原站题解

上次编辑到这里,代码来自缓存 点击恢复默认模板

C++(clang++ 11.0.1) 解法, 执行用时: 652ms, 内存消耗: 18456K, 提交时间: 2023-03-05 21:56:34

 #define  _CRT_SECURE_NO_WARNINGS 1
#include<bits/stdc++.h>
#include<random>
using namespace std;
typedef long long LL;
typedef unsigned long long ULL;
typedef pair<int, int> PII;
typedef pair<LL, LL> PLL;
#define x first
#define y second
#define bit(x) (1<<x)
#define lowbit(x) (x&-x)
const int INF = 0x3f3f3f3f;
const LL inf = 1e18;
const int mod = 998244353;
const int N = 2e5 + 10, M = 1e6+10;
const int inv2=(mod+1)/2;
int n, m;
vector<int>g[N];
bool st[N];
PII p[N];
int a[N];
LL ans=0;
int get_size(int u,int fa)
{
	if(st[u])return 0;
	int res=1;
	for(auto v:g[u])
	{
		if(v==fa)continue;
		res+=get_size(v,u);
	}
	return res;
}
int get_wc(int u,int fa,int tot,int &wc)
{
	if(st[u])return 0;
	int maxv=0, sum=1;
	for(auto v:g[u])
	{
		if(v==fa)continue;
		int t=get_wc(v,u,tot,wc);
		maxv=max(maxv,t);
		sum+=t;
	}
	maxv=max(maxv,tot-sum);
	if(maxv<=tot/2)wc=u;
	return sum;
}
void get_dist(int u,int fa,int cnt,int &pt)
{
	if(st[u])return;
	p[pt++]={a[u],cnt};
	for(auto v:g[u])
	{
		if(v==fa)continue;
		get_dist(v,u,(LL)cnt*inv2%mod,pt);
	}
}
LL get(PII a[], int k)
{
    sort(a, a + k);
    LL sl=0,sr=0;
    for(int i=0;i<k;i++)sr=(sr+a[i].y)%mod;
    int ans=0;
    for(int i=1;i<k;i++)
    {
    	sl=(sl+a[i-1].y)%mod;
    	sr=((sr-a[i-1].y)%mod+mod)%mod;
    	ans=(ans+(LL)sl*sr%mod*(a[i].x-a[i-1].x)%mod)%mod;
	}
	return ans;
}
void calc(int u)
{
	if(st[u])return;
	get_wc(u,-1,get_size(u,-1),u);
	int pt=0;
	get_dist(u,-1,1,pt);
	ans=(ans+get(p,pt))%mod;
	for(auto v:g[u])
	{
		pt=0;
		get_dist(v,u,inv2,pt);
		ans=((ans-get(p,pt))%mod+mod)%mod;
	}
	st[u]=true;
	for(auto v:g[u]) calc(v);
	
}
void solve()
{
	scanf("%d",&n);
	for(int i=1;i<=n;i++)g[i].clear();
	for(int i=1;i<n;i++)
	{
		int u,v;scanf("%d%d",&u,&v);
		g[u].push_back(v);g[v].push_back(u);
	}
	for(int i=1;i<=n;i++)scanf("%d",&a[i]);
	calc(1);
	for(int i=1;i<=n;i++)ans=ans*2%mod;
	cout<<ans<<'\n';
}

signed main()
{
	int T = 1;
	//scanf("%d", &T);
	while (T--)solve();
	return 0;
}

C++(g++ 7.5.0) 解法, 执行用时: 573ms, 内存消耗: 22000K, 提交时间: 2023-03-05 21:34:15

#include <bits/stdc++.h>
#define MAXN ((int) 2e5)
#define MOD 998244353
using namespace std;
typedef pair<int, int> pii;

int n, A[MAXN + 10];
long long ans;

vector<int> e[MAXN + 10];
int alls, core, sz[MAXN + 10], mx[MAXN + 10];
bool vis[MAXN + 10];
vector<pii> vec;

long long P[MAXN + 10];

void dfs1(int sn, int fa) {
    sz[sn] = 1; mx[sn] = 0;
    for (int fn : e[sn]) if (fn != fa && !vis[fn]) {
        dfs1(fn, sn);
        sz[sn] += sz[fn];
        mx[sn] = max(mx[sn], sz[fn]);
    }
    mx[sn] = max(mx[sn], alls - sz[sn]);
    if (core == 0 || mx[sn] < mx[core]) core = sn;
}

void dfs2(int sn, int fa, int D) {
    vec.push_back(pii(A[sn], D));
    for (int fn : e[sn]) if (fn != fa && !vis[fn]) dfs2(fn, sn, D + 1);
}

long long calc(int sn, int D) {
    vec.clear(); dfs2(sn, 0, D);
    sort(vec.begin(), vec.end());

    long long ret = 0, X = 0, Y = 0;
    for (pii p : vec) {
        ret = (ret + X * p.first % MOD * P[p.second] % MOD - Y * P[p.second] % MOD + MOD) % MOD;
        X = (X + P[p.second]) % MOD;
        Y = (Y + p.first * P[p.second]) % MOD;
    }
    return ret;
}

void divide(int sn, int tot) {
    alls = tot; core = 0; dfs1(sn, 0);
    ans = (ans + calc(core, 0)) % MOD;
    vis[core] = true;
    for (int fn : e[core]) if (!vis[fn]) {
        ans = (ans - calc(fn, 1) + MOD) % MOD;
        divide(fn, sz[fn]);
    }
}

int main() {
    scanf("%d", &n);
    for (int i = 1; i < n; i++) {
        int x, y; scanf("%d%d", &x, &y);
        e[x].push_back(y); e[y].push_back(x);
    }
    for (int i = 1; i <= n; i++) scanf("%d", &A[i]);

    P[0] = 1; P[1] = MOD - MOD / 2;
    for (int i = 2; i <= n; i++) P[i] = P[i - 1] * P[1] % MOD;
    divide(1, n);
    for (int i = 1; i <= n; i++) ans = ans * 2 % MOD;
    printf("%lld\n", ans);
    return 0;
}

上一题