列表

详情


NC244327. 树剖分剖树

描述

Z有一颗 n 个结点的树 ,结点的编号为 , 结点的点权为 w_i
Z现在给定了一个正整数 k,他希望你告诉他树上有多少个这样的二元组 (u, v) , 满足  且 u 到 v 的最短路径上的点权恰好是一个 的排列。

输入描述

第一行,包含两个整数 n
第二行,共 n 个整数,表示每个结点的点权。
接下来的 n-1 行,每行包含两个整数 ,代表结点 u_i 和结点 v_i 之间有一条边相连。

输出描述

输出一行,包含一个整数,如题意所示。

示例1

输入:

5 2
1 2 1 2 5
1 2
2 3
3 4
4 5

输出:

3

示例2

输入:

5 3
1 2 2 2 3
1 2
2 3
3 4
4 5

输出:

0

原站题解

上次编辑到这里,代码来自缓存 点击恢复默认模板

C++(g++ 7.5.0) 解法, 执行用时: 1137ms, 内存消耗: 154348K, 提交时间: 2022-10-22 08:31:02

#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
typedef unsigned long long ULL;
const int INF = 0x3f3f3f3f;
const LL mod = 1e9 + 7;
const int N = 300005;

ULL a[N], b[N], wc;
vector<int> G[N];
map<ULL, int> mp[N];
LL ans;
void dfs(int u, int fa, ULL now) {
    mp[u][now]++;
    if (a[u] == wc) ans++;
    for (auto v : G[u]) {
        if (v == fa) continue;
        dfs(v, u, now + a[v]);
        if (mp[v].size() > mp[u].size()) swap(mp[u], mp[v]);
        for (auto [x, y] : mp[v]) {
            ULL t = wc + 2 * now - x - a[u];
            if (mp[u].find(t) != mp[u].end()) ans += (LL)mp[u][t] * y;
        }
        for (auto [x, y] : mp[v]) {
            mp[u][x] += y;
        }
    }
}
int main() {
    random_device rd;
    mt19937_64 g(rd());
    int n, m;
    scanf("%d%d", &n, &m);
    //if (m == 1) return puts("0"), 0;
    for (int i = 1; i <= n; i++) {
        b[i] = g();
        if (i <= m) wc += b[i];
    }
    for (int i = 1, x; i <= n; i++) {
        scanf("%d", &x);
        a[i] = b[x];
    }
    for (int i = 1; i < n; i++) {
        int u, v;
        scanf("%d%d", &u, &v);
        G[u].push_back(v);
        G[v].push_back(u);
    }
    dfs(1, 0, 0);
    printf("%lld\n", ans);
    return 0;
}

C++(clang++ 11.0.1) 解法, 执行用时: 1097ms, 内存消耗: 154332K, 提交时间: 2022-11-11 09:27:50

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> PII;
const int N = 3e5 + 10;

// 多重集 哈希 

int n, k;
vector<int> G[N];
ull a[N], w[N], sum, now, ans;
map<ull, int> mp[N];

void dfs(int u, int fa, ull now)
{
	if (w[u] == sum) ans ++ ;
	mp[u][now] ++ ;
	for (auto v: G[u])
	{
		if (v == fa) continue;
		dfs(v, u, now + w[v]);
		
		if (mp[u].size() < mp[v].size()) swap(mp[u], mp[v]);
		
		for (auto [x, y]: mp[v])
		{
			ull t = sum + 2 * now - w[u] - x;
			if (mp[u].find(t) != mp[u].end())
				ans += 1ull * y * mp[u][t];
		}
		
		for (auto [x, y]: mp[v])
			mp[u][x] += y; // 合并
	}
}

int main()
{
	cin >> n >> k;
	random_device rd;
	mt19937_64 g(rd());
	for (int i = 1; i <= n; i ++ )
	{
		scanf("%lld", &w[i]);
		a[i] = g();
	}
	for (int i = 1; i <= n; i ++ )
		w[i] = a[w[i]];
	
	for (int i = 1; i <= k; i ++ )
		sum += a[i];
	 
	for (int i = 1; i <= n - 1; i ++ )
	{
		int u, v;
		scanf("%d%d", &u, &v);
		G[u].push_back(v);
		G[v].push_back(u);
	}
	dfs(1, 0, w[1]);
 	cout << ans;
    return 0;
}

上一题