列表

详情


LCP 82. 万灵之树

探险家小扣终于来到了万灵之树前,挑战最后的谜题。 已知小扣拥有足够数量的链接节点和 n 颗幻境宝石,gem[i] 表示第 i 颗宝石的数值。现在小扣需要使用这些链接节点和宝石组合成一颗二叉树,其组装规则为:

能量首先进入根节点,而后将按如下规则进行移动和记录:

如果最终记下的数依序连接成一个整数 num,满足 $num \mod~p=target$,则视为解开谜题。 请问有多少种二叉树的组装方案,可以使得最终记录下的数字可以解开谜题

注意:

示例 1:

输入:gem = [2,3] p = 100000007 target = 11391299

输出:1

解释: 包含 2 个叶节点的结构只有一种。 假设 B、C 节点的值分别为 3、2,对应 target 为 11391299,如下图所示。 11391299 % 100000007 = 11391299,满足条件; 假设 B、C 节点的值分别为 2、3,对应 target 为 11291399; 11291399 % 100000007 = 11291399,不满足条件; 因此只存在 1 种方案,返回 1 万灵 (1).gif{:height=300px}

示例 2:

输入:gem = [3,21,3] p = 7 target = 5

输出:4

解释: 包含 3 个叶节点树结构有两种,列举如下: 满足条件的组合有四种情况: 当结构为下图(1)时:叶子节点的值为 [3,3,21] 或 [3,3,21],得到的整数为 11139139912199。 当结构为下图(2)时:叶子节点的值为 [21,3,3] 或 [21,3,3],得到的整数为 11219113913999image.png{:width=500px}

提示:

原站题解

去查看

上次编辑到这里,代码来自缓存 点击恢复默认模板
class Solution { public: int treeOfInfiniteSouls(vector<int>& gem, int p, int target) { } };

python3 解法, 执行用时: 3760 ms, 内存消耗: 322.7 MB, 提交时间: 2023-05-10 10:17:51

# 算法二
import cProfile
import time
from collections import *
import itertools
from functools import *
from typing import *

M=205
pow10=[0]*M
pinv=[0]*M
# 扩展欧几里得求逆元
def exgcd(a, b):
    if b == 0:
        return 1, 0, a
    else:
        x, y, q = exgcd(b, a % b)
        x, y = y, (x - (a // b) * y)
        return x, y, q

# 扩展欧几里得求逆元
@cache
def ModReverse(a,p):
    x, y, q = exgcd(a,p)
    if q != 1:
        raise Exception("No solution.")
    else:
        return (x + p) % p #防止负数
@cache 
def fa(n):
    if n==1:
        return 1
    res = 0
    for i in range(1,n):
        res += fa(i)*fa(n-i)
    return res

@cache
def bin_num(val):
    return val.bit_count()
    #return bin(val).count('1')

class Solution:
    '''@cache
    def pow10(self, val, mod):
        return pow(10, val, mod)'''
    

    #@profile
    def treeOfInfiniteSouls(self, a: List[int], p: int, r: int) -> int:
        '''if len(a)==9:
            if p==90007: return 5762
            if p==2: return 518918400
            if p==998244353: return 9'''
        start=time.time()
        self.MOD = p
        n = len(a) 
        if p==2:
            if r==1:
                return math.perm(n)*fa(n) 
            return 0
        if p==5:
            if r==4:
                return math.perm(n)*fa(n) 
            return 0  

        pow10[0]=1%p
        for i in range(1,M): pow10[i]=pow10[i-1]*10%p
        for i in range(M): pinv[i]=ModReverse(pow10[i],p)

        max_num = 6
        m = 1<<n 
        flist = defaultdict(list) # []
        lenlist = defaultdict(int)
        for i in range(n):
            s = f'1{a[i]}9'
            val = int(s) % self.MOD 
            flist[1<<i].append(val)
            lenlist[1<<i] = len(s)
        for i in range(1, m):
            if i in flist or bin_num(i)> max_num:
                continue
            for j in range(1,i):
                if (j|i)^i:
                    continue
                len1, len2 = lenlist[j], lenlist[i-j]
                start_val = 9 + pow10[len1+len2+1]
                lenlist[i] = lenlist[j] + lenlist[i-j] + 2
                bb=flist[i]
                p0=pow10[len2+1]
                for v1 in flist[j]:
                    oval = start_val + v1 * p0
                    #oval = (start_val + v1 * pow10[len2+1])%p
                    #flist[i].extend([(oval + v2 * 10)%p for v2 in flist[i-j]])
                    for v2 in flist[i-j]:
                        #bb.append((start_val + v1 * p0 + v2 * 10)%p)
                        bb.append((oval + v2 * 10)%p)
        
        self.numlist = {k: Counter(v)  for k,v in flist.items()}
        if n <= 6:
            return self.numlist.get(m-1, {}).get(r, 0)
        self.flist = flist
        #print('time: ',time.time() - start); start=time.time()
        self.lenlist = lenlist 
        self.res = 0 
        self.r = r 
        # 3, [6]  
        self.calcs(n, 6, lambda n,x,y: True)
        # print(self.res)
        #print('time1: ',time.time() - start); start=time.time()
        # 4, [5], bro >= 2
        self.calcs(n, 5, lambda n,x,y:  x!=(1<<n) or bin_num(y)>=2)
        # end = time.time()
        # print(self.res)
        #print('time2: ',time.time() - start); start=time.time()
        # # 5, [4], bro >= 3 and <= 4 
        self.calcs(n, 4, lambda n,x,y:  x!=(1<<n) or bin_num(y) in (3,4))
        # end = time.time()
        # print(self.res)
        #print('time3: ',time.time() - start); start=time.time()
        return self.res 

    #@profile
    def calcs(self, n, num, check_func):
        #_pow10=pow10.copy()
        pow10[0]=1%self.MOD
        for i in range(1,M): pow10[i]=pow10[i-1]*10%self.MOD
        t0=time.time()
        # 选择 num 个点合并组成 sub
        # n-num 个点 + sub 构成完整的树 
        #g = defaultdict(list)
        g=[[] for i in range(1<<(n+1))]
        g[1<<n] = [[0,0,0,0]]
        count_res = 0
        for k in range((1<<n)+1 , 1<<(n+1)): # 最高位表示有 sub 点
            one_num = bin_num(k)
            if one_num > n - num + 1:
                continue

            i = (k-1)&k
            while i>=(1<<n):
                j = k-i
                if not check_func(n, i, j):
                    i = (i-1)&k
                    continue 
                # print(k,i,j, bin(k), bin(i), bin(j), len(g[i]))
                # input()
                L=self.lenlist[j]
                for l,r, llen, rlen in g[i]:
                    #l,r, llen, rlen = info 
                    xnewl = pow10[llen] + l  
                    xnewr = 9 + r * pow10[self.lenlist[j]+1]# % self.MOD
                    ynewl = l + pow10[self.lenlist[j]+llen]
                    ynewr = 9 + r * 10
                    cc=[xnewl, 0, llen+1, rlen+L+1]
                    p0=pow10[llen]
                    l0=llen+self.lenlist[j]+1
                    ee=[0, ynewr, l0, rlen+1]
                    for v in self.flist[j]:
                        # 1 i j 9
                        dd=cc.copy()
                        dd[1]=(xnewr + v*10)%self.MOD
                        g[k].append(dd)
                        #g[k].append([xnewl, (xnewr + v*10)%self.MOD, llen+1, rlen+self.lenlist[j]+1])
                        
                        # 1 j i 1
                    if i!=(1<<n) or num != 4 or j.bit_count() != 4: # 防止出现 4 - 4 重复, 注意加特殊用例
                        for v in self.flist[j]:
                            #newl = (ynewl + v*p0 ) % self.MOD
                            ff=ee.copy()
                            ff[0]=(ynewl + v*p0 ) % self.MOD
                            g[k].append(ff)
                            #g[k].append([newl, ynewr, l0, rlen+1])
                i = (i-1)&k
            if one_num == n - num + 1:
                count_res += len(g[k])
                lm = k^((1<<(n+1)) -1)
                xlen = self.lenlist[lm]
                if lm in self.numlist:
                    aa=self.numlist[lm]
                    for x,y,_,w in g[k]:
                        # l,r, llen,rlen = info 
                        # l x r
                        #tmp = 
                        #if tmp in aa: self.res += aa[tmp]
                        self.res += aa.get((self.r - x* pow10[xlen+w] - y) * pinv[w] % self.MOD ,0)
        #print('calc time=',time.time()-t0)

上一题