列表

详情


NC53634. 树的直径

描述

小a觉得求一棵树的直径太naive了,于是在树上划掉了几条边,想求所有连通块的直径。
给定一棵n个点的树,每个点有权值f[i]。
q次询问,每次询问给定k条边,求删除这k条边之后,剩下k+1个连通块的直径。你只需要输出这k+1个连通块直径的异或和即可。
记dis(u,v)为的简单路径上,所有边的边权和。(u,v)间距离DIS(u,v)的定义为:
某个连通块S的直径定义为:
询问间互不影响,即当前删除某些边对下一次询问没有影响。

输入描述

第一行两个整数n,q,表示树的点数及询问次数。
接下来一行有n个整数,表示每个点的点权f[i]。
接下来n-1行,每行三个整数u,v,w,表示u,v间有一条权值为w的无向边。
接下来2q行,每两行格式如:第一行一个整数,表示要删掉的边的数量;第二行2k个整数u_i,v_i,表示要删掉的第i条边是u_i,v_i之间的边。保证每对二元组(u_i,v_i)互不相同。

输出描述

输出q行。对于每次询问,输出一个整数表示k+1个连通块的直径的异或和。

示例1

输入:

7 2
1 2 3 1 2 3 4
1 2 2
2 3 1
2 4 3
2 5 2
1 6 2
1 7 2
1
1 2
2
2 4 1 7

输出:

3
11

原站题解

上次编辑到这里,代码来自缓存 点击恢复默认模板

C++14(g++5.4) 解法, 执行用时: 2607ms, 内存消耗: 114260K, 提交时间: 2019-11-18 23:43:32

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<ll,int> pli;
typedef pair<int,int> pii;
mt19937 gen(time(0));
const int maxn=200010;
const int maxm=1000010;
const int INF=0x7fffffff;
const int max_logn=20;
struct E
{
    int to,next,w;
}edge[maxm<<1];
int head[maxn],tol;
inline void Addedge(int u,int v,int w)
{
    edge[tol].to=v;edge[tol].w=w;edge[tol].next=head[u];head[u]=tol++;
}
int a[maxn],fa[maxn];
int id[maxn],vs[maxn<<1],depth[maxn<<1],mm[maxn<<1],dp[max_logn][maxn<<1],tot1;
int dfn[maxn],rdfn[maxn],sz[maxn],dis[maxn],tot2;
pii D[max_logn][maxn];
void dfs(int u,int d)
{
    id[u]=++tot1;
    vs[tot1]=u;
    depth[tot1]=d;
    dfn[u]=++tot2;
    rdfn[tot2]=u;
    sz[u]=1;
    for(int i=head[u];i!=-1;i=edge[i].next){
        int v=edge[i].to;
        if(v==fa[u]) continue;
        fa[v]=u;
        dis[v]=dis[u]+edge[i].w;
        dfs(v,d+1);
        sz[u]+=sz[v];
        vs[++tot1]=u;
        depth[tot1]=d;
    }
}
void init1(int n)
{
    for(int i=1;i<=n;++i){
        mm[i]=mm[i-1];
        if(i==1<<mm[i]+1) ++mm[i];
    }
    dfs(1,0);
    for(int i=1;i<=n;++i) dp[0][i]=i;
    for(int j=1;j<=mm[n];++j){
        for(int i=1;i+(1<<j)-1<=n;++i){
            int x=dp[j-1][i],y=dp[j-1][i+(1<<(j-1))];
            if(depth[x]<depth[y]) dp[j][i]=x;
            else dp[j][i]=y;
        }
    }
}
inline int Rmq(int l,int r)
{
    int k=mm[r-l+1],x=dp[k][l],y=dp[k][r-(1<<k)+1];
    return depth[x]<depth[y]?x:y;
}
inline int Lca(int u,int v)
{
    if(id[u]>id[v]) swap(u,v);
    return vs[Rmq(id[u],id[v])];
}
inline int Dis(int u,int v)
{
    if(u==v) return -INF;
    return dis[u]+dis[v]-(dis[Lca(u,v)]<<1)+a[u]+a[v];
}
inline int Dis(const pii&A)
{
    return Dis(A.first,A.second);
}
inline pii Max(const pii&A,const pii&B)
{
    return Dis(A.first,A.second)>=Dis(B.first,B.second)?A:B;
}
inline pii Union(const pii&A,const pii&B)
{
    if(A.first==0) return B;
    pii res=Max(Max(make_pair(A.first,B.first),make_pair(A.first,B.second)),Max(make_pair(A.second,B.first),make_pair(A.second,B.second)));
    return Max(res,Max(A,B));
}
void init2(int n)
{
    for(int i=1;i<=n;++i) D[0][i]=make_pair(rdfn[i],rdfn[i]);
    for(int j=1;j<=mm[n];++j){
        for(int i=1;i+(1<<j)-1<=n;++i){
            D[j][i]=Union(D[j-1][i],D[j-1][i+(1<<(j-1))]);
        }
    }
}
inline pii RmqD(int l,int r)
{
    int k=mm[r-l+1];
    return Union(D[k][l],D[k][r-(1<<k)+1]);
}
pii r[maxm];
int stk[maxm],ans;
vector<int>G[maxm];
void solve(int u)
{
    if(G[u].empty()){
        ans^=max(0,Dis(RmqD(r[u].first,r[u].second)));
        return;
    }
    pii res=make_pair(0,0);
    int l=r[u].first;
    reverse(G[u].begin(),G[u].end());
    for(int v:G[u]){
        solve(v);
        if(l<=r[v].first-1) res=Union(res,RmqD(l,r[v].first-1));
        l=r[v].second+1;
    }
    if(l<=r[u].second) res=Union(res,RmqD(l,r[u].second));
    ans^=max(0,Dis(res));
    G[u].clear();
}
inline int read()
{
    char ch=getchar();
    if(ch==EOF) return 0;
    int ans=0,f=1;
    while(ch<'0'||ch>'9') {if(ch=='-') f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9') ans=(ans<<3)+(ans<<1)+ch-'0',ch=getchar();
    return ans*f;
}
int main()
{
    #ifdef local
    freopen("a.in","r",stdin);
    #endif // local
    memset(head,-1,sizeof(head));
    int n=read(),q=read(),u,v,k,w;
    for(int i=1;i<=n;++i) a[i]=read();
    for(int i=1;i<n;++i){
        u=read(),v=read(),w=read();
        Addedge(u,v,w);
        Addedge(v,u,w);
    }
    init1(n+n-1);
    init2(n);
//    for(int i=1;i<=n;++i) cout<<dfn[i]<<endl;
    while(q--){
        k=read();
        for(int i=1;i<=k;++i){
            u=read(),v=read();
            if(u==fa[v]) swap(u,v);
            r[i]=make_pair(dfn[u],dfn[u]+sz[u]-1);
        }
        r[0]=make_pair(1,n);
        sort(r,r+k+1,[&](const pii&A,const pii&B){
                if(A.second!=B.second) return A.second<B.second;
                return A.first>B.first;
             });
        int top=0;
        for(int i=0;i<=k;++i){
            while(top&&r[i].first<=r[stk[top]].first){
                G[i].emplace_back(stk[top--]);
            }
            stk[++top]=i;
        }
        assert(top==1);
        ans=0;
        solve(stk[1]);
        printf("%d\n",ans);
    }
    return 0;
}

C++11(clang++ 3.9) 解法, 执行用时: 1322ms, 内存消耗: 101592K, 提交时间: 2019-11-15 22:06:26

#include <cstdio>
#include <cctype>
#include <vector>
#include <cstring>
#include <cassert>
#include <algorithm>
#define BIT 18
#define gc() getchar()
typedef long long LL;
const int N=2e5+5,M=1e6+5;
const int INF=2e9+1;

inline int read()
{
	int now=0,f=1;register char c=gc();
	for(;!isdigit(c);c=='-'&&(f=-1),c=gc());
	for(;isdigit(c);now=now*10+c-48,c=gc());
	return now*f;
}
namespace G
{
	int f[N],Enum,H[N],to[N<<1],nxt[N<<1],len[N<<1],dis[N],ref[N],L[N],R[N],pos[N],Log[N<<1],st[BIT+1][N<<1];
	inline void AE(int u,int v,int w)
	{
		to[++Enum]=v, nxt[Enum]=H[u], H[u]=Enum, len[Enum]=w;
		to[++Enum]=u, nxt[Enum]=H[v], H[v]=Enum, len[Enum]=w;
	}
	inline int LCA_dis(int l,int r)
	{
		l=pos[l], r=pos[r];
		if(l>r) std::swap(l,r);
		int k=Log[r-l+1];
		return std::min(st[k][l],st[k][r-(1<<k)+1])<<1;
	}
	inline LL Dis(int u,int v)
	{
		return u==v?-INF:(LL)f[u]+f[v]+dis[u]+dis[v]-LCA_dis(u,v);
	}
	void DFS(int x,int fa)
	{
		static int Index=0,tot=0;
		ref[L[x]=++Index]=x, st[0][pos[x]=++tot]=dis[x];
		for(int i=H[x],v; i; i=nxt[i])
			if((v=to[i])!=fa) dis[v]=dis[x]+len[i], DFS(v,x), st[0][++tot]=dis[x];
		R[x]=Index;
	}
	void Init(int n)
	{
		for(int i=1; i<=n; ++i) f[i]=read();
		for(int i=1,u,v; i<n; ++i) u=read(),v=read(),AE(u,v,read());
		DFS(1,1);
		for(int i=2,m=n<<1; i<=m; ++i) Log[i]=Log[i>>1]+1;
		for(int j=1,m=n<<1; j<=Log[m]; ++j)
			for(int t=1<<j-1,i=m-t; i; --i)
				st[j][i]=std::min(st[j-1][i],st[j-1][i+t]);
	}
}
struct Node
{
	int x,y; LL w;
	inline Node operator +(const Node &a)
	{
		int nx=x,ny=y; LL nw=w,tmp;
		if(a.w>nw) nx=a.x, ny=a.y, nw=a.w;
		if((tmp=G::Dis(x,a.x))>nw) nx=x, ny=a.x, nw=tmp;
		if((tmp=G::Dis(x,a.y))>nw) nx=x, ny=a.y, nw=tmp;
		if((tmp=G::Dis(y,a.x))>nw) nx=y, ny=a.x, nw=tmp;
		if((tmp=G::Dis(y,a.y))>nw) nx=y, ny=a.y, nw=tmp;
		return (Node){nx,ny,nw};
	}
};
namespace RMQ
{
	Node st[BIT][N];
	using G::Log;
	inline Node Query(int l,int r)
	{
		int k=Log[r-l+1];
		return st[k][l]+st[k][r-(1<<k)+1];
	}
	void Init(int n)
	{
		for(int i=1; i<=n; ++i) st[0][i]=(Node){G::ref[i],G::ref[i],-INF};
		for(int j=1; j<=Log[n]; ++j)
			for(int t=1<<j-1,i=n-t; i; --i)
				st[j][i]=st[j-1][i]+st[j-1][i+t];
	}
}

std::vector<int> vec[N];
struct Segment
{
	int l,r;
	inline bool operator <(const Segment &x)const
	{
		return r==x.r?l>x.l:r<x.r;
	}
}seg[M];

inline bool Include(int i,int j)//seg[i] includes seg[j]
{
	return seg[i].l<=seg[j].l&&seg[j].r<=seg[i].r;
}
LL Solve(int x)
{
	LL sum=0; Node sta=(Node){-1,-1,0};
	int las=seg[x].l;
	for(int i=vec[x].size()-1; ~i; --i)
	{
		int p=vec[x][i]; sum^=Solve(p);
		if(las<seg[p].l)
			if(sta.x==-1) sta=RMQ::Query(las,seg[p].l-1);
			else sta=sta+RMQ::Query(las,seg[p].l-1);
		las=seg[p].r+1;
	}
	if(las<=seg[x].r)
		if(sta.x==-1) sta=RMQ::Query(las,seg[x].r);
		else sta=sta+RMQ::Query(las,seg[x].r);
	return sum^std::max(0ll,sta.w);
}

int main()
{
	static int sk[N];

	int n=read(),Q=read(); G::Init(n), RMQ::Init(n);
	LL s=0;
	while(Q--)
	{
		int k=read(); s+=k;
		for(int i=1; i<=k; ++i)
		{
			int u=read(),v=read(); if(G::dis[u]<G::dis[v]) std::swap(u,v);
			seg[i]=(Segment){G::L[u],G::R[u]};
		}
		seg[++k]=(Segment){1,n}, std::sort(seg+1,seg+1+k);
		for(int i=1,top=0; i<=k; ++i)
		{
			while(top && Include(i,sk[top])) vec[i].push_back(sk[top--]);
			sk[++top]=i;
		}
		printf("%lld\n",Solve(k));
		for(int i=1; i<=k; ++i) std::vector<int>().swap(vec[i]);
	}
	assert(s<=1e6); fprintf(stderr,"sum_k:%d\n",s);

	return 0;
}

上一题