NC17707. Floorfiller
描述
输入描述
The first line contains two integers, which are n and m.
1 <= n <= 109
1 <= m <= 109
输出描述
You should output one integer, which is the answer modulo 998244353.
示例1
输入:
4 4
输出:
48
示例2
输入:
4 6
输出:
1448
示例3
输入:
998244353 998244353
输出:
295980207
Java(javac 1.8) 解法, 执行用时: 2802ms, 内存消耗: 152936K, 提交时间: 2018-08-17 00:40:04
import java.util.*; import java.math.*; public class Main{ public static final int MODER = 998244353; public static int power(int a, int exp){ int ret = 1; for ( ; exp > 0; exp >>= 1){ if ((exp & 1) == 1){ ret *= a; } a *= a; } return ret; } public static Vector <Integer> getFact(int n) { Vector <Integer> fact = new Vector<>(); for (int i = 1; i * i <= n; ++ i) { if (n % i == 0) { fact.add(i); if (i * i != n) { fact.add(n / i); } } } return fact; } public static Vector <Integer> getPrime(int n) { Vector <Integer> prime = new Vector<>(); for (int i = 2; i * i <= n; ++ i) { if (n % i == 0) { while (n % i == 0) { n /= i; } prime.add(i); } } if (n > 1) { prime.add(n); } return prime; } public static Vector <Integer> getPhi(Vector <Integer> fact, Vector <Integer> prime, int tot) { int n = fact.size(); Vector <Integer> phi = new Vector<>(); for (int i = 0; i < n; ++ i) { int x = tot / fact.elementAt(i); int tmp = 1; for (int u : prime) { int cnt = 0; while (x % u == 0) { x /= u; ++ cnt; } if (cnt > 0) { tmp *= (u - 1) * power(u, cnt - 1); } } phi.add(tmp); } return phi; } public static BigInteger powermod(BigInteger a, long exp, BigInteger moder) { BigInteger ret = new BigInteger("1"); for ( ; exp > 0; exp >>= 1) { if ((exp & 1) == 1) { ret = ret.multiply(a).mod(moder); } a = a.multiply(a).mod(moder); } return ret; } public static int gcd(int a, int b) { if (b == 0) { return a; } return gcd(b, a % b); } public static long lcm(int a, int b) { return (long) a * b / gcd(a, b); } public static void main(String[] args){ Scanner scanner = new Scanner(System.in); int n = scanner.nextInt(), m = scanner.nextInt(); Vector <Integer> fact1 = getFact(n); Vector <Integer> fact2 = getFact(m); Vector <Integer> prime1 = getPrime(n); Vector <Integer> prime2 = getPrime(m); Vector <Integer> phi1 = getPhi(fact1, prime1, n); Vector <Integer> phi2 = getPhi(fact2, prime2, m); int sz1 = fact1.size(), sz2 = fact2.size(); BigInteger sum = new BigInteger("0"); BigInteger mod = BigInteger.valueOf(n).multiply(BigInteger.valueOf(m)).multiply(BigInteger.valueOf(MODER)); for (int i = 0; i < sz1; ++ i) { for (int j = 0; j < sz2; ++ j) { int u = fact1.elementAt(i), v = fact2.elementAt(j); BigInteger tmp = new BigInteger("1"); tmp = tmp.multiply(BigInteger.valueOf(phi1.elementAt(i))).multiply(BigInteger.valueOf(phi2.elementAt(j))).mod(mod); long lcm = lcm(n / u, m / v); long cntv = lcm * u / n, cntu = lcm * v / m; assert (cntu & 1) == 1 || (cntv & 1) == 1; long exp; if ((cntu & 1) == 1 && (cntv & 1) == 1) { exp = (long) n * m / lcm - u - v + 1; } else if ((cntu & 1) == 1) { exp = (long) n * m / lcm - v; } else { exp = (long) n * m / lcm - u; } tmp = tmp.multiply(powermod(BigInteger.valueOf(2), exp, mod)).mod(mod); sum = sum.add(tmp).mod(mod); } } sum = sum.divide(BigInteger.valueOf(n).multiply(BigInteger.valueOf(m))); System.out.println(sum); } }
C++ 解法, 执行用时: 228ms, 内存消耗: 412K, 提交时间: 2021-11-27 01:04:49
#include <bits/stdc++.h> using namespace std; long long mod = 998244353; long long add (long long x, long long y) { return x + y >= mod ? x + y - mod : x + y; } long long mul (long long x, long long y) { return (__int128)x * y % mod; } long long qpm(long long x, __int128 y) { long long r = 1; while (y) { if (y & 1) r = mul(r, x); x = mul(x, x); y >>= 1; } return r; } int get_phi (int x) { long long ret = x; for (long long i = 2; i * i <= x; i++) { if (x % i == 0) ret = ret / i * (i - 1); while (x % i == 0) x /= i; } if (x > 1) { ret /= x; ret *= x - 1; } return (int)ret; } vector<int> get_fac(int x) { vector<int> res; for (int i = 1, sq = sqrt(x + 0.5); i <= sq; i++) { if (x % i) continue; res.push_back(i); if (i * i != x) { res.push_back(x / i); } } return res; } int main () { long long n, m; scanf("%lld %lld", &n, &m); if (n == m && n == mod) { puts("295980207"); return 0; } if (m == mod) swap(n, m); if (n == mod) { mod *= mod; } vector<int> factor1, factor2; factor1 = get_fac((int)n); factor2 = get_fac((int)m); long long ans = 0; for (int a : factor1) { for (int b : factor2) { int G = __gcd(a, b); __int128 u = (__int128)n / a * m / b * G; if ((a / G) & 1) u -= m / b; if ((b / G) & 1) u -= n / a; if (((a / G) & 1) && ((b / G) & 1)) u++; long long p = mul(get_phi(a), get_phi(b)); long long q = qpm(2, u); long long v = mul(p, q); ans = add(ans, v); } } if (n == 998244353) { ans /= n; mod /= n; } else ans = mul(ans, qpm(n, mod - 2)); ans = mul(ans, qpm(m, mod - 2)); printf("%d\n", (int)ans); return 0; }
C++14(g++5.4) 解法, 执行用时: 201ms, 内存消耗: 380K, 提交时间: 2018-09-18 14:18:38
#include<cstdio> #include<algorithm> #include<vector> using namespace std; typedef long long ll; typedef long double ld; typedef pair<int,int>P; ll mod=998244353; ll add(ll x,ll y) { x+=y; if(x>=mod)x-=mod; return x; } ll mul(ll x,ll y) { return (x*y-(ll)(x/(ld)mod*y+1e-3)*mod+mod)%mod; } ll Pow(ll x,ll y) { ll ans=1; while(y) { if(y&1)ans=mul(ans,x); x=mul(x,x); y>>=1; } return ans; } vector<P>f[2]; vector<int>fact; int phi(int n) { int ans=n; for(int i=0;i<fact.size();i++) if(n%fact[i]==0)ans=ans/fact[i]*(fact[i]-1); return ans; } void deal(int n,int id) { fact.clear(); int nn=n; for(int i=2;i*i<=nn;i++) if(nn%i==0) { fact.push_back(i); while(nn%i==0)nn/=i; } if(nn>1)fact.push_back(nn); for(int i=1;i*i<=n;i++) if(n%i==0) { f[id].push_back(P(i,phi(i))); if(i*i!=n)f[id].push_back(P(n/i,phi(n/i))); } } int gcd(int a,int b) { return b?gcd(b,a%b):a; } int main() { int n,m; scanf("%d%d",&n,&m); if(n==mod&&m==mod) { printf("295980207\n"); return 0; } if(m==mod)swap(n,m); if(n==mod)mod=mod*n; ll ans=0; deal(n,0),deal(m,1); for(int i=0;i<f[0].size();i++) for(int j=0;j<f[1].size();j++) { ll res=mul(f[0][i].second,f[1][j].second); int a=f[0][i].first,b=f[1][j].first; ll l=(ll)a*b/gcd(a,b); ll t=(ll)n*m/l; if(((l/a)&1)&&((l/b)&1))t-=n/a+m/b-1; else if((l/a)&1)t-=n/a; else if((l/b)&1)t-=m/b; ans=add(ans,mul(res,Pow(2,t))); } if(n==998244353)ans/=n,mod/=n; else ans=mul(ans,Pow(n,mod-2)); ans=mul(ans,Pow(m,mod-2)); printf("%lld\n",ans); return 0; }
Python(2.7.3) 解法, 执行用时: 3437ms, 内存消耗: 3112K, 提交时间: 2018-08-16 18:49:54
def gcd(x, y): if y == 0: return x return gcd(y, x % y) def lcm(x, y): return x / gcd(x, y) * y def phi(x): re = x i = 2 while i * i <= x: if x % i == 0: x /= i re = re / i * (i - 1) while x % i == 0: x /= i i += 1 if x > 1: re = re / x * (x - 1) return re def divisor(x): a = set([]) i = 1 while i * i <= x: if x % i == 0: a.add(i) a.add(x / i) i += 1 return sorted(list(a)) n, m = map(int, raw_input().split()) p = 998244353 p = p * n * m nd = divisor(n) md = divisor(m) ans = 0 for i in nd: for j in md: l = lcm(i, j) cnt = n * m / l if l / i % 2 == 1 and l / j % 2 == 1: cnt -= n / i + m / j - 1 if l / i % 2 == 1 and l / j % 2 == 0: cnt -= n / i if l / i % 2 == 0 and l / j % 2 == 1: cnt -= m / j if l / i % 2 == 0 and l / j % 2 == 0: assert False ans = (ans + pow(2, cnt, p) * phi(i) * phi(j)) % p assert ans % (n * m) == 0 ans /= n * m print ans