NC234839. Gachapon
描述
输入描述
第一行,一个整数 。接下来 行,每行两个整数 。数据范围:
所有输入的数都是整数
输出描述
输出一个整数,表示答案对 998244353 取模的结果。假设你的答案是 (最简分数),那么你应当输出 ,其中是 在模 998244353 意义下的逆元(数据保证存在)。
示例1
输入:
2 1 1 1 1
输出:
3
说明:
Snuke 使用生成器的期望次数是3。示例2
输入:
3 1 3 2 2 3 1
输出:
971485877
说明:
Snuke 使用生成器的期望次数是示例3
输入:
15 29 3 78 69 19 15 82 14 9 120 14 51 3 7 6 14 28 4 13 12 1 5 32 30 49 24 35 23 2 9
输出:
371626143
C++(clang++ 11.0.1) 解法, 执行用时: 1228ms, 内存消耗: 1208K, 提交时间: 2023-07-11 20:30:32
#include <bits/stdc++.h> using ll = long long; template<class T> constexpr T power(T a, ll b) { T res = 1; for (; b; b /= 2, a *= a) { if (b % 2) { res *= a; } } return res; } constexpr ll mul(ll a, ll b, ll p) { ll res = a * b - ll(1.L * a * b / p) * p; res %= p; if (res < 0) { res += p; } return res; } template<ll P> struct MLong { ll x; constexpr MLong() : x{} {} constexpr MLong(ll x) : x{norm(x % getMod())} {} static ll Mod; constexpr static ll getMod() { if (P > 0) { return P; } else { return Mod; } } constexpr static void setMod(ll Mod_) { Mod = Mod_; } constexpr ll norm(ll x) const { if (x < 0) { x += getMod(); } if (x >= getMod()) { x -= getMod(); } return x; } constexpr ll val() const { return x; } explicit constexpr operator ll() const { return x; } constexpr MLong operator-() const { MLong res; res.x = norm(getMod() - x); return res; } constexpr MLong inv() const { assert(x != 0); return power(*this, getMod() - 2); } constexpr MLong &operator*=(MLong rhs) & { x = mul(x, rhs.x, getMod()); return *this; } constexpr MLong &operator+=(MLong rhs) & { x = norm(x + rhs.x); return *this; } constexpr MLong &operator-=(MLong rhs) & { x = norm(x - rhs.x); return *this; } constexpr MLong &operator/=(MLong rhs) & { return *this *= rhs.inv(); } friend constexpr MLong operator*(MLong lhs, MLong rhs) { MLong res = lhs; res *= rhs; return res; } friend constexpr MLong operator+(MLong lhs, MLong rhs) { MLong res = lhs; res += rhs; return res; } friend constexpr MLong operator-(MLong lhs, MLong rhs) { MLong res = lhs; res -= rhs; return res; } friend constexpr MLong operator/(MLong lhs, MLong rhs) { MLong res = lhs; res /= rhs; return res; } friend constexpr std::istream &operator>>(std::istream &is, MLong &a) { ll v; is >> v; a = MLong(v); return is; } friend constexpr std::ostream &operator<<(std::ostream &os, const MLong &a) { return os << a.val(); } friend constexpr bool operator==(MLong lhs, MLong rhs) { return lhs.val() == rhs.val(); } friend constexpr bool operator!=(MLong lhs, MLong rhs) { return lhs.val() != rhs.val(); } }; template<> ll MLong<0LL>::Mod = 1; template<int P> struct MInt { int x; constexpr MInt() : x{} {} constexpr MInt(ll x) : x{norm(x % getMod())} {} static int Mod; constexpr static int getMod() { if (P > 0) { return P; } else { return Mod; } } constexpr static void setMod(int Mod_) { Mod = Mod_; } constexpr int norm(int x) const { if (x < 0) { x += getMod(); } if (x >= getMod()) { x -= getMod(); } return x; } constexpr int val() const { return x; } explicit constexpr operator int() const { return x; } constexpr MInt operator-() const { MInt res; res.x = norm(getMod() - x); return res; } constexpr MInt inv() const { assert(x != 0); return power(*this, getMod() - 2); } constexpr MInt &operator*=(MInt rhs) & { x = 1LL * x * rhs.x % getMod(); return *this; } constexpr MInt &operator+=(MInt rhs) & { x = norm(x + rhs.x); return *this; } constexpr MInt &operator-=(MInt rhs) & { x = norm(x - rhs.x); return *this; } constexpr MInt &operator/=(MInt rhs) & { return *this *= rhs.inv(); } friend constexpr MInt operator*(MInt lhs, MInt rhs) { MInt res = lhs; res *= rhs; return res; } friend constexpr MInt operator+(MInt lhs, MInt rhs) { MInt res = lhs; res += rhs; return res; } friend constexpr MInt operator-(MInt lhs, MInt rhs) { MInt res = lhs; res -= rhs; return res; } friend constexpr MInt operator/(MInt lhs, MInt rhs) { MInt res = lhs; res /= rhs; return res; } friend constexpr std::istream &operator>>(std::istream &is, MInt &a) { ll v; is >> v; a = MInt(v); return is; } friend constexpr std::ostream &operator<<(std::ostream &os, const MInt &a) { return os << a.val(); } friend constexpr bool operator==(MInt lhs, MInt rhs) { return lhs.val() == rhs.val(); } friend constexpr bool operator!=(MInt lhs, MInt rhs) { return lhs.val() != rhs.val(); } }; template<> int MInt<0>::Mod = 1; template<int V, int P> constexpr MInt<P> CInv = MInt<P>(V).inv(); constexpr int P = 998244353; using Z = MInt<P>; struct Comb { int n; std::vector<Z> _fac; std::vector<Z> _invfac; std::vector<Z> _inv; Comb() : n{0}, _fac{1}, _invfac{1}, _inv{0} {} Comb(int n) : Comb() { init(n); } void init(int m) { if (m <= n) return; _fac.resize(m + 1); _invfac.resize(m + 1); _inv.resize(m + 1); for (int i = n + 1; i <= m; i++) { _fac[i] = _fac[i - 1] * i; } _invfac[m] = _fac[m].inv(); for (int i = m; i > n; i--) { _invfac[i - 1] = _invfac[i] * i; _inv[i] = _invfac[i] * _fac[i - 1]; } n = m; } Z fac(int m) { if (m > n) init(2 * m); return _fac[m]; } Z invfac(int m) { if (m > n) init(2 * m); return _invfac[m]; } Z inv(int m) { if (m > n) init(2 * m); return _inv[m]; } Z binom(int n, int m) { if (n < m || m < 0) return 0; return fac(n) * invfac(m) * invfac(n - m); } } comb; std::vector<Z> operator*(const std::vector<Z> &a, const std::vector<Z> &b) { int n = a.size(); std::vector<Z> c(n); std::vector<ll> d(n); for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { int k = (i + j) % n; if (d[k] > 8E18) { c[k] += d[k]; d[k] = 0; } d[k] += 1LL * int(a[i]) * int(b[j]); } } for (int i = 0; i < n; i++) { c[i] += d[i]; } return c; } std::vector<Z> power(std::vector<Z> a, ll n) { std::vector<Z> res(a.size()); res[0] = 1; for (; n; n /= 2, a = a * a) { if (n & 1) { res = res * a; } } return res; } // Z::setMod(m); int main() { std::cin.tie(nullptr); std::ios_base::sync_with_stdio(false); int n; std::cin >> n; std::vector<int> a(n), b(n); for (int i = 0; i < n; ++ i) { std::cin >> a[i] >> b[i]; } auto s = std::accumulate(a.begin(), a.end(), 0); auto t = std::accumulate(b.begin(), b.end(), 0); std::vector dp(s + 1, std::vector<Z>(t + 1)); dp[0][0] = -1; for (int i = 0; i < n; ++ i) { for (int j = s; j >= a[i]; -- j) { for (int k = t; k >= 0; -- k) { Z base = 1; for (int l = 0; l <= std::min(k, b[i] - 1); ++ l, base *= a[i]) { dp[j][k] -= dp[j - a[i]][k - l] * base * comb.invfac(l); } } } } Z ans = 0; for (int i = 1; i <= s; ++ i) { Z base = comb.inv(i); for (int j = 0; j <= t; ++ j, base *= comb.inv(i)) { ans += dp[i][j] * base * s * comb.fac(j); } } std::cout << ans << "\n"; return 0; }