列表

详情


NC203503. 路径积

描述

给定一棵n个节点的无根树(n个结点,n-1条边的无环连通图),每个节点有一个权值a_i

一共有m次查询,每次查询x_iy_i的最短路径上所有点权的乘积。

为了防止答案过大,答案对1e9+7取模。


输入描述



第一个参数n代表节点个数
第二个参数m代表查询次数
第三个参数vector a代表每个节点的权值
第四、五个参数vector u,v各自包含n-1个元素代表树上的边,u_iv_i相连
第六、七个参数vector x,y各自包含m个元素,代表查询x_iy_i的最短路径上的点权乘积

对于每个查询,答案存储在vector中并按照查询的时间顺序输出

示例1

输入:

6,2,[1,2,3,4,5,6],[1,2,2,4,4],[2,3,4,5,6],[3,6],[5,5]

输出:

[120,120]

说明:

3到5的最短路径为3->2->4->5,路径积为3*2*4*5=120
6到5的最短路径为6->4->5,路径积为6*4*5=120

原站题解

上次编辑到这里,代码来自缓存 点击恢复默认模板

C++11(clang++ 3.9) 解法, 执行用时: 170ms, 内存消耗: 25572K, 提交时间: 2020-08-14 08:14:42

class Solution {
public:
    static const int N=1e5+10,M=22,mod=1e9+7;
    typedef long long ll;
    ll power(ll a,ll b=mod-2) {
        ll c=1;
        while(b) {
            if(b&1) c=c*a%mod;
            b /= 2; a=a*a%mod;
        }
        return c;
    }
    int f[N][22],dep[N];
    vector<int> e[N];
    ll w[N];
    void add(int x,int y) {e[x].push_back(y);}
    void dfs(int x) {
        for(int y:e[x]) if(y^f[x][0]) {
            dep[y]=dep[x]+1; f[y][0]=x;
            for(int i=1;i<=20;i++) f[y][i]=f[f[y][i-1]][i-1];
            w[y] = w[y]*w[x]%mod; dfs(y);
        }
    }
    int lca(int x,int y) {
        if(dep[x]<dep[y]) swap(x,y);
        for(int k=dep[x]-dep[y],i=0; k;i++)
           if(k>>i&1) x=f[x][i],k^=1<<i;
        if(x==y) return x;
        for(int i=20;~i;i--)
              if(f[x][i]^f[y][i])
                      x=f[x][i],y=f[y][i];
         return f[x][0];
    }
    vector<int> solve(int n, int m, vector<int>& a, vector<int>& u, vector<int>& v, vector<int>& x, vector<int>& y) {
        w[0]=1; for(int i=1;i<=n;i++) w[i]=a[i-1],e[i].clear();;
        for(int i=0;i<n-1;i++) {
            int s=u[i],t=v[i];
            add(s,t); add(t,s);
        }
        memset(f,0,sizeof f); dfs(1);
        vector<int> res;
        for(int i=0;i<m;i++) {
            int s=x[i],t=y[i],z=lca(s,t);
            res.push_back((ll)w[s]*w[t]%mod*power(w[z])%mod*power(w[f[z][0]])%mod);
        }
        return res;
    }
};

Java(javac 1.8) 解法, 执行用时: 821ms, 内存消耗: 120048K, 提交时间: 2020-08-14 17:26:05

import java.util.*;


public class Solution {
    /**
     * 路径积
     * @param n int整型
     * @param m int整型
     * @param a int整型一维数组
     * @param u int整型一维数组
     * @param v int整型一维数组
     * @param x int整型一维数组
     * @param y int整型一维数组
     * @return int整型一维数组
     */
    int mod=(int)(1e9+7);
    ArrayList<Integer> map[];
    int[] deep;
    int[] pre;
    void dfs(int node,int last){
        deep[node]=deep[last]+1;
        pre[node]=last;
        for(int next: map[node]){
            if(next==last)continue;
            dfs(next,node);
        }
    }
    int getnum(int[] a,int x,int y){
        long ans=a[x-1]*a[y-1]%mod;
        int last=1;
        while(x!=y){
            ans=ans*last%mod;
            if(deep[x]>deep[y]){
                x=pre[x];last=a[x-1];
            }else {
                y=pre[y];last=a[y-1];
            }
        }
        return (int)ans;
    }
    public int[] solve (int n, int m, int[] a, int[] u, int[] v, int[] x, int[] y) {
        // write code here
        map=new ArrayList[n+1];
        for(int i=0;i<=n;i++)map[i]=new ArrayList<>();
        deep=new int[n+1];
        pre=new int[n+1];
        deep[1]=-1;
        for(int i=0;i<n-1;i++){
            map[u[i]].add(v[i]);
            map[v[i]].add(u[i]);
        }
        dfs(1,1);
        int[] res=new int[m];
        for(int i=0;i<m;i++){
            res[i]=getnum(a,x[i],y[i]);
        }
        return res;
    }
}

Go(1.14.4) 解法, 执行用时: 381ms, 内存消耗: 23264K, 提交时间: 2020-08-14 01:32:18

package main

// github.com/EndlessCheng/codeforces-go
func solve(n, q int, a, vs, ws, x, y []int) []int {
	const mod int = 1e9 + 7
	g := make([][]int, n)
	for i, v := range vs {
		v--
		w := ws[i] - 1
		g[v] = append(g[v], w)
		g[w] = append(g[w], v)
	}
	p := make([]int, n)
	for i := range p {
		p[i] = i
	}
	var find func(int) int
	find = func(x int) int {
		if p[x] != x {
			p[x] = find(p[x])
		}
		return p[x]
	}
	pow := func(x int) int {
		res := 1
		for n := mod - 2; n > 0; n >>= 1 {
			if n&1 == 1 {
				res = res * x % mod
			}
			x = x * x % mod
		}
		return res
	}

	ans := make([]int, q)
	type query struct{ w, i int }
	qs := make([][]query, n)
	for i, v := range x {
		v--
		if w := y[i] - 1; v != w {
			qs[v] = append(qs[v], query{w, i})
			qs[w] = append(qs[w], query{v, i})
		} else {
			ans[i] = a[v]
		}
	}

	mul := make([]int, n)
	vis := make([]int8, n)
	var f func(v, m int)
	f = func(v, m int) {
		mul[v] = m
		vis[v] = 1
		for _, w := range g[v] {
			if vis[w] == 0 {
				f(w, m*a[w]%mod)
				p[w] = v
			}
		}
		for _, q := range qs[v] {
			if w := q.w; vis[w] == 2 {
				lca := find(w)
				ans[q.i] = mul[v] * mul[w] % mod * pow(mul[lca]*mul[lca]%mod) % mod * a[lca] % mod
			}
		}
		vis[v] = 2
	}
	f(0, a[0])
	return ans
}

上一题