NC236069. Permutation on Tree
描述
Let be the index of in the permutation (That is, ). For all , if vertex is an ancestor of vertex in the tree, then .
输入描述
There is only one test case in each test file.The first line contains two integers and (, ) indicating the size of the tree and the root.For the following lines, the -th line contains two integers and () indicating an edge connecting vertex and in the tree.
输出描述
For each test case output one line containing one integer indicating the sum of scores of all different good permutations. As the answer may be large, output the answer modulo .
示例1
输入:
4 2 1 2 2 3 1 4
输出:
15
示例2
输入:
3 1 1 2 2 3
输出:
2
C++ 解法, 执行用时: 308ms, 内存消耗: 46980K, 提交时间: 2022-07-16 11:52:28
#include<bits/stdc++.h> #define For(i,a,b) for(int i=(a),i##END=(b);i<=i##END;i++) #define Rof(i,b,a) for(int i=(b),i##END=(a);i>=i##END;i--) #define go(u) for(int i=head[u];i;i=nxt[i]) #define int long long using namespace std; inline int read(){ int x=0,f=1; char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();} return x*f; } const int N=2e3+10,p=1e9+7; vector<int> g[N]; inline int inv(int a,int b=p-2){ int x=1;while(b){ if(b&1)(x*=a)%=p; (a*=a)%=p,b>>=1; }return x; } int C[N][N],iC[N][N]; void init(int n=N-10){ For(i,0,n){C[i][0]=1;For(j,1,i)C[i][j]=(C[i-1][j-1]+C[i-1][j])%p;} For(i,0,n)For(j,0,i)iC[i][j]=inv(C[i][j]); } int dp[N],n,rt,sz[N],fa[N],ivdp[N]; void dfs1(int u,int f){ sz[u]=dp[u]=1,fa[u]=f; for(int v:g[u])if(v!=f){ dfs1(v,u),sz[u]+=sz[v]; (dp[u]*=dp[v]*C[sz[u]-1][sz[v]]%p)%=p; }ivdp[u]=inv(dp[u]); }int f[N][N],ans[N]; int get(int x,int y,int u){ if(f[x][y])return f[x][y]; int& w=f[x][y]; if(fa[x]==u&&fa[y]==u)w=dp[u]*ivdp[x]%p*ivdp[y]%p*iC[sz[u]-1][sz[x]]%p*iC[sz[u]-1-sz[x]][sz[y]]%p*C[sz[u]-2][sz[x]+sz[y]-1]%p; if(fa[x]!=u)(w+=dp[fa[x]]*ivdp[x]%p*iC[sz[fa[x]]-1][sz[x]]%p*C[sz[fa[x]]+sz[y]-2][sz[fa[x]]-sz[x]-1]%p*get(fa[x],y,u)%p)%=p; if(fa[y]!=u)(w+=dp[fa[y]]*ivdp[y]%p*iC[sz[fa[y]]-1][sz[y]]%p*C[sz[fa[y]]+sz[x]-2][sz[fa[y]]-sz[y]-1]%p*get(x,fa[y],u)%p)%=p; return w; } vector<int> solve(int u,int f){ vector<int> t,d;int iv=inv(sz[u]-1); for(int v:g[u])if(v!=f){ (ans[u]+=dp[u]*iv%p*sz[v]%p*abs(u-v))%=p,d=solve(v,u); for(int a:t)for(int b:d)(ans[u]+=2*abs(a-b)%p*dp[a]%p*dp[b]%p*C[sz[a]+sz[b]-2][sz[a]-1]%p*get(a,b,u))%=p; for(int b:d)t.push_back(b); }t.push_back(u);return t; } void dfs2(int u,int f){ int iv=inv(sz[u]-1);for(int v:g[u])if(v!=f) dfs2(v,u),(ans[u]+=dp[u]*ivdp[v]%p*iC[sz[u]-1][sz[v]]%p*C[sz[u]-2][sz[v]-1]%p*ans[v])%=p; } signed main(){ init(),n=read(),rt=read();For(i,2,n){ int u=read(),v=read(); g[u].push_back(v),g[v].push_back(u); }dfs1(rt,0),solve(rt,0),dfs2(rt,0),cout<<ans[rt]<<endl; return 0; }