NC244327. 树剖分剖树
描述
输入描述
第一行,包含两个整数 和 。
第二行,共 个整数,表示每个结点的点权。
接下来的 行,每行包含两个整数 ,代表结点 和结点 之间有一条边相连。
输出描述
输出一行,包含一个整数,如题意所示。
示例1
输入:
5 2 1 2 1 2 5 1 2 2 3 3 4 4 5
输出:
3
示例2
输入:
5 3 1 2 2 2 3 1 2 2 3 3 4 4 5
输出:
0
C++(g++ 7.5.0) 解法, 执行用时: 1137ms, 内存消耗: 154348K, 提交时间: 2022-10-22 08:31:02
#include <bits/stdc++.h> using namespace std; typedef long long LL; typedef unsigned long long ULL; const int INF = 0x3f3f3f3f; const LL mod = 1e9 + 7; const int N = 300005; ULL a[N], b[N], wc; vector<int> G[N]; map<ULL, int> mp[N]; LL ans; void dfs(int u, int fa, ULL now) { mp[u][now]++; if (a[u] == wc) ans++; for (auto v : G[u]) { if (v == fa) continue; dfs(v, u, now + a[v]); if (mp[v].size() > mp[u].size()) swap(mp[u], mp[v]); for (auto [x, y] : mp[v]) { ULL t = wc + 2 * now - x - a[u]; if (mp[u].find(t) != mp[u].end()) ans += (LL)mp[u][t] * y; } for (auto [x, y] : mp[v]) { mp[u][x] += y; } } } int main() { random_device rd; mt19937_64 g(rd()); int n, m; scanf("%d%d", &n, &m); //if (m == 1) return puts("0"), 0; for (int i = 1; i <= n; i++) { b[i] = g(); if (i <= m) wc += b[i]; } for (int i = 1, x; i <= n; i++) { scanf("%d", &x); a[i] = b[x]; } for (int i = 1; i < n; i++) { int u, v; scanf("%d%d", &u, &v); G[u].push_back(v); G[v].push_back(u); } dfs(1, 0, 0); printf("%lld\n", ans); return 0; }
C++(clang++ 11.0.1) 解法, 执行用时: 1097ms, 内存消耗: 154332K, 提交时间: 2022-11-11 09:27:50
#include<bits/stdc++.h> using namespace std; typedef long long ll; typedef unsigned long long ull; typedef pair<int,int> PII; const int N = 3e5 + 10; // 多重集 哈希 int n, k; vector<int> G[N]; ull a[N], w[N], sum, now, ans; map<ull, int> mp[N]; void dfs(int u, int fa, ull now) { if (w[u] == sum) ans ++ ; mp[u][now] ++ ; for (auto v: G[u]) { if (v == fa) continue; dfs(v, u, now + w[v]); if (mp[u].size() < mp[v].size()) swap(mp[u], mp[v]); for (auto [x, y]: mp[v]) { ull t = sum + 2 * now - w[u] - x; if (mp[u].find(t) != mp[u].end()) ans += 1ull * y * mp[u][t]; } for (auto [x, y]: mp[v]) mp[u][x] += y; // 合并 } } int main() { cin >> n >> k; random_device rd; mt19937_64 g(rd()); for (int i = 1; i <= n; i ++ ) { scanf("%lld", &w[i]); a[i] = g(); } for (int i = 1; i <= n; i ++ ) w[i] = a[w[i]]; for (int i = 1; i <= k; i ++ ) sum += a[i]; for (int i = 1; i <= n - 1; i ++ ) { int u, v; scanf("%d%d", &u, &v); G[u].push_back(v); G[v].push_back(u); } dfs(1, 0, w[1]); cout << ans; return 0; }