列表

详情


NC50525. 涂抹果酱

描述

Tyvj两周年庆典要到了,Sam想为Tyvj做一个大蛋糕。蛋糕俯视图是一个N×M的矩形,它被划分成N×M个边长为1×1的小正方形区域(可以把蛋糕当成N行M列的矩阵)。蛋糕很快做好了,但光秃秃的蛋糕肯定不好看!所以,Sam要在蛋糕的上表面涂抹果酱。果酱有三种,分别是红果酱、绿果酱、蓝果酱,三种果酱的编号分别为1,2,3。为了保证蛋糕的视觉效果,Admin下达了死命令:相邻的区域严禁使用同种果酱。但Sam在接到这条命令之前,已经涂好了蛋糕第K行的果酱,且无法修改。
现在Sam想知道:能令Admin满意的涂果酱方案有多少种。请输出方案数。若不存在满足条件的方案,请输出0。

输入描述

输入共三行。
第一行:N,M;
第二行:K;
第三行:M个整数,表示第K行的方案。
字母的详细含义见题目描述,其他参见样例。

输出描述

输出仅一行,为可行的方案总数。

示例1

输入:

2 2 
1 
2 3

输出:

3

说明:

原站题解

上次编辑到这里,代码来自缓存 点击恢复默认模板

Java(javac 1.8) 解法, 执行用时: 1917ms, 内存消耗: 48104K, 提交时间: 2021-03-02 12:08:07

import java.io.*;
import java.lang.reflect.Type;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.math.RoundingMode;
import java.text.DecimalFormat;
import java.util.*;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.Future;
import java.util.concurrent.FutureTask;
import java.util.concurrent.locks.LockSupport;

public class Main {

    static class Task {

        public static String roundS(double result, int scale) {
            String fmt = String.format("%%.%df", scale);
            return String.format(fmt, result);
            //        DecimalFormat df = new DecimalFormat("0.000000");
            //        double result = Double.parseDouble(df.format(result));
        }

        int rt(int x) {
            if (x != fa[x]) {
                int to = rt(fa[x]);
                dp[x] ^= dp[fa[x]];
                fa[x] = to;
                return to;
            }

            return x;
        }

        void combine(int x, int y, int val) {
            int rt1 = rt(x);
            int rt2 = rt(y);

            if (rt1 == rt2)
                return;

            fa[rt1] = rt2;
            dp[rt1] = dp[x] ^ dp[y] ^ val;
            g--;

        }


        int fa[], dp[];
        int g;


        static int MAXN = 10000;
        static Random rd = new Random(348957438574659L);

        static int[] ch[], val, size, rnd, cnt;
        static int len = 0, rt = 0;

        // return new node, the node below s
        static int rotate(int s, int d) {
            // child
            int x = ch[s][d ^ 1];
            // give me grandson
            ch[s][d ^ 1] = ch[x][d];
            // child become father
            ch[x][d] = s;
            // update size, update new son first
            update(s);
            update(x);
            return x;
        }

        static void update(int s) {
            size[s] = size[ch[s][0]] + size[ch[s][1]] + cnt[s];
        }

        // 0 for left, 1 for right
        static int cmp(int x, int num) {
            if (val[x] == num)
                return -1;

            return num < val[x] ? 0 : 1;
        }

        static int insert(int s, int num) {
            if (s == 0) {
                s = ++len;
                val[s] = num;
                size[s] = 1;
                rnd[s] = rd.nextInt();
                cnt[s] = 1;
            } else {
                int d = cmp(s, num);

                if (d != -1) {
                    ch[s][d] = insert(ch[s][d], num);

                    // father's random should be greater
                    if (rnd[s] < rnd[ch[s][d]]) {
                        s = rotate(s, d ^ 1);
                    } else {
                        update(s);
                    }
                } else {
                    ++cnt[s];
                    ++size[s];
                }
            }

            return s;
        }

        static int del(int s, int num) {
            int d = cmp(s, num);

            if (d != -1) {
                ch[s][d] = del(ch[s][d], num);
                update(s);
            } else if (ch[s][0] * ch[s][1] == 0) {
                if (--cnt[s] == 0) {
                    s = ch[s][0] + ch[s][1];
                }
            } else {
                int k = rnd[ch[s][0]] < rnd[ch[s][1]] ? 0 : 1;
                // k points to smaller random value,then bigger one up
                s = rotate(s, k);
                // now the node with value num become the child
                ch[s][k] = del(ch[s][k], num);
                update(s);
            }

            return s;
        }

        static int getKth(int s, int k) {
            int lz = size[ch[s][0]];

            if (k >= lz + 1 && k <= lz + cnt[s]) {
                return val[s];
            } else if (k <= lz) {
                return getKth(ch[s][0], k);
            } else {
                return getKth(ch[s][1], k - lz - cnt[s]);
            }
        }

        static int getRank(int s, int value) {
            if (s == 0)
                return 1;

            if (value == val[s])
                return size[ch[s][0]] + 1;

            if (value < val[s])
                return getRank(ch[s][0], value);

            return getRank(ch[s][1], value) + size[ch[s][0]] + cnt[s];
        }


        static int getPre(int data) {
            int ans = -1;
            int p = rt;

            while (p > 0) {
                if (data > val[p]) {
                    if (ans == -1 || val[p] > val[ans])
                        ans = p;

                    p = ch[p][1];
                } else
                    p = ch[p][0];
            }

            return ans != -1 ? val[ans] : (-2147483647);
        }

        static int getNext(int data) {
            int ans = -1;
            int p = rt;

            while (p > 0) {
                if (data < val[p]) {
                    if (ans == -1 || val[p] < val[ans])
                        ans = p;

                    p = ch[p][0];
                } else
                    p = ch[p][1];
            }

            return ans != -1 ? val[ans] : 2147483647;
        }


        static boolean find(int s, int num) {
            while (s != 0) {
                int d = cmp(s, num);

                if (d == -1)
                    return true;
                else
                    s = ch[s][d];
            }

            return false;
        }
        static int ans = -10000000;
        static boolean findX(int s, int num) {
            while (s != 0) {
                if (val[s] <= num) {
                    ans = num;
                }

                int d = cmp(s, num);

                if (d == -1)
                    return true;
                else {
                    s = ch[s][d];
                }
            }

            return false;
        }

        long gcd(long a, long b) {
            if (b == 0)
                return a;

            return gcd(b, a % b);
        }


        void linear_sort(int arr[]) {
            int d =  65536;
            ArrayDeque bucket[] = new ArrayDeque[d];

            for (int j = 0; j < d; ++j) {
                bucket[j] = new ArrayDeque();
            }

            for (int u : arr) {
                bucket[u % d].offer(u);
            }

            int pos = 0;

            for (int j = 0; j < d; ++j) {
                while (bucket[j].size() > 0) {
                    arr[pos++] = (int)bucket[j].pollFirst();
                }
            }

            for (int u : arr) {
                bucket[u / d].offer(u);
            }

            pos = 0;

            for (int j = 0; j < d; ++j) {
                while (bucket[j].size() > 0) {
                    arr[pos++] = (int)bucket[j].pollFirst();
                }
            }
        }

        int cur  = 0;
        int h[];
        int to[];
        int ne[];
        void add(int u, int v) {
            to[cur] = v;
            ne[cur] = h[u];
            h[u] = cur++;
        }


        public boolean dfs(String cur, HashMap<String, Integer> vis, Map<String, List<String>> list, int cl) {

            vis.put(cur, cl);

            for (String hp : list.get(cur)) {
                if (vis.containsKey(hp)) {
                    if (vis.get(hp) == cl) {
                        return false;
                    }

                } else {
                    if (!dfs(hp, vis, list, 1 - cl)) {
                        return false;
                    }
                }
            }

            return true;

        }


        public class MultiSet {

            TreeMap<Integer, Integer> map;
            int ct = 0;
            MultiSet() {
                map = new TreeMap<>();
            }

            public void remove(int key) {
                if (map.containsKey(key)) {
                    int times = map.get(key);

                    if (times == 1) {
                        map.remove(key);
                    } else {
                        map.put(key, times - 1);
                    }

                    ct--;
                }
            }

            public void add(int key) {
                map.put(key, map.getOrDefault(key, 0) + 1);
                ct++;
            }

            public int last() {
                return map.lastKey();
            }

            public int first() {
                return map.firstKey();
            }

            public int size() {
                return ct;
            }

        }

        BigInteger dp1[][];
        int x[][];
        BigInteger s(int l, int r, BigInteger a[]) {
            if (l >= r)
                return BigInteger.ZERO;

            if (l + 1 == r) {
                x[l][r] = l;
                return (a[l].add(a[r])).multiply(a[l]);
            }

            if (!dp1[l][r].equals(BigInteger.valueOf(-1))) {
                return dp1[l][r];
            }

            BigInteger ans  = BigInteger.valueOf(Long.MIN_VALUE);

            for (int j = l; j < r; ++j) {
                BigInteger tp = (a[l].add(a[r])).multiply(a[j]).add(s(l, j, a)).add(s(j + 1, r, a));

                if (tp.compareTo(ans) > 0) {
                    ans = tp;
                    x[l][r] = j;
                }
            }

            return dp1[l][r] = ans;
        }
        TreeMap<Integer, ArrayList<Integer>> mp = new TreeMap<>();
        void ss(int l, int r, int dep) {
            if (l == r)
                return;

            int sp = x[l][r];
            ArrayList<Integer> li = mp.getOrDefault(dep, new ArrayList<>());
            li.add(sp + 1);
            mp.put(dep, li);

            ss(l, sp, dep + 1);
            ss(sp + 1, r, dep + 1);
        }

        int color[], dfn[], low[], stack[];
        int sccno[];

        boolean iscut[];
        int time = 0, top = 0;
        int scc_cnt;
        int  dcc_cnt;
        List<Integer> dcc[];
        int root = 0;
        //  无向图的强连通分量
        void tarjanNonDirect(int u) {
            low[u] = dfn[u] = ++time;
            stack[top++] = u;
            int child = 0;

            for (int i = h[u]; i != -1; i = ne[i]) {
                int v = to[i];

                if (dfn[v] == 0) {
                    tarjanNonDirect(v);
                    low[u] = Math.min(low[u], low[v]);

                    if (low[v] >= dfn[u]) {
                        if (u != root || ++child > 1) { // 不是root,直接记为cut,是root,判断是否有两个儿子
                            iscut[u] = true;
                        }

                        ++dcc_cnt;
                        // 一个割点可能会被多个点双共享
                        int z = -1;

                        do {
                            z = stack[--top];
                            //dcc[dcc_cnt].add(z);
                        } while (z != v);

                        //dcc[dcc_cnt].add(u);
                    }
                } else {
                    low[u] = Math.min(low[u], dfn[v]);
                    // 没有特判是否直接指向父亲
                    // 回边,使用dfn【v】更新low【u】,因为可能是指向父亲,而父亲的low可能比较小
                }
            }
        }

        long dp11[][];
        public long s(int l, int r, int a[]) {
            if (l > r)
                return 0;

            if (dp11[l][r] != 0) {
                return dp11[l][r];
            }

            if (l == r) {
                return dp11[l][r] = a[l];
            }

            return dp11[l][r] = Math.max((long)a[l] - s(l + 1, r, a), (long)a[r] - s(l, r - 1, a));
        }




        List<Long> all = new ArrayList<>();
        void dfs(int cur,long st,int m){
            if(cur==m){
                all.add(st);
                return;
            }
            long slt =  st%3;
            if(cur==0){
                slt =  -1;
            }
            for(int i=0;i<3;++i){
                if(i==slt) continue;
                dfs( cur  +1 ,st *3 + i ,m );
            }
        }

        boolean chontu(long st,long st1,int m,long f[]){

            for(int i=0;i<m;++i){
                long q1 = st/f[i]%3;
                long q2 = st1/f[i]%3;
                if(q1==q2) return true;
            }
            return false;
        }

        long g(long s,int k,int k1,int m,long f[],long mod){
            Map<Long,Long> mp = new HashMap<>();
            mp.put(s,1L);
            long tot2 = 0;
            for(int i=0;i<k;++i){
                Map<Long,Long> nmp = new HashMap<>();

                for(long cur:mp.keySet()){
                    for(long a:all){
                        if(!chontu(a,cur,m,f)){
                            long val =  (nmp.getOrDefault(a,0L)+mp.get(cur))%mod;
                            nmp.put(a,val);
                        }
                    }
                }
                mp = nmp;
                if(i==k1-1){
                    for(long ke:mp.keySet()){
                        tot2 += mp.get(ke);
                        //tot2 %= mod;
                    }
                }


            }
            if(tot2==0){
                tot2 = 1;
            }
            tot2 %= mod;

            long tot1 = 0;
            for(long ke:mp.keySet()){
                tot1 += mp.get(ke);
                tot1 %= mod;
            }


            return (tot1*tot2)%mod;
        }


        public void solve(int testNumber, InputReader in, PrintWriter out) {

            int n  = in.nextInt();
            int m  = in.nextInt();
            int k  = in.nextInt();

            long f[] = new long[m+1];
            f[0] = 1;
            for(int i=1;i<=m;++i){
                f[i] = f[i-1]* 3;
            }

            dfs(0,0,m);

            long mod = (long)1e6;

            long s = 0;
            long lst = -1;
            for(int j=0;j<m;++j){
                long v = in.nextInt() - 1;
                if(v==lst){
                    out.println(0);return;
                }
                s += v*f[j];
                lst =  v;
            }

           // long g(long s,int k,int m,long f[],long mod){
            int mx = Math.max(k-1,n-k);
            int mi = Math.min(k-1,n-k);
            long ans1 = g(s,mx,mi,m,f,mod);

            out.println(ans1);




            //            boolean f[] = new boolean[n+1];
            //
            //            long r = 1;
            //            long mod = 1000000007;
            //            for(int j=2;j<=n;++j){
            //                if(!f[j]){
            //
            //                    for(int v=j*j;v<=n;v+=j){
            //                        f[j] = true;
            //                    }
            //                    int nn = n;
            //                    long s = 0;
            //                    while(nn>0){
            //                        s += nn/j;
            //                       // s %= mod;
            //                        nn /= j;
            //                    }
            //                    s *= 2;
            //                    s++;
            //                    s %= mod;
            //                    r *= s;
            //                    r %= mod;
            //                }
            //            }
            //            out.println(r);



            //            int t = 1;
            //
            //
            //            outer:for(int z=1;z<=t;++z) {
            //                int n = in.nextInt();
            //
            //                h = new int[n+1000000];
            //                Arrays.fill(h,-1);
            //                to = new int[10*(n+1000000)];
            //                ne = new int[10*(n+1000000)];
            //
            //
            //
            //                int a[] = in.nextArray(n);
            //
            //
            //                Arrays.sort(a);
            //                int x=  0;
            //                for(int i=0;i<n;++i){
            //                    if(i==0||a[i]!=a[i-1]){
            //                        a[x++] = a[i];
            //                    }
            //                }
            //
            //                int maxn = 10000000;
            //
            //                int tot[] = new int[maxn+1];
            //
            //                for(int u:a){
            //                    tot[u]++;
            //                }
            //
            //                int prime[] = new int[maxn + 1];
            //                int phi[] = new int[maxn+1];
            //                boolean visit[] = new boolean[maxn + 1];
            //                visit[1] = true;
            //                int minfactor[] = new int[maxn+1];
            //                minfactor[1] = 1;
            //                int mu[] = new int[maxn+1];
            //                mu[1]=1;
            //                int p = 0;
            //                for (int i = 2; i <= maxn; ++i) {
            //                    if (!visit[i]) {
            //                        prime[p++] = i;
            //                        // prime
            //                        phi[i] = i-1;
            //                        minfactor[i] = i;
            //                        mu[i] = -1;
            //                    }
            //                    for (int j = 0; j < p; ++j) {
            //                        int check = i * prime[j];
            //                        if( check > maxn){
            //                            break;
            //                        }
            //                        visit[check] = true;
            //                        minfactor[check] = prime[j];
            //                        if (i % prime[j] == 0) {
            //                            phi[ check ] = phi[i] * prime[j];
            //                            mu[ check ] = 0;
            //                            break;
            //                        }
            //                        mu[ check ] = -mu[i];
            //                        phi[ check ] = phi[i] * (prime[j] - 1);
            //                    }
            //                }
            //                int ct[] = new int[maxn+1];
            //                int id[] = new int[maxn+1];
            //                int from = 1;
            //
            //                Map<Integer,Integer> old = new HashMap<>();
            //                for(int i=0;i<x;++i){
            //                    ct[a[i]]++;
            //                    id[a[i]] = from++;
            //                    old.put(from-1,a[i]);
            //                }
            //
            //                for(int i=0;i<p;++i){
            //                    int v = prime[i];
            //                    int cur =  from++;
            //                    for(int j=v;j<=maxn;j+=v){
            //                        if(ct[j]>0){
            //                            add(cur,id[j]);
            //                            add(id[j],cur);
            //                        }
            //                    }
            //                }
            //                low =new int[from+1];
            //                dfn =new int[from+1];
            //                stack = new int[from+1];
            //                iscut = new boolean[from+1];
            //
            //
            //                for(int i=1;i<from;++i){
            //                    tarjanNonDirect(i);
            //                }
            //                int r = 0;
            //                for(int i=1;i<from;++i){
            //                    if(old.containsKey(i)&&iscut[i]&&tot[old.get(i)]==1){
            //                        r++;
            //                    }
            //                }
            //                out.println(r);
            //
            //
            //
            //
            //            }



































































            //            while(true) {
            //                int n = in.nextInt();
            //
            //                int m  =in.nextInt();
            //
            //                fa = new int[n];
            //                dp = new int[n];
            //                for(int i=0;i<n;++i){
            //                    fa[i] = i;
            //                }
            //                g =  n;
            //                int c = 0;
            //                int as[] = new int[n];
            //                int bs[] = new int[n];
            //                char xs[] = new char[n];
            //
            //                int at = -1;
            //                Set<Integer> st = new HashSet<>();
            //
            //                for (int i = 0; i < n; ++i) {
            //                    String line = in.next();
            //                    int p = 0;
            //                    int a = 0;
            //                    while(Character.isDigit(line.charAt(p))){
            //                        a =  a*10 + (line.charAt(p)-'0'); p++;
            //                    }
            //                    char x = line.charAt(p++);
            //
            //                    int b = 0;
            //                    while(p<line.length()){
            //                        b =  b*10 + (line.charAt(p)-'0'); p++;
            //                    }
            //
            //                    as[i] = a;
            //                    xs[i] = x;
            //                    bs[i] = b;
            //
            //                    if(x=='='){
            //                        int r1 = rt(a); int r2 = rt(b);
            //                        if(r1==r2){
            //                            if(dp[a]!=dp[b]){
            //                                c++;
            //                                at = i;
            //                            }
            //                        }else {
            //                            combine(a, b, 0);
            //                        }
            //                    }else if(x=='<'){
            //                        int r1 = rt(a); int r2 = rt(b);
            //                        if(r1==r2){
            //                            if(dp[a]>=dp[b]){
            //                                c++;
            //                                at = i;
            //                            }
            //                        }else {
            //                            combine(a, b, -1);
            //                        }
            //                    }else{
            //                        int r1 = rt(a); int r2 = rt(b);
            //                        if(r1==r2){
            //                            if(dp[a]<=dp[b]){
            //                                c++;
            //                                at = i;
            //                            }
            //                        }else {
            //                            combine(a, b, 1);
            //                        }
            //                    }
            //
            //
            //                }
            //                if(g==1||c>=2){
            //                    out.println("Impossible");
            //                    continue;
            //                }
            //
            //
            //                for(int xuan: st){
            //
            //
            //
            //
            //                }
            //
            //
            //
            //
            //
            //
            //            }






        }


        static long mul(long a, long b, long p) {
            long res = 0, base = a;

            while (b > 0) {
                if ((b & 1L) > 0)
                    res = (res + base) % p;

                base = (base + base) % p;
                b >>= 1;
            }

            return res;
        }

        static long mod_pow(long k, long n, long p) {
            long res = 1L;
            long temp = k % p;

            while (n != 0L) {
                if ((n & 1L) == 1L) {
                    res = mul(res, temp, p);
                }

                temp = mul(temp, temp, p);
                n = n >> 1L;
            }

            return res % p;
        }


        public static double roundD(double result, int scale) {
            BigDecimal bg = new BigDecimal(result).setScale(scale, RoundingMode.UP);
            return bg.doubleValue();
        }



    }
    private static void solve() {
        InputStream inputStream = System.in;
//                InputStream inputStream  = null;
//                try {
//                    inputStream = new FileInputStream(new File("D:\\chrome_download\\exp.out"));
//                } catch (FileNotFoundException e) {
//                    e.printStackTrace();
//                }
        OutputStream outputStream = System.out;
//                OutputStream outputStream = null;
//                File f = new File("D:\\chrome_download\\");
//                try {
//                    f.createNewFile();
//                } catch (IOException e) {
//                    e.printStackTrace();
//                }
//                try {
//                    outputStream = new FileOutputStream(f);
//                } catch (FileNotFoundException e) {
//                    e.printStackTrace();
//                }
        InputReader in = new InputReader(inputStream);
        PrintWriter out = new PrintWriter(outputStream);
        Task task = new Task();
        task.solve(1, in, out);
        out.close();
    }
    public static void main(String[] args) {
       // new Thread(null, () -> solve(), "1", (1 << 30)).start();
        solve();
    }
    static class InputReader {
        public BufferedReader reader;
        public StringTokenizer tokenizer;
        public InputReader(InputStream stream) {
            reader = new BufferedReader(new InputStreamReader(stream), 32768);
            tokenizer = null;
        }
        public String nextLine() {
            String line = null;

            try {
                line = reader.readLine();
            } catch (IOException e) {
                throw new RuntimeException(e);
            }

            return line;
        }
        public String next() {
            while (tokenizer == null || !tokenizer.hasMoreTokens()) {
                try {
                    tokenizer = new StringTokenizer(reader.readLine());
                } catch (IOException e) {
                    throw new RuntimeException(e);
                }
            }

            return tokenizer.nextToken();
        }
        public int nextInt() {
            return Integer.parseInt(next());
        }
        public char nextChar() {
            return next().charAt(0);
        }
        public int[] nextArray(int n) {
            int res[] = new int[n];

            for (int i = 0; i < n; ++i) {
                res[i] = nextInt();
            }

            return res;
        }
        public long nextLong() {
            return Long.parseLong(next());
        }
        public double nextDouble() {
            return Double.parseDouble(next());
        }
    }
}

Kotlin 解法, 执行用时: 507ms, 内存消耗: 18052K, 提交时间: 2022-12-01 15:18:25

@file:Suppress("NOTHING_TO_INLINE", "EXPERIMENTAL_FEATURE_WARNING", "OVERRIDE_BY_INLINE", "DEPRECATION")
@file:OptIn(ExperimentalStdlibApi::class)

import java.io.PrintWriter
import java.util.StringTokenizer
import kotlin.collections.ArrayDeque
import kotlin.math.*
import kotlin.random.*
import java.util.TreeMap
import java.util.TreeSet
import java.util.PriorityQueue
// import java.math.BigInteger
// import java.util.*

@JvmField val INPUT = System.`in`
@JvmField val OUTPUT = System.out

@JvmField val reader = INPUT.bufferedReader()
fun readLine(): String? = reader.readLine()
fun readLn() = reader.readLine()!!
@JvmField var _tokenizer: StringTokenizer = StringTokenizer("")
fun read(): String {
    while (_tokenizer.hasMoreTokens().not()) _tokenizer = StringTokenizer(reader.readLine() ?: return "", " ")
    return _tokenizer.nextToken()
}
fun readInt() = read().toInt()
fun readDouble() = read().toDouble()
fun readLong() = read().toLong()
fun readStrings(n: Int = 2) = List(n) { read() }
fun readString() = readStrings(1)[0]
fun readLines(n: Int) = List(n) { readLn() }
fun readInts(n: Int = 2) = List(n) { read().toInt() }
fun readInts1(n: Int = 2) = List(n) { read().toInt() - 1 }
fun readIntArray(n: Int) = IntArray(n) { read().toInt() }
fun readDoubles(n: Int = 2) = List(n) { read().toDouble() }
fun readDoubleArray(n: Int) = DoubleArray(n) { read().toDouble() }
fun readLongs(n: Int = 2) = List(n) { read().toLong() }
fun readLongArray(n: Int) = LongArray(n) { read().toLong() }

// val isLocal = System.getenv("IS_LOCAL_CP") == "true"

@JvmField
val writer = PrintWriter(OUTPUT)

// ----------------------------------------------------------------------------

private fun IntArray.modSum(mod: Long = MODL): Long {
    var sum = 0L
    for (num in this) {
        sum = (sum + num) % mod
    }
    return sum
}

private fun LongArray.modSum(mod: Long = MODL): Long {
    var sum = 0L
    for (num in this) {
        sum = (sum + num) % mod
    }
    return sum
}

private fun Int.modSum(other: Int, mod: Int = MOD): Int {
    return ((this % mod) + (other % mod)) % mod
}

private fun Long.modSum(other: Long, mod: Long = MODL): Long {
    return ((this % mod) + (other % mod)) % mod
}

private fun <T : Comparable<T>> T.max(target: T): T = this.coerceAtLeast(target)

private fun <T : Comparable<T>> T.min(target: T): T = this.coerceAtMost(target)

private val DIR = listOf(
    listOf(1, 0), listOf(0, -1), listOf(-1, 0), listOf(0, 1)
)

// private const val MOD = 998244353
// private const val MODL = 998244353L
private const val MOD = 1_000_000
private const val MODL = 1_000_000L
private const val EPS = 0.000001

// ----------------------------------------------------------------------------

fun main() {
    val go: Runnable = Runnable {
        writer.solve()
        writer.flush()
    }
    Thread(null, go, "thread", 1L.shl(28)).start()
    // writer.solve()
    // writer.flush()
}

private fun PrintWriter.solve() {
    val (m, n) = readInts()
    val k = readInt()
    val row = readInts1(n)
    // println(jamCount2(m, n, k, row))
    // println(jamCount1(m, n, k, row))
    println(jamCount(m, n, k, row))
}

// AC
// private fun jamCount(m: Int, n: Int, k: Int, row: List<Int>): Long {
//     var ms = 1
//     var rowState = 0
//     for (j in 0 until n) {
//         ms *= 3
//         rowState = rowState * 3 + row[j]
//     }
//     if (!check1(rowState, n)) {
//         return 0L
//     }
//     val r1 = k - 1
//     val r2 = m - k
//     val r = Math.max(r1, r2)
//     val dp = Array(r + 1) { IntArray(ms) }
//     dp[0][rowState] = 1
//     for (i in 0 until r) {
//         for (s in 0 until ms) {
//             if (dp[i][s] == 0) {
//                 continue
//             }
//             for (ns in 0 until ms) {
//                 if (check1(ns, n) && check2(s, ns, n)) {
//                     dp[i + 1][ns] = dp[i + 1][ns].modSum(dp[i][s])
//                 }
//             }
//         }
//     }
//     val c1 = dp[r1].modSum()
//     val c2 = dp[r2].modSum()
//     // println("$c1, $c2")
//     return (c1 * c2) % MODL
// }

// private fun check1(state: Int, n: Int): Boolean {
//     var s = state
//     repeat(n - 1) {
//         if (s % 3 == (s / 3) % 3) {
//             return false
//         }
//         s /= 3
//     }
//     return true
// }

// private fun check2(s1: Int, s2: Int, n: Int): Boolean {
//     var a = s1
//     var b = s2
//     repeat(n) {
//         if (a % 3 == b % 3) {
//             return false
//         }
//         a /= 3
//         b /= 3
//     }
//     return true
// }

// TLE
// private fun jamCount2(m: Int, n: Int, k: Int, row: List<Int>): Long {
//     for (j in 1 until n) {
//         if (row[j] == row[j - 1]) {
//             return 0L
//         }
//     }
//     val grid = Array(m) { IntArray(n) }
//     for (j in 0 until n) {
//         grid[k - 1][j] = row[j]
//     }
//     return search(0, m, n, k, grid)
// }

// private fun search(
//     index: Int,
//     m: Int,
//     n: Int,
//     k: Int,
//     grid: Array<IntArray>
// ): Long {
//     if (index == m * n) {
//         return 1L
//     }
//     val r = index / n
//     val c = index % n
//     if (r == k - 1) {
//         return search(index + n, m, n, k, grid)
//     }
//     val left = if (c == 0) -1 else grid[r][c - 1]
//     val up = if (r == 0) -1 else grid[r - 1][c]
//     var result = 0L
//     for (num in 0 until 3) {
//         if (num == left || num == up) {
//             continue
//         }
//         grid[r][c] = num
//         result += search(index + 1, m, n, k, grid)
//         result %= MODL
//     }
//     return result
// }

// TLE
private fun jamCount(m: Int, n: Int, k: Int, row: List<Int>): Long {
    for (j in 1 until n) {
        if (row[j] == row[j - 1]) {
            return 0L
        }
    }
    val mapping = mutableMapOf<Long, MutableList<Long>>().apply {
        search(0, IntArray(n), IntArray(n), n, this)
    }
    // println(mapping)
    val firstState = rowToState(row.toIntArray(), n)
    var stateToCount = mutableMapOf(firstState to 1L)
    val r1 = k - 1
    val r2 = m - k
    var c1 = 1L
    var c2 = 1L
    // println(stateToCount)
    for (r in 1 until Math.max(r1, r2) + 1) {
        val nextStateToCount = mutableMapOf<Long, Long>()
        for ((state, count) in stateToCount) {
            for (next in mapping[state].orEmpty()) {
                nextStateToCount[next] = (nextStateToCount.getOrDefault(next, 0L) + count) % MODL
            }
        }
        stateToCount = nextStateToCount
        if (r == r1) {
            c1 = stateToCount.values.toLongArray().modSum()
        }
        if (r == r2) {
            c2 = stateToCount.values.toLongArray().modSum()
        }
        // println("$r, $stateToCount")
    }
    // println("$c1, $c2")
    return (c1 * c2) % MODL
}

private fun search(
    index: Int,
    row1: IntArray,
    row2: IntArray,
    n: Int,
    mapping: MutableMap<Long, MutableList<Long>>
) {
    // println("$index. ${row1.toList()}, ${row2.toList()}")
    if (index == n * 2) {
        val s1 = rowToState(row1, n)
        val s2 = rowToState(row2, n)
        mapping[s1] = mapping.getOrPut(s1, { mutableListOf() }).apply { add(s2) }
        return
    }
    if (index < n) {
        val left = if (index == 0) -1 else row1[index - 1]
        for (num in 0 until 3) {
            if (num == left) {
                continue
            }
            row1[index] = num
            search(index + 1, row1, row2, n, mapping)
        }
        return
    }
    val col = index % n
    val left = if (col == 0) -1 else row2[col - 1]
    val up = row1[col]
    for (num in 0 until 3) {
        if (num == left || num == up) {
            continue
        }
        row2[col] = num
        search(index + 1, row1, row2, n, mapping)
    }
}

private fun rowToState(row: IntArray, n: Int): Long {
    var base = 1L
    var state = 0L
    for (num in row) {
        state += base * num
        base *= 3L
    }
    // println("$state: ${row.toList()}")
    return state
}

// TLE 80%
// val POW = mutableListOf(1)

// private fun jamCount(m: Int, n: Int, k: Int, row: List<Int>): Long {
//     for (j in 1 until n) {
//         if (row[j] == row[j - 1]) {
//             return 0L
//         }
//     }
//     repeat(n) {
//         POW.add(POW.last() * 3)
//     }
//     var rowState = 0
//     for (j in 0 until n) {
//         rowState = setColValue(rowState, j, row[j])
//     }
//     val ms = POW.last()
//     val dp = LongArray(ms)
//     dp[rowState] = 1L
//     val ndp = LongArray(ms)
//     val r1 = k - 1
//     val r2 = m - k
//     var c1 = 1L
//     var c2 = 1L
//     for (r in 1 until Math.max(r1, r2) + 1) {
//         ndp.fill(0L)
//         for (s in 0 until ms) {
//             if (dp[s] == 0L) {
//                 continue
//             }
//             for (ns in 0 until ms) {
//                 // println("${getCols(s, n)}, ${getCols(ns, n)}, ${check(s, ns, n)}")
//                 if (check(s, ns, n)) {
//                     ndp[ns] = ndp[ns].modSum(dp[s])
//                 }
//             }
//         }
//         ndp.copyInto(dp)
//         if (r == r1) {
//             c1 = dp.modSum()
//         }
//         if (r == r2) {
//             c2 = dp.modSum()
//         }
//         // println("$r, ${dp.toList()}")
//     }
//     // println("$c1, $c2")
//     return (c1 * c2) % MODL
// }

// private fun check(s1: Int, s2: Int, n: Int): Boolean {
//     for (j in 0 until n) {
//         if (getColValue(s1, j) == getColValue(s2, j)) {
//             return false
//         }
//         if (j > 0 && getColValue(s2, j) == getColValue(s2, j - 1)) {
//             return false
//         }
//     }
//     return true
// }

// private fun getCols(state: Int, n: Int): List<Int> {
//     return (0 until n).map { getColValue(state, it) }
// }

// private fun setColValue(state: Int, offset: Int, value: Int): Int {
//     return state + value * POW[offset]
// }

// private fun getColValue(state: Int, offset: Int): Int {
//     return (state / POW[offset]) % 3
// }

C++14(g++5.4) 解法, 执行用时: 394ms, 内存消耗: 8276K, 提交时间: 2019-08-03 23:27:35

#include<bits/stdc++.h>
#define ll long long
#define mod 1000000
using namespace std;
ll n,m,k,a[6],f[10010][100],mode,ans1,ans2;
vector<ll> v;
void init(int x,int y)//初始化合法状态 
{
    if(x==m)
    {
        v.push_back(y);
        return ;
    }
    else if(x==0)
    {
        init(1,0);
        init(1,1);
        init(1,2);
        return ;
    }
    if((int)(y/pow(3,x-1))%3==0)
    {
        init(x+1,y+1*pow(3,x));
        init(x+1,y+2*pow(3,x));
    }
    else if((int)(y/pow(3,x-1))%3==1)
    {
        init(x+1,y+0*pow(3,x));
        init(x+1,y+2*pow(3,x));
    }
    else
    {
        init(x+1,y+0*pow(3,x));
        init(x+1,y+1*pow(3,x));
    }
}
bool judge(int x,int y)//判断x状态和y状态分别为上下两行时是否合法 
{
    int wei = m;
    while(wei--)
    {
        int a = x%3;
        int b = y%3;
        if(a==b)
            return false;
        x/=3;
        y/=3;
    }
    return true;
}
int main()
{
    cin>>n>>m>>k;
    for(int i = 0;i<m;i++)
    {
        cin>>a[i];
        mode+=pow(3,m-1-i)*(a[i]-1);//mode记录第k行状态 
    }  
    init(0,0);
    int index = find(v.begin(),v.end(),mode)-v.begin();
    if(index==v.size())//如果第k行的状态不合法 
    {
        cout<<0;
        return 0;
    }  
    if(k==1)//如果k为1,那么第一阶段方案数为1 
    {
	    ans1 = 1;
	    f[k][index] = 1;
	} 
    if(k!=1)
    {
        for(int i = 0;i<v.size();i++)
            f[1][i] = 1;
    } 
    for(int i = 2;i<=n;i++)
    {
        if(i==k)
        {
            for(int l = 0;l<v.size();l++)
            {
                if(judge(mode,v[l]))
                {
                    f[i][index]+=f[i-1][l];
                    f[i][index]%=mod;
                }
            }
            ans1 = f[i][index]%mod;
            f[i][index] = 1;
            continue;
        }
        for(int j = 0;j<v.size();j++)
        {
            for(int l = 0;l<v.size();l++)
            {
                if(judge(v[j],v[l])&&f[i-1][l])
                {//cout<<v[j]<<" "<<v[l]<<endl;
                    f[i][j]+=f[i-1][l];
                    f[i][j]%=mod;
                }
            }
        }
    }
    for(int j = 0;j<v.size();j++)
    {
        ans2 = (ans2+f[n][j])%mod;
    }
    cout<<(ans1*ans2)%mod;
    return 0;
}

C++11(clang++ 3.9) 解法, 执行用时: 63ms, 内存消耗: 16268K, 提交时间: 2020-06-01 10:17:46

#include<iostream>
#define mod 1000000
using namespace std;
long long n,m,K,c[10000],cnt,a[10000],f[10001][200],s,num,correct[10000][10000],ans1,ans2;
inline int check(int x)
{
	for(int i=m-1;i;i--)
	if((x%c[i+1]/c[i])==(x%c[i]/c[i-1]))
	return 0;
	return 1;
}
inline int check2(int x,int y)
{
	for(int i=m-1;i>=0;i--)
	if((x%c[i+1]/c[i])==(y%c[i+1]/c[i]))
	return 0;
	return 1;
}
int main()
{
	cin>>n>>m>>K;
	c[0]=1;
	for(int i=1;i<=m;i++)
	{
		cin>>s;
		num*=3;
		num+=s-1;
		c[i]=c[i-1]*3;
	}
	if(!check(num))
	{
		cout<<0;
		return 0;
	}
	for(int i=0;i<c[m];i++)
	{
		if(check(i))
		a[++cnt]=i;
		if(i==num)
		f[0][cnt]=1;
	}
	for(int i=1;i<=cnt;i++)
	for(int j=1;j<=cnt;j++)
	if(check2(a[i],a[j]))
	correct[i][j]=1;
	int x=max(K-1,n-K);
	int y=min(K-1,n-K);
	for(int i=1;i<=x;i++)
	for(int j=1;j<=cnt;j++)
	for(int k=1;k<=cnt;k++)
	if(correct[j][k])
	f[i][j]=(f[i][j]+f[i-1][k])%mod;
	for(int i=1;i<=cnt;i++)
	{
		ans1+=f[x][i];
		ans1%=mod;
		ans2+=f[y][i];
		ans2%=mod;
	}
	cout<<ans1*ans2%mod;
	return 0;
}

上一题