题解 | #Easy Counting Problem#(组合计数,NTT)

Easy Counting Problem

https://ac.nowcoder.com/acm/contest/33189/C

#include<iostream>
#include<algorithm>
#include<string.h>
#include<vector>
#include<cassert>
using namespace std;
// #define debug(x) cout<<"[debug]"#x<<"="<<x<<endl
typedef long long ll;
typedef long double ld;
typedef pair<int,int> pii;
const double eps=1e-8;
const int INF=0x3f3f3f3f;
#ifndef ONLINE_JUDGE
#define debug(...)
#include<debug>
#else
#define debug(...)
#endif

const int N=100005,M=11;
const int mod=998244353,G=3,Gi=332748118;//G是原根,Gi是原根的逆元
int bit,tot;
int rev[N];
int c[M],sigma;
int fact[10000007],infact[10000007];
vector<ll> f[M][M];
ll qmi(ll a,ll b,ll p)
{
    ll res=1;
    while(b)
    {
        if(b&1)
        res=res*a%p;

        a=a*a%p;
        b>>=1;
    }
    return res;
}
void NTT(vector<ll> &a,int inv)
{
    a.resize(tot);
    for(int i=0;i<tot;i++)
        if(i<rev[i]) swap(a[i],a[rev[i]]);

    for(int mid=1;mid<tot;mid*=2)
    {
        ll w1=qmi(inv==1?G:Gi,(mod-1)/(2*mid),mod);
        for(int i=0;i<tot;i+=mid*2)
        {
            ll wk=1;
            for(int j=0;j<mid;j++,wk=wk*w1%mod)
            {
                ll x=a[i+j];
                ll y=wk*a[i+j+mid]%mod;
                a[i+j]=(x+y)%mod;
                a[i+j+mid]=(x-y+mod)%mod;
            }
        }
    }

    if(inv==-1)//就不用后面除了
    {
        int intot=qmi(tot,mod-2,mod);
        for(int i=0;i<tot;i++)
        {
            a[i]=a[i]*intot%mod;
        }
    }
}
vector<ll> mul(vector<ll> a,vector<ll> b,int deg)//deg是系数的数量,所以有0~deg-1次项
{
    bit=0;
    while((1<<bit)<deg) bit++;//至少要系数的数量
    tot=1<<bit; //系数项为0~tot-1
    for(int i=0;i<tot;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
    
    vector<ll> c(tot);
    NTT(a,1),NTT(b,1);
    for(int i=0;i<tot;i++) c[i]=a[i]*b[i]%mod;
    NTT(c,-1);
    c.resize(deg);
    return c;
}
void init(int n)
{
    fact[0]=1;
    for(int i=1;i<=n;i++)
    {
        fact[i]=1ll*fact[i-1]*i%mod;
    }

    infact[n]=qmi(fact[n],mod-2,mod);
    for(int i=n-1;i>=0;i--)
    {
        infact[i]=1ll*infact[i+1]*(i+1)%mod;
    }
}
int main()
{
    init(1e7+2);
    int w;
    scanf("%d",&w);
    for(int i=1;i<=w;i++) scanf("%d",&c[i]),sigma+=c[i];
    for(int i=0;i<=w;i++)
        for(int j=0;j<=w;j++)
            f[i][j].resize(N);
    
    f[0][0][0]=1;
    for(int i=1;i<=w;i++)
    {
        for(int j=0;j<=i;j++)
        {
            //f[i][j]=f[i-1][j]+f[i-1][j-1]*Ai
            f[i][j]=f[i-1][j];
            if(j)
            {
                vector<ll> A;
                for(int k=0;k<c[i];k++) A.push_back(infact[k]);
                vector<ll> C=mul(f[i-1][j-1],A,min(sigma,(int)(f[i-1][j-1].size()+A.size()-1)));
                for(int k=0;k<sigma;k++) f[i][j][k]=(f[i][j][k]+C[k])%mod;
            }
        }
    }

    int Q;
    scanf("%d",&Q);
    while(Q--)
    {
        int n;
        scanf("%d",&n);
        if(n<sigma)
        {
            puts("0");
            continue;
        }

        ll res=0;
        for(int i=0;i<=w;i++)//f里面选了多少
        {
            for(int k=0,wk=qmi(w-i,n,mod),inv=qmi(w-i,mod-2,mod);k<=sigma;wk=1ll*wk*inv%mod,k++)//f里面选了多少x次幂
            {
                //f[w][i][k]
                ll ans=((i%2)?-1:1)*f[w][i][k]*wk%mod*infact[n-k]%mod;
                if(ans<0) ans=(ans+mod)%mod;
                res=(res+ans)%mod;
            }
        }
        res=(res*fact[n])%mod;
        printf("%lld\n",res);
    }
}
(学习了懵哥的做法补的,它的题解公式写的好完整qwq太强啦!)

全部评论

相关推荐

11-04 14:10
东南大学 Java
_可乐多加冰_:去市公司包卖卡的
点赞 评论 收藏
分享
感性的干饭人在线蹲牛友:🐮 应该是在嘉定这边叭,禾赛大楼挺好看的
点赞 评论 收藏
分享
2 收藏 评论
分享
牛客网
牛客企业服务