牛客国庆集训派对Day4 E 乒乓球【公式+ntt】
题目:小 Bo 是某省乒乓球名列前茅的选手,现在他有 n 颗乒乓球一字排开,第 i 颗乒乓球的权值为 wi
每次他会随机从现有的乒乓球中等概率选一颗拿走,然后得到的收益是这颗球左边第一个乒乓球和右边第一个乒乓球的权值的乘积,如果左边没有乒乓球或者右边没有乒乓球,则收益为 0,这个过程会重复进行到所有球都被拿走为止
现在小 Bo 想知道他的期望总收益
为了方便,你只需要输出答案对 998244353 取模的值
分析:
这个题,我们考虑对于间隔为 d的 d+1个数,wi,wi+1,…,wi+d ,对于 wi*wi+d概率应该为
(d−1)!∗2/(d+1)! 也就是说,最旁边的两个最后取的情况
所以对于给定的d 我们只要算 wi∗wj (j=i+d)的和就可以;
很显然就是用fft
由于这里要取模,所以用了ntt
#include <bits/stdc++.h>
#define cl(a) memset(a,0,sizeof(a))
#define ll long long
#define pb(i) push_back(i)
#define mp make_pair
using namespace std;
const int maxn=1e6+50;
const int inf=0x3f3f3f3f;
const int mod=998244353;
const double pi = acos(-1.0);
typedef pair<int,int> PII;
ll fpow(ll n, ll k, ll p = mod) {ll r = 1; for (; k; k >>= 1) {if (k & 1) r = r * n%p; n = n * n%p;} return r;}
ll pp(ll x, ll y){
ll ret = 1;
while(y){
if(y & 1) ret = ret * x % mod;
x = x * x % mod;
y >>= 1;
}
return ret;
}
int rev[maxn << 2], inv[maxn];
void init(int p){
for(int i = 0; i < (1 << p); ++i){
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (p - 1));
}
}
void fft(ll *a, int n, int op)
{
for(int i = 0; i < n; ++i) if(i < rev[i]) swap(a[i], a[rev[i]]);
ll now, x, y, ome;
for(int i = 1; i < n; i <<= 1){
ome = pp(3, (mod - 1) / (i << 1));
if(op == -1) ome = pp(ome, mod - 2);
for(int j = 0; j < n; j += i << 1){
now = 1;
for(int k = j; k < j + i; ++k){
x = a[k], y = now * a[k + i] % mod;
a[k] = x + y, a[k + i] = x - y;
if(a[k] >= mod) a[k] -= mod;
if(a[k + i] < 0) a[k + i] += mod;
now = now * ome % mod;
}
}
}
if(op == -1){
now = pp(n, mod - 2);
for(int i = 0; i < n; ++i) a[i] = a[i] * now % mod;
}
}
ll a[maxn],b[maxn];
vector<int>vc;
void solve(int p)
{
init(p);
fft(a, (1 << p), 1), fft(b, (1 << p), 1);
for(int i = 0; i < (1 << p); ++i) a[i] = a[i] * b[i] % mod;
fft(a, (1 << p), -1);
}
int main()
{
inv[1]=1;
for(int i = 2; i < maxn; i ++)
inv[i] = (mod - mod / i) * 1ll * inv[mod % i] % mod;
int n;
scanf("%d",&n);
for(int i=0;i<n;i++)
{
scanf("%lld",&a[i]);
b[n-1-i]=a[i];
}
int p=1;
while((1<<p)<=n)p++;
p++;
solve(p);
ll sum=0;
for(int i=1;i<n-1;i++)
{
sum+= 2ll * inv[i + 1] % mod * inv[i + 2] % mod * a[n + i] % mod;
if(sum>=mod) sum-=mod;
}
printf("%lld\n",sum);
return 0;
}