列表

详情


NC17069. 空间隧道

描述

    袋鼠先生的王国有n座城市,城市之间有n-1条长度为1的双向道路,任意两座城市都可以通过这些双向道路互通。袋鼠先生王国的商业十分发达,对于任意两座不同的城市A, B,都有一支从A通往B的商队,以及一支从B通往A的商队,商队一定会沿着最短路径行进。
    某一天袋鼠先生王国突然出现了m条空间隧道,每条空间隧道连接两个不同的城市。当一支商队行进到某条空间隧道的一端时,如果隧道的另一端也在这支商队的线路上,那商队会直接穿越隧道到达另一端并继续前进。
    不幸的是,这些空间隧道其实是敌对势力的阴谋,某一天敌对势力对空间隧道发动了奇袭,所有穿过了空间隧道的商队都被俘获了。现在袋鼠先生希望你帮他统计有多少商队还存留着。

输入描述

第一行两个整数n, m(2≤n≤100,000,1≤m≤100,000)代表城市数量以及空间隧道的数量。
接下来n-1行每行两个不同的整数A, B代表城市A, B之间存在一条双向道路。
接下来m行每行两个不同的整数x, y代表城市x, y之间出现了一条空间隧道。

输出描述

输出一个整数表示有多少商队还存留着。

示例1

输入:

5 2
1 2
2 3
1 4
4 5
1 3
4 5

输出:

8

原站题解

上次编辑到这里,代码来自缓存 点击恢复默认模板

C++14(g++5.4) 解法, 执行用时: 252ms, 内存消耗: 35992K, 提交时间: 2018-07-07 22:46:08

#include<bits/stdc++.h>
#define rep(i,a,b) for (int i=(a); i<=(b); i++)
#define per(i,a,b) for (int i=(a); i>=(b); i--)
using namespace std;
typedef long long LL;

const int maxn = 200005;
struct seg { int l, r, op; } x; vector<seg> q[maxn];
int T[maxn<<2], S[maxn<<2], fa[maxn][18];
int in[maxn], out[maxn], dep[maxn];
int clk, n, m, a, b, c;
vector<int> e[maxn];
LL ans;

void dfs(int u, int pa) {
	in[u] = ++clk; fa[u][0] = pa;
	rep (j, 1, 17)
		fa[u][j] = fa[fa[u][j-1]][j-1];
	for (auto v : e[u])
		if (v != pa) dep[v] = dep[u] + 1, dfs(v, u);
	out[u] = clk;
}

inline int getlca(int a, int b) {
	if (dep[a] < dep[b]) swap(a, b);
	per (j, 17, 0)
		if (dep[fa[a][j]] >= dep[b])
			a = fa[a][j];
	if (a == b) return a;
	per (j, 17, 0)
		if (fa[a][j] != fa[b][j])
			a = fa[a][j], b = fa[b][j];
	return fa[a][0];
}

#define lc (o << 1)
#define rc (o << 1 | 1)
#define mid ((l + r) >> 1)
void update(int o, int l, int r, int x, int y, int z) {
	if (l == x && y == r) {
		T[o] += z;
		if (T[o]) S[o] = r - l + 1;
		else S[o] = S[lc] + S[rc];
		return;
	}
	if (x <= mid) update(lc, l, mid, x, min(y, mid), z);
	if (mid < y) update(rc, mid+1, r, max(mid+1, x), y, z);
	if (T[o]) S[o] = r - l + 1;
	else S[o] = S[lc] + S[rc];
}

LL solve() {
	LL res = 0;
	rep (i, 1, n) {
		while (!q[i].empty()) {
			x = q[i].back(); q[i].pop_back();
			if (x.l <= x.r)
				update(1, 1, n, x.l, x.r, x.op);
		}
		res += S[1];
	}
	return res;
}

int main() {
	scanf("%d%d", &n, &m);
	rep (i, 1, n - 1) {
		scanf("%d%d", &a, &b);
		e[a].push_back(b);
		e[b].push_back(a);
	}
	dep[1] = 1; dfs(1, 0);
	rep (i, 1, m) {
		scanf("%d%d", &a, &b);
		c = getlca(a, b);
		if (c == b) swap(a, b);
		if (c == a) {
			a = b;
			per (j, 17, 0)
				if (dep[fa[a][j]] > dep[c])
					a = fa[a][j];
			//[1,in[a]-1] or [out[a]+1,n], [in[b],out[b]]
			q[1].push_back((seg){in[b], out[b], 1});
			q[in[a]].push_back((seg){in[b], out[b], -1});
			q[out[a]+1].push_back((seg){in[b], out[b], 1});
			
			if (1 < in[a]) {
				q[in[b]].push_back((seg){1, in[a]-1, 1});
				q[out[b]+1].push_back((seg){1, in[a]-1, -1});
			}
			if (out[a] < n) {
				q[in[b]].push_back((seg){out[a]+1, n, 1});
				q[out[b]+1].push_back((seg){out[a]+1, n, -1});
			}
		}
		else {
			//[in[a],out[a]], [in[b],out[b]]
			q[in[a]].push_back((seg){in[b], out[b], 1});
			q[out[a]+1].push_back((seg){in[b], out[b], -1});
			q[in[b]].push_back((seg){in[a], out[a], 1});
			q[out[b]+1].push_back((seg){in[a], out[a], -1});
		}
	}
	cout << 1LL * n * (n - 1) - solve();
	return 0;
}

C++11(clang++ 3.9) 解法, 执行用时: 520ms, 内存消耗: 47708K, 提交时间: 2018-07-17 19:03:48

#include<vector>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=2e5+7;
vector<int>G[N];
int dfsL[N],dfsR[N],deep[N],fa[N][19],tot,n,m;
int sum[N<<2],cnt[N<<2];
void dfs(int x,int f,int depth){
	deep[x]=depth;
	dfsL[x]=++tot;
	for(int i=0;i<(int)G[x].size();++i){
		int v=G[x][i];
		if(v==f) continue;
		fa[v][0]=x;
		dfs(v,x,depth+1);
	}
	dfsR[x]=tot;
}
void up(int x,int l,int r){
	if(cnt[x]>0) sum[x]=r-l+1;
	else sum[x]=sum[x<<1]+sum[x<<1|1]; 
}
void update(int x,int l,int r,int tl,int tr,int v){
	if(tl<=l&&tr>=r) {
		cnt[x]+=v;
		up(x,l,r);
		return;
	}
	int m=(l+r)>>1;
	if(tl<=m) update(x<<1,l,m,tl,tr,v);
	if(tr>m) update(x<<1|1,m+1,r,tl,tr,v);
	up(x,l,r);
}
bool Ju(int x,int y){
	return dfsL[x]<=dfsL[y]&&dfsR[x]>=dfsR[y];
}
int findx(int x,int y){
	int dt=deep[y]-deep[x]-1;
	for(int i=0;i<=18;++i) if((1<<i)&dt) y=fa[y][i];
	return y;
}
vector<pair<int,int> >add[N<<1],del[N<<1];
void mdy(int x1,int y1,int x2,int y2){
	add[x1].push_back(make_pair(x2,y2));
	del[y1].push_back(make_pair(x2,y2));
    add[x2].push_back(make_pair(x1,y1));
    del[y2].push_back(make_pair(x1,y1));
}
int main(){
	int x,y;
	scanf("%d%d",&n,&m);
	for(int i=1;i<n;++i) {
		scanf("%d%d",&x,&y);
		G[x].push_back(y);
		G[y].push_back(x);
	}
    dfs(1,0,0);
	for(int i=1;i<=18;++i) for(int j=1;j<=n;++j) fa[j][i]=fa[fa[j][i-1]][i-1];	
    for(int i=1;i<=m;++i){
    	scanf("%d%d",&x,&y);
    	if(deep[x]>deep[y]) swap(x,y);
    	if(Ju(x,y)) {
    		int p=findx(x,y);
    		if(dfsL[p]>1) mdy(1,dfsL[p]-1,dfsL[y],dfsR[y]);
    		if(dfsR[p]<n) mdy(dfsR[p]+1,n,dfsL[y],dfsR[y]);
    		//printf("%d %d %d %d\n",1,dfsL[p]-1,dfsL[y],dfsR[y]);
    		//printf("%d %d %d %d\n",dfsR[p]+1,n,dfsL[y],dfsR[y]);
		}
		else {
			//printf("%d %d %d %d\n",dfsL[x],dfsR[x],dfsL[y],dfsR[y]);
			mdy(dfsL[x],dfsR[x],dfsL[y],dfsR[y]);
		}
	}
	long long ans=1LL*n*(n-1);
	for(int i=1;i<=n;++i) {
		for(int j=0;j<(int)add[i].size();++j) {
			update(1,1,n,add[i][j].first,add[i][j].second,1);
		}
		ans-=sum[1];
		//printf("%d   %d\n",i,sum[1]);
		for(int j=0;j<(int)del[i].size();++j) {
			update(1,1,n,del[i][j].first,del[i][j].second,-1);
		}
	}
	printf("%lld\n",ans);
}

上一题