题解 | #Easy Counting Problem#(组合计数,NTT)
Easy Counting Problem
https://ac.nowcoder.com/acm/contest/33189/C
#include<iostream> #include<algorithm> #include<string.h> #include<vector> #include<cassert> using namespace std; // #define debug(x) cout<<"[debug]"#x<<"="<<x<<endl typedef long long ll; typedef long double ld; typedef pair<int,int> pii; const double eps=1e-8; const int INF=0x3f3f3f3f; #ifndef ONLINE_JUDGE #define debug(...) #include<debug> #else #define debug(...) #endif const int N=100005,M=11; const int mod=998244353,G=3,Gi=332748118;//G是原根,Gi是原根的逆元 int bit,tot; int rev[N]; int c[M],sigma; int fact[10000007],infact[10000007]; vector<ll> f[M][M]; ll qmi(ll a,ll b,ll p) { ll res=1; while(b) { if(b&1) res=res*a%p; a=a*a%p; b>>=1; } return res; } void NTT(vector<ll> &a,int inv) { a.resize(tot); for(int i=0;i<tot;i++) if(i<rev[i]) swap(a[i],a[rev[i]]); for(int mid=1;mid<tot;mid*=2) { ll w1=qmi(inv==1?G:Gi,(mod-1)/(2*mid),mod); for(int i=0;i<tot;i+=mid*2) { ll wk=1; for(int j=0;j<mid;j++,wk=wk*w1%mod) { ll x=a[i+j]; ll y=wk*a[i+j+mid]%mod; a[i+j]=(x+y)%mod; a[i+j+mid]=(x-y+mod)%mod; } } } if(inv==-1)//就不用后面除了 { int intot=qmi(tot,mod-2,mod); for(int i=0;i<tot;i++) { a[i]=a[i]*intot%mod; } } } vector<ll> mul(vector<ll> a,vector<ll> b,int deg)//deg是系数的数量,所以有0~deg-1次项 { bit=0; while((1<<bit)<deg) bit++;//至少要系数的数量 tot=1<<bit; //系数项为0~tot-1 for(int i=0;i<tot;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1)); vector<ll> c(tot); NTT(a,1),NTT(b,1); for(int i=0;i<tot;i++) c[i]=a[i]*b[i]%mod; NTT(c,-1); c.resize(deg); return c; } void init(int n) { fact[0]=1; for(int i=1;i<=n;i++) { fact[i]=1ll*fact[i-1]*i%mod; } infact[n]=qmi(fact[n],mod-2,mod); for(int i=n-1;i>=0;i--) { infact[i]=1ll*infact[i+1]*(i+1)%mod; } } int main() { init(1e7+2); int w; scanf("%d",&w); for(int i=1;i<=w;i++) scanf("%d",&c[i]),sigma+=c[i]; for(int i=0;i<=w;i++) for(int j=0;j<=w;j++) f[i][j].resize(N); f[0][0][0]=1; for(int i=1;i<=w;i++) { for(int j=0;j<=i;j++) { //f[i][j]=f[i-1][j]+f[i-1][j-1]*Ai f[i][j]=f[i-1][j]; if(j) { vector<ll> A; for(int k=0;k<c[i];k++) A.push_back(infact[k]); vector<ll> C=mul(f[i-1][j-1],A,min(sigma,(int)(f[i-1][j-1].size()+A.size()-1))); for(int k=0;k<sigma;k++) f[i][j][k]=(f[i][j][k]+C[k])%mod; } } } int Q; scanf("%d",&Q); while(Q--) { int n; scanf("%d",&n); if(n<sigma) { puts("0"); continue; } ll res=0; for(int i=0;i<=w;i++)//f里面选了多少 { for(int k=0,wk=qmi(w-i,n,mod),inv=qmi(w-i,mod-2,mod);k<=sigma;wk=1ll*wk*inv%mod,k++)//f里面选了多少x次幂 { //f[w][i][k] ll ans=((i%2)?-1:1)*f[w][i][k]*wk%mod*infact[n-k]%mod; if(ans<0) ans=(ans+mod)%mod; res=(res+ans)%mod; } } res=(res*fact[n])%mod; printf("%lld\n",res); } }(学习了懵哥的做法补的,它的题解公式写的好完整qwq太强啦!)