2021牛客暑期多校训练营10 Game of Death(子集反演(容斥原理)+多项式优化)

Game of Death

https://ac.nowcoder.com/acm/contest/11261/G

链接:https://ac.nowcoder.com/acm/contest/11261/G

确实想不到这个状态设计(),估计是对于子集容斥(反演)这个概念接触不多。。

我们设的状态是表示被杀的人刚好是集合S,根据各种反演的惯例我们设表示被杀的人是S的子集,这里,我们子集反演
因为
根据子集反演(其实就是容斥):
(可能不是那么显然,但想想:g(S)的符号一定是+,那么就需要-g(比S少1人的集合)……这样的容斥形式,也不是不能接受)。

现在目标变成求
显然有式子:
令q=1-p,也就是没有命中的概率。

G(x)为固定某x人被射死的概率。


(其实意义就是:对于S内的人,要么没打中,要么打中了集合内的其他人。对于集合外的人,要么没命中,要么命中的是集合外的人)。

因此:


令F(x)为集合大小为x的所有情况的概率总和:


F(x)即为k-x时的答案。

怎么算?
观察后面的式子:


所以直接拿卷就行了。

CODE:

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 3e5 + 10;
const int mod = 998244353;
#define int ll
namespace ModInt {
#define mul(x, y) ((1ll * (x) * (y) >= mod ? (x) * (y) % mod : 1ll * (x) * (y)))
#define dec(x, y) (1ll * (x) - (y) < 0 ? 1ll * (x) - (y) + mod : 1ll * (x) - (y))
#define add(x, y) (1ll * (x) + (y) >= mod ? 1ll * (x) + (y)-mod : 1ll * (x) + (y))
#define ck(x) ((x) >= mod ? (x) -= mod : (x))
#define kc(x) ((x) < mod ? (x) += mod : (x))
int ksm(int a, int n) {
    int res = 1;
    while (n) {
        if (n & 1)
            res = 1ll * res * a % mod;
        a = 1ll * a * a % mod;
        n >>= 1;
    }
    return res;
}

struct modInt {
    int val;
    modInt()
        : val(0) {}
    modInt(int _v) { val = _v; }
    friend modInt operator+(const modInt& a, const modInt& b) {
        modInt res(a.val + b.val);
        ck(res.val);
        return res;
    }
    friend modInt operator*(const modInt& a, const modInt& b) {
        modInt res(1ll * a.val * b.val % mod);
        return res;
    }
    friend modInt operator-(const modInt& a, const modInt& b) {
        modInt res(a.val - b.val);
        kc(res.val);
        return res;
    }
    friend modInt operator/(const modInt& a, const modInt& b) { return a * modInt(ksm(b.val, mod - 2)); }
    modInt operator^(int x) { return ksm(this->val, x); }
    friend ostream& operator<<(ostream& out, modInt& a) {
        out << a.val;
        return out;
    }
    modInt operator-() {
        return mod - *this;
    }
    modInt operator+() {
        return *this;
    }
    modInt operator>(const modInt& a) {
        return val > a.val;
    }
    modInt operator<(const modInt& a) {
        return val < a.val;
    }
    modInt operator==(const modInt& a) {
        return val == a.val;
    }
    modInt operator<=(const modInt& a) {
        return val <= a.val;
    }
    modInt operator>=(const modInt& a) {
        return val >= a.val;
    }
    explicit operator int() const{
        return val;
    }
};
#undef mul
#undef dec
#undef add
#undef ck
#undef kc
};  // namespace ModInt
namespace IO {
template <typename T>
void write(T x) {
    if (x < 0) {
        putchar('-');
        x = -x;
    }
    if (x > 9) {
        write(x / 10);
    }
    putchar(x % 10 + '0');
}

template <typename T>
void read(T& x) {
    x = 0;
    char ch = getchar();
    int f = 1;
    while (!isdigit(ch)) {
        if (ch == '-')
            f *= -1;
        ch = getchar();
    }
    while (isdigit(ch)) {
        x = x * 10 + (ch - '0');
        ch = getchar();
    }
    x *= f;
}
};  // namespace IO
namespace Math {
ll w;
const int p = ::mod;
struct complex {
    ll real, imag;
    complex(ll a = 0, ll b = 0) {
        real = a;
        imag = b;
    }
    friend complex operator*(const complex& a, const complex& b) {
        complex ans;
        ans.real = ((a.real * b.real % p + a.imag * b.imag % p * w % p) % p + p) % p;
        ans.imag = ((a.real * b.imag % p + a.imag * b.real % p) % p + p) % p;
        return ans;
    }
};
ll x1, x2;
ll ksm(ll a, ll b, ll p) {
    ll ans = 1;
    while (b) {
        if (b & 1)
            ans = (ans * a) % p;
        a = (a * a) % p;
        b >>= 1;
    }
    return ans;
}

ll ksm(complex a, ll b, ll p) {
    complex ans(1, 0);
    while (b) {
        if (b & 1)
            ans = ans * a;
        a = a * a;
        b >>= 1;
    }
    return ans.real % p;
}

bool Cipolla(ll n, ll& x0, ll& x1) {
    n %= p;
    if (ksm(n, (p - 1) >> 1, p) == p - 1)
        return false;
    ll a;
    while (true) {
        a = rand() % p;
        w = ((a * a % p - n) % p + p) % p;
        if (ksm(w, (p - 1) >> 1, p) == p - 1)
            break;
    }
    complex x(a, 1);
    x0 = (ksm(x, (p + 1) >> 1, p) + p) % p;
    x1 = (p - x0 + p) % p;
    return true;
}
};  // namespace Math

namespace NTT {
#define mul(x, y) ((1ll * x * y >= mod ? x * y % mod : 1ll * x * y))
#define dec(x, y) (1ll * x - y < 0 ? 1ll * x - y + mod : 1ll * x - y)
#define add(x, y) (1ll * x + y >= mod ? 1ll * x + y - mod : 1ll * x + y)
#define ck(x) (x >= mod ? x - mod : x)
typedef vector<int> Poly;
int ksm(int a, int n, int mod = ::mod) {
    int res = 1;
    while (n) {
        if (n & 1)
            res = 1ll * res * a % mod;
        a = 1ll * a * a % mod;
        n >>= 1;
    }
    return res;
}
const int img = 86583718;
const int g = 3, INV = ksm(g, mod - 2);
const int mx = 21;
int R[maxn << 2], deer[2][mx][maxn << 2], inv[maxn << 2];
void init(const int t) {
    for (int p = 1; p <= t; ++p) {
        int buf1 = ksm(g, (mod - 1) / (1 << p));
        int buf0 = ksm(INV, (mod - 1) / (1 << p));
        deer[0][p][0] = deer[1][p][0] = 1;
        for (int i = 1; i < (1 << p); ++i) {
            deer[0][p][i] = 1ll * deer[0][p][i - 1] * buf0 % mod;
            deer[1][p][i] = 1ll * deer[1][p][i - 1] * buf1 % mod;
        }
    }
    inv[1] = 1;
    for (int i = 2; i <= (1 << t); ++i)
        inv[i] = 1ll * inv[mod % i] * (mod - mod / i) % mod;
}

int NTT_init(int n) {
    int lim = 1, l = 0;
    while (lim < n)
        lim <<= 1, l++;
    for (int i = 0; i < lim; ++i)
        R[i] = (R[i >> 1] >> 1) | ((i & 1) << (l - 1));
    return lim;
}
void ntt(Poly& A, int type, int lim) {
    A.resize(lim);
    for (int i = 0; i < lim; ++i)
        if (i < R[i])
            swap(A[i], A[R[i]]);
    for (int mid = 2, j = 1; mid <= lim; mid <<= 1, ++j) {
        int len = mid >> 1;
        for (int pos = 0; pos < lim; pos += mid) {
            int* wn = deer[type][j];
            for (int i = pos; i < pos + len; ++i, ++wn) {
                int tmp = 1ll * (*wn) * A[i + len] % mod;
                A[i + len] = ck(A[i] - tmp + mod);
                A[i] = ck(A[i] + tmp);
            }
        }
    }
    if (type == 0) {
        for (int i = 0; i < lim; ++i)
            A[i] = 1ll * A[i] * inv[lim] % mod;
    }
}
Poly poly_mul(Poly A, Poly B) {
    int deg = A.size() + B.size() - 1;
    int limit = NTT_init(deg);
    Poly C(limit);
    ntt(A, 1, limit);
    ntt(B, 1, limit);
    for (int i = 0; i < limit; ++i)
        C[i] = 1ll * A[i] * B[i] % mod;
    ntt(C, 0, limit);
    C.resize(deg);
    return C;
}
Poly poly_inv(Poly& f, int deg) {
    if (deg == 1)
        return Poly(1, ksm(f[0], mod - 2));

    Poly A(f.begin(), f.begin() + deg);
    Poly B = poly_inv(f, (deg + 1) >> 1);
    int limit = NTT_init(deg << 1);
    ntt(A, 1, limit), ntt(B, 1, limit);
    for (int i = 0; i < limit; ++i)
        A[i] = B[i] * (2 - 1ll * A[i] * B[i] % mod + mod) % mod;
    ntt(A, 0, limit);
    A.resize(deg);
    return A;
}
Poly poly_idev(Poly f) {
    int n = f.size();
    for (int i = n - 1; i - 1 >= 0; --i)
        f[i] = 1ll * f[i - 1] * inv[i] % mod;
    f[0] = 0;
    return f;
}

Poly poly_dev(Poly f) {
    int n = f.size();
    for (int i = 1; i < n; ++i)
        f[i - 1] = 1ll * f[i] * i % mod;
    f.resize(n - 1);
    return f;
}

Poly poly_ln(Poly f, int deg) {
    Poly A = poly_idev(poly_mul(poly_dev(f), poly_inv(f, deg)));
    return A.resize(deg), A;
}

Poly poly_exp(Poly& f, int deg) {
    //cerr<<deg<<endl;
    if (deg == 1)
        return Poly(1, 1);

    Poly B = poly_exp(f, (deg + 1) >> 1);
    B.resize(deg);
    Poly lnB = poly_ln(B, deg);
    for (int i = 0; i < deg; ++i)
        lnB[i] = ck(f[i] - lnB[i] + mod);

    int limit = NTT_init(deg << 1);
    ntt(B, 1, limit), ntt(lnB, 1, limit);
    for (int i = 0; i < limit; ++i)
        B[i] = 1ll * B[i] * (1 + lnB[i]) % mod;
    ntt(B, 0, limit);
    B.resize(deg);
    return B;
}
Poly poly_pow(Poly& f, int k) {
    f = poly_ln(f, f.size());
    for (auto& x : f)
        x = 1ll * x * k % mod;
    return poly_exp(f, f.size());
}
Poly power(Poly f, int k1, int k2, int deg) {
    int s = 0;
    while (f[s] == 0 && s < f.size())
        ++s;
    if (1ll * s * k1 >= deg) {
        return vector<int>(deg);
    }
    int Inv = ksm(f[s], mod - 2, mod);
    int Mul = ksm(f[s], k2);
    deg -= s;
    for (int i = 0; i < deg; ++i)
        f[i] = f[i + s];
    f.resize(deg);
    for (int i = 0; i < deg; ++i)
        f[i] = 1ll * f[i] * Inv % mod;
    auto res1 = poly_ln(f, deg);
    for (int i = 0; i < res1.size(); ++i)
        res1[i] = 1ll * res1[i] * k1 % mod;
    auto res2 = poly_exp(res1, deg);
    for (int i = 0; i < deg; ++i)
        res2[i] = 1ll * res2[i] * Mul % mod;
    deg += s;
    int now = s * k1;
    Poly res;
    res.resize(deg);
    for (int i = deg - 1; i >= now; --i)
        res[i] = res2[i - now];
    for (int i = now - 1; i >= 0; --i)
        res[i] = 0;
    return res;
}
Poly Poly_Sqrt(Poly& f, int deg) {
    if (deg == 1)
        return Poly(1, 1);
    Poly A(f.begin(), f.begin() + deg);
    Poly B = Poly_Sqrt(f, (deg + 1) >> 1);
    Poly IB = poly_inv(B, deg);
    int lim = NTT_init(deg << 1);
    ntt(A, 1, lim), ntt(IB, 1, lim);
    for (int i = 0; i < lim; ++i) {
        A[i] = 1ll * A[i] * IB[i] % mod;
    }
    ntt(A, 0, lim);
    for (int i = 0; i < deg; ++i) {
        A[i] = (1ll * A[i] + B[i]) % mod * inv[2] % mod;
    }
    A.resize(deg);
    return A;
}
Poly Sqrt(Poly& f, int deg) {
    const int Pow = ksm(2, mod - 2);
    int k1 = 1;
    if (f[0] != 1) {
        k1 = ksm(f[0], mod - 2);
        for (int i = 1; i < f.size(); ++i) {
            f[i] = 1ll * k1 * f[i] % mod;
        }
        ll x0, x1;
        assert(Math::Cipolla(f[0], x0, x1));
        k1 = min(x1, x0);
        f[0] = 1;
    }
    auto Ln = poly_ln(f, deg);
    for (int i = 0; i < f.size(); ++i) {
        Ln[i] = 1ll * Ln[i] * Pow % mod;
    }
    auto Exp = poly_exp(Ln, deg);
    for (int i = 0; i < Exp.size(); ++i)
        Exp[i] = 1ll * Exp[i] * k1 % mod;
    return Exp;
}
Poly poly_sin(Poly& f, int deg) {
    Poly A(f.begin(), f.begin() + deg);
    Poly B(deg), C(deg);
    for (int i = 0; i < deg; ++i) {
        A[i] = 1ll * A[i] * img % mod;
    }
    B = poly_exp(A, deg);
    C = poly_inv(B, deg);
    const int inv2i = ksm(img << 1, mod - 2);
    for (int i = 0; i < deg; ++i) {
        A[i] = 1ll * (1ll * B[i] - C[i] + mod) % mod * inv2i % mod;
    }
    return A;
}
Poly poly_cos(Poly& f, int deg) {
    Poly A(f.begin(), f.begin() + deg);
    Poly B(deg), C(deg);
    for (int i = 0; i < deg; ++i) {
        A[i] = 1ll * A[i] * img % mod;
    }
    B = poly_exp(A, deg);
    C = poly_inv(B, deg);
    const int inv2 = ksm(2, mod - 2);
    for (int i = 0; i < deg; ++i) {
        A[i] = (1ll * B[i] + C[i]) % mod * inv2 % mod;
    }
    return A;
}

Poly poly_arcsin(Poly f, int deg) {
    Poly A(f.size()), B(f.size()), C(f.size());
    A = poly_dev(f);
    B = poly_mul(f, f);
    for (int i = 0; i < deg; ++i) {
        B[i] = dec(mod, B[i]);
    }
    B[0] = add(B[0], 1);
    C = Poly_Sqrt(B, deg);
    C = poly_inv(C, deg);
    C = poly_mul(A, C);
    C = poly_idev(C);
    return C;
}

Poly poly_arctan(Poly f, int deg) {
    Poly A(f.size()), B(f.size()), C(f.size());
    A = poly_dev(f);
    B = poly_mul(f, f);
    B[0] = add(B[0], 1);
    C = poly_inv(B, deg);
    C = poly_mul(A, C);
    C = poly_idev(C);
    C.resize(deg);
    return C;
}
};  // namespace NTT
using NTT::Poly;
using namespace IO;
using mi = ModInt::modInt;
int n, t;
mi f[maxn], finv[maxn];
mi a,b,p;

mi C(int n,int m){
    if(n<0||m<0||n-m<0)return 0;
    return f[n]*finv[n-m]*finv[m];
}

signed main() {
    f[0] = 1;
    for (int i = 1; i < maxn; ++i)
        f[i] = f[i - 1] * i;
    finv[maxn - 1] = f[maxn - 1] ^ (mod - 2);
    for (int i = maxn - 2; i >= 0; --i) {
        finv[i] = (i + 1) * finv[i + 1];
        //assert((finv[i] * f[i]).val == 1);
    }
    NTT::init(NTT::mx - 1);
    read(n),read(a.val),read(b.val);
    p=a/b;
    Poly F(n+1),G(n+1);
    for(int i=0;i<=n;++i){
        mi x=((mi)1-p+(mi(i)-1)/(n-1)*p);
        mi y=((mi)1-p+p*(mi(i))/(n-1));
        x=x^(i),y=y^(n-i);
        G[i]=(x*y*finv[i]).val;
    }
    int now=1;
    for(int i=0;i<=n;++i,now=(-now+mod)%mod){
        F[i]=(now*finv[i]).val;
    }
    F=NTT::poly_mul(F,G);
    for(int i=0;i<F.size();++i){
        F[i]=int(F[i]*C(n,i));
    }
    for(int i=n;i>=0;--i){
        write(int(F[i]*f[i]));putchar('\n');
    }
    return 0;
}
全部评论

相关推荐

3 收藏 评论
分享
牛客网
牛客企业服务