题解 | #Lndjy and the mex#
Lndjy and the mex
https://ac.nowcoder.com/acm/contest/38727/L
Lndjy and the mex
题意
- 给定多重集 ,满足元素是 内的整数,且 。
- 一个序列的权值定义为所有区间的 之和。
- 计算所有长度为 且元素与 完全相同的序列的权值之和。
- 取模 998244353。
- 。
思路
-
对于一个长度为 的区间而言,它的 小于 ,当且仅当 到 中的某些数不在这个区间内出现。
-
于是考虑容斥。对于 的一个大小为 的子集 而言,假设不在子集中的元素的阶乘之积为 ,那么它不出现在区间内的情况共有 种,而其容斥系数为 。
-
展开上述式子,发现原式即为 。其中的 与 均为常量,这意味着子集的贡献只与区间长度 ,子集大小 的奇偶性以及子集元素个数之和有关。
-
首先解决长度固定时的问题。假设当前枚举的 为 ,那么为 到 中的元素 构造多项式 ,则结果中元素个数之和为 的子集的贡献即为 中 次项的系数。
-
假设数列中可能出现的最大 为 ,考虑对于从 到 的每一个 mex 计算结果,则最终的多项式为 。
-
考虑用分治NTT解决这个问题,假设左半侧得到的多项式与前缀和分别为 ,右半侧得到的多项式与前缀和分别为 ,则合并结果为 。
-
现在考虑不同长度时的情况。对于长度为 的区间而言,其贡献为常量 乘上。发现需要计算的式子仍然是卷积形式,因此将分子与分母再用NTT卷一遍即可。
-
最终答案即为区间个数乘 再减去每个 产生的贡献。
Code
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
#define M 998244353
ll jc[1005000],inv[1005000];
ll ksm(ll a,int p){ll res=1;while(p){if(p&1){res=res*a%M;}a=a*a%M;p>>=1;}return res;}
ll su(ll a,ll b){a+=b;return (a>=M)?a-M:a;}
const int G = 3;
const int Gi = 332748118;
int r[1000500], lim;
void ntt (vector <ll> &A, int type) {
for (int i = 0;i <= lim - 1;++i) {
if (i < r[i]) swap(A[i], A[r[i]]);
}
for (int mid = 1;mid <= lim - 1;mid <<= 1) {
ll Wn = ksm(type == 1 ? G : Gi, (M - 1) / (mid << 1));
for (int j = 0;j <= lim - 1;j += (mid << 1)) {
ll w = 1;
for (int k = 0;k <= mid - 1;++k, w = (w * Wn) % M) {
int x = A[j + k];
int y = w * A[j + mid + k] % M;
A[j + k] = su(x,y);
A[j + mid + k] = su(x,M-y);
}
}
}
if (type == -1) {
ll tmp = ksm(lim,M-2);
for (int i = 0;i <= lim - 1;++i) {
A[i] = A[i] * tmp % M;
}
}
}
vector <ll> operator * (vector <ll> A, vector <ll> B) {
int len = A.size() + B.size() - 1;
lim = 1;
while (lim <= len) lim <<= 1;
for (int i = 0;i <= lim - 1;++i) {
r[i] = (r[i >> 1] >> 1) | ((i & 1) * (lim >> 1));
}
A.resize(lim);
B.resize(lim);
ntt(A, 1);
ntt(B, 1);
for (int i = 0;i <= lim - 1;++i) {
A[i] = A[i] * B[i] % M;
}
ntt(A, -1);
A.resize(len);
return A;
}
int i,j,k,n,m,t,a[100500],li;
ll res,res2,res3;
pair< vector<ll>,vector<ll> > fuk(int l,int r){
if(l==r){
vector<ll> v(a[l]+1);
v[0]=1;
v[a[l]]=M-1;
return {v,v};
}
int i,j,k,md=(l+r)/2;
auto [f1,g1]=fuk(l,md);
auto [f2,g2]=fuk(md+1,r);
g2=f1*g2;
k=max(g1.size(),g2.size());
g1.resize(k);g2.resize(k);
for(i=0;i<k;i++){
g1[i]=su(g1[i],g2[i]);
}
f1=f1*f2;
return {f1,g1};
}
int main() {
ios::sync_with_stdio(0);
jc[0]=inv[0]=1;
for(i=1;i<=1000000;i++){jc[i]=jc[i-1]*i%M;}
inv[1000000]=ksm(jc[1000000],M-2);
for(i=999999;i>=1;i--){inv[i]=inv[i+1]*(i+1)%M;}
cin>>n;
res=ksm(2,M-2)*n%M*(n+1)%M*jc[n]%M;
res2=1;
for(i=1;i<=n+1;i++){
cin>>a[i];
res=res*inv[a[i]]%M;
res2=res2*inv[a[i]]%M;
}
if(!a[1]){
cout<<0;return 0;
}
for(i=1;;i++){
if(!a[i]){
li=i-1;break;
}
}
res=res*li%M;
auto [f,g]=fuk(1,li);
g[0]=0;g.resize(n+1);
for(i=1;i<=n;i++){
g[i]=g[i]*jc[n-i]%M;
}
f=vector<ll>(n+1,0);
for(i=0;i<=n;i++){
f[i]=inv[i];
}
f=f*g;
for(i=1;i<n;i++){
res3=su(res3,jc[i]*f[i]%M*(i+1)%M);
}
res2=res2*res3%M;
res=su(res,res2);
cout<<res;
}