列表

详情


NC21533. Sortable Path on Tree

描述

Chiaki has a tree with n nodes numbered 1 to n. Each nodes has a positive integer weight wi on it.

Chiaki would like to know the number of unordered pairs (u,v) such that:
Let (t1=u,tk=v) be the shortest path from u to v. Then the sequence or the sequence can be sorted into nondecreasing order using several circular shift operations.
Note that a circular shift is the operation of rearranging the entries in a sequence, either by moving the final entry to the first position, while shifting all other entries to the next position, or by performing the inverse operation.

输入描述

There are multiple test cases. The first line of the input contains an integer T, indicating the number of test cases. For each test case:
The first line contains an integer n (1 ≤ n ≤ 105) -- the number of nodes in the tree.
The second line contains n integers w1,w2,...,wn (1 ≤ wi ≤ 105).
Each of the next n-1 lines contains two integers u and v (1 ≤ u, v ≤ n, u ≠ v) denoting an edge on tree.
It's guaranteed that the sum of n in all test cases will not exceed 105.

输出描述

For each test case, output an integer denoting the answer.

示例1

输入:

1
4
3 4 1 2
1 2
2 3
3 4

输出:

10

原站题解

上次编辑到这里,代码来自缓存 点击恢复默认模板

C++14(g++5.4) 解法, 执行用时: 352ms, 内存消耗: 6008K, 提交时间: 2018-12-26 20:33:28

#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
using namespace std;
const int N=1e5+5;
int T,n,w[N],tp[N],hd[N],xnt,to[N<<1],nxt[N<<1],siz[N],rt,mn,lm;
int f[3][3][N]; ll ans; bool vis[N];
int rdn()
{
  int ret=0;bool fx=1;char ch=getchar();
  while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();}
  while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar();
  return fx?ret:-ret;
}
int Mx(int a,int b){return a>b?a:b;}
int Mn(int a,int b){return a<b?a:b;}
void init()
{
  xnt=0;for(int i=1;i<=n;i++)hd[i]=0;//memset(hd,0,sizeof hd);
  ans=0;for(int i=1;i<=n;i++)vis[i]=0;//memset(vis,0,sizeof vis);
  sort(tp+1,tp+n+1);lm=unique(tp+1,tp+n+1)-tp-1;///
  for(int i=1;i<=n;i++)w[i]=lower_bound(tp+1,tp+lm+1,w[i])-tp;
}
void add(int x,int y){to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;}
/*
void init_dfs(int cr,int fa)
{
  siz[cr]=1;
  for(int i=hd[cr],v;i;i=nxt[i])
    if((v=to[i])!=fa)init_dfs(v,cr),siz[cr]+=siz[v];
}
*/
void getrt(int cr,int fa,int s)
{
  int mx=0;siz[cr]=1;
  for(int i=hd[cr],v;i;i=nxt[i])
    if(!vis[v=to[i]]&&v!=fa)
      {
	getrt(v,cr,s);siz[cr]+=siz[v];
	mx=Mx(mx,siz[v]);
      }
  mx=Mx(mx,s-siz[cr]);if(mx<mn)mn=mx,rt=cr;
}
void add(int x,int k,int s0,int s1){for(;x<=lm;x+=(x&-x))f[s0][s1][x]+=k;}
int qry(int x,int s0,int s1){if(!x)return 0;int ret=0;for(;x;x-=(x&-x))ret+=f[s0][s1][x];return ret;}
int qry_s(int x,int s0,int s1){return qry(lm,s0,s1)-qry(x-1,s0,s1);}
int cal(int s0,int s1,int i,int j,int tw)
{
  if(s0>1&&s1>1)return 0;//
  if(s0==1&&s1>1)return qry_s(tw,i,j);
  if(s0>1&&s1==1)return qry(tw,i,j);
  return qry(lm,i,j);
}
void calc(int tw,int s0,int s1)
{
  ans++;//with rt
  for(int i=0;i<=2;i++)for(int j=0;j<=2;j++)ans+=cal(s0+i,s1+j,i,j,tw);
}
void dfs(int cr,int fa,int lst,int s0,int s1,int op)
{
  if(op==1){ if(w[cr]>lst)s1++; if(w[cr]<lst)s0++; }
  else{ if(lst>w[cr])s1++; if(lst<w[cr])s0++; }
  if(s0>1&&s1>1)return;
  if(s0==1&&s1>1){ if(op==1&&w[cr]>w[rt])return; if(op>1&&w[cr]<w[rt])return; }
  if(s1==1&&s0>1){ if(op==1&&w[cr]<w[rt])return; if(op>1&&w[cr]>w[rt])return; }
  if(op==1)calc(w[cr],s0,s1);//,printf("cr=%d[%d,%d]ans=%lld\n",cr,s0,s1,ans);
  if(op==2)add(w[cr],1,s0>1?2:s0,s1>1?2:s1);
  if(op==3)add(w[cr],-1,s0>1?2:s0,s1>1?2:s1);
  for(int i=hd[cr],v;i;i=nxt[i])
    if(!vis[v=to[i]]&&v!=fa)dfs(v,cr,w[cr],s0,s1,op);
}
void solve(int cr,int s)
{
  vis[cr]=1;// printf("cr=%d s=%d\n",cr,s);
  for(int i=hd[cr],v;i;i=nxt[i])
    if(!vis[v=to[i]])
      {
	dfs(v,cr,w[cr],0,0,1);dfs(v,cr,w[cr],0,0,2);
      }
  //  printf("ans=%lld\n",ans);
  for(int i=hd[cr],v;i;i=nxt[i])
    if(!vis[v=to[i]])dfs(v,cr,w[cr],0,0,3);
  for(int i=hd[cr],v,ts;i;i=nxt[i])
    if(!vis[v=to[i]])
      {
	mn=N;ts=(siz[v]<siz[cr]?siz[v]:s-siz[cr]);
	getrt(v,cr,ts);solve(rt,ts);
      }
}
int main()
{
  T=rdn();
  while(T--)
    {
      n=rdn();lm=0;for(int i=1;i<=n;i++)w[i]=rdn(),tp[i]=w[i];
      init();
      for(int i=1,u,v;i<n;i++)
	u=rdn(),v=rdn(),add(u,v),add(v,u);
      /*init_dfs(1,0);*/mn=N;getrt(1,0,n);solve(rt,n);
      printf("%lld\n",ans+n);
    }
  return 0;
}

C++11(clang++ 3.9) 解法, 执行用时: 733ms, 内存消耗: 10572K, 提交时间: 2019-02-03 19:29:57

#include <bits/stdc++.h>
using namespace std;

#define ll long long
#define N 100010
int t, n;
int w[N];
vector <int> G[N];
ll res;

struct BIT
{
	int a[N];
	void init() { memset(a, 0, sizeof a); }
	void update(int x, int val)
	{
		for (; x <= 100000; x += x & -x)
			a[x] += val;
	}
	int query(int x)
	{
		int res = 0;
		for (; x > 0; x -= x & -x)
			res += a[x];
		return res;
	}
	int query(int l, int r) { if (r < l) return 0; return query(r) - query(l - 1); }
}bit[3][3]; 

int vis[N];
int root, sum, sze[N], f[N];
void getroot(int u, int fa)
{
	sze[u] = 1, f[u] = 0;
	for (auto v : G[u]) if (v != fa && !vis[v])
	{
		getroot(v, u);
		sze[u] += sze[v];
		f[u] = max(f[u], sze[v]);
	}
	f[u] = max(f[u], sum - sze[u]);
	if (f[u] < f[root]) root = u;
}

int big[N], small[N];
void getdeep(int u, int fa)
{
	if (big[u] > 2) big[u] = 2;
	if (small[u] > 2) small[u] = 2;
	int x = big[u], y = small[u];
	for (int i = 0; i <= 2; ++i)
		for (int j = 0; j <= 2; ++j)
		{
			int nx = x + i;
			int ny = y + j;
			if (nx >= 2 && ny >= 2) continue;
			if (!nx || !ny || (nx == 1 && ny == 1)) res += bit[i][j].query(100000);
			else if (nx == 1) res += bit[i][j].query(1, w[u]);
			else if (ny == 1) res += bit[i][j].query(w[u], 100000);
		}
	for (auto v : G[u]) if (v != fa && !vis[v])
	{
		big[v] = big[u] + (w[v] > w[u]);
		small[v] = small[u] + (w[v] < w[u]);
		getdeep(v, u); 
	}
}

void add(int u, int fa, int flag)
{
	bit[small[u]][big[u]].update(w[u], flag);
	for (auto v : G[u]) if (v != fa && !vis[v]) 
		add(v, u, flag);  
}

void solve(int u)
{
	vis[u] = 1; 
	bit[0][0].update(w[u], 1);
	for (auto v : G[u]) if (!vis[v])
	{
		big[v] = (w[v] > w[u]);
		small[v] = (w[v] < w[u]);  
		getdeep(v, u);
		add(v, u, 1);
	}
	for (auto v : G[u]) if (!vis[v]) add(v, u, -1);
	bit[0][0].update(w[u], -1); 
	for (auto v : G[u]) if (!vis[v])
	{
		sum = f[0] = sze[v]; root = 0;
		getroot(v, 0);
		solve(root);
	}
}

int main()
{
	scanf("%d", &t);
	while (t--)
	{
		scanf("%d", &n); 
		for (int i = 1; i <= n; ++i) G[i].clear(), vis[i] = 0;
		for (int i = 1; i <= n; ++i) scanf("%d", w + i);
		for (int i = 1, u, v; i < n; ++i)
		{
			scanf("%d%d", &u, &v);
			G[u].push_back(v);
			G[v].push_back(u);
		}
		res = 0; 
		sum = f[0] = n; root = 0;
		getroot(1, 0);
		solve(root);
		printf("%lld\n", res + n);
	}
	return 0;
}

上一题