NC16847. Counting paths
描述
Niuniu is interested in counting paths. He has a tree with n vertices initially painted white. For a vertex set S, Niuniu wants to calculate f(S). f(S) is calculated as below. First the vertices in set S are painted black. If there is any white vertex which lies on the path between any two black vertices, f(S)=0. Otherwise, he will choose some set of paths. (a path from x to y is same as y to x, a path from x to x is allowed). The paths in the set cannot contain any black vertices. Next he will paint the vertices in the paths of the set red. f(S) is the number of different path set which makes all adjacent vertices of black vertices black or red. You need to calculate the sum of f(S) for every possible vertex set S. Of course S should contain at least one element. The answer may be large, you only need to calculate it modulo 998244353.
输入描述
The input has the format as described below。
nx1 y1
x2 y2
...
xn-1 yn-1n is the number of vertices. (1<=n<=200000) There’s an edge between xi and yi. (1 ≤ xi, yi ≤ n). It is guaranteed the graph is a tree.
输出描述
You should print exactly one number, which is the
answer modulo 998244353.
示例1
输入:
2 1 2
输出:
3
说明:
f({1})=1示例2
输入:
3 1 2 2 3
输出:
16
说明:
f({1})=f({3})=6示例3
输入:
5 1 2 2 3 2 4 1 5
输出:
3128
C++11(clang++ 3.9) 解法, 执行用时: 372ms, 内存消耗: 24828K, 提交时间: 2020-02-14 21:14:46
#include<bits/stdc++.h> using namespace std; #define N 200005 #define go(i,a,b) for(int i=a;i<=b;i++) #define ll long long #define mod 998244353 int n,u,to; vector<int>path[N]; ll qpow(ll x,ll k) { ll res=1; for(;k;k/=2,x=x*x%mod) if(k&1) res=res*x%mod; return res; } void add(int u,int to) { path[u].push_back(to); path[to].push_back(u); } ll ans,w,sz[N],dp[N]; void dfs(int u,int fa) { for(auto to:path[u]) { if(to==fa) continue; dfs(to,u); dp[u]+=sz[u]*sz[to]; sz[u]+=sz[to]; } sz[u]++; dp[u]+=sz[u]*(n-sz[u]+1); ans=(ans+qpow(2,w-dp[u])-1)%mod; for(auto to:path[u]) { if(to==fa) continue; ans=(ans-qpow(2,w-dp[u]-dp[to]+sz[to]*(n-sz[to]))+1)%mod; } } int main() { cin>>n; w=1ll*n*(n-1)/2+n; go(i,2,n) scanf("%d%d",&u,&to),add(u,to); dfs(1,0); cout<<(ans+mod+1)%mod<<endl; return 0; }