列表

详情


NC234839. Gachapon

描述

Snuke 发现一个随机数生成器,这个生成器可以生产 0n-1 的整数。一个整数序列 ,代表这些数被生成的概率。整数 i () 被生成的概率是 ()。每次生成整数互不影响。

现在Snuke重复使用生成器生成整数,直到满足:
对于每个 i (), i 生成了 b_i 次。

请计算Snuke使用生成器的次数的期望值,把它取模 998244353 输出。

输入描述

第一行,一个整数 n
接下来n 行,每行两个整数

数据范围:





所有输入的数都是整数

输出描述

输出一个整数,表示答案对 998244353 取模的结果。

假设你的答案是 (最简分数),那么你应当输出 ,其中q 在模 998244353 意义下的逆元(数据保证存在)。

示例1

输入:

2
1 1
1 1

输出:

3

说明:

Snuke 使用生成器的期望次数是3。

示例2

输入:

3
1 3
2 2
3 1

输出:

971485877

说明:

Snuke 使用生成器的期望次数是 \frac{132929}{7200}

示例3

输入:

15
29 3
78 69
19 15
82 14
9 120
14 51
3 7
6 14
28 4
13 12
1 5
32 30
49 24
35 23
2 9

输出:

371626143

原站题解

上次编辑到这里,代码来自缓存 点击恢复默认模板

C++(clang++ 11.0.1) 解法, 执行用时: 1228ms, 内存消耗: 1208K, 提交时间: 2023-07-11 20:30:32

#include <bits/stdc++.h>

using ll = long long;
 
template<class T>
constexpr T power(T a, ll b) {
    T res = 1;
    for (; b; b /= 2, a *= a) {
        if (b % 2) {
            res *= a;
        }
    }
    return res;
}
 
constexpr ll mul(ll a, ll b, ll p) {
    ll res = a * b - ll(1.L * a * b / p) * p;
    res %= p;
    if (res < 0) {
        res += p;
    }
    return res;
}
template<ll P>
struct MLong {
    ll x;
    constexpr MLong() : x{} {}
    constexpr MLong(ll x) : x{norm(x % getMod())} {}
    
    static ll Mod;
    constexpr static ll getMod() {
        if (P > 0) {
            return P;
        } else {
            return Mod;
        }
    }
    constexpr static void setMod(ll Mod_) {
        Mod = Mod_;
    }
    constexpr ll norm(ll x) const {
        if (x < 0) {
            x += getMod();
        }
        if (x >= getMod()) {
            x -= getMod();
        }
        return x;
    }
    constexpr ll val() const {
        return x;
    }
    explicit constexpr operator ll() const {
        return x;
    }
    constexpr MLong operator-() const {
        MLong res;
        res.x = norm(getMod() - x);
        return res;
    }
    constexpr MLong inv() const {
        assert(x != 0);
        return power(*this, getMod() - 2);
    }
    constexpr MLong &operator*=(MLong rhs) & {
        x = mul(x, rhs.x, getMod());
        return *this;
    }
    constexpr MLong &operator+=(MLong rhs) & {
        x = norm(x + rhs.x);
        return *this;
    }
    constexpr MLong &operator-=(MLong rhs) & {
        x = norm(x - rhs.x);
        return *this;
    }
    constexpr MLong &operator/=(MLong rhs) & {
        return *this *= rhs.inv();
    }
    friend constexpr MLong operator*(MLong lhs, MLong rhs) {
        MLong res = lhs;
        res *= rhs;
        return res;
    }
    friend constexpr MLong operator+(MLong lhs, MLong rhs) {
        MLong res = lhs;
        res += rhs;
        return res;
    }
    friend constexpr MLong operator-(MLong lhs, MLong rhs) {
        MLong res = lhs;
        res -= rhs;
        return res;
    }
    friend constexpr MLong operator/(MLong lhs, MLong rhs) {
        MLong res = lhs;
        res /= rhs;
        return res;
    }
    friend constexpr std::istream &operator>>(std::istream &is, MLong &a) {
        ll v;
        is >> v;
        a = MLong(v);
        return is;
    }
    friend constexpr std::ostream &operator<<(std::ostream &os, const MLong &a) {
        return os << a.val();
    }
    friend constexpr bool operator==(MLong lhs, MLong rhs) {
        return lhs.val() == rhs.val();
    }
    friend constexpr bool operator!=(MLong lhs, MLong rhs) {
        return lhs.val() != rhs.val();
    }
};
 
template<>
ll MLong<0LL>::Mod = 1;
 
template<int P>
struct MInt {
    int x;
    constexpr MInt() : x{} {}
    constexpr MInt(ll x) : x{norm(x % getMod())} {}
    
    static int Mod;
    constexpr static int getMod() {
        if (P > 0) {
            return P;
        } else {
            return Mod;
        }
    }
    constexpr static void setMod(int Mod_) {
        Mod = Mod_;
    }
    constexpr int norm(int x) const {
        if (x < 0) {
            x += getMod();
        }
        if (x >= getMod()) {
            x -= getMod();
        }
        return x;
    }
    constexpr int val() const {
        return x;
    }
    explicit constexpr operator int() const {
        return x;
    }
    constexpr MInt operator-() const {
        MInt res;
        res.x = norm(getMod() - x);
        return res;
    }
    constexpr MInt inv() const {
        assert(x != 0);
        return power(*this, getMod() - 2);
    }
    constexpr MInt &operator*=(MInt rhs) & {
        x = 1LL * x * rhs.x % getMod();
        return *this;
    }
    constexpr MInt &operator+=(MInt rhs) & {
        x = norm(x + rhs.x);
        return *this;
    }
    constexpr MInt &operator-=(MInt rhs) & {
        x = norm(x - rhs.x);
        return *this;
    }
    constexpr MInt &operator/=(MInt rhs) & {
        return *this *= rhs.inv();
    }
    friend constexpr MInt operator*(MInt lhs, MInt rhs) {
        MInt res = lhs;
        res *= rhs;
        return res;
    }
    friend constexpr MInt operator+(MInt lhs, MInt rhs) {
        MInt res = lhs;
        res += rhs;
        return res;
    }
    friend constexpr MInt operator-(MInt lhs, MInt rhs) {
        MInt res = lhs;
        res -= rhs;
        return res;
    }
    friend constexpr MInt operator/(MInt lhs, MInt rhs) {
        MInt res = lhs;
        res /= rhs;
        return res;
    }
    friend constexpr std::istream &operator>>(std::istream &is, MInt &a) {
        ll v;
        is >> v;
        a = MInt(v);
        return is;
    }
    friend constexpr std::ostream &operator<<(std::ostream &os, const MInt &a) {
        return os << a.val();
    }
    friend constexpr bool operator==(MInt lhs, MInt rhs) {
        return lhs.val() == rhs.val();
    }
    friend constexpr bool operator!=(MInt lhs, MInt rhs) {
        return lhs.val() != rhs.val();
    }
};
 
template<>
int MInt<0>::Mod = 1;
 
template<int V, int P>
constexpr MInt<P> CInv = MInt<P>(V).inv();
 
constexpr int P = 998244353;
using Z = MInt<P>;

struct Comb {
    int n;
    std::vector<Z> _fac;
    std::vector<Z> _invfac;
    std::vector<Z> _inv;
    
    Comb() : n{0}, _fac{1}, _invfac{1}, _inv{0} {}
    Comb(int n) : Comb() {
        init(n);
    }
    
    void init(int m) {
        if (m <= n) return;
        _fac.resize(m + 1);
        _invfac.resize(m + 1);
        _inv.resize(m + 1);
        
        for (int i = n + 1; i <= m; i++) {
            _fac[i] = _fac[i - 1] * i;
        }
        _invfac[m] = _fac[m].inv();
        for (int i = m; i > n; i--) {
            _invfac[i - 1] = _invfac[i] * i;
            _inv[i] = _invfac[i] * _fac[i - 1];
        }
        n = m;
    }
    
    Z fac(int m) {
        if (m > n) init(2 * m);
        return _fac[m];
    }
    Z invfac(int m) {
        if (m > n) init(2 * m);
        return _invfac[m];
    }
    Z inv(int m) {
        if (m > n) init(2 * m);
        return _inv[m];
    }
    Z binom(int n, int m) {
        if (n < m || m < 0) return 0;
        return fac(n) * invfac(m) * invfac(n - m);
    }
} comb;

std::vector<Z> operator*(const std::vector<Z> &a, const std::vector<Z> &b) {
    int n = a.size();
    std::vector<Z> c(n);
    std::vector<ll> d(n);
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < n; j++) {
            int k = (i + j) % n;
            if (d[k] > 8E18) {
                c[k] += d[k];
                d[k] = 0;
            }
            d[k] += 1LL * int(a[i]) * int(b[j]);
        }
    }
    for (int i = 0; i < n; i++) {
        c[i] += d[i];
    }
    return c;
}
 
std::vector<Z> power(std::vector<Z> a, ll n) {
    std::vector<Z> res(a.size());
    res[0] = 1;
    for (; n; n /= 2, a = a * a) {
        if (n & 1) {
            res = res * a;
        }
    }
    return res;
}

// Z::setMod(m);

int main() {
    std::cin.tie(nullptr);
    std::ios_base::sync_with_stdio(false);

    int n;
    std::cin >> n;

    std::vector<int> a(n), b(n);
    for (int i = 0; i < n; ++ i) {
        std::cin >> a[i] >> b[i];
    }

    auto s = std::accumulate(a.begin(), a.end(), 0);
    auto t = std::accumulate(b.begin(), b.end(), 0);

    std::vector dp(s + 1, std::vector<Z>(t + 1));
    dp[0][0] = -1;

    for (int i = 0; i < n; ++ i)  {
        for (int j = s; j >= a[i]; -- j) {
            for (int k = t; k >= 0; -- k) {
                Z base = 1;
                for (int l = 0; l <= std::min(k, b[i] - 1); ++ l, base *= a[i]) {
                    dp[j][k] -= dp[j - a[i]][k - l] * base * comb.invfac(l);
                }
            }
        }

    }

    Z ans = 0;
    for (int i = 1; i <= s; ++ i) {
        Z base = comb.inv(i);
        for (int j = 0; j <= t; ++ j, base *= comb.inv(i)) {
            ans += dp[i][j] * base * s * comb.fac(j);
        }
    }

    std::cout << ans << "\n";

    return 0;
}

上一题