2021牛客暑期多校训练营10 Game of Death(子集反演(容斥原理)+多项式优化)
Game of Death
https://ac.nowcoder.com/acm/contest/11261/G
链接:https://ac.nowcoder.com/acm/contest/11261/G
确实想不到这个状态设计(),估计是对于子集容斥(反演)这个概念接触不多。。
我们设的状态是表示被杀的人刚好是集合S,根据各种反演的惯例我们设表示被杀的人是S的子集,这里,我们子集反演
因为
根据子集反演(其实就是容斥):
(可能不是那么显然,但想想:g(S)的符号一定是+,那么就需要-g(比S少1人的集合)……这样的容斥形式,也不是不能接受)。
现在目标变成求:显然有式子:
令q=1-p,也就是没有命中的概率。
G(x)为固定某x人被射死的概率。
(其实意义就是:对于S内的人,要么没打中,要么打中了集合内的其他人。对于集合外的人,要么没命中,要么命中的是集合外的人)。
因此:
令F(x)为集合大小为x的所有情况的概率总和:
F(x)即为k-x时的答案。
怎么算?
观察后面的式子:
所以直接拿和卷就行了。
CODE:
#include <bits/stdc++.h> using namespace std; typedef long long ll; const int maxn = 3e5 + 10; const int mod = 998244353; #define int ll namespace ModInt { #define mul(x, y) ((1ll * (x) * (y) >= mod ? (x) * (y) % mod : 1ll * (x) * (y))) #define dec(x, y) (1ll * (x) - (y) < 0 ? 1ll * (x) - (y) + mod : 1ll * (x) - (y)) #define add(x, y) (1ll * (x) + (y) >= mod ? 1ll * (x) + (y)-mod : 1ll * (x) + (y)) #define ck(x) ((x) >= mod ? (x) -= mod : (x)) #define kc(x) ((x) < mod ? (x) += mod : (x)) int ksm(int a, int n) { int res = 1; while (n) { if (n & 1) res = 1ll * res * a % mod; a = 1ll * a * a % mod; n >>= 1; } return res; } struct modInt { int val; modInt() : val(0) {} modInt(int _v) { val = _v; } friend modInt operator+(const modInt& a, const modInt& b) { modInt res(a.val + b.val); ck(res.val); return res; } friend modInt operator*(const modInt& a, const modInt& b) { modInt res(1ll * a.val * b.val % mod); return res; } friend modInt operator-(const modInt& a, const modInt& b) { modInt res(a.val - b.val); kc(res.val); return res; } friend modInt operator/(const modInt& a, const modInt& b) { return a * modInt(ksm(b.val, mod - 2)); } modInt operator^(int x) { return ksm(this->val, x); } friend ostream& operator<<(ostream& out, modInt& a) { out << a.val; return out; } modInt operator-() { return mod - *this; } modInt operator+() { return *this; } modInt operator>(const modInt& a) { return val > a.val; } modInt operator<(const modInt& a) { return val < a.val; } modInt operator==(const modInt& a) { return val == a.val; } modInt operator<=(const modInt& a) { return val <= a.val; } modInt operator>=(const modInt& a) { return val >= a.val; } explicit operator int() const{ return val; } }; #undef mul #undef dec #undef add #undef ck #undef kc }; // namespace ModInt namespace IO { template <typename T> void write(T x) { if (x < 0) { putchar('-'); x = -x; } if (x > 9) { write(x / 10); } putchar(x % 10 + '0'); } template <typename T> void read(T& x) { x = 0; char ch = getchar(); int f = 1; while (!isdigit(ch)) { if (ch == '-') f *= -1; ch = getchar(); } while (isdigit(ch)) { x = x * 10 + (ch - '0'); ch = getchar(); } x *= f; } }; // namespace IO namespace Math { ll w; const int p = ::mod; struct complex { ll real, imag; complex(ll a = 0, ll b = 0) { real = a; imag = b; } friend complex operator*(const complex& a, const complex& b) { complex ans; ans.real = ((a.real * b.real % p + a.imag * b.imag % p * w % p) % p + p) % p; ans.imag = ((a.real * b.imag % p + a.imag * b.real % p) % p + p) % p; return ans; } }; ll x1, x2; ll ksm(ll a, ll b, ll p) { ll ans = 1; while (b) { if (b & 1) ans = (ans * a) % p; a = (a * a) % p; b >>= 1; } return ans; } ll ksm(complex a, ll b, ll p) { complex ans(1, 0); while (b) { if (b & 1) ans = ans * a; a = a * a; b >>= 1; } return ans.real % p; } bool Cipolla(ll n, ll& x0, ll& x1) { n %= p; if (ksm(n, (p - 1) >> 1, p) == p - 1) return false; ll a; while (true) { a = rand() % p; w = ((a * a % p - n) % p + p) % p; if (ksm(w, (p - 1) >> 1, p) == p - 1) break; } complex x(a, 1); x0 = (ksm(x, (p + 1) >> 1, p) + p) % p; x1 = (p - x0 + p) % p; return true; } }; // namespace Math namespace NTT { #define mul(x, y) ((1ll * x * y >= mod ? x * y % mod : 1ll * x * y)) #define dec(x, y) (1ll * x - y < 0 ? 1ll * x - y + mod : 1ll * x - y) #define add(x, y) (1ll * x + y >= mod ? 1ll * x + y - mod : 1ll * x + y) #define ck(x) (x >= mod ? x - mod : x) typedef vector<int> Poly; int ksm(int a, int n, int mod = ::mod) { int res = 1; while (n) { if (n & 1) res = 1ll * res * a % mod; a = 1ll * a * a % mod; n >>= 1; } return res; } const int img = 86583718; const int g = 3, INV = ksm(g, mod - 2); const int mx = 21; int R[maxn << 2], deer[2][mx][maxn << 2], inv[maxn << 2]; void init(const int t) { for (int p = 1; p <= t; ++p) { int buf1 = ksm(g, (mod - 1) / (1 << p)); int buf0 = ksm(INV, (mod - 1) / (1 << p)); deer[0][p][0] = deer[1][p][0] = 1; for (int i = 1; i < (1 << p); ++i) { deer[0][p][i] = 1ll * deer[0][p][i - 1] * buf0 % mod; deer[1][p][i] = 1ll * deer[1][p][i - 1] * buf1 % mod; } } inv[1] = 1; for (int i = 2; i <= (1 << t); ++i) inv[i] = 1ll * inv[mod % i] * (mod - mod / i) % mod; } int NTT_init(int n) { int lim = 1, l = 0; while (lim < n) lim <<= 1, l++; for (int i = 0; i < lim; ++i) R[i] = (R[i >> 1] >> 1) | ((i & 1) << (l - 1)); return lim; } void ntt(Poly& A, int type, int lim) { A.resize(lim); for (int i = 0; i < lim; ++i) if (i < R[i]) swap(A[i], A[R[i]]); for (int mid = 2, j = 1; mid <= lim; mid <<= 1, ++j) { int len = mid >> 1; for (int pos = 0; pos < lim; pos += mid) { int* wn = deer[type][j]; for (int i = pos; i < pos + len; ++i, ++wn) { int tmp = 1ll * (*wn) * A[i + len] % mod; A[i + len] = ck(A[i] - tmp + mod); A[i] = ck(A[i] + tmp); } } } if (type == 0) { for (int i = 0; i < lim; ++i) A[i] = 1ll * A[i] * inv[lim] % mod; } } Poly poly_mul(Poly A, Poly B) { int deg = A.size() + B.size() - 1; int limit = NTT_init(deg); Poly C(limit); ntt(A, 1, limit); ntt(B, 1, limit); for (int i = 0; i < limit; ++i) C[i] = 1ll * A[i] * B[i] % mod; ntt(C, 0, limit); C.resize(deg); return C; } Poly poly_inv(Poly& f, int deg) { if (deg == 1) return Poly(1, ksm(f[0], mod - 2)); Poly A(f.begin(), f.begin() + deg); Poly B = poly_inv(f, (deg + 1) >> 1); int limit = NTT_init(deg << 1); ntt(A, 1, limit), ntt(B, 1, limit); for (int i = 0; i < limit; ++i) A[i] = B[i] * (2 - 1ll * A[i] * B[i] % mod + mod) % mod; ntt(A, 0, limit); A.resize(deg); return A; } Poly poly_idev(Poly f) { int n = f.size(); for (int i = n - 1; i - 1 >= 0; --i) f[i] = 1ll * f[i - 1] * inv[i] % mod; f[0] = 0; return f; } Poly poly_dev(Poly f) { int n = f.size(); for (int i = 1; i < n; ++i) f[i - 1] = 1ll * f[i] * i % mod; f.resize(n - 1); return f; } Poly poly_ln(Poly f, int deg) { Poly A = poly_idev(poly_mul(poly_dev(f), poly_inv(f, deg))); return A.resize(deg), A; } Poly poly_exp(Poly& f, int deg) { //cerr<<deg<<endl; if (deg == 1) return Poly(1, 1); Poly B = poly_exp(f, (deg + 1) >> 1); B.resize(deg); Poly lnB = poly_ln(B, deg); for (int i = 0; i < deg; ++i) lnB[i] = ck(f[i] - lnB[i] + mod); int limit = NTT_init(deg << 1); ntt(B, 1, limit), ntt(lnB, 1, limit); for (int i = 0; i < limit; ++i) B[i] = 1ll * B[i] * (1 + lnB[i]) % mod; ntt(B, 0, limit); B.resize(deg); return B; } Poly poly_pow(Poly& f, int k) { f = poly_ln(f, f.size()); for (auto& x : f) x = 1ll * x * k % mod; return poly_exp(f, f.size()); } Poly power(Poly f, int k1, int k2, int deg) { int s = 0; while (f[s] == 0 && s < f.size()) ++s; if (1ll * s * k1 >= deg) { return vector<int>(deg); } int Inv = ksm(f[s], mod - 2, mod); int Mul = ksm(f[s], k2); deg -= s; for (int i = 0; i < deg; ++i) f[i] = f[i + s]; f.resize(deg); for (int i = 0; i < deg; ++i) f[i] = 1ll * f[i] * Inv % mod; auto res1 = poly_ln(f, deg); for (int i = 0; i < res1.size(); ++i) res1[i] = 1ll * res1[i] * k1 % mod; auto res2 = poly_exp(res1, deg); for (int i = 0; i < deg; ++i) res2[i] = 1ll * res2[i] * Mul % mod; deg += s; int now = s * k1; Poly res; res.resize(deg); for (int i = deg - 1; i >= now; --i) res[i] = res2[i - now]; for (int i = now - 1; i >= 0; --i) res[i] = 0; return res; } Poly Poly_Sqrt(Poly& f, int deg) { if (deg == 1) return Poly(1, 1); Poly A(f.begin(), f.begin() + deg); Poly B = Poly_Sqrt(f, (deg + 1) >> 1); Poly IB = poly_inv(B, deg); int lim = NTT_init(deg << 1); ntt(A, 1, lim), ntt(IB, 1, lim); for (int i = 0; i < lim; ++i) { A[i] = 1ll * A[i] * IB[i] % mod; } ntt(A, 0, lim); for (int i = 0; i < deg; ++i) { A[i] = (1ll * A[i] + B[i]) % mod * inv[2] % mod; } A.resize(deg); return A; } Poly Sqrt(Poly& f, int deg) { const int Pow = ksm(2, mod - 2); int k1 = 1; if (f[0] != 1) { k1 = ksm(f[0], mod - 2); for (int i = 1; i < f.size(); ++i) { f[i] = 1ll * k1 * f[i] % mod; } ll x0, x1; assert(Math::Cipolla(f[0], x0, x1)); k1 = min(x1, x0); f[0] = 1; } auto Ln = poly_ln(f, deg); for (int i = 0; i < f.size(); ++i) { Ln[i] = 1ll * Ln[i] * Pow % mod; } auto Exp = poly_exp(Ln, deg); for (int i = 0; i < Exp.size(); ++i) Exp[i] = 1ll * Exp[i] * k1 % mod; return Exp; } Poly poly_sin(Poly& f, int deg) { Poly A(f.begin(), f.begin() + deg); Poly B(deg), C(deg); for (int i = 0; i < deg; ++i) { A[i] = 1ll * A[i] * img % mod; } B = poly_exp(A, deg); C = poly_inv(B, deg); const int inv2i = ksm(img << 1, mod - 2); for (int i = 0; i < deg; ++i) { A[i] = 1ll * (1ll * B[i] - C[i] + mod) % mod * inv2i % mod; } return A; } Poly poly_cos(Poly& f, int deg) { Poly A(f.begin(), f.begin() + deg); Poly B(deg), C(deg); for (int i = 0; i < deg; ++i) { A[i] = 1ll * A[i] * img % mod; } B = poly_exp(A, deg); C = poly_inv(B, deg); const int inv2 = ksm(2, mod - 2); for (int i = 0; i < deg; ++i) { A[i] = (1ll * B[i] + C[i]) % mod * inv2 % mod; } return A; } Poly poly_arcsin(Poly f, int deg) { Poly A(f.size()), B(f.size()), C(f.size()); A = poly_dev(f); B = poly_mul(f, f); for (int i = 0; i < deg; ++i) { B[i] = dec(mod, B[i]); } B[0] = add(B[0], 1); C = Poly_Sqrt(B, deg); C = poly_inv(C, deg); C = poly_mul(A, C); C = poly_idev(C); return C; } Poly poly_arctan(Poly f, int deg) { Poly A(f.size()), B(f.size()), C(f.size()); A = poly_dev(f); B = poly_mul(f, f); B[0] = add(B[0], 1); C = poly_inv(B, deg); C = poly_mul(A, C); C = poly_idev(C); C.resize(deg); return C; } }; // namespace NTT using NTT::Poly; using namespace IO; using mi = ModInt::modInt; int n, t; mi f[maxn], finv[maxn]; mi a,b,p; mi C(int n,int m){ if(n<0||m<0||n-m<0)return 0; return f[n]*finv[n-m]*finv[m]; } signed main() { f[0] = 1; for (int i = 1; i < maxn; ++i) f[i] = f[i - 1] * i; finv[maxn - 1] = f[maxn - 1] ^ (mod - 2); for (int i = maxn - 2; i >= 0; --i) { finv[i] = (i + 1) * finv[i + 1]; //assert((finv[i] * f[i]).val == 1); } NTT::init(NTT::mx - 1); read(n),read(a.val),read(b.val); p=a/b; Poly F(n+1),G(n+1); for(int i=0;i<=n;++i){ mi x=((mi)1-p+(mi(i)-1)/(n-1)*p); mi y=((mi)1-p+p*(mi(i))/(n-1)); x=x^(i),y=y^(n-i); G[i]=(x*y*finv[i]).val; } int now=1; for(int i=0;i<=n;++i,now=(-now+mod)%mod){ F[i]=(now*finv[i]).val; } F=NTT::poly_mul(F,G); for(int i=0;i<F.size();++i){ F[i]=int(F[i]*C(n,i)); } for(int i=n;i>=0;--i){ write(int(F[i]*f[i]));putchar('\n'); } return 0; }