NC21533. Sortable Path on Tree
描述
Let (t1=u,tk=v) be the shortest path from u to v. Then the sequence or the sequence can be sorted into nondecreasing order using several circular shift operations.Note that a circular shift is the operation of rearranging the entries in a sequence, either by moving the final entry to the first position, while shifting all other entries to the next position, or by performing the inverse operation.
输入描述
There are multiple test cases. The first line of the input contains an integer T, indicating the number of test cases. For each test case:
The first line contains an integer n (1 ≤ n ≤ 105) -- the number of nodes in the tree.
The second line contains n integers w1,w2,...,wn (1 ≤ wi ≤ 105).
Each of the next n-1 lines contains two integers u and v (1 ≤ u, v ≤ n, u ≠ v) denoting an edge on tree.
It's guaranteed that the sum of n in all test cases will not exceed 105.
输出描述
For each test case, output an integer denoting the answer.
示例1
输入:
1 4 3 4 1 2 1 2 2 3 3 4
输出:
10
C++14(g++5.4) 解法, 执行用时: 352ms, 内存消耗: 6008K, 提交时间: 2018-12-26 20:33:28
#include<cstdio> #include<cstring> #include<algorithm> #define ll long long using namespace std; const int N=1e5+5; int T,n,w[N],tp[N],hd[N],xnt,to[N<<1],nxt[N<<1],siz[N],rt,mn,lm; int f[3][3][N]; ll ans; bool vis[N]; int rdn() { int ret=0;bool fx=1;char ch=getchar(); while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();} while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar(); return fx?ret:-ret; } int Mx(int a,int b){return a>b?a:b;} int Mn(int a,int b){return a<b?a:b;} void init() { xnt=0;for(int i=1;i<=n;i++)hd[i]=0;//memset(hd,0,sizeof hd); ans=0;for(int i=1;i<=n;i++)vis[i]=0;//memset(vis,0,sizeof vis); sort(tp+1,tp+n+1);lm=unique(tp+1,tp+n+1)-tp-1;/// for(int i=1;i<=n;i++)w[i]=lower_bound(tp+1,tp+lm+1,w[i])-tp; } void add(int x,int y){to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;} /* void init_dfs(int cr,int fa) { siz[cr]=1; for(int i=hd[cr],v;i;i=nxt[i]) if((v=to[i])!=fa)init_dfs(v,cr),siz[cr]+=siz[v]; } */ void getrt(int cr,int fa,int s) { int mx=0;siz[cr]=1; for(int i=hd[cr],v;i;i=nxt[i]) if(!vis[v=to[i]]&&v!=fa) { getrt(v,cr,s);siz[cr]+=siz[v]; mx=Mx(mx,siz[v]); } mx=Mx(mx,s-siz[cr]);if(mx<mn)mn=mx,rt=cr; } void add(int x,int k,int s0,int s1){for(;x<=lm;x+=(x&-x))f[s0][s1][x]+=k;} int qry(int x,int s0,int s1){if(!x)return 0;int ret=0;for(;x;x-=(x&-x))ret+=f[s0][s1][x];return ret;} int qry_s(int x,int s0,int s1){return qry(lm,s0,s1)-qry(x-1,s0,s1);} int cal(int s0,int s1,int i,int j,int tw) { if(s0>1&&s1>1)return 0;// if(s0==1&&s1>1)return qry_s(tw,i,j); if(s0>1&&s1==1)return qry(tw,i,j); return qry(lm,i,j); } void calc(int tw,int s0,int s1) { ans++;//with rt for(int i=0;i<=2;i++)for(int j=0;j<=2;j++)ans+=cal(s0+i,s1+j,i,j,tw); } void dfs(int cr,int fa,int lst,int s0,int s1,int op) { if(op==1){ if(w[cr]>lst)s1++; if(w[cr]<lst)s0++; } else{ if(lst>w[cr])s1++; if(lst<w[cr])s0++; } if(s0>1&&s1>1)return; if(s0==1&&s1>1){ if(op==1&&w[cr]>w[rt])return; if(op>1&&w[cr]<w[rt])return; } if(s1==1&&s0>1){ if(op==1&&w[cr]<w[rt])return; if(op>1&&w[cr]>w[rt])return; } if(op==1)calc(w[cr],s0,s1);//,printf("cr=%d[%d,%d]ans=%lld\n",cr,s0,s1,ans); if(op==2)add(w[cr],1,s0>1?2:s0,s1>1?2:s1); if(op==3)add(w[cr],-1,s0>1?2:s0,s1>1?2:s1); for(int i=hd[cr],v;i;i=nxt[i]) if(!vis[v=to[i]]&&v!=fa)dfs(v,cr,w[cr],s0,s1,op); } void solve(int cr,int s) { vis[cr]=1;// printf("cr=%d s=%d\n",cr,s); for(int i=hd[cr],v;i;i=nxt[i]) if(!vis[v=to[i]]) { dfs(v,cr,w[cr],0,0,1);dfs(v,cr,w[cr],0,0,2); } // printf("ans=%lld\n",ans); for(int i=hd[cr],v;i;i=nxt[i]) if(!vis[v=to[i]])dfs(v,cr,w[cr],0,0,3); for(int i=hd[cr],v,ts;i;i=nxt[i]) if(!vis[v=to[i]]) { mn=N;ts=(siz[v]<siz[cr]?siz[v]:s-siz[cr]); getrt(v,cr,ts);solve(rt,ts); } } int main() { T=rdn(); while(T--) { n=rdn();lm=0;for(int i=1;i<=n;i++)w[i]=rdn(),tp[i]=w[i]; init(); for(int i=1,u,v;i<n;i++) u=rdn(),v=rdn(),add(u,v),add(v,u); /*init_dfs(1,0);*/mn=N;getrt(1,0,n);solve(rt,n); printf("%lld\n",ans+n); } return 0; }
C++11(clang++ 3.9) 解法, 执行用时: 733ms, 内存消耗: 10572K, 提交时间: 2019-02-03 19:29:57
#include <bits/stdc++.h> using namespace std; #define ll long long #define N 100010 int t, n; int w[N]; vector <int> G[N]; ll res; struct BIT { int a[N]; void init() { memset(a, 0, sizeof a); } void update(int x, int val) { for (; x <= 100000; x += x & -x) a[x] += val; } int query(int x) { int res = 0; for (; x > 0; x -= x & -x) res += a[x]; return res; } int query(int l, int r) { if (r < l) return 0; return query(r) - query(l - 1); } }bit[3][3]; int vis[N]; int root, sum, sze[N], f[N]; void getroot(int u, int fa) { sze[u] = 1, f[u] = 0; for (auto v : G[u]) if (v != fa && !vis[v]) { getroot(v, u); sze[u] += sze[v]; f[u] = max(f[u], sze[v]); } f[u] = max(f[u], sum - sze[u]); if (f[u] < f[root]) root = u; } int big[N], small[N]; void getdeep(int u, int fa) { if (big[u] > 2) big[u] = 2; if (small[u] > 2) small[u] = 2; int x = big[u], y = small[u]; for (int i = 0; i <= 2; ++i) for (int j = 0; j <= 2; ++j) { int nx = x + i; int ny = y + j; if (nx >= 2 && ny >= 2) continue; if (!nx || !ny || (nx == 1 && ny == 1)) res += bit[i][j].query(100000); else if (nx == 1) res += bit[i][j].query(1, w[u]); else if (ny == 1) res += bit[i][j].query(w[u], 100000); } for (auto v : G[u]) if (v != fa && !vis[v]) { big[v] = big[u] + (w[v] > w[u]); small[v] = small[u] + (w[v] < w[u]); getdeep(v, u); } } void add(int u, int fa, int flag) { bit[small[u]][big[u]].update(w[u], flag); for (auto v : G[u]) if (v != fa && !vis[v]) add(v, u, flag); } void solve(int u) { vis[u] = 1; bit[0][0].update(w[u], 1); for (auto v : G[u]) if (!vis[v]) { big[v] = (w[v] > w[u]); small[v] = (w[v] < w[u]); getdeep(v, u); add(v, u, 1); } for (auto v : G[u]) if (!vis[v]) add(v, u, -1); bit[0][0].update(w[u], -1); for (auto v : G[u]) if (!vis[v]) { sum = f[0] = sze[v]; root = 0; getroot(v, 0); solve(root); } } int main() { scanf("%d", &t); while (t--) { scanf("%d", &n); for (int i = 1; i <= n; ++i) G[i].clear(), vis[i] = 0; for (int i = 1; i <= n; ++i) scanf("%d", w + i); for (int i = 1, u, v; i < n; ++i) { scanf("%d%d", &u, &v); G[u].push_back(v); G[v].push_back(u); } res = 0; sum = f[0] = n; root = 0; getroot(1, 0); solve(root); printf("%lld\n", res + n); } return 0; }