NC248620. 有向树
描述
输入描述
输入共 行。
第一行一个数表示 ,接下来 行每行两个数表示树上的边结构,最后一行一共 个数表示每个点的点权 。
输出描述
输出共一行一个数表示 种状态的权值之和取模后的结果。
示例1
输入:
3 1 2 2 3 3 1 2
输出:
14
说明:
C++(clang++ 11.0.1) 解法, 执行用时: 652ms, 内存消耗: 18456K, 提交时间: 2023-03-05 21:56:34
#define _CRT_SECURE_NO_WARNINGS 1 #include<bits/stdc++.h> #include<random> using namespace std; typedef long long LL; typedef unsigned long long ULL; typedef pair<int, int> PII; typedef pair<LL, LL> PLL; #define x first #define y second #define bit(x) (1<<x) #define lowbit(x) (x&-x) const int INF = 0x3f3f3f3f; const LL inf = 1e18; const int mod = 998244353; const int N = 2e5 + 10, M = 1e6+10; const int inv2=(mod+1)/2; int n, m; vector<int>g[N]; bool st[N]; PII p[N]; int a[N]; LL ans=0; int get_size(int u,int fa) { if(st[u])return 0; int res=1; for(auto v:g[u]) { if(v==fa)continue; res+=get_size(v,u); } return res; } int get_wc(int u,int fa,int tot,int &wc) { if(st[u])return 0; int maxv=0, sum=1; for(auto v:g[u]) { if(v==fa)continue; int t=get_wc(v,u,tot,wc); maxv=max(maxv,t); sum+=t; } maxv=max(maxv,tot-sum); if(maxv<=tot/2)wc=u; return sum; } void get_dist(int u,int fa,int cnt,int &pt) { if(st[u])return; p[pt++]={a[u],cnt}; for(auto v:g[u]) { if(v==fa)continue; get_dist(v,u,(LL)cnt*inv2%mod,pt); } } LL get(PII a[], int k) { sort(a, a + k); LL sl=0,sr=0; for(int i=0;i<k;i++)sr=(sr+a[i].y)%mod; int ans=0; for(int i=1;i<k;i++) { sl=(sl+a[i-1].y)%mod; sr=((sr-a[i-1].y)%mod+mod)%mod; ans=(ans+(LL)sl*sr%mod*(a[i].x-a[i-1].x)%mod)%mod; } return ans; } void calc(int u) { if(st[u])return; get_wc(u,-1,get_size(u,-1),u); int pt=0; get_dist(u,-1,1,pt); ans=(ans+get(p,pt))%mod; for(auto v:g[u]) { pt=0; get_dist(v,u,inv2,pt); ans=((ans-get(p,pt))%mod+mod)%mod; } st[u]=true; for(auto v:g[u]) calc(v); } void solve() { scanf("%d",&n); for(int i=1;i<=n;i++)g[i].clear(); for(int i=1;i<n;i++) { int u,v;scanf("%d%d",&u,&v); g[u].push_back(v);g[v].push_back(u); } for(int i=1;i<=n;i++)scanf("%d",&a[i]); calc(1); for(int i=1;i<=n;i++)ans=ans*2%mod; cout<<ans<<'\n'; } signed main() { int T = 1; //scanf("%d", &T); while (T--)solve(); return 0; }
C++(g++ 7.5.0) 解法, 执行用时: 573ms, 内存消耗: 22000K, 提交时间: 2023-03-05 21:34:15
#include <bits/stdc++.h> #define MAXN ((int) 2e5) #define MOD 998244353 using namespace std; typedef pair<int, int> pii; int n, A[MAXN + 10]; long long ans; vector<int> e[MAXN + 10]; int alls, core, sz[MAXN + 10], mx[MAXN + 10]; bool vis[MAXN + 10]; vector<pii> vec; long long P[MAXN + 10]; void dfs1(int sn, int fa) { sz[sn] = 1; mx[sn] = 0; for (int fn : e[sn]) if (fn != fa && !vis[fn]) { dfs1(fn, sn); sz[sn] += sz[fn]; mx[sn] = max(mx[sn], sz[fn]); } mx[sn] = max(mx[sn], alls - sz[sn]); if (core == 0 || mx[sn] < mx[core]) core = sn; } void dfs2(int sn, int fa, int D) { vec.push_back(pii(A[sn], D)); for (int fn : e[sn]) if (fn != fa && !vis[fn]) dfs2(fn, sn, D + 1); } long long calc(int sn, int D) { vec.clear(); dfs2(sn, 0, D); sort(vec.begin(), vec.end()); long long ret = 0, X = 0, Y = 0; for (pii p : vec) { ret = (ret + X * p.first % MOD * P[p.second] % MOD - Y * P[p.second] % MOD + MOD) % MOD; X = (X + P[p.second]) % MOD; Y = (Y + p.first * P[p.second]) % MOD; } return ret; } void divide(int sn, int tot) { alls = tot; core = 0; dfs1(sn, 0); ans = (ans + calc(core, 0)) % MOD; vis[core] = true; for (int fn : e[core]) if (!vis[fn]) { ans = (ans - calc(fn, 1) + MOD) % MOD; divide(fn, sz[fn]); } } int main() { scanf("%d", &n); for (int i = 1; i < n; i++) { int x, y; scanf("%d%d", &x, &y); e[x].push_back(y); e[y].push_back(x); } for (int i = 1; i <= n; i++) scanf("%d", &A[i]); P[0] = 1; P[1] = MOD - MOD / 2; for (int i = 2; i <= n; i++) P[i] = P[i - 1] * P[1] % MOD; divide(1, n); for (int i = 1; i <= n; i++) ans = ans * 2 % MOD; printf("%lld\n", ans); return 0; }