列表

详情


NC20620. 蚂蚁开会

描述

一颗n个节点的树,m次操作,有点权(该节点蚂蚁个数)和边权(相邻节点的距离),蚂蚁们喜欢开会,为蚁族的发展做出长远规划。
三种操作:
操作1:1 i x将节点i的点权修改为x。(1 <= i <= n; 1 <= x <= 100000)
操作2:2 i x将第i条边的边权修改为x。(1 <= i < n; 1 <= x <= 100000)
操作3:3 i 节点i发出开会指令,求树上所有蚂蚁到走到节点i的距离和。(1 <= i <= n)

输入描述

第1行2个整数 n, m;
第2行n个整数,第i个整数ai 表示节点i的蚂蚁数量;
接下来n-1行,每行两个整数,bi, ci。表示节点i+1与节点bi 之间有一条长为ci的边相连。
接下来m行表示m个操作。

输出描述

对于每个操作3,输出一个整数,表示树上所有蚂蚁到走到节点i的距离和。

示例1

输入:

3 4
1 1 1
1 1
2 1
3 1
1 2 5
2 1 5
3 3

输出:

3
11

原站题解

上次编辑到这里,代码来自缓存 点击恢复默认模板

C++14(g++5.4) 解法, 执行用时: 391ms, 内存消耗: 12896K, 提交时间: 2018-10-12 23:34:29

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cmath>
#define lson o<<1
#define rson o<<1|1
using namespace std;
typedef long long ll;
const int N=100010;
int tim,n,m,tt;
ll ans,sum,gg;
ll c[N*4],tr[N*4],lz[N*4],g[N];
int head[N],to[N*2],nxt[N*2],s[N*2],sz[N],fr[N],ch[N],dfn[N],fa[N],tp[N],id[N],a[N],ed[N];
inline int gi() {
    int x=0,o=1;
    char ch=getchar();
    while(ch<'0'||ch>'9') ch=='-'?o=-1:0,ch=getchar();
    while(ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar();
    return o*x;
}
inline void dfs1(int x) {
    sz[x]=1;
    for(int i=head[x];i;i=nxt[i]) {
    int y=to[i];
    dfs1(y),fa[y]=x,sz[x]+=sz[y],fr[y]=s[i];
    if(sz[y]>sz[ch[x]]) ch[x]=y;
    }
}
inline void dfs2(int x,int f) {
    tp[x]=f,dfn[x]=++tim,id[tim]=x;
    if(ch[x]) dfs2(ch[x],f);
    for(int i=head[x];i;i=nxt[i])
    if(to[i]!=ch[x]) dfs2(to[i],to[i]);
    ed[x]=tim;
}
inline void pushdown(int o) {
    int l=lson,r=rson;
    lz[l]+=lz[o],lz[r]+=lz[o];
    tr[l]+=lz[o]*c[l],tr[r]+=lz[o]*c[r],lz[o]=0;
}
inline void pushup(int o) {
    c[o]=c[lson]+c[rson];
    tr[o]=tr[lson]+tr[rson];
}
inline void update(int l,int r,int L,int R,int z,int o) {
    if(L<=l&&r<=R) {
    lz[o]+=z,tr[o]+=c[o]*z,sum+=c[o]*z;
    return;
    }
    if(lz[o]) pushdown(o);
    int mid=(l+r)>>1;
    if(L<=mid) update(l,mid,L,R,z,lson);
    if(R>mid) update(mid+1,r,L,R,z,rson);
    pushup(o);
}
inline ll query(int l,int r,int L,int R,int o) {
    if(L<=l&&r<=R) return sum+=c[o],tr[o];
    if(lz[o]) pushdown(o);
    int mid=(l+r)>>1;
    ll ret=0;
    if(L<=mid) ret=query(l,mid,L,R,lson);
    if(R>mid) ret+=query(mid+1,r,L,R,rson);
    return pushup(o),ret;
}
inline void insert(int l,int r,int p,int z,int o) {
    if(l==r) {
    tr[o]=tr[o]/c[o]*z,c[o]=z;
    return;
    }
    if(lz[o]) pushdown(o);
    int mid=(l+r)>>1;
    if(p<=mid) insert(l,mid,p,z,lson);
    else insert(mid+1,r,p,z,rson);
    pushup(o);
}
inline void build(int l,int r,int o) {
    if(l==r) {c[o]=fr[id[l]];return;}
    int mid=(l+r)>>1;
    build(l,mid,lson),build(mid+1,r,rson);
    pushup(o);
}
inline void ins(int x,int f) {
    while(x<=n) g[x]+=f,x+=(x&-x);
}
inline ll get(int x) {
    ll ret=0;
    while(x) ret+=g[x],x-=(x&-x);
    return ret;
}
inline void add(int x,int y) {
    sum=0,gg+=y,ins(dfn[x],y);
    while(x) {
    update(1,n,dfn[tp[x]],dfn[x],y,1);
    x=fa[tp[x]];
    }
    ans+=sum;
}
inline ll query(int x) {
    sum=0;
    ll ret=0;
    while(x) {
    ret+=query(1,n,dfn[tp[x]],dfn[x],1);
    x=fa[tp[x]];
    }
    return ans+sum*gg-2*ret;
}
int main() {
    cin>>n>>m;
    for(int i=1;i<=n;i++) a[i]=gi();
    for(int i=2,x;i<=n;i++)
    to[++tt]=i,nxt[tt]=head[x=gi()],s[tt]=gi(),head[x]=tt;
    dfs1(1),dfs2(1,1);
    build(1,n,1);
    for(int i=1;i<=n;i++) add(i,a[i]);
    while(m--) {
    int op=gi(),x=gi();
    if(op==3) printf("%lld\n",query(x));
    else {
        int y=gi();
        if(op==1) add(x,y-a[x]),a[x]=y;
        else {
        ans+=(y-fr[x+1])*(get(ed[x+1])-get(dfn[x+1]-1));
        insert(1,n,dfn[x+1],y,1),fr[x+1]=y;
        }
    }
    }
    return 0;
}

C++11(clang++ 3.9) 解法, 执行用时: 436ms, 内存消耗: 13004K, 提交时间: 2019-03-07 17:00:21

#include<bits/stdc++.h>
#define lson o<<1
#define rson o<<1|1
using namespace std;
typedef long long ll;
const int N=100010;
int tim,n,m,tt, head[N],to[N*2],nxt[N*2],s[N*2],sz[N],fr[N],ch[N],dfn[N],fa[N],tp[N],id[N],a[N],ed[N];
ll ans,sum,gg;
ll c[N*4],tr[N*4],lz[N*4],g[N];
inline int gi() {
    int x=0,o=1;
    char ch=getchar();
    while(ch<'0'||ch>'9') ch=='-'?o=-1:0,ch=getchar();
    while(ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar();
    return o*x;
}
inline void dfs1(int x) {
    sz[x]=1;
    for(int i=head[x];i;i=nxt[i]) {
		int y=to[i];
		dfs1(y),fa[y]=x,sz[x]+=sz[y],fr[y]=s[i];
		if(sz[y]>sz[ch[x]]) ch[x]=y;
    }
}
inline void dfs2(int x,int f) {
    tp[x]=f,dfn[x]=++tim,id[tim]=x;
    if(ch[x]) dfs2(ch[x],f);
    for(int i=head[x];i;i=nxt[i])
		if(to[i]!=ch[x]) dfs2(to[i],to[i]);
    ed[x]=tim;
}
inline void pushdown(int o) {
    int l=lson,r=rson;
    lz[l]+=lz[o],lz[r]+=lz[o];
    tr[l]+=lz[o]*c[l],tr[r]+=lz[o]*c[r],lz[o]=0;
}
inline void pushup(int o) {
    c[o]=c[lson]+c[rson];
    tr[o]=tr[lson]+tr[rson];
}
inline void build(int l,int r,int o) {
    if(l==r) {c[o]=fr[id[l]];return;}
    int mid=(l+r)>>1;
    build(l,mid,lson),build(mid+1,r,rson);
    pushup(o);
}
inline void update(int l,int r,int L,int R,int z,int o) {
    if(L<=l&&r<=R) {
		lz[o]+=z,tr[o]+=c[o]*z,sum+=c[o]*z;
		return;
    }
    if(lz[o]) pushdown(o);
    int mid=(l+r)>>1;
    if(L<=mid) update(l,mid,L,R,z,lson);
    if(R>mid) update(mid+1,r,L,R,z,rson);
    pushup(o);
}
inline ll query(int l,int r,int L,int R,int o) {
    if(L<=l&&r<=R) return sum+=c[o],tr[o];
    if(lz[o]) pushdown(o);
    int mid=(l+r)>>1;
    ll ret=0;
    if(L<=mid) ret=query(l,mid,L,R,lson);
    if(R>mid) ret+=query(mid+1,r,L,R,rson);
    return ret;
}
inline void insert(int l,int r,int p,int z,int o) {
    if(l==r) {
		tr[o]=tr[o]/c[o]*z,c[o]=z;
		return;
    }
    if(lz[o]) pushdown(o);
    int mid=(l+r)>>1;
    if(p<=mid) insert(l,mid,p,z,lson);
    else insert(mid+1,r,p,z,rson);
    pushup(o);
}
inline void ins(int x,int f) {
    while(x<=n) g[x]+=f,x+=(x&-x);
}
inline ll get(int x) {
    ll ret=0;
    while(x) ret+=g[x],x-=(x&-x);
    return ret;
}
inline void add(int x,int y) {
	sum=0, gg+=y,ins(dfn[x],y);
    while(x) {
		update(1,n,dfn[tp[x]],dfn[x],y,1);
		x=fa[tp[x]];
    }
    ans+=sum;
}
inline ll query(int x) {
    sum=0;
    ll ret=0;
    while(x) {
		ret+=query(1,n,dfn[tp[x]],dfn[x],1);
		x=fa[tp[x]];
    }
    return ans+sum*gg-2*ret;
}
int main() {
    cin>>n>>m;
    for(int i=1;i<=n;i++) a[i]=gi();
    for(int i=2,x;i<=n;i++)
		to[++tt]=i,nxt[tt]=head[x=gi()],s[tt]=gi(),head[x]=tt;
    dfs1(1),dfs2(1,1);
    build(1,n,1);
    for(int i=1;i<=n;i++) add(i,a[i]);
    while(m--) {
		int op=gi(),x=gi();
		if(op==3) printf("%lld\n",query(x));
		else {
			int y=gi();
			if(op==1) add(x,y-a[x]),a[x]=y;
			else{
				ans+=(y-fr[x+1])*(get(ed[x+1])-get(dfn[x+1]-1));
				insert(1,n,dfn[x+1],y,1),fr[x+1]=y;
			}
		}
    }
    return 0;
}

上一题