NC203503. 路径积
描述
输入描述
第一个参数n代表节点个数
第二个参数m代表查询次数
第三个参数vector a代表每个节点的权值
第四、五个参数vector u,v各自包含n-1个元素代表树上的边,与相连
第六、七个参数vector x,y各自包含m个元素,代表查询到的最短路径上的点权乘积
对于每个查询,答案存储在vector中并按照查询的时间顺序输出
示例1
输入:
6,2,[1,2,3,4,5,6],[1,2,2,4,4],[2,3,4,5,6],[3,6],[5,5]
输出:
[120,120]
说明:
C++11(clang++ 3.9) 解法, 执行用时: 170ms, 内存消耗: 25572K, 提交时间: 2020-08-14 08:14:42
class Solution { public: static const int N=1e5+10,M=22,mod=1e9+7; typedef long long ll; ll power(ll a,ll b=mod-2) { ll c=1; while(b) { if(b&1) c=c*a%mod; b /= 2; a=a*a%mod; } return c; } int f[N][22],dep[N]; vector<int> e[N]; ll w[N]; void add(int x,int y) {e[x].push_back(y);} void dfs(int x) { for(int y:e[x]) if(y^f[x][0]) { dep[y]=dep[x]+1; f[y][0]=x; for(int i=1;i<=20;i++) f[y][i]=f[f[y][i-1]][i-1]; w[y] = w[y]*w[x]%mod; dfs(y); } } int lca(int x,int y) { if(dep[x]<dep[y]) swap(x,y); for(int k=dep[x]-dep[y],i=0; k;i++) if(k>>i&1) x=f[x][i],k^=1<<i; if(x==y) return x; for(int i=20;~i;i--) if(f[x][i]^f[y][i]) x=f[x][i],y=f[y][i]; return f[x][0]; } vector<int> solve(int n, int m, vector<int>& a, vector<int>& u, vector<int>& v, vector<int>& x, vector<int>& y) { w[0]=1; for(int i=1;i<=n;i++) w[i]=a[i-1],e[i].clear();; for(int i=0;i<n-1;i++) { int s=u[i],t=v[i]; add(s,t); add(t,s); } memset(f,0,sizeof f); dfs(1); vector<int> res; for(int i=0;i<m;i++) { int s=x[i],t=y[i],z=lca(s,t); res.push_back((ll)w[s]*w[t]%mod*power(w[z])%mod*power(w[f[z][0]])%mod); } return res; } };
Java(javac 1.8) 解法, 执行用时: 821ms, 内存消耗: 120048K, 提交时间: 2020-08-14 17:26:05
import java.util.*; public class Solution { /** * 路径积 * @param n int整型 * @param m int整型 * @param a int整型一维数组 * @param u int整型一维数组 * @param v int整型一维数组 * @param x int整型一维数组 * @param y int整型一维数组 * @return int整型一维数组 */ int mod=(int)(1e9+7); ArrayList<Integer> map[]; int[] deep; int[] pre; void dfs(int node,int last){ deep[node]=deep[last]+1; pre[node]=last; for(int next: map[node]){ if(next==last)continue; dfs(next,node); } } int getnum(int[] a,int x,int y){ long ans=a[x-1]*a[y-1]%mod; int last=1; while(x!=y){ ans=ans*last%mod; if(deep[x]>deep[y]){ x=pre[x];last=a[x-1]; }else { y=pre[y];last=a[y-1]; } } return (int)ans; } public int[] solve (int n, int m, int[] a, int[] u, int[] v, int[] x, int[] y) { // write code here map=new ArrayList[n+1]; for(int i=0;i<=n;i++)map[i]=new ArrayList<>(); deep=new int[n+1]; pre=new int[n+1]; deep[1]=-1; for(int i=0;i<n-1;i++){ map[u[i]].add(v[i]); map[v[i]].add(u[i]); } dfs(1,1); int[] res=new int[m]; for(int i=0;i<m;i++){ res[i]=getnum(a,x[i],y[i]); } return res; } }
Go(1.14.4) 解法, 执行用时: 381ms, 内存消耗: 23264K, 提交时间: 2020-08-14 01:32:18
package main // github.com/EndlessCheng/codeforces-go func solve(n, q int, a, vs, ws, x, y []int) []int { const mod int = 1e9 + 7 g := make([][]int, n) for i, v := range vs { v-- w := ws[i] - 1 g[v] = append(g[v], w) g[w] = append(g[w], v) } p := make([]int, n) for i := range p { p[i] = i } var find func(int) int find = func(x int) int { if p[x] != x { p[x] = find(p[x]) } return p[x] } pow := func(x int) int { res := 1 for n := mod - 2; n > 0; n >>= 1 { if n&1 == 1 { res = res * x % mod } x = x * x % mod } return res } ans := make([]int, q) type query struct{ w, i int } qs := make([][]query, n) for i, v := range x { v-- if w := y[i] - 1; v != w { qs[v] = append(qs[v], query{w, i}) qs[w] = append(qs[w], query{v, i}) } else { ans[i] = a[v] } } mul := make([]int, n) vis := make([]int8, n) var f func(v, m int) f = func(v, m int) { mul[v] = m vis[v] = 1 for _, w := range g[v] { if vis[w] == 0 { f(w, m*a[w]%mod) p[w] = v } } for _, q := range qs[v] { if w := q.w; vis[w] == 2 { lca := find(w) ans[q.i] = mul[v] * mul[w] % mod * pow(mul[lca]*mul[lca]%mod) % mod * a[lca] % mod } } vis[v] = 2 } f(0, a[0]) return ans }