NC50525. 涂抹果酱
描述
输入描述
输入共三行。
第一行:N,M;
第二行:K;
第三行:M个整数,表示第K行的方案。
字母的详细含义见题目描述,其他参见样例。
输出描述
输出仅一行,为可行的方案总数。
示例1
输入:
2 2 1 2 3
输出:
3
说明:
Java(javac 1.8) 解法, 执行用时: 1917ms, 内存消耗: 48104K, 提交时间: 2021-03-02 12:08:07
import java.io.*; import java.lang.reflect.Type; import java.math.BigDecimal; import java.math.BigInteger; import java.math.RoundingMode; import java.text.DecimalFormat; import java.util.*; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.Future; import java.util.concurrent.FutureTask; import java.util.concurrent.locks.LockSupport; public class Main { static class Task { public static String roundS(double result, int scale) { String fmt = String.format("%%.%df", scale); return String.format(fmt, result); // DecimalFormat df = new DecimalFormat("0.000000"); // double result = Double.parseDouble(df.format(result)); } int rt(int x) { if (x != fa[x]) { int to = rt(fa[x]); dp[x] ^= dp[fa[x]]; fa[x] = to; return to; } return x; } void combine(int x, int y, int val) { int rt1 = rt(x); int rt2 = rt(y); if (rt1 == rt2) return; fa[rt1] = rt2; dp[rt1] = dp[x] ^ dp[y] ^ val; g--; } int fa[], dp[]; int g; static int MAXN = 10000; static Random rd = new Random(348957438574659L); static int[] ch[], val, size, rnd, cnt; static int len = 0, rt = 0; // return new node, the node below s static int rotate(int s, int d) { // child int x = ch[s][d ^ 1]; // give me grandson ch[s][d ^ 1] = ch[x][d]; // child become father ch[x][d] = s; // update size, update new son first update(s); update(x); return x; } static void update(int s) { size[s] = size[ch[s][0]] + size[ch[s][1]] + cnt[s]; } // 0 for left, 1 for right static int cmp(int x, int num) { if (val[x] == num) return -1; return num < val[x] ? 0 : 1; } static int insert(int s, int num) { if (s == 0) { s = ++len; val[s] = num; size[s] = 1; rnd[s] = rd.nextInt(); cnt[s] = 1; } else { int d = cmp(s, num); if (d != -1) { ch[s][d] = insert(ch[s][d], num); // father's random should be greater if (rnd[s] < rnd[ch[s][d]]) { s = rotate(s, d ^ 1); } else { update(s); } } else { ++cnt[s]; ++size[s]; } } return s; } static int del(int s, int num) { int d = cmp(s, num); if (d != -1) { ch[s][d] = del(ch[s][d], num); update(s); } else if (ch[s][0] * ch[s][1] == 0) { if (--cnt[s] == 0) { s = ch[s][0] + ch[s][1]; } } else { int k = rnd[ch[s][0]] < rnd[ch[s][1]] ? 0 : 1; // k points to smaller random value,then bigger one up s = rotate(s, k); // now the node with value num become the child ch[s][k] = del(ch[s][k], num); update(s); } return s; } static int getKth(int s, int k) { int lz = size[ch[s][0]]; if (k >= lz + 1 && k <= lz + cnt[s]) { return val[s]; } else if (k <= lz) { return getKth(ch[s][0], k); } else { return getKth(ch[s][1], k - lz - cnt[s]); } } static int getRank(int s, int value) { if (s == 0) return 1; if (value == val[s]) return size[ch[s][0]] + 1; if (value < val[s]) return getRank(ch[s][0], value); return getRank(ch[s][1], value) + size[ch[s][0]] + cnt[s]; } static int getPre(int data) { int ans = -1; int p = rt; while (p > 0) { if (data > val[p]) { if (ans == -1 || val[p] > val[ans]) ans = p; p = ch[p][1]; } else p = ch[p][0]; } return ans != -1 ? val[ans] : (-2147483647); } static int getNext(int data) { int ans = -1; int p = rt; while (p > 0) { if (data < val[p]) { if (ans == -1 || val[p] < val[ans]) ans = p; p = ch[p][0]; } else p = ch[p][1]; } return ans != -1 ? val[ans] : 2147483647; } static boolean find(int s, int num) { while (s != 0) { int d = cmp(s, num); if (d == -1) return true; else s = ch[s][d]; } return false; } static int ans = -10000000; static boolean findX(int s, int num) { while (s != 0) { if (val[s] <= num) { ans = num; } int d = cmp(s, num); if (d == -1) return true; else { s = ch[s][d]; } } return false; } long gcd(long a, long b) { if (b == 0) return a; return gcd(b, a % b); } void linear_sort(int arr[]) { int d = 65536; ArrayDeque bucket[] = new ArrayDeque[d]; for (int j = 0; j < d; ++j) { bucket[j] = new ArrayDeque(); } for (int u : arr) { bucket[u % d].offer(u); } int pos = 0; for (int j = 0; j < d; ++j) { while (bucket[j].size() > 0) { arr[pos++] = (int)bucket[j].pollFirst(); } } for (int u : arr) { bucket[u / d].offer(u); } pos = 0; for (int j = 0; j < d; ++j) { while (bucket[j].size() > 0) { arr[pos++] = (int)bucket[j].pollFirst(); } } } int cur = 0; int h[]; int to[]; int ne[]; void add(int u, int v) { to[cur] = v; ne[cur] = h[u]; h[u] = cur++; } public boolean dfs(String cur, HashMap<String, Integer> vis, Map<String, List<String>> list, int cl) { vis.put(cur, cl); for (String hp : list.get(cur)) { if (vis.containsKey(hp)) { if (vis.get(hp) == cl) { return false; } } else { if (!dfs(hp, vis, list, 1 - cl)) { return false; } } } return true; } public class MultiSet { TreeMap<Integer, Integer> map; int ct = 0; MultiSet() { map = new TreeMap<>(); } public void remove(int key) { if (map.containsKey(key)) { int times = map.get(key); if (times == 1) { map.remove(key); } else { map.put(key, times - 1); } ct--; } } public void add(int key) { map.put(key, map.getOrDefault(key, 0) + 1); ct++; } public int last() { return map.lastKey(); } public int first() { return map.firstKey(); } public int size() { return ct; } } BigInteger dp1[][]; int x[][]; BigInteger s(int l, int r, BigInteger a[]) { if (l >= r) return BigInteger.ZERO; if (l + 1 == r) { x[l][r] = l; return (a[l].add(a[r])).multiply(a[l]); } if (!dp1[l][r].equals(BigInteger.valueOf(-1))) { return dp1[l][r]; } BigInteger ans = BigInteger.valueOf(Long.MIN_VALUE); for (int j = l; j < r; ++j) { BigInteger tp = (a[l].add(a[r])).multiply(a[j]).add(s(l, j, a)).add(s(j + 1, r, a)); if (tp.compareTo(ans) > 0) { ans = tp; x[l][r] = j; } } return dp1[l][r] = ans; } TreeMap<Integer, ArrayList<Integer>> mp = new TreeMap<>(); void ss(int l, int r, int dep) { if (l == r) return; int sp = x[l][r]; ArrayList<Integer> li = mp.getOrDefault(dep, new ArrayList<>()); li.add(sp + 1); mp.put(dep, li); ss(l, sp, dep + 1); ss(sp + 1, r, dep + 1); } int color[], dfn[], low[], stack[]; int sccno[]; boolean iscut[]; int time = 0, top = 0; int scc_cnt; int dcc_cnt; List<Integer> dcc[]; int root = 0; // 无向图的强连通分量 void tarjanNonDirect(int u) { low[u] = dfn[u] = ++time; stack[top++] = u; int child = 0; for (int i = h[u]; i != -1; i = ne[i]) { int v = to[i]; if (dfn[v] == 0) { tarjanNonDirect(v); low[u] = Math.min(low[u], low[v]); if (low[v] >= dfn[u]) { if (u != root || ++child > 1) { // 不是root,直接记为cut,是root,判断是否有两个儿子 iscut[u] = true; } ++dcc_cnt; // 一个割点可能会被多个点双共享 int z = -1; do { z = stack[--top]; //dcc[dcc_cnt].add(z); } while (z != v); //dcc[dcc_cnt].add(u); } } else { low[u] = Math.min(low[u], dfn[v]); // 没有特判是否直接指向父亲 // 回边,使用dfn【v】更新low【u】,因为可能是指向父亲,而父亲的low可能比较小 } } } long dp11[][]; public long s(int l, int r, int a[]) { if (l > r) return 0; if (dp11[l][r] != 0) { return dp11[l][r]; } if (l == r) { return dp11[l][r] = a[l]; } return dp11[l][r] = Math.max((long)a[l] - s(l + 1, r, a), (long)a[r] - s(l, r - 1, a)); } List<Long> all = new ArrayList<>(); void dfs(int cur,long st,int m){ if(cur==m){ all.add(st); return; } long slt = st%3; if(cur==0){ slt = -1; } for(int i=0;i<3;++i){ if(i==slt) continue; dfs( cur +1 ,st *3 + i ,m ); } } boolean chontu(long st,long st1,int m,long f[]){ for(int i=0;i<m;++i){ long q1 = st/f[i]%3; long q2 = st1/f[i]%3; if(q1==q2) return true; } return false; } long g(long s,int k,int k1,int m,long f[],long mod){ Map<Long,Long> mp = new HashMap<>(); mp.put(s,1L); long tot2 = 0; for(int i=0;i<k;++i){ Map<Long,Long> nmp = new HashMap<>(); for(long cur:mp.keySet()){ for(long a:all){ if(!chontu(a,cur,m,f)){ long val = (nmp.getOrDefault(a,0L)+mp.get(cur))%mod; nmp.put(a,val); } } } mp = nmp; if(i==k1-1){ for(long ke:mp.keySet()){ tot2 += mp.get(ke); //tot2 %= mod; } } } if(tot2==0){ tot2 = 1; } tot2 %= mod; long tot1 = 0; for(long ke:mp.keySet()){ tot1 += mp.get(ke); tot1 %= mod; } return (tot1*tot2)%mod; } public void solve(int testNumber, InputReader in, PrintWriter out) { int n = in.nextInt(); int m = in.nextInt(); int k = in.nextInt(); long f[] = new long[m+1]; f[0] = 1; for(int i=1;i<=m;++i){ f[i] = f[i-1]* 3; } dfs(0,0,m); long mod = (long)1e6; long s = 0; long lst = -1; for(int j=0;j<m;++j){ long v = in.nextInt() - 1; if(v==lst){ out.println(0);return; } s += v*f[j]; lst = v; } // long g(long s,int k,int m,long f[],long mod){ int mx = Math.max(k-1,n-k); int mi = Math.min(k-1,n-k); long ans1 = g(s,mx,mi,m,f,mod); out.println(ans1); // boolean f[] = new boolean[n+1]; // // long r = 1; // long mod = 1000000007; // for(int j=2;j<=n;++j){ // if(!f[j]){ // // for(int v=j*j;v<=n;v+=j){ // f[j] = true; // } // int nn = n; // long s = 0; // while(nn>0){ // s += nn/j; // // s %= mod; // nn /= j; // } // s *= 2; // s++; // s %= mod; // r *= s; // r %= mod; // } // } // out.println(r); // int t = 1; // // // outer:for(int z=1;z<=t;++z) { // int n = in.nextInt(); // // h = new int[n+1000000]; // Arrays.fill(h,-1); // to = new int[10*(n+1000000)]; // ne = new int[10*(n+1000000)]; // // // // int a[] = in.nextArray(n); // // // Arrays.sort(a); // int x= 0; // for(int i=0;i<n;++i){ // if(i==0||a[i]!=a[i-1]){ // a[x++] = a[i]; // } // } // // int maxn = 10000000; // // int tot[] = new int[maxn+1]; // // for(int u:a){ // tot[u]++; // } // // int prime[] = new int[maxn + 1]; // int phi[] = new int[maxn+1]; // boolean visit[] = new boolean[maxn + 1]; // visit[1] = true; // int minfactor[] = new int[maxn+1]; // minfactor[1] = 1; // int mu[] = new int[maxn+1]; // mu[1]=1; // int p = 0; // for (int i = 2; i <= maxn; ++i) { // if (!visit[i]) { // prime[p++] = i; // // prime // phi[i] = i-1; // minfactor[i] = i; // mu[i] = -1; // } // for (int j = 0; j < p; ++j) { // int check = i * prime[j]; // if( check > maxn){ // break; // } // visit[check] = true; // minfactor[check] = prime[j]; // if (i % prime[j] == 0) { // phi[ check ] = phi[i] * prime[j]; // mu[ check ] = 0; // break; // } // mu[ check ] = -mu[i]; // phi[ check ] = phi[i] * (prime[j] - 1); // } // } // int ct[] = new int[maxn+1]; // int id[] = new int[maxn+1]; // int from = 1; // // Map<Integer,Integer> old = new HashMap<>(); // for(int i=0;i<x;++i){ // ct[a[i]]++; // id[a[i]] = from++; // old.put(from-1,a[i]); // } // // for(int i=0;i<p;++i){ // int v = prime[i]; // int cur = from++; // for(int j=v;j<=maxn;j+=v){ // if(ct[j]>0){ // add(cur,id[j]); // add(id[j],cur); // } // } // } // low =new int[from+1]; // dfn =new int[from+1]; // stack = new int[from+1]; // iscut = new boolean[from+1]; // // // for(int i=1;i<from;++i){ // tarjanNonDirect(i); // } // int r = 0; // for(int i=1;i<from;++i){ // if(old.containsKey(i)&&iscut[i]&&tot[old.get(i)]==1){ // r++; // } // } // out.println(r); // // // // // } // while(true) { // int n = in.nextInt(); // // int m =in.nextInt(); // // fa = new int[n]; // dp = new int[n]; // for(int i=0;i<n;++i){ // fa[i] = i; // } // g = n; // int c = 0; // int as[] = new int[n]; // int bs[] = new int[n]; // char xs[] = new char[n]; // // int at = -1; // Set<Integer> st = new HashSet<>(); // // for (int i = 0; i < n; ++i) { // String line = in.next(); // int p = 0; // int a = 0; // while(Character.isDigit(line.charAt(p))){ // a = a*10 + (line.charAt(p)-'0'); p++; // } // char x = line.charAt(p++); // // int b = 0; // while(p<line.length()){ // b = b*10 + (line.charAt(p)-'0'); p++; // } // // as[i] = a; // xs[i] = x; // bs[i] = b; // // if(x=='='){ // int r1 = rt(a); int r2 = rt(b); // if(r1==r2){ // if(dp[a]!=dp[b]){ // c++; // at = i; // } // }else { // combine(a, b, 0); // } // }else if(x=='<'){ // int r1 = rt(a); int r2 = rt(b); // if(r1==r2){ // if(dp[a]>=dp[b]){ // c++; // at = i; // } // }else { // combine(a, b, -1); // } // }else{ // int r1 = rt(a); int r2 = rt(b); // if(r1==r2){ // if(dp[a]<=dp[b]){ // c++; // at = i; // } // }else { // combine(a, b, 1); // } // } // // // } // if(g==1||c>=2){ // out.println("Impossible"); // continue; // } // // // for(int xuan: st){ // // // // // } // // // // // // // } } static long mul(long a, long b, long p) { long res = 0, base = a; while (b > 0) { if ((b & 1L) > 0) res = (res + base) % p; base = (base + base) % p; b >>= 1; } return res; } static long mod_pow(long k, long n, long p) { long res = 1L; long temp = k % p; while (n != 0L) { if ((n & 1L) == 1L) { res = mul(res, temp, p); } temp = mul(temp, temp, p); n = n >> 1L; } return res % p; } public static double roundD(double result, int scale) { BigDecimal bg = new BigDecimal(result).setScale(scale, RoundingMode.UP); return bg.doubleValue(); } } private static void solve() { InputStream inputStream = System.in; // InputStream inputStream = null; // try { // inputStream = new FileInputStream(new File("D:\\chrome_download\\exp.out")); // } catch (FileNotFoundException e) { // e.printStackTrace(); // } OutputStream outputStream = System.out; // OutputStream outputStream = null; // File f = new File("D:\\chrome_download\\"); // try { // f.createNewFile(); // } catch (IOException e) { // e.printStackTrace(); // } // try { // outputStream = new FileOutputStream(f); // } catch (FileNotFoundException e) { // e.printStackTrace(); // } InputReader in = new InputReader(inputStream); PrintWriter out = new PrintWriter(outputStream); Task task = new Task(); task.solve(1, in, out); out.close(); } public static void main(String[] args) { // new Thread(null, () -> solve(), "1", (1 << 30)).start(); solve(); } static class InputReader { public BufferedReader reader; public StringTokenizer tokenizer; public InputReader(InputStream stream) { reader = new BufferedReader(new InputStreamReader(stream), 32768); tokenizer = null; } public String nextLine() { String line = null; try { line = reader.readLine(); } catch (IOException e) { throw new RuntimeException(e); } return line; } public String next() { while (tokenizer == null || !tokenizer.hasMoreTokens()) { try { tokenizer = new StringTokenizer(reader.readLine()); } catch (IOException e) { throw new RuntimeException(e); } } return tokenizer.nextToken(); } public int nextInt() { return Integer.parseInt(next()); } public char nextChar() { return next().charAt(0); } public int[] nextArray(int n) { int res[] = new int[n]; for (int i = 0; i < n; ++i) { res[i] = nextInt(); } return res; } public long nextLong() { return Long.parseLong(next()); } public double nextDouble() { return Double.parseDouble(next()); } } }
Kotlin 解法, 执行用时: 507ms, 内存消耗: 18052K, 提交时间: 2022-12-01 15:18:25
@file:Suppress("NOTHING_TO_INLINE", "EXPERIMENTAL_FEATURE_WARNING", "OVERRIDE_BY_INLINE", "DEPRECATION") @file:OptIn(ExperimentalStdlibApi::class) import java.io.PrintWriter import java.util.StringTokenizer import kotlin.collections.ArrayDeque import kotlin.math.* import kotlin.random.* import java.util.TreeMap import java.util.TreeSet import java.util.PriorityQueue // import java.math.BigInteger // import java.util.* @JvmField val INPUT = System.`in` @JvmField val OUTPUT = System.out @JvmField val reader = INPUT.bufferedReader() fun readLine(): String? = reader.readLine() fun readLn() = reader.readLine()!! @JvmField var _tokenizer: StringTokenizer = StringTokenizer("") fun read(): String { while (_tokenizer.hasMoreTokens().not()) _tokenizer = StringTokenizer(reader.readLine() ?: return "", " ") return _tokenizer.nextToken() } fun readInt() = read().toInt() fun readDouble() = read().toDouble() fun readLong() = read().toLong() fun readStrings(n: Int = 2) = List(n) { read() } fun readString() = readStrings(1)[0] fun readLines(n: Int) = List(n) { readLn() } fun readInts(n: Int = 2) = List(n) { read().toInt() } fun readInts1(n: Int = 2) = List(n) { read().toInt() - 1 } fun readIntArray(n: Int) = IntArray(n) { read().toInt() } fun readDoubles(n: Int = 2) = List(n) { read().toDouble() } fun readDoubleArray(n: Int) = DoubleArray(n) { read().toDouble() } fun readLongs(n: Int = 2) = List(n) { read().toLong() } fun readLongArray(n: Int) = LongArray(n) { read().toLong() } // val isLocal = System.getenv("IS_LOCAL_CP") == "true" @JvmField val writer = PrintWriter(OUTPUT) // ---------------------------------------------------------------------------- private fun IntArray.modSum(mod: Long = MODL): Long { var sum = 0L for (num in this) { sum = (sum + num) % mod } return sum } private fun LongArray.modSum(mod: Long = MODL): Long { var sum = 0L for (num in this) { sum = (sum + num) % mod } return sum } private fun Int.modSum(other: Int, mod: Int = MOD): Int { return ((this % mod) + (other % mod)) % mod } private fun Long.modSum(other: Long, mod: Long = MODL): Long { return ((this % mod) + (other % mod)) % mod } private fun <T : Comparable<T>> T.max(target: T): T = this.coerceAtLeast(target) private fun <T : Comparable<T>> T.min(target: T): T = this.coerceAtMost(target) private val DIR = listOf( listOf(1, 0), listOf(0, -1), listOf(-1, 0), listOf(0, 1) ) // private const val MOD = 998244353 // private const val MODL = 998244353L private const val MOD = 1_000_000 private const val MODL = 1_000_000L private const val EPS = 0.000001 // ---------------------------------------------------------------------------- fun main() { val go: Runnable = Runnable { writer.solve() writer.flush() } Thread(null, go, "thread", 1L.shl(28)).start() // writer.solve() // writer.flush() } private fun PrintWriter.solve() { val (m, n) = readInts() val k = readInt() val row = readInts1(n) // println(jamCount2(m, n, k, row)) // println(jamCount1(m, n, k, row)) println(jamCount(m, n, k, row)) } // AC // private fun jamCount(m: Int, n: Int, k: Int, row: List<Int>): Long { // var ms = 1 // var rowState = 0 // for (j in 0 until n) { // ms *= 3 // rowState = rowState * 3 + row[j] // } // if (!check1(rowState, n)) { // return 0L // } // val r1 = k - 1 // val r2 = m - k // val r = Math.max(r1, r2) // val dp = Array(r + 1) { IntArray(ms) } // dp[0][rowState] = 1 // for (i in 0 until r) { // for (s in 0 until ms) { // if (dp[i][s] == 0) { // continue // } // for (ns in 0 until ms) { // if (check1(ns, n) && check2(s, ns, n)) { // dp[i + 1][ns] = dp[i + 1][ns].modSum(dp[i][s]) // } // } // } // } // val c1 = dp[r1].modSum() // val c2 = dp[r2].modSum() // // println("$c1, $c2") // return (c1 * c2) % MODL // } // private fun check1(state: Int, n: Int): Boolean { // var s = state // repeat(n - 1) { // if (s % 3 == (s / 3) % 3) { // return false // } // s /= 3 // } // return true // } // private fun check2(s1: Int, s2: Int, n: Int): Boolean { // var a = s1 // var b = s2 // repeat(n) { // if (a % 3 == b % 3) { // return false // } // a /= 3 // b /= 3 // } // return true // } // TLE // private fun jamCount2(m: Int, n: Int, k: Int, row: List<Int>): Long { // for (j in 1 until n) { // if (row[j] == row[j - 1]) { // return 0L // } // } // val grid = Array(m) { IntArray(n) } // for (j in 0 until n) { // grid[k - 1][j] = row[j] // } // return search(0, m, n, k, grid) // } // private fun search( // index: Int, // m: Int, // n: Int, // k: Int, // grid: Array<IntArray> // ): Long { // if (index == m * n) { // return 1L // } // val r = index / n // val c = index % n // if (r == k - 1) { // return search(index + n, m, n, k, grid) // } // val left = if (c == 0) -1 else grid[r][c - 1] // val up = if (r == 0) -1 else grid[r - 1][c] // var result = 0L // for (num in 0 until 3) { // if (num == left || num == up) { // continue // } // grid[r][c] = num // result += search(index + 1, m, n, k, grid) // result %= MODL // } // return result // } // TLE private fun jamCount(m: Int, n: Int, k: Int, row: List<Int>): Long { for (j in 1 until n) { if (row[j] == row[j - 1]) { return 0L } } val mapping = mutableMapOf<Long, MutableList<Long>>().apply { search(0, IntArray(n), IntArray(n), n, this) } // println(mapping) val firstState = rowToState(row.toIntArray(), n) var stateToCount = mutableMapOf(firstState to 1L) val r1 = k - 1 val r2 = m - k var c1 = 1L var c2 = 1L // println(stateToCount) for (r in 1 until Math.max(r1, r2) + 1) { val nextStateToCount = mutableMapOf<Long, Long>() for ((state, count) in stateToCount) { for (next in mapping[state].orEmpty()) { nextStateToCount[next] = (nextStateToCount.getOrDefault(next, 0L) + count) % MODL } } stateToCount = nextStateToCount if (r == r1) { c1 = stateToCount.values.toLongArray().modSum() } if (r == r2) { c2 = stateToCount.values.toLongArray().modSum() } // println("$r, $stateToCount") } // println("$c1, $c2") return (c1 * c2) % MODL } private fun search( index: Int, row1: IntArray, row2: IntArray, n: Int, mapping: MutableMap<Long, MutableList<Long>> ) { // println("$index. ${row1.toList()}, ${row2.toList()}") if (index == n * 2) { val s1 = rowToState(row1, n) val s2 = rowToState(row2, n) mapping[s1] = mapping.getOrPut(s1, { mutableListOf() }).apply { add(s2) } return } if (index < n) { val left = if (index == 0) -1 else row1[index - 1] for (num in 0 until 3) { if (num == left) { continue } row1[index] = num search(index + 1, row1, row2, n, mapping) } return } val col = index % n val left = if (col == 0) -1 else row2[col - 1] val up = row1[col] for (num in 0 until 3) { if (num == left || num == up) { continue } row2[col] = num search(index + 1, row1, row2, n, mapping) } } private fun rowToState(row: IntArray, n: Int): Long { var base = 1L var state = 0L for (num in row) { state += base * num base *= 3L } // println("$state: ${row.toList()}") return state } // TLE 80% // val POW = mutableListOf(1) // private fun jamCount(m: Int, n: Int, k: Int, row: List<Int>): Long { // for (j in 1 until n) { // if (row[j] == row[j - 1]) { // return 0L // } // } // repeat(n) { // POW.add(POW.last() * 3) // } // var rowState = 0 // for (j in 0 until n) { // rowState = setColValue(rowState, j, row[j]) // } // val ms = POW.last() // val dp = LongArray(ms) // dp[rowState] = 1L // val ndp = LongArray(ms) // val r1 = k - 1 // val r2 = m - k // var c1 = 1L // var c2 = 1L // for (r in 1 until Math.max(r1, r2) + 1) { // ndp.fill(0L) // for (s in 0 until ms) { // if (dp[s] == 0L) { // continue // } // for (ns in 0 until ms) { // // println("${getCols(s, n)}, ${getCols(ns, n)}, ${check(s, ns, n)}") // if (check(s, ns, n)) { // ndp[ns] = ndp[ns].modSum(dp[s]) // } // } // } // ndp.copyInto(dp) // if (r == r1) { // c1 = dp.modSum() // } // if (r == r2) { // c2 = dp.modSum() // } // // println("$r, ${dp.toList()}") // } // // println("$c1, $c2") // return (c1 * c2) % MODL // } // private fun check(s1: Int, s2: Int, n: Int): Boolean { // for (j in 0 until n) { // if (getColValue(s1, j) == getColValue(s2, j)) { // return false // } // if (j > 0 && getColValue(s2, j) == getColValue(s2, j - 1)) { // return false // } // } // return true // } // private fun getCols(state: Int, n: Int): List<Int> { // return (0 until n).map { getColValue(state, it) } // } // private fun setColValue(state: Int, offset: Int, value: Int): Int { // return state + value * POW[offset] // } // private fun getColValue(state: Int, offset: Int): Int { // return (state / POW[offset]) % 3 // }
C++14(g++5.4) 解法, 执行用时: 394ms, 内存消耗: 8276K, 提交时间: 2019-08-03 23:27:35
#include<bits/stdc++.h> #define ll long long #define mod 1000000 using namespace std; ll n,m,k,a[6],f[10010][100],mode,ans1,ans2; vector<ll> v; void init(int x,int y)//初始化合法状态 { if(x==m) { v.push_back(y); return ; } else if(x==0) { init(1,0); init(1,1); init(1,2); return ; } if((int)(y/pow(3,x-1))%3==0) { init(x+1,y+1*pow(3,x)); init(x+1,y+2*pow(3,x)); } else if((int)(y/pow(3,x-1))%3==1) { init(x+1,y+0*pow(3,x)); init(x+1,y+2*pow(3,x)); } else { init(x+1,y+0*pow(3,x)); init(x+1,y+1*pow(3,x)); } } bool judge(int x,int y)//判断x状态和y状态分别为上下两行时是否合法 { int wei = m; while(wei--) { int a = x%3; int b = y%3; if(a==b) return false; x/=3; y/=3; } return true; } int main() { cin>>n>>m>>k; for(int i = 0;i<m;i++) { cin>>a[i]; mode+=pow(3,m-1-i)*(a[i]-1);//mode记录第k行状态 } init(0,0); int index = find(v.begin(),v.end(),mode)-v.begin(); if(index==v.size())//如果第k行的状态不合法 { cout<<0; return 0; } if(k==1)//如果k为1,那么第一阶段方案数为1 { ans1 = 1; f[k][index] = 1; } if(k!=1) { for(int i = 0;i<v.size();i++) f[1][i] = 1; } for(int i = 2;i<=n;i++) { if(i==k) { for(int l = 0;l<v.size();l++) { if(judge(mode,v[l])) { f[i][index]+=f[i-1][l]; f[i][index]%=mod; } } ans1 = f[i][index]%mod; f[i][index] = 1; continue; } for(int j = 0;j<v.size();j++) { for(int l = 0;l<v.size();l++) { if(judge(v[j],v[l])&&f[i-1][l]) {//cout<<v[j]<<" "<<v[l]<<endl; f[i][j]+=f[i-1][l]; f[i][j]%=mod; } } } } for(int j = 0;j<v.size();j++) { ans2 = (ans2+f[n][j])%mod; } cout<<(ans1*ans2)%mod; return 0; }
C++11(clang++ 3.9) 解法, 执行用时: 63ms, 内存消耗: 16268K, 提交时间: 2020-06-01 10:17:46
#include<iostream> #define mod 1000000 using namespace std; long long n,m,K,c[10000],cnt,a[10000],f[10001][200],s,num,correct[10000][10000],ans1,ans2; inline int check(int x) { for(int i=m-1;i;i--) if((x%c[i+1]/c[i])==(x%c[i]/c[i-1])) return 0; return 1; } inline int check2(int x,int y) { for(int i=m-1;i>=0;i--) if((x%c[i+1]/c[i])==(y%c[i+1]/c[i])) return 0; return 1; } int main() { cin>>n>>m>>K; c[0]=1; for(int i=1;i<=m;i++) { cin>>s; num*=3; num+=s-1; c[i]=c[i-1]*3; } if(!check(num)) { cout<<0; return 0; } for(int i=0;i<c[m];i++) { if(check(i)) a[++cnt]=i; if(i==num) f[0][cnt]=1; } for(int i=1;i<=cnt;i++) for(int j=1;j<=cnt;j++) if(check2(a[i],a[j])) correct[i][j]=1; int x=max(K-1,n-K); int y=min(K-1,n-K); for(int i=1;i<=x;i++) for(int j=1;j<=cnt;j++) for(int k=1;k<=cnt;k++) if(correct[j][k]) f[i][j]=(f[i][j]+f[i-1][k])%mod; for(int i=1;i<=cnt;i++) { ans1+=f[x][i]; ans1%=mod; ans2+=f[y][i]; ans2%=mod; } cout<<ans1*ans2%mod; return 0; }