列表

详情


NC236069. Permutation on Tree

描述

Given a tree with n vertices where vertex r is the root, we say a permutation of n is good if it satisfies the following constraint:
Let a_x be the index of x in the permutation (That is, ). For all , if vertex u is an ancestor of vertex v in the tree, then .
Define the score of a permutation to be where is the absolute value of x. Calculate the sum of scores of all different good permutations.

输入描述

There is only one test case in each test file.
The first line contains two integers n and r (, ) indicating the size of the tree and the root.
For the following (n - 1) lines, the i-th line contains two integers u_i and v_i () indicating an edge connecting vertex u_i and v_i in the tree.

输出描述

For each test case output one line containing one integer indicating the sum of scores of all different good permutations. As the answer may be large, output the answer modulo .

示例1

输入:

4 2
1 2
2 3
1 4

输出:

15

示例2

输入:

3 1
1 2
2 3

输出:

2

原站题解

上次编辑到这里,代码来自缓存 点击恢复默认模板

C++ 解法, 执行用时: 308ms, 内存消耗: 46980K, 提交时间: 2022-07-16 11:52:28

#include<bits/stdc++.h>
#define For(i,a,b) for(int i=(a),i##END=(b);i<=i##END;i++)
#define Rof(i,b,a) for(int i=(b),i##END=(a);i>=i##END;i--)
#define go(u) for(int i=head[u];i;i=nxt[i])
#define int long long
using namespace std;
inline int read(){
    int x=0,f=1;
    char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
    return x*f;
}
const int N=2e3+10,p=1e9+7;
vector<int> g[N];
inline int inv(int a,int b=p-2){
	int x=1;while(b){
		if(b&1)(x*=a)%=p;
		(a*=a)%=p,b>>=1;
	}return x;
}
int C[N][N],iC[N][N];
void init(int n=N-10){
	For(i,0,n){C[i][0]=1;For(j,1,i)C[i][j]=(C[i-1][j-1]+C[i-1][j])%p;}
	For(i,0,n)For(j,0,i)iC[i][j]=inv(C[i][j]);
}
int dp[N],n,rt,sz[N],fa[N],ivdp[N];
void dfs1(int u,int f){
	sz[u]=dp[u]=1,fa[u]=f;
	for(int v:g[u])if(v!=f){
		dfs1(v,u),sz[u]+=sz[v];
		(dp[u]*=dp[v]*C[sz[u]-1][sz[v]]%p)%=p;
	}ivdp[u]=inv(dp[u]);
}int f[N][N],ans[N];
int get(int x,int y,int u){
	if(f[x][y])return f[x][y];
	int& w=f[x][y];
	if(fa[x]==u&&fa[y]==u)w=dp[u]*ivdp[x]%p*ivdp[y]%p*iC[sz[u]-1][sz[x]]%p*iC[sz[u]-1-sz[x]][sz[y]]%p*C[sz[u]-2][sz[x]+sz[y]-1]%p;
	if(fa[x]!=u)(w+=dp[fa[x]]*ivdp[x]%p*iC[sz[fa[x]]-1][sz[x]]%p*C[sz[fa[x]]+sz[y]-2][sz[fa[x]]-sz[x]-1]%p*get(fa[x],y,u)%p)%=p;
	if(fa[y]!=u)(w+=dp[fa[y]]*ivdp[y]%p*iC[sz[fa[y]]-1][sz[y]]%p*C[sz[fa[y]]+sz[x]-2][sz[fa[y]]-sz[y]-1]%p*get(x,fa[y],u)%p)%=p;
	return w;
}
vector<int> solve(int u,int f){
	vector<int> t,d;int iv=inv(sz[u]-1);
	for(int v:g[u])if(v!=f){
		(ans[u]+=dp[u]*iv%p*sz[v]%p*abs(u-v))%=p,d=solve(v,u);
		for(int a:t)for(int b:d)(ans[u]+=2*abs(a-b)%p*dp[a]%p*dp[b]%p*C[sz[a]+sz[b]-2][sz[a]-1]%p*get(a,b,u))%=p;
		for(int b:d)t.push_back(b);
	}t.push_back(u);return t;
}
void dfs2(int u,int f){
	int iv=inv(sz[u]-1);for(int v:g[u])if(v!=f)
		dfs2(v,u),(ans[u]+=dp[u]*ivdp[v]%p*iC[sz[u]-1][sz[v]]%p*C[sz[u]-2][sz[v]-1]%p*ans[v])%=p;
}
signed main(){
	init(),n=read(),rt=read();For(i,2,n){
		int u=read(),v=read();
		g[u].push_back(v),g[v].push_back(u);
	}dfs1(rt,0),solve(rt,0),dfs2(rt,0),cout<<ans[rt]<<endl;
	return 0;
}

上一题