列表

详情


NC50485. 树上操作

描述

有一棵点数为N的树,以点1为根,且树有点权。然后有M个操作,分为三种:
  1. 把某个节点x的点权增加a。
  2. 把某个节点x为根的子树中所有点的点权都增加a。
  3. 询问某个节点x到根的路径中所有点的点权和。

输入描述

第一行包含两个整数N,M。表示点数和操作数。
接下来一行N个整数,表示树中节点的初始权值。
接下来N-1 行每行两个正整数,表示该树中存在一条边
再接下来M行,每行分别表示一次操作。其中第一个数表示该操作的种类(1-3),之后接这个操作的参数(x或者xa)。

输出描述

对于每个询问操作,输出该询问的答案。答案之间用换行隔开。

示例1

输入:

5 5
1 2 3 4 5
1 2
1 4
2 3
2 5
3 3
1 2 1
3 5
2 1 2
3 3

输出:

6
9
13

原站题解

上次编辑到这里,代码来自缓存 点击恢复默认模板

C++14(g++5.4) 解法, 执行用时: 496ms, 内存消耗: 17128K, 提交时间: 2020-01-28 16:38:17

#include<bits/stdc++.h>
using namespace std;

const int MAXN = 1e5+1;
long long sum[MAXN*4];
long long delay_sum[MAXN*4];
void add(int o, int l, int r, int L, int R, long long v);
void push_down(int o, int l, int mid, int r) {
  if (delay_sum[o]==0) return;
  add(o*2,  l,     mid, l,   mid, delay_sum[o]);
  add(o*2+1,mid+1, r,   mid+1, r, delay_sum[o]);
  delay_sum[o] = 0;
}
void add(int o, int l, int r, int L, int R, long long v) {
  if (R < l || r < L) return;
  if (L <= l && r <= R) {
    sum[o] += (r-l+1)*v;
    delay_sum[o] += v;
    return;
  }
  int mid = (l+r)/2;
  push_down(o, l, mid, r);
  add(o*2,  l,  mid, L, R, v);
  add(o*2+1,mid+1,r, L, R, v);
  sum[o] = sum[o*2] + sum[o*2+1];
}
long long query_sum(int o, int l, int r, int L, int R) {
  // cout << "query_sum" << o << l << r << "|";
  // cout << L << " " << R << " " << sum[o] << " " << delay_sum[o] << endl;
  if (R < l || r < L) return 0;
  if (L <= l && r <= R) return sum[o];
  int mid = (l + r) / 2;
  push_down(o, l, mid, r);
  return query_sum(o*2, l, mid, L, R)+query_sum(o*2+1, mid+1, r, L, R);
}

int N;
vector<int> G[MAXN];
int fa[MAXN];
int dep[MAXN];
int siz[MAXN];
int top[MAXN];
int son[MAXN]; 
int dfn[MAXN];
int rnk[MAXN];
int dfe[MAXN];
int clk;
int w[MAXN];
void dfs1(int s, int f) {
  dep[s] = dep[f]+1;
  siz[s] = 1;
  fa[s] = f;
  for (auto t: G[s]) if (t != f) {
    dfs1(t, s); 
    siz[s] += siz[t];
    if (siz[son[s]] < siz[t]) son[s] = t;
  }
}
void dfs2(int s, int f, int tp) {
  top[s] = tp;
  dfn[s] = ++clk;
  rnk[clk] = s;
  if (son[s]) dfs2(son[s], s, tp);
  for (auto t: G[s]) if (t != f && t != son[s]) dfs2(t, s, t);
  dfe[s] = clk;
}


int main() {
  int M; cin >> N >> M;
  for (int i = 1; i <= N; i++) cin >> w[i];
  for (int i = 1; i < N; i++) {
    int u, v; cin >> u >> v;
    G[u].push_back(v);
    G[v].push_back(u);
  }
  dfs1(1, 0);
  dfs2(1, 0, 1);
  for (int i = 1; i <= N; i++) add(1, 1, N, dfn[i], dfn[i], w[i]);
  // cout << "Init Done" << endl;
  while (M--) {
    int op; cin >> op;
    if (op == 1) {
      int x, a; cin >> x >> a;
      add(1, 1, N, dfn[x], dfn[x], a);
    } else if (op == 2) {
      int x, a; cin >> x >> a;
      add(1, 1, N, dfn[x], dfe[x], a);
    } else if (op == 3) {
      int x; cin >> x;
      long long ans = 0;
      while (top[x] != 1) {
        // cout << x << " " << top[x] << endl;
        ans += query_sum(1, 1, N, dfn[top[x]], dfn[x]);
        x = fa[top[x]];
      }
      // cout << x << endl;
      ans += query_sum(1, 1, N, 1, dfn[x]);
      cout << ans << endl;
    }
  }
}



C++(clang++11) 解法, 执行用时: 150ms, 内存消耗: 11768K, 提交时间: 2021-03-31 17:14:43

#include<iostream>
#include<algorithm>
#include<cmath>
#include<cstdio>
#include<string>
#include<queue>
#include<vector>
#define ll long long

using namespace std;
const int INF = 0x3f3f3f3f;
const int maxn = 1e5+5;
vector<int>G[maxn];
int N,M,in[maxn],out[maxn],deep[maxn],a[maxn];
ll c[maxn][2];
int lowbit(int x) { return x&(-x);}
void update(int x,ll v,int y) {
	while(x<=N) {
		c[x][y]+=v;
		x+=lowbit(x);
	}
}
ll getsum(int x,int y) {
	ll sum = 0;
	while(x>=1) {
		sum+=c[x][y];
		x-=lowbit(x);
	} 
	return sum;
}
ll query(int x) {
	return getsum(in[x],0)+deep[x]*getsum(in[x],1);
}int cnt = 0;
void dfs(int u,int pre) {
	in[u] = ++cnt;
	deep[u] = deep[pre]+1;
	for(int i = 0; i < G[u].size(); i++) {
		int v = G[u][i];
		if(v == pre) continue;
		dfs(v,u);
	}
	out[u] = cnt;
}
int main()
{
	scanf("%d%d",&N,&M);
	for(int i = 1; i<= N; i++) scanf("%d",&a[i]);
	for(int i = 1; i <= N-1; i++) {
		int u,v; scanf("%d%d",&u,&v);
		G[u].push_back(v);	G[v].push_back(u); 
	}
	dfs(1,0);
	for(int i = 1; i <= N; i++) {
		update(in[i],a[i],0);
		update(out[i]+1,-a[i],0);
	}
	for(int i = 1; i <= M; i++) {
		int q,x,a;
		scanf("%d%d",&q,&x);
		if(q == 3) {
			printf("%lld\n",query(x));
		}
		else  {
			scanf("%d",&a);
			if(q == 1) {
				update(in[x],a,0);
				update(out[x]+1,-a,0);
			}else {
				update(in[x],-1ll*(deep[x]-1)*a,0);
				update(out[x]+1,1ll*(deep[x]-1)*a,0);
				update(in[x],a,1);
				update(out[x]+1,-a,1);
			}
		}
	}
	return 0;
} 

上一题