列表

详情


NC220470. DeckRandomisation

描述

    Alice and Bob love playing Don'tminion, which typically involves a lot of shuffling of decks of different sizes. Because they play so often, they are not only very quick at shuffling, but also very consistent. Each time Alice shuffles her deck, her cards get permuted in the same way, just like Bob always permutes his cards the same way when he shuffles them. This isn't good for playing games, but raises an interesting question.

    They know that if they take turns shuffling, then at some point the deck will end up ordered in the same way as when they started. Alice shuffles once first, then Bob shuffles once, then Alice shuffles again, et cetera. They start with a sorted deck. What they do not know, however, is how many shuffles it will take before the deck is sorted again.

    Can you help them compute how many shuffles it will take? As Alice and Bob can only do  shuffles in the limited time they have, any number strictly larger than this should be returned as huge instead.

输入描述

The input consists of:
- One line contains a single integer , the number of cards in the deck.
- One line contains  distinct integers , where  is the new position of the card previously at position  when Alice shuffles the deck.
- One line contains  distinct integers , where  is the new position of the card previously at position  when Bob shuffles the deck.

输出描述

Output a single positive integer , the minimal number of shuffles required to sort the deck, or "huge" when this number is strictly larger than .

示例1

输入:

3
2 3 1
3 1 2

输出:

2

示例2

输入:

6
5 1 6 3 2 4
4 6 5 1 3 2

输出:

5

示例3

输入:

8
1 4 2 6 7 8 5 3
3 6 8 4 7 1 5 2

输出:

10

原站题解

上次编辑到这里,代码来自缓存 点击恢复默认模板

pypy3(pypy3.6.1) 解法, 执行用时: 148ms, 内存消耗: 43644K, 提交时间: 2021-04-04 21:11:25

inf, N = int(1e17), int(100100)
# tot 个循环结,T[i] 每个循环节长度
# vis[i] 第 i 个点所在循环节位置
n, a, b, c, tot, vis = 0, [0], [0], [0]*N, 0, [0]*N
T, mod, m, tt, tmp, len = [0]*N, [0]*N, [0]*N, [0]*N, 0, 0


def init():
    global n, a, b, c, tot, vis, T, mod, m, tt, inf, N, tmp, len
    for i in range(1, n+1):
        c[i] = b[a[i]]
    for i in range(1, n+1):
        if vis[i] != 0:
            continue
        tmp, len = c[i], 1
        while tmp != i:
            tmp = c[tmp]
            len += 1
        tot += 1
        T[tot] = len
        vis[i] = tot
        tmp = c[i]
        while tmp != i:
            vis[tmp] = tot
            tmp = c[tmp]


def gcd(a, b):
    if b == 0:
        return a
    else:
        return gcd(b, a % b)


def solve1():
    global n, a, b, c, tot, vis, T, mod, m, tt, inf, N
    ans = 0
    ans = T[1]
    for i in range(2, tot+1):
        ans = ans * T[i] / gcd(ans, T[i])
        if ans > 1e12:
            return int(1e13)
    return ans * int(2)


def exgcd(a, b, x, y):
    re, tmp = 0, 0
    if b == 0:
        return a, 1, 0
    re, x, y = exgcd(b, a % b, x, y)
    tmp = x
    x = y
    y = tmp - (a//b)*y
    return re, x, y


def inv(a, b):
    x, y, r = 0, 0, 0
    r, x, y = exgcd(a, b, x, y)
    while x < 0:
        x += b
    return x


def excrt(n, M, C):
    M1, M2, C1, C2, T = 0, 0, 0, 0, 0
    for i in range(2, n+1):
        M1, M2, C1, C2 = M[i-1], M[i], C[i-1], C[i]
        T = gcd(M1, M2)
        if (C2-C1) % T != 0:
            return 1e13
        M[i] = (M1 * M2) // T
        C[i] = (inv(M1//T, M2//T) * (C2-C1) // T) % (M2 // T) * M1 + C1
        C[i] = (mod[i] % M[i] + M[i]) % M[i]
    return C[n]


def solve2():
    global n, a, b, c, tot, vis, T, mod, m, tt, inf, N, tmp, len
    for i in range(1, n+1):
        tt[i] = a[i]
    for i in range(1, n+1):
        a[tt[i]] = i
    for i in range(1, n+1):
        m[i] = inf
    for i in range(1, n+1):
        if m[i] != inf:
            continue
        if vis[a[i]] != vis[i]:
            return int(1e13)
        else:
            tmp = 0
            len, tmp, p1, p2 = 0, i, 0, 0
            while tmp != a[i]:
                tmp = c[tmp]
                len += 1
            m[i] = len
            p1, p2 = c[i], c[a[i]]
            while p1 != i:
                if a[p1] != p2:
                    return int(1e13)
                else:
                    m[p1] = len
                p1 = c[p1]
                p2 = c[p2]
    for i in range(1, n+1):
        mod[i] = inf
    for i in range(1, n+1):
        if m[i] == T[vis[i]]:
            m[i] = 0
    for i in range(1, n+1):
        if mod[vis[i]] == inf:
            mod[vis[i]] = m[i]
        elif mod[vis[i]] != m[i]:
            return int(1e13)
    return excrt(tot, T, mod) * 2 + 1


def main():
    global n, a, b, c, tot, vis, T, mod, m, tt, inf, N
    n = int(input())
    a.extend(list(map(int, input().split())))
    b.extend(list(map(int, input().split())))
    ans = int(1e13)
    init()
    ans = min(ans, solve1())
    ans = min(ans, solve2())
    if ans > 1e12:
        print("huge")
    else:
        print(int(ans))


if __name__ == "__main__":
    main()

C++(clang++11) 解法, 执行用时: 49ms, 内存消耗: 2428K, 提交时间: 2021-04-13 15:37:07

#include<cstdio>
typedef long long ll;
const ll N=100005,MX=1e12;
bool vis[N];
int n,a[N],b[N],c[N],fa[N],d[N],tt;
ll ans1=1,ans2=-1,tmp;
ll gcd(ll x,ll y){
    return y?gcd(y,x%y):x;}
ll exgcd(ll a,ll b,ll &x,ll &y,ll mod){
    if(b==0){x=1; y=0; return a;}
    ll ans=exgcd(b,a%b,x,y,mod),tmp=x;
    x=y;  y=(tmp-a/b*y)%mod;
    return ans;
}
inline ll min(ll x,ll y){return x<y?x:y;}
inline ll lcm(ll x,ll y){
    return x/gcd(x,y)*y;}
inline bool excrt(ll &a1,ll &m1,ll a2,ll m2){
    ll x,y,g=exgcd(m1,m2,x,y,m2),l=m1/g*m2;
    if((a1-a2)%g) return true;
    x=(x+m2)%m2;
    a1=(((a2-a1)%m2/g*m1*x+a1)%l+l)%l;
    m1=l; return false;
}
int main(){
    scanf("%d",&n);
    for(int i=1;i<=n;i++)
        scanf("%d",&tt),a[tt]=i;
    for(int i=1;i<=n;i++)
        scanf("%d",&tt),b[tt]=i;
    for(int i=1;i<=n;i++)
        c[i]=b[a[i]];
    for(int i=1,j=1;i<=n;i++,j=i){
        b[a[i]]=i;
        if(vis[i]) continue;
        for(;!vis[j];j=c[j]){
            vis[j]=1; fa[j]=i;
            d[c[j]]=d[j]+1;
        }
        if(ans1==-1) continue;
        ans1=lcm(ans1,d[i]);
        if(ans1*2>MX) ans1=-1;
    }
    ll a1=0,a2,m1=1,m2;
    for(int i=1;i<=n;i++){
        if(fa[i]!=fa[b[i]]){ans2=-1;break;}
        m2=d[fa[i]];
        a2=(d[b[i]]-d[i]+m2)%m2;
        if(m1*2+1>MX){
            if(a1%m2!=a2){ans2=-1; break;}}
        else if(excrt(a1,m1,a2,m2)){
            ans2=-1; break;}
        ans2=a1;
    }
    if(ans2==-1||ans2*2+1>MX){
        if(ans1==-1||ans1*2>MX) puts("huge");
        else printf("%lld\n",ans1*2);
    }
    else if(ans1==-1||ans1*2>MX)
        printf("%lld\n",ans2*2+1);
    else printf("%lld\n",min(ans1*2,ans2*2+1));
    return 0;
}

上一题