NC17443. F、Squirtle
The input starts with one line containing exactly one integer t which is the number of test cases. (1 ≤ t ≤ 20)
For each test case, the first line contains exactly one integer n which is the number of leaf nodes. (2 ≤ n ≤ 2000)
Each of the next n-1 lines contains a string si of length 16. It is guaranteed that si only consists of 0 and 1 and contains at least one 1. If si's j-th character(numbered from 0) sij=1, then operator j∈ Si, otherwise operator .
Each of the next 2n-2 lines contains an integer ai, which represents the father of i+1. It is guaranteed that the tree is valid. That is, the tree is a binary tree whose nodes have either 0 or 2 children. Nodes from 1 to n-1 are guaranteed to be non-leaf nodes. The child with the smaller index is regarded as the left operand. (1 ≤ ai ≤ i)
For each test case, output "Case #x: y" in one line (without quotes), where x is the test case number (starting from 1) and y is the maximum possible sum.
2 2 0000000010000010 1 1 3 0000000010000010 0000000010000011 1 1 2 2
Case #1: 3 Case #2: 8
Java(javac 1.8) 解法, 执行用时: 1973ms, 内存消耗: 179668K, 提交时间: 2018-08-04 15:59:52
import java.math.*; import java.util.*; public class Main { public static String op[]; public static int ch[][]; public static int p[]; public static BigInteger _[][],all[]; public static BigInteger __1; public static BigInteger __2; public static void dfs(int u) { if(p[u]==0) { _[0][u]=__1; _[1][u]=__1; all[u]=__2; return; } dfs(ch[u][0]); dfs(ch[u][1]); _[0][u]=BigInteger.ZERO;_[1][u]=BigInteger.ZERO;all[u]=all[ch[u][0]].multiply(all[ch[u][1]]); for(int i=0;i<2;i++)for(int j=0;j<2;j++) { for(int k=0;k<16;k++)if(op[u].charAt(k)=='1') { BigInteger cs0=BigInteger.ZERO,cs1=BigInteger.ZERO; BigInteger ans0=all[ch[u][0]].subtract(_[i][ch[u][0]]).multiply(all[ch[u][1]].subtract(_[j][ch[u][1]])); BigInteger ans1=all[ch[u][0]].subtract(_[i][ch[u][0]]).multiply(_[j][ch[u][1]]); BigInteger ans2=_[i][ch[u][0]].multiply(all[ch[u][1]].subtract(_[j][ch[u][1]])); BigInteger ans3=_[i][ch[u][0]].multiply(_[j][ch[u][1]]); if((k&1)==1)cs1=cs1.add(ans0);else cs0=cs0.add(ans0); if((k&2)==2)cs1=cs1.add(ans1);else cs0=cs0.add(ans1); if((k&4)==4)cs1=cs1.add(ans2);else cs0=cs0.add(ans2); if((k&8)==8)cs1=cs1.add(ans3);else cs0=cs0.add(ans3); if(cs0.compareTo(_[0][u])==1)_[0][u]=cs0; if(cs1.compareTo(_[1][u])==1)_[1][u]=cs1; } } _[0][u]=all[u].subtract(_[0][u]); } public static void main(String[] args) { Scanner cin=new Scanner(; int T=cin.nextInt(); for(int cas=1;cas<=T;cas++) { int n=cin.nextInt(); op=new String[2*n]; ch=new int[2*n][2]; p=new int[2*n]; _=new BigInteger[2][2*n]; all=new BigInteger[2*n]; __1=new BigInteger("1"); __2=new BigInteger("2"); for(int i=1;i<=n-1;i++)op[i]; for(int i=2;i<=2*n-1;i++) { int a=cin.nextInt(); ch[a][p[a]++]=i; } dfs(1); System.out.println("Case #"+cas+": "+_[1][1]); } } }
Python(2.7.3) 解法, 执行用时: 1960ms, 内存消耗: 7372K, 提交时间: 2018-08-04 19:41:25
def calc(op, x, y): rt = 0 if (x&y)==1: rt += 8 if (x&(1-y))==1: rt += 4 if ((1-x)&y)==1: rt += 2 if ((1-x)&(1-y))==1: rt += 1 if (rt&op)!=0: return 1; return 0; tes = (int)(input()); for tt in range(1, tes+1): n = int(input()); e = [[] for i in range(n*2)] dp = [[[0 for i in range(2)] for i in range(2)] for i in range(n*2)] s = [""] que = [1] for i in range(1,n): s.append(raw_input()) for i in range(2, n*2): x = int(input()) e[x].append(i) i = 0 while i<len(que): x = que[i] #print(x) for y in e[x]: que.append(y) i += 1 for i in range(len(que)-1, -1, -1): x = que[i] #print(len(e[x])) if len(e[x]) == 0: dp[x] = [[1, 1], [1, 1]] continue l = e[x][0] r = e[x][1] if l>r: t = l l = r r = t for op in range(16): if s[x][op] == '0': continue for lp in range(2): for rp in range(2): now = [0, 0] for j in range(2): for k in range(2): now[calc(op, j, k)] += dp[l][lp][j]*dp[r][rp][k] if now[0] > dp[x][0][0]: dp[x][0] = now if now[1] > dp[x][1][1]: dp[x][1] = now #for i in range(1, 2*n): print("%d: %d %d %d %d" %(i, dp[i][0][0], dp[i][0][1], dp[i][1][0], dp[i][1][1])) print("Case #%d: %d" %(tt, dp[1][1][1]))
Python3 解法, 执行用时: 1566ms, 内存消耗: 8336K, 提交时间: 2021-10-06 17:11:12
N = 2005 pw = [1 for i in range(N)] for i in range(1, N): pw[i] = pw[i - 1] << 1 def calc(u): if len(G[u]) == 0: f[u], sz[u] = [1, 1], 1 return f[u], sz[u] = [0, 0], 0 for v in G[u]: sz[u] += sz[v] [L, R], x, y = G[u], [0, 0], [0, 0] for i in range(2): x[i], x[i ^ 1] = f[L][i], pw[sz[L]] - f[L][i] for j in range(2): y[j], y[j ^ 1] = f[R][j], pw[sz[R]] - f[R][j] for k in range(16): if s[u][k] == '0': continue z = [0, 0] for r in range(4): z[k >> r & 1] += x[r >> 1 & 1] * y[r & 1] for r in range(2): f[u][r] = max(f[u][r], z[r]) for _ in range(int(input())): n = int(input()) m = 2 * n - 1 s = [""] * (n + 1) sz = [0 for i in range(m + 1)] G = [[] for i in range(m + 1)] f = [[] for i in range(m + 1)] for i in range(1, n): s[i] = input() for i in range(2, m + 1): G[int(input())].append(i) for i in range(m, 0, -1): calc(i) print("Case #%d: %d" % (_ + 1, f[1][1]))