NC200006. 雷顿女士与平衡树
描述
输入描述
第一行输入一个T(1<=T<=10)表示数据组数。
接下来T组数据。
对于每组数据,第一行输入一个n(1<=n<=500000)。保证n的总和不大于2000000
然后接下来一行输入n个数字,第i个数字表示标号为i的点的权值()。
然后输入n-1行,每行包含两个数u、v,表示一条边连接u与v。
保证输入n-1条边将标号1至标号n组成一棵树。
输出描述
对于每组数据,请你输出一个数表示输入的树的BALANCE值,由于答案可能很大,请将答案对1000000007取模后输出。
请注意行末不要输出多余空格。
示例1
输入:
1 10 9 9 6 2 4 5 8 5 5 6 2 1 3 1 4 3 5 3 6 4 7 2 8 4 9 5 10 3
输出:
179
C++14(g++5.4) 解法, 执行用时: 634ms, 内存消耗: 31604K, 提交时间: 2019-12-07 21:49:36
#include <bits/stdc++.h> using namespace std; typedef long long ll; const ll nmax = 5e5 + 5; const ll mod = 1e9 + 7; struct node { ll x, y, v; }maxNode[nmax], minNode[nmax]; ll n, e; ll a[nmax]; int fa[nmax]; int cnt[nmax]; void init(){ for(int i = 0; i < n; i++){ fa[i] = i; cnt[i] = 1; } } int find(int x){ return x == fa[x] ? x : fa[x] = find(fa[x]); } void uni(int x, int y){ fa[x] = y; cnt[y] += cnt[x]; } int main() { std::ios::sync_with_stdio(false), cin.tie(0), cout.tie(0); int t; cin >> t; while(t--){ cin >> n; e = n - 1; for(int i = 0; i < n; i++) cin >> a[i]; int x, y; for(int i = 0; i < e; i++){ cin >> x >> y; x--; y--; maxNode[i] = node{x, y, max(a[x], a[y])}; minNode[i] = node{x, y, min(a[x], a[y])}; } sort(maxNode, maxNode + e, [](node& a, node& b){ return a.v < b.v; }); sort(minNode, minNode + e, [](node& a, node& b){ return a.v > b.v; }); init(); ll maxsum = 0; for(int i = 0; i < e; i++){ int curx = find(maxNode[i].x); int cury = find(maxNode[i].y); maxsum = (maxNode[i].v * cnt[curx] % mod * cnt[cury] % mod + maxsum) % mod; uni(curx, cury); } init(); ll minsum = 0; for(int i = 0; i < e; i++){ int curx = find(minNode[i].x); int cury = find(minNode[i].y); minsum = (minNode[i].v * cnt[curx] % mod * cnt[cury] % mod + minsum) % mod; uni(curx, cury); } ll ans = (maxsum - minsum + mod) % mod; cout << ans << endl; } return 0; }
C++11(clang++ 3.9) 解法, 执行用时: 1416ms, 内存消耗: 38904K, 提交时间: 2019-12-07 16:39:04
#include <bits/stdc++.h> #define x first #define y second using namespace std; const int maxn=5e5+5; const int mod=1e9+7; typedef long long ll; typedef pair<int,int> pii; int n; pii a[maxn]; int b[maxn]; int f[maxn],num[maxn]; vector<int>G[maxn]; void init(){ for (int i=1;i<=n;i++) f[i]=i,num[i]=1; } int get(int w){ if (f[w]==w) return f[w]; return f[w]=get(f[w]); } ll solve(){ init(); ll cnt=0; for (int i=1;i<=n;i++){ int p=a[i].y; for (auto j:G[p]){ if (b[j]<b[p]){ get(j); get(p); cnt+=1ll*num[f[p]]*num[f[j]]%mod*a[i].x%mod; cnt%=mod; num[f[p]]+=num[f[j]]; f[f[j]]=f[p]; } } } return cnt; } ll ans; int t; int main(){ scanf("%d",&t); while (t--){ ans=0; scanf("%d",&n); for (int i=1;i<=n;i++) G[i].clear(); for (int i=1;i<=n;i++){ scanf("%d",&a[i].x); a[i].y=i; } sort(a+1,a+1+n); for (int i=1;i<=n;i++) b[a[i].y]=i; for (int i=1,u,v;i<n;i++){ scanf("%d%d",&u,&v); G[u].push_back(v); G[v].push_back(u); } ans+=solve(); reverse(a+1,a+1+n); for (int i=1;i<=n;i++) b[i]=-b[i]; ans-=solve(); cout << (ans%mod+mod)%mod << '\n'; } }