NC17069. 空间隧道
描述
输入描述
第一行两个整数n, m(2≤n≤100,000,1≤m≤100,000)代表城市数量以及空间隧道的数量。
接下来n-1行每行两个不同的整数A, B代表城市A, B之间存在一条双向道路。
接下来m行每行两个不同的整数x, y代表城市x, y之间出现了一条空间隧道。
输出描述
输出一个整数表示有多少商队还存留着。
示例1
输入:
5 2 1 2 2 3 1 4 4 5 1 3 4 5
输出:
8
C++14(g++5.4) 解法, 执行用时: 252ms, 内存消耗: 35992K, 提交时间: 2018-07-07 22:46:08
#include<bits/stdc++.h> #define rep(i,a,b) for (int i=(a); i<=(b); i++) #define per(i,a,b) for (int i=(a); i>=(b); i--) using namespace std; typedef long long LL; const int maxn = 200005; struct seg { int l, r, op; } x; vector<seg> q[maxn]; int T[maxn<<2], S[maxn<<2], fa[maxn][18]; int in[maxn], out[maxn], dep[maxn]; int clk, n, m, a, b, c; vector<int> e[maxn]; LL ans; void dfs(int u, int pa) { in[u] = ++clk; fa[u][0] = pa; rep (j, 1, 17) fa[u][j] = fa[fa[u][j-1]][j-1]; for (auto v : e[u]) if (v != pa) dep[v] = dep[u] + 1, dfs(v, u); out[u] = clk; } inline int getlca(int a, int b) { if (dep[a] < dep[b]) swap(a, b); per (j, 17, 0) if (dep[fa[a][j]] >= dep[b]) a = fa[a][j]; if (a == b) return a; per (j, 17, 0) if (fa[a][j] != fa[b][j]) a = fa[a][j], b = fa[b][j]; return fa[a][0]; } #define lc (o << 1) #define rc (o << 1 | 1) #define mid ((l + r) >> 1) void update(int o, int l, int r, int x, int y, int z) { if (l == x && y == r) { T[o] += z; if (T[o]) S[o] = r - l + 1; else S[o] = S[lc] + S[rc]; return; } if (x <= mid) update(lc, l, mid, x, min(y, mid), z); if (mid < y) update(rc, mid+1, r, max(mid+1, x), y, z); if (T[o]) S[o] = r - l + 1; else S[o] = S[lc] + S[rc]; } LL solve() { LL res = 0; rep (i, 1, n) { while (!q[i].empty()) { x = q[i].back(); q[i].pop_back(); if (x.l <= x.r) update(1, 1, n, x.l, x.r, x.op); } res += S[1]; } return res; } int main() { scanf("%d%d", &n, &m); rep (i, 1, n - 1) { scanf("%d%d", &a, &b); e[a].push_back(b); e[b].push_back(a); } dep[1] = 1; dfs(1, 0); rep (i, 1, m) { scanf("%d%d", &a, &b); c = getlca(a, b); if (c == b) swap(a, b); if (c == a) { a = b; per (j, 17, 0) if (dep[fa[a][j]] > dep[c]) a = fa[a][j]; //[1,in[a]-1] or [out[a]+1,n], [in[b],out[b]] q[1].push_back((seg){in[b], out[b], 1}); q[in[a]].push_back((seg){in[b], out[b], -1}); q[out[a]+1].push_back((seg){in[b], out[b], 1}); if (1 < in[a]) { q[in[b]].push_back((seg){1, in[a]-1, 1}); q[out[b]+1].push_back((seg){1, in[a]-1, -1}); } if (out[a] < n) { q[in[b]].push_back((seg){out[a]+1, n, 1}); q[out[b]+1].push_back((seg){out[a]+1, n, -1}); } } else { //[in[a],out[a]], [in[b],out[b]] q[in[a]].push_back((seg){in[b], out[b], 1}); q[out[a]+1].push_back((seg){in[b], out[b], -1}); q[in[b]].push_back((seg){in[a], out[a], 1}); q[out[b]+1].push_back((seg){in[a], out[a], -1}); } } cout << 1LL * n * (n - 1) - solve(); return 0; }
C++11(clang++ 3.9) 解法, 执行用时: 520ms, 内存消耗: 47708K, 提交时间: 2018-07-17 19:03:48
#include<vector> #include<cstdio> #include<cstring> #include<algorithm> using namespace std; const int N=2e5+7; vector<int>G[N]; int dfsL[N],dfsR[N],deep[N],fa[N][19],tot,n,m; int sum[N<<2],cnt[N<<2]; void dfs(int x,int f,int depth){ deep[x]=depth; dfsL[x]=++tot; for(int i=0;i<(int)G[x].size();++i){ int v=G[x][i]; if(v==f) continue; fa[v][0]=x; dfs(v,x,depth+1); } dfsR[x]=tot; } void up(int x,int l,int r){ if(cnt[x]>0) sum[x]=r-l+1; else sum[x]=sum[x<<1]+sum[x<<1|1]; } void update(int x,int l,int r,int tl,int tr,int v){ if(tl<=l&&tr>=r) { cnt[x]+=v; up(x,l,r); return; } int m=(l+r)>>1; if(tl<=m) update(x<<1,l,m,tl,tr,v); if(tr>m) update(x<<1|1,m+1,r,tl,tr,v); up(x,l,r); } bool Ju(int x,int y){ return dfsL[x]<=dfsL[y]&&dfsR[x]>=dfsR[y]; } int findx(int x,int y){ int dt=deep[y]-deep[x]-1; for(int i=0;i<=18;++i) if((1<<i)&dt) y=fa[y][i]; return y; } vector<pair<int,int> >add[N<<1],del[N<<1]; void mdy(int x1,int y1,int x2,int y2){ add[x1].push_back(make_pair(x2,y2)); del[y1].push_back(make_pair(x2,y2)); add[x2].push_back(make_pair(x1,y1)); del[y2].push_back(make_pair(x1,y1)); } int main(){ int x,y; scanf("%d%d",&n,&m); for(int i=1;i<n;++i) { scanf("%d%d",&x,&y); G[x].push_back(y); G[y].push_back(x); } dfs(1,0,0); for(int i=1;i<=18;++i) for(int j=1;j<=n;++j) fa[j][i]=fa[fa[j][i-1]][i-1]; for(int i=1;i<=m;++i){ scanf("%d%d",&x,&y); if(deep[x]>deep[y]) swap(x,y); if(Ju(x,y)) { int p=findx(x,y); if(dfsL[p]>1) mdy(1,dfsL[p]-1,dfsL[y],dfsR[y]); if(dfsR[p]<n) mdy(dfsR[p]+1,n,dfsL[y],dfsR[y]); //printf("%d %d %d %d\n",1,dfsL[p]-1,dfsL[y],dfsR[y]); //printf("%d %d %d %d\n",dfsR[p]+1,n,dfsL[y],dfsR[y]); } else { //printf("%d %d %d %d\n",dfsL[x],dfsR[x],dfsL[y],dfsR[y]); mdy(dfsL[x],dfsR[x],dfsL[y],dfsR[y]); } } long long ans=1LL*n*(n-1); for(int i=1;i<=n;++i) { for(int j=0;j<(int)add[i].size();++j) { update(1,1,n,add[i][j].first,add[i][j].second,1); } ans-=sum[1]; //printf("%d %d\n",i,sum[1]); for(int j=0;j<(int)del[i].size();++j) { update(1,1,n,del[i][j].first,del[i][j].second,-1); } } printf("%lld\n",ans); }