NC19314. 颓红警
描述
小可爱率领的部队现在面对的是敌军在这一地区的驻军,敌国战争机器的运作很大程度上依赖指挥,所以敌军内部是严明分级的,就是说,全部敌军可以看作一棵树,每只敌军部队(树上每个节点)有其战斗力。你可以对任意敌军部队发动进攻,小可爱的部队有战斗力p,意味着他的每次进攻将使得被进攻的这支部队的战斗力减少p,对上级指挥系统的打击同时会影响其下级部队。具体来说,当他对点i发动进攻,部队i的战力减少p的同时,对于其子树内点j,部队j的战力减少Max(0,p−dis(i,j)2)(dis(i,j)表示点i,j间简单路径的长度)。如果某支部队战力小于0,那么这支部队就被消灭了,一支部队被消灭不会改变敌军编制(即这棵树的结构不会改变)。
小可爱想知道,你的部队最少发动几次进攻,才能全歼敌军
由于小可爱还要爆手速发展自己实力,所以把这个问题交给了你。输入描述
第一行两个正整数n,p,分别表示德军部队数目和你部战斗力
第二行n个正整数,表示德军各部战斗力mi
第三行到第n+1行,每行两个正整数i,j,表示i,j两支部队存在从属关系(i为j的上级)
输出描述
输出一个整数,表示最少进攻次数
示例1
输入:
7 3 1 1 3 7 5 3 3 1 2 2 3 1 4 2 5 4 6 1 7
输出:
8
说明:
对一号、七号部队各发动一次进攻,对三号、四号、五号部队各发动两次进攻C++14(g++5.4) 解法, 执行用时: 792ms, 内存消耗: 172640K, 提交时间: 2020-06-22 17:10:32
#include<bits/stdc++.h> #define ll long long using namespace std; const int MX=1e6+9; int n; ll s[MX],s1[MX],s2[MX],val[MX],p,lim; ll tot[MX],cnt[MX],ans; vector<int> vec[MX]; void dfs(int u,ll di){ if( di-lim>0 ){ ll temp=di-lim; s[u]-=tot[temp]; s1[u]-=tot[temp]*temp; s2[u]-=tot[temp]*temp*temp; } ll sum=s[u]*(p-di*di)-s2[u]+2*s1[u]*di; ll rem=0; if( val[u]>sum ){ rem=(val[u]-sum)/p+1; ans+=rem; } tot[di]=rem; for( int i=0 ; i<vec[u].size() ; i++ ){ int v=vec[u][i]; s[v]=s[u]+rem; s1[v]=s1[u]+di*rem; s2[v]=s2[u]+di*di*rem; dfs(v,di+1); } } int main() { // freopen("input.txt","r",stdin); scanf("%d %lld",&n,&p); for( int i=1 ; i<=n ; i++ ) scanf("%lld",&val[i]); for( int i=1,a,b ; i<=n-1 ; i++ ){ scanf("%d %d",&a,&b); vec[a].push_back(b); } lim=sqrt(p)+1; // 注意 dfs(1,1); printf("%lld\n",ans); return 0; }
C++11(clang++ 3.9) 解法, 执行用时: 1146ms, 内存消耗: 172220K, 提交时间: 2020-07-05 19:51:03
#include<bits/stdc++.h> using namespace std; int p,t,a[1000005],V[1000005]; long long ans=0,b[1000005]={0},c[1000005]={0},sum[1000005]={0}; vector<int>R[1000005]; void DFS(int x,long long dep) { long long i=dep-t,j; if(i>0)sum[x]-=V[i],b[x]-=V[i]*i,c[x]-=V[i]*i*i; V[dep]=0,j=sum[x]*(p-dep*dep)+2*dep*b[x]-c[x]; if(a[x]>=j)V[dep]=(a[x]-j)/p+1,ans+=V[dep]; for(i=0;i<R[x].size();i++) { j=R[x][i]; sum[j]=sum[x]+V[dep],b[j]=b[x]+dep*V[dep],c[j]=c[x]+V[dep]*dep*dep; DFS(j,dep+1); } } int main() { int i,j,k,n; scanf("%d%d",&n,&p),t=sqrt(p)+1; for(i=1;i<=n;i++)scanf("%d",&a[i]); for(i=1;i<n;i++)scanf("%d%d",&j,&k),R[j].push_back(k); DFS(1,1); printf("%lld\n",ans); return 0; }
pypy3(pypy3.6.1) 解法, 执行用时: 3893ms, 内存消耗: 289332K, 提交时间: 2020-07-20 03:48:21
#!/usr/bin/env python3 # # 颓红警 # import sys, os from collections import deque def read_ints(): return list(map(int, input().split())) n, p = read_ints() m = read_ints() g = [[] for _ in range(n)] ind = [0] * n for _ in range(n - 1): u, v = read_ints() u -= 1; v -= 1 g[u].append(v) ind[v] += 1 root = -1 for i, c in enumerate(ind): if c == 0: root = i ans = 0 q = deque() q.append([root, []]) while q: u, a = q.popleft() r, t = m[u] + 1, [] for l, c in a: r -= max(0, p - l * l) * c if p - (l + 1) * (l + 1) > 0: t.append([l + 1, c]) c = ((r - 1) // p + 1) if r > 0 else 0 if c > 0: t.append([1, c]) ans += c for v in g[u]: q.append([v, t]) print(ans)