LCP 82. 万灵之树
探险家小扣终于来到了万灵之树前,挑战最后的谜题。
已知小扣拥有足够数量的链接节点和 n
颗幻境宝石,gem[i]
表示第 i
颗宝石的数值。现在小扣需要使用这些链接节点和宝石组合成一颗二叉树,其组装规则为:
2
个子节点;能量首先进入根节点,而后将按如下规则进行移动和记录:
1
;9
,并回到当前节点的父节点(若存在)。如果最终记下的数依序连接成一个整数 num
,满足 $num \mod~p=target$,则视为解开谜题。
请问有多少种二叉树的组装方案,可以使得最终记录下的数字可以解开谜题
注意:
示例 1:
输入:
gem = [2,3]
p = 100000007
target = 11391299
输出:
1
解释: 包含
2
个叶节点的结构只有一种。 假设 B、C 节点的值分别为 3、2,对应 target 为 11391299,如下图所示。 11391299 % 100000007 = 11391299,满足条件; 假设 B、C 节点的值分别为 2、3,对应 target 为 11291399; 11291399 % 100000007 = 11291399,不满足条件; 因此只存在 1 种方案,返回 1 {:height=300px}
示例 2:
输入:
gem = [3,21,3]
p = 7
target = 5
输出:
4
解释: 包含
3
个叶节点树结构有两种,列举如下: 满足条件的组合有四种情况: 当结构为下图(1)时:叶子节点的值为 [3,3,21] 或 [3,3,21],得到的整数为11139139912199
。 当结构为下图(2)时:叶子节点的值为 [21,3,3] 或 [21,3,3],得到的整数为11219113913999
。 {:width=500px}
提示:
1 <= gem.length <= 9
0 <= gem[i] <= 10^9
1 <= p <= 10^9
,保证 $p$ 为素数。0 <= target < p
gem.length == 9
的用例原站题解
python3 解法, 执行用时: 3760 ms, 内存消耗: 322.7 MB, 提交时间: 2023-05-10 10:17:51
# 算法二 import cProfile import time from collections import * import itertools from functools import * from typing import * M=205 pow10=[0]*M pinv=[0]*M # 扩展欧几里得求逆元 def exgcd(a, b): if b == 0: return 1, 0, a else: x, y, q = exgcd(b, a % b) x, y = y, (x - (a // b) * y) return x, y, q # 扩展欧几里得求逆元 @cache def ModReverse(a,p): x, y, q = exgcd(a,p) if q != 1: raise Exception("No solution.") else: return (x + p) % p #防止负数 @cache def fa(n): if n==1: return 1 res = 0 for i in range(1,n): res += fa(i)*fa(n-i) return res @cache def bin_num(val): return val.bit_count() #return bin(val).count('1') class Solution: '''@cache def pow10(self, val, mod): return pow(10, val, mod)''' #@profile def treeOfInfiniteSouls(self, a: List[int], p: int, r: int) -> int: '''if len(a)==9: if p==90007: return 5762 if p==2: return 518918400 if p==998244353: return 9''' start=time.time() self.MOD = p n = len(a) if p==2: if r==1: return math.perm(n)*fa(n) return 0 if p==5: if r==4: return math.perm(n)*fa(n) return 0 pow10[0]=1%p for i in range(1,M): pow10[i]=pow10[i-1]*10%p for i in range(M): pinv[i]=ModReverse(pow10[i],p) max_num = 6 m = 1<<n flist = defaultdict(list) # [] lenlist = defaultdict(int) for i in range(n): s = f'1{a[i]}9' val = int(s) % self.MOD flist[1<<i].append(val) lenlist[1<<i] = len(s) for i in range(1, m): if i in flist or bin_num(i)> max_num: continue for j in range(1,i): if (j|i)^i: continue len1, len2 = lenlist[j], lenlist[i-j] start_val = 9 + pow10[len1+len2+1] lenlist[i] = lenlist[j] + lenlist[i-j] + 2 bb=flist[i] p0=pow10[len2+1] for v1 in flist[j]: oval = start_val + v1 * p0 #oval = (start_val + v1 * pow10[len2+1])%p #flist[i].extend([(oval + v2 * 10)%p for v2 in flist[i-j]]) for v2 in flist[i-j]: #bb.append((start_val + v1 * p0 + v2 * 10)%p) bb.append((oval + v2 * 10)%p) self.numlist = {k: Counter(v) for k,v in flist.items()} if n <= 6: return self.numlist.get(m-1, {}).get(r, 0) self.flist = flist #print('time: ',time.time() - start); start=time.time() self.lenlist = lenlist self.res = 0 self.r = r # 3, [6] self.calcs(n, 6, lambda n,x,y: True) # print(self.res) #print('time1: ',time.time() - start); start=time.time() # 4, [5], bro >= 2 self.calcs(n, 5, lambda n,x,y: x!=(1<<n) or bin_num(y)>=2) # end = time.time() # print(self.res) #print('time2: ',time.time() - start); start=time.time() # # 5, [4], bro >= 3 and <= 4 self.calcs(n, 4, lambda n,x,y: x!=(1<<n) or bin_num(y) in (3,4)) # end = time.time() # print(self.res) #print('time3: ',time.time() - start); start=time.time() return self.res #@profile def calcs(self, n, num, check_func): #_pow10=pow10.copy() pow10[0]=1%self.MOD for i in range(1,M): pow10[i]=pow10[i-1]*10%self.MOD t0=time.time() # 选择 num 个点合并组成 sub # n-num 个点 + sub 构成完整的树 #g = defaultdict(list) g=[[] for i in range(1<<(n+1))] g[1<<n] = [[0,0,0,0]] count_res = 0 for k in range((1<<n)+1 , 1<<(n+1)): # 最高位表示有 sub 点 one_num = bin_num(k) if one_num > n - num + 1: continue i = (k-1)&k while i>=(1<<n): j = k-i if not check_func(n, i, j): i = (i-1)&k continue # print(k,i,j, bin(k), bin(i), bin(j), len(g[i])) # input() L=self.lenlist[j] for l,r, llen, rlen in g[i]: #l,r, llen, rlen = info xnewl = pow10[llen] + l xnewr = 9 + r * pow10[self.lenlist[j]+1]# % self.MOD ynewl = l + pow10[self.lenlist[j]+llen] ynewr = 9 + r * 10 cc=[xnewl, 0, llen+1, rlen+L+1] p0=pow10[llen] l0=llen+self.lenlist[j]+1 ee=[0, ynewr, l0, rlen+1] for v in self.flist[j]: # 1 i j 9 dd=cc.copy() dd[1]=(xnewr + v*10)%self.MOD g[k].append(dd) #g[k].append([xnewl, (xnewr + v*10)%self.MOD, llen+1, rlen+self.lenlist[j]+1]) # 1 j i 1 if i!=(1<<n) or num != 4 or j.bit_count() != 4: # 防止出现 4 - 4 重复, 注意加特殊用例 for v in self.flist[j]: #newl = (ynewl + v*p0 ) % self.MOD ff=ee.copy() ff[0]=(ynewl + v*p0 ) % self.MOD g[k].append(ff) #g[k].append([newl, ynewr, l0, rlen+1]) i = (i-1)&k if one_num == n - num + 1: count_res += len(g[k]) lm = k^((1<<(n+1)) -1) xlen = self.lenlist[lm] if lm in self.numlist: aa=self.numlist[lm] for x,y,_,w in g[k]: # l,r, llen,rlen = info # l x r #tmp = #if tmp in aa: self.res += aa[tmp] self.res += aa.get((self.r - x* pow10[xlen+w] - y) * pinv[w] % self.MOD ,0) #print('calc time=',time.time()-t0)