列表

详情


NC17707. Floorfiller

描述

Niuniu wants to fill an n x m sheet with 0s and 1s.

Niuniu wants the xor sum for each row and each column is 0.

In other words there is a even number of 1 in each row and each column.

Two sheets are considered the same, if they are identical after cyclic shift (vertical or horizontal).

Formally, for two sheets A and B, if we can find x and y such that

we will consider A and B are the same sheet.

Niuniu  wants to know the number of ways to fill the sheet.

As the result might be very large, he wants to know the result modulo 998244353. 

输入描述

The first line contains two integers, which are n and m.

1 <= n <= 109
1 <= m <= 109

输出描述

You should output one integer, which is the answer modulo 998244353.

示例1

输入:

4 4

输出:

48

示例2

输入:

4 6

输出:

1448

示例3

输入:

998244353 998244353

输出:

295980207

原站题解

上次编辑到这里,代码来自缓存 点击恢复默认模板

Java(javac 1.8) 解法, 执行用时: 2802ms, 内存消耗: 152936K, 提交时间: 2018-08-17 00:40:04

import java.util.*;
import java.math.*;

public class Main{
    public static final int MODER = 998244353;

    public static int power(int a, int exp){
        int ret = 1;
        for ( ; exp > 0; exp >>= 1){
            if ((exp & 1) == 1){
                ret *= a;
            }
            a *= a;
        }
        return ret;
    }

    public static Vector <Integer> getFact(int n) {
        Vector <Integer> fact = new Vector<>();
        for (int i = 1; i * i <= n; ++ i) {
            if (n % i == 0) {
                fact.add(i);
                if (i * i != n) {
                    fact.add(n / i);
                }
            }
        }
        return fact;
    }

    public static Vector <Integer> getPrime(int n) {
        Vector <Integer> prime = new Vector<>();
        for (int i = 2; i * i <= n; ++ i) {
            if (n % i == 0) {
                while (n % i == 0) {
                    n /= i;
                }
                prime.add(i);
            }
        }
        if (n > 1) {
            prime.add(n);
        }
        return prime;
    }

    public static Vector <Integer> getPhi(Vector <Integer> fact, Vector <Integer> prime, int tot) {
        int n = fact.size();
        Vector <Integer> phi = new Vector<>();
        for (int i = 0; i < n; ++ i) {
            int x = tot / fact.elementAt(i);
            int tmp = 1;
            for (int u : prime) {
                int cnt = 0;
                while (x % u == 0) {
                    x /= u;
                    ++ cnt;
                }
                if (cnt > 0) {
                    tmp *= (u - 1) * power(u, cnt - 1);
                }
            }
            phi.add(tmp);
        }
        return phi;
    }

    public static BigInteger powermod(BigInteger a, long exp, BigInteger moder) {
        BigInteger ret = new BigInteger("1");
        for ( ; exp > 0; exp >>= 1) {
            if ((exp & 1) == 1) {
                ret = ret.multiply(a).mod(moder);
            }
            a = a.multiply(a).mod(moder);
        }
        return ret;
    }

    public static int gcd(int a, int b) {
        if (b == 0) {
            return a;
        }
        return gcd(b, a % b);
    }

    public static long lcm(int a, int b) {
        return (long) a * b / gcd(a, b);
    }

    public static void main(String[] args){
        Scanner scanner = new Scanner(System.in);
        int n = scanner.nextInt(), m = scanner.nextInt();
        Vector <Integer> fact1 = getFact(n);
        Vector <Integer> fact2 = getFact(m);
        Vector <Integer> prime1 = getPrime(n);
        Vector <Integer> prime2 = getPrime(m);
        Vector <Integer> phi1 = getPhi(fact1, prime1, n);
        Vector <Integer> phi2 = getPhi(fact2, prime2, m);
        int sz1 = fact1.size(), sz2 = fact2.size();
        BigInteger sum = new BigInteger("0");
        BigInteger mod = BigInteger.valueOf(n).multiply(BigInteger.valueOf(m)).multiply(BigInteger.valueOf(MODER));
        for (int i = 0; i < sz1; ++ i) {
            for (int j = 0; j < sz2; ++ j) {
                int u = fact1.elementAt(i), v = fact2.elementAt(j);
                BigInteger tmp = new BigInteger("1");
                tmp = tmp.multiply(BigInteger.valueOf(phi1.elementAt(i))).multiply(BigInteger.valueOf(phi2.elementAt(j))).mod(mod);
                long lcm = lcm(n / u, m / v);
                long cntv = lcm * u / n, cntu = lcm * v / m;
                assert (cntu & 1) == 1 || (cntv & 1) == 1;
                long exp;
                if ((cntu & 1) == 1 && (cntv & 1) == 1) {
                    exp = (long) n * m / lcm - u - v + 1;
                }
                else if ((cntu & 1) == 1) {
                    exp = (long) n * m / lcm - v;
                }
                else {
                    exp = (long) n * m / lcm - u;
                }
                 tmp = tmp.multiply(powermod(BigInteger.valueOf(2), exp, mod)).mod(mod);
                sum = sum.add(tmp).mod(mod);
            }
        }
        sum = sum.divide(BigInteger.valueOf(n).multiply(BigInteger.valueOf(m)));
        System.out.println(sum);
    }
}

C++ 解法, 执行用时: 228ms, 内存消耗: 412K, 提交时间: 2021-11-27 01:04:49

#include <bits/stdc++.h>
using namespace std;

long long mod = 998244353;

long long add (long long x, long long y) { return x + y >= mod ? x + y - mod : x + y; }
long long mul (long long x, long long y) {
    return (__int128)x * y % mod;
}
long long qpm(long long x, __int128 y) {
    long long r = 1;
    while (y) {
        if (y & 1) r = mul(r, x);
        x = mul(x, x);
        y >>= 1;
    }
    return r;
}

int get_phi (int x) {
    long long ret = x;
    for (long long i = 2; i * i <= x; i++)
    {
        if (x % i == 0) ret = ret / i * (i - 1);
        while (x % i == 0) x /= i;
    }
    if (x > 1)
    {
        ret /= x;
        ret *= x - 1;
    }
    return (int)ret;
}

vector<int> get_fac(int x) {
    vector<int> res;
    for (int i = 1, sq = sqrt(x + 0.5); i <= sq; i++) {
        if (x % i) continue;
        res.push_back(i);
        if (i * i != x) {
            res.push_back(x / i);
        }
    }
    return res;
}

int main ()
{
    long long n, m; scanf("%lld %lld", &n, &m);
    if (n == m && n == mod) {
        puts("295980207");
        return 0;
    }
    if (m == mod) swap(n, m);
    if (n == mod) {
        mod *= mod;
    }
    vector<int> factor1, factor2;
    factor1 = get_fac((int)n);
    factor2 = get_fac((int)m);
    long long ans = 0;
    for (int a : factor1) {
        for (int b : factor2) {
            int G = __gcd(a, b);
            __int128 u = (__int128)n / a * m / b * G;
            if ((a / G) & 1) u -= m / b;
            if ((b / G) & 1) u -= n / a;
            if (((a / G) & 1) && ((b / G) & 1)) u++;
            long long p = mul(get_phi(a), get_phi(b));
            long long q = qpm(2, u);
            long long v = mul(p, q);
            ans = add(ans, v);
        }
    }
    if (n == 998244353) {
        ans /= n;
        mod /= n;
    }
    else ans = mul(ans, qpm(n, mod - 2));
    ans = mul(ans, qpm(m, mod - 2));
    printf("%d\n", (int)ans);
    return 0;
}

C++14(g++5.4) 解法, 执行用时: 201ms, 内存消耗: 380K, 提交时间: 2018-09-18 14:18:38

#include<cstdio>
#include<algorithm>
#include<vector>
using namespace std;
typedef long long ll;
typedef long double ld;
typedef pair<int,int>P;
ll mod=998244353;
ll add(ll x,ll y)
{
	x+=y;
	if(x>=mod)x-=mod;
	return x;
}
ll mul(ll x,ll y)
{
	return (x*y-(ll)(x/(ld)mod*y+1e-3)*mod+mod)%mod;
}
ll Pow(ll x,ll y)
{
	ll ans=1;
	while(y)
	{
		if(y&1)ans=mul(ans,x);
		x=mul(x,x);
		y>>=1;
	}
	return ans;
}
vector<P>f[2];
vector<int>fact;
int phi(int n)
{
	int ans=n;
	for(int i=0;i<fact.size();i++)
		if(n%fact[i]==0)ans=ans/fact[i]*(fact[i]-1);
	return ans;
}
void deal(int n,int id)
{
	fact.clear();
	int nn=n;
	for(int i=2;i*i<=nn;i++)
		if(nn%i==0)
		{
			fact.push_back(i);
			while(nn%i==0)nn/=i;
		}
	if(nn>1)fact.push_back(nn);
	for(int i=1;i*i<=n;i++)
		if(n%i==0)
		{
			f[id].push_back(P(i,phi(i)));
			if(i*i!=n)f[id].push_back(P(n/i,phi(n/i)));
		}
}
int gcd(int a,int b)
{
	return b?gcd(b,a%b):a;
}
int main()
{
	int n,m;
	scanf("%d%d",&n,&m);
	if(n==mod&&m==mod)
	{
		printf("295980207\n");
		return 0;
	}
	if(m==mod)swap(n,m);
	if(n==mod)mod=mod*n;
	ll ans=0;
	deal(n,0),deal(m,1);
	for(int i=0;i<f[0].size();i++)
		for(int j=0;j<f[1].size();j++)
		{
			ll res=mul(f[0][i].second,f[1][j].second);
			int a=f[0][i].first,b=f[1][j].first;
			ll l=(ll)a*b/gcd(a,b);
			ll t=(ll)n*m/l;
			if(((l/a)&1)&&((l/b)&1))t-=n/a+m/b-1;	
			else if((l/a)&1)t-=n/a;
			else if((l/b)&1)t-=m/b;
			ans=add(ans,mul(res,Pow(2,t)));
		}
	if(n==998244353)ans/=n,mod/=n;
	else ans=mul(ans,Pow(n,mod-2));
	ans=mul(ans,Pow(m,mod-2));
	printf("%lld\n",ans);
	return 0;
}

Python(2.7.3) 解法, 执行用时: 3437ms, 内存消耗: 3112K, 提交时间: 2018-08-16 18:49:54

def gcd(x, y):
	if y == 0:
		return x
	return gcd(y, x % y)

def lcm(x, y):
	return x / gcd(x, y) * y

def phi(x):
	re = x
	i = 2
	while i * i <= x:
		if x % i == 0:
			x /= i
			re = re / i * (i - 1)
			while x % i == 0:
				x /= i
		i += 1
	if x > 1:
		re = re / x * (x - 1)
	return re

def divisor(x):
	a = set([])
	i = 1
	while i * i <= x:
		if x % i == 0:
			a.add(i)
			a.add(x / i)
		i += 1
	return sorted(list(a))

n, m = map(int, raw_input().split())
p = 998244353
p = p * n * m
nd = divisor(n)
md = divisor(m)

ans = 0
for i in nd:
	for j in md:
		l = lcm(i, j)
		cnt = n * m / l
		if l / i % 2 == 1 and l / j % 2 == 1:
			cnt -= n / i + m / j - 1
		if l / i % 2 == 1 and l / j % 2 == 0:
			cnt -= n / i
		if l / i % 2 == 0 and l / j % 2 == 1:
			cnt -= m / j
		if l / i % 2 == 0 and l / j % 2 == 0:
			assert False
		ans = (ans + pow(2, cnt, p) * phi(i) * phi(j)) % p

assert ans % (n * m) == 0
ans /= n * m
print ans

上一题