NC212812. 生活在树上
描述
输入描述
第一行两个正整数 n, q 表示树的大小和牛牛需要帮助计算的人数。
后面 n-1 行每行两个正整数数 ,表示之间有一条边相连。后面 q 行每行 4 个正整数 ,表示一个人的参数。
输出描述
q 行,每行一个数,第 i 行表示第 i 个人个人与家庭、社会之间的落差值。
示例1
输入:
7 4 1 2 2 3 2 4 4 5 1 6 6 7 1 2 6 7 3 5 4 7 3 7 5 6 1 1 1 1
输出:
8 44 49 0
C++14(g++5.4) 解法, 执行用时: 418ms, 内存消耗: 73344K, 提交时间: 2020-10-10 00:20:11
#include<bits/stdc++.h> using namespace std; typedef long long ll; #define all(x) (x).begin(), (x).end() vector<vector<int> > E; struct SparseTable { int n; vector<vector<int>> dp; SparseTable() { n = 0;} SparseTable(const vector<int> &a) { n = a.size(); dp.emplace_back(a); for(int i = 1, L = 2; L < n; L <<= 1, i++) { dp.emplace_back(n - L + 1); for(int j = 0; j + L - 1 < n; j++) { dp[i][j] = min(dp[i - 1][j], dp[i - 1][j + L / 2]); } } } int query(int l, int r) { int lgt = __lg(r - l + 1); return min(dp[lgt][l], dp[lgt][r - (1 << lgt) + 1]); } }; struct RMQLCA { int n, dfs_clock; vector<int> seq, iseq, id; SparseTable rmq; RMQLCA() {} RMQLCA(int _n, int root) { n = _n; dfs_clock = 0; seq.resize(2 * n - 1); iseq.resize(2 * n - 1); id.resize(n); dfs(root, -1); rmq = SparseTable(seq); } void dfs(int u, int fa) { id[u] = dfs_clock; seq[dfs_clock++] = id[u]; iseq[id[u]] = u; for(auto &e : E[u]) { int v = e; if(v == fa) continue; dfs(v, u); seq[dfs_clock++] = id[u]; } } int lca(int u, int v) { if(id[u] > id[v]) swap(u, v); return iseq[rmq.query(id[u], id[v])]; } }; RMQLCA eq; vector<int> dep; void prepare(int u, int pre) { for(auto &v : E[u]) { if(v == pre) continue; dep[v] = dep[u] + 1; prepare(v, u); } } int dis(int u, int v) { return dep[u] + dep[v] - 2 * dep[eq.lca(u, v)]; } ll calc(int n, int x) { return x * (x + 1ll) / 2 + (n - x) * (n - x + 1ll) / 2; } ll S(int x) { return x * (x + 1ll) * (x + 2ll) / 6; } ll calc2(int n, int x) { ll ans = S(x) + S(n) - S(n - x - 1); return ans; } ll calc(int n, int l, int r) { if(l == r) return calc(n, l); ll ret = calc2(n, r); if(l) ret -= calc2(n, l - 1); return ret; } int lca(int u, int v) { return eq.lca(u, v); } vector<int> findinsect(int u1, int v1, int u2, int v2) { int t[4] = {lca(u1, u2), lca(u1, v2), lca(v1, u2), lca(v1, v2)}; sort(t, t + 4, [&](int x, int y) { return dep[x] < dep[y];}); int r = lca(u1, v1), rr = lca(u2, v2); if(dep[t[3]] < max(dep[r], dep[rr])) { return {}; } return {t[2], t[3]}; } void solve(int u1, int v1, int u2, int v2) { auto V = findinsect(u1, v1, u2, v2); int n = dis(u1, v1); if(V.empty()) { int a1 = lca(u1, v1), a2 = lca(u2, v2); int x1 = -1, x2 = -1; if(a2 == lca(a1, a2)) { swap(a1, a2); swap(u1, u2); swap(v1, v2); } else if(a1 != lca(a1, a2)) { x1 = a1; x2 = a2; } if(x1 == -1) { x2 = a2; if(dep[lca(a2, u1)] < dep[lca(a2, v1)]) { x1 = lca(a2, v1); } else { x1 = lca(a2, u1); } } int d = dis(x1, x2); ll ans = (calc(dis(u1, v1), dis(x1, u1)) + (ll) d * (dis(u1, v1) + 1)) * (dis(u2, v2) + 1); ans += calc(dis(u2, v2), dis(x2, u2)) * (dis(u1, v1) + 1); cout << ans << '\n'; return; } int l = dis(V[0], u1), r = dis(V[1], u1); if(l > r) swap(l, r), swap(V[0], V[1]); ll res = calc(n, l, r); ll z1 = calc(n, l), z2 = calc(n, r); int k = 0; if(dis(u2, V[0]) > dis(u2, V[1])) { swap(z1, z2); k ^= 1; } int d1 = dis(V[k], u2), d2 = dis(V[k ^ 1], v2); res += z1 * d1 + z2 * d2 + (d2 * (d2 + 1ll) / 2 + d1 * (d1 + 1ll) / 2) * (dis(u1, v1) + 1); cout << res << '\n'; } int main() { #ifdef local freopen("in.txt", "r", stdin); #endif ios::sync_with_stdio(false); cin.tie(0), cout.tie(0); int n, q; cin >> n >> q; E.resize(n); dep.resize(n); for(int i = 1, u, v; i < n; i++) { cin >> u >> v; --u, --v; E[u].emplace_back(v); E[v].emplace_back(u); } eq = RMQLCA(n, 0); prepare(0, -1); for(int i = 0; i < q; i++) { int u1, v1, u2, v2; cin >> u1 >> v1 >> u2 >> v2; --u1, --v1, --u2, --v2; solve(u1, v1, u2, v2); } return 0; }
C++11(clang++ 3.9) 解法, 执行用时: 547ms, 内存消耗: 37280K, 提交时间: 2020-10-14 10:55:35
#include<cstdio> #include<algorithm> #include<vector> using namespace std; const int N=3e5+5; int n,q,fa[N],dep[N],sz[N],top[N]; vector<int>E[N]; void dfs1(int u,int f){ fa[u]=f;dep[u]=dep[f]+1;sz[u]=1; for(int v:E[u]) if(v!=f) dfs1(v,u),sz[u]+=sz[v]; } void dfs2(int u,int f){ top[u]=f;int son=0; for(int v:E[u]) if(v!=fa[u]) son=sz[v]>sz[son]?v:son; if(son)dfs2(son,f); for(int v:E[u]) if(v!=fa[u]&&v!=son) dfs2(v,v); } int lca(int x,int y){ while(top[x]!=top[y]) if(dep[top[x]]>dep[top[y]]) x=fa[top[x]]; else y=fa[top[y]]; return dep[x]<dep[y]?x:y; } long long C3(int x){ return (long long)(x+1)*x*(x-1)/6; } long long sum(int l,int r){ return (long long)(l+r)*(r-l+1)>>1; } long long query(int x,int y){ int z=lca(x,y); x=dep[x];y=dep[y];z=dep[z]; long long res=C3(z)+C3(y)-C3(y-z); long long s=sum(1,z-1)+sum(1,y-z); res+=s*(x-z)+sum(1,x-z)*y; return res; } int main(){ scanf("%d%d",&n,&q); for(int i=1,x,y;i<n;++i){ scanf("%d%d",&x,&y); E[x].push_back(y); E[y].push_back(x); } dfs1(1,0);dfs2(1,1); for(int i=1,x1,y1,x2,y2,z1,z2;i<=q;++i){ scanf("%d%d%d%d",&x1,&y1,&x2,&y2); z1=lca(x1,y1); z2=lca(x2,y2); vector<int>vec1={x1,y1,-z1},vec2={x2,y2,-z2}; if(z1>1) vec1.push_back(-fa[z1]); if(z2>1) vec2.push_back(-fa[z2]); long long ans=0; for(int x:vec1) for(int y:vec2) ans+=query(abs(x),abs(y))*(x>0?1:-1)*(y>0?1:-1); printf("%lld\n",ans); } return 0; }