FFT
推介这篇博客吧,写的真的挺好的.
算了说点自己的理解吧...
首先是解决两个多项式相乘把时间复杂度从优化到.学它可以解决一些用前缀不能解决的优化.
里面引进了复数,它是非常的妙的,无论是使得从多项式用系数表示转化成多项式用值来表示,还是用多项式用值来表示转化成多项式用系数表示.
假定我们知道两个多项式的系数表示,那么一定可以求出两个点在至少点的值,为什么要求这么多呢,因为其实可以用脑子想想就了,把多少次方全部看成未知数,那么元方程组由高斯消元或者其他东西都能知道至少需要个不同点的信息.
第一步就是用多项式的系数信息求个点信息,方便分治.用多项式的系数信息求个点信息.因为是乘法,所以它两直接相乘就得到了要求的多项式的点的信息,然后用一个转化把的点的信息当成系数,用单位根的倒数得到,原本的系数=求出来的点的信息/项数.
然后就可以用分治写出了,两个证明部分非常简单,建议看看博客,至此可以写出分治的了.
分治code:
#include <bits/stdc++.h>
using namespace std;
const int N=1e6+5;
const double pi=acos(-1);
struct cp{
double x,y;
cp(){x=y=0;}
cp(double xx,double yy){x=xx,y=yy;}
}f[N<<2],g[N<<2],ans[N<<2];
cp operator + (cp A,cp B){
return cp(A.x+B.x,A.y+B.y);
}
cp operator - (cp A,cp B){
return cp(A.x-B.x,A.y-B.y);
}
cp operator * (cp A,cp B){
return cp(A.x*B.x-A.y*B.y,A.x*B.y+A.y*B.x);
}
void fft(int n,cp *a,int op)//op等于1系数转点值 op等于-1点值转系数.
{
if(n<=1) return;
int mid=(n>>1);
cp a1[mid],a2[mid];
for(int i=0;i<mid;i++)
{
a1[i]=a[i<<1];
a2[i]=a[i<<1|1];
}
fft(mid,a1,op);
fft(mid,a2,op);
cp w1(cos(pi/mid),sin(pi/mid)*op),wt,w(1,0);
for(int i=0;i<mid;i++)
{
wt=w*a2[i];
a[i]=a1[i]+wt;
a[i+mid]=a1[i]-wt;
w=w*w1;
}
}
void run()
{
int n,m;
scanf("%d%d",&n,&m);
for(int i=0;i<=n;i++)
{
int x;
scanf("%d",&x);
f[i].x=x;
}
for(int i=0;i<=m;i++)
{
int x;
scanf("%d",&x);
g[i].x=x;
}
int k=1;
while(k<=n+m) k<<=1;
fft(k,f,1);
fft(k,g,1);
for(int i=0;i<k;i++)
ans[i]=f[i]*g[i];
fft(k,ans,-1);
for(int i=0;i<=n+m;i++)
printf("%.0f ",ans[i].x/k+0.5);
puts("");
}
int main()
{
int T=1;
// scanf("%d",&T);
while(T--) run();
return 0;
}
还有种实现也是分治不过是把分治改成了迭代,说实话个人觉得时间可能真差不多.但是快一点点吧,主要是优化了一些空间吧qwq. 实现也不难.
迭代code:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=4e6+5;
const int mod=998244353;
const double pi=acos(-1);
struct cp{
double x,y;
cp(){x=0,y=0;}
cp(double X,double Y){x=X,y=Y;};
}f[N],g[N],ans[N];
cp operator + (cp A,cp B)
{
return cp(A.x+B.x,A.y+B.y);
}
cp operator - (cp A,cp B)
{
return cp(A.x-B.x,A.y-B.y);
}
cp operator * (cp A,cp B)
{
return cp(A.x*B.x-A.y*B.y,A.x*B.y+A.y*B.x);
}
int b=0;
int rev[N];
void fft(int n,cp *a,int op)
{
for(int i=0;i<n;i++)
if(i<rev[i]) swap(a[i],a[rev[i]]);
for(int len=1;len<=n/2;len<<=1)//枚举分治半长度.
{
cp w1=cp(cos(pi/len),op*sin(pi/len));
for(int i=0;i<=n-(len<<1);i+=(len<<1))
{
cp w=cp(1,0);
for(int j=0;j<len;j++)
{
cp x=a[i+j];cp y=w*a[i+j+len];
a[i+j]=x+y;
a[i+j+len]=x-y;
w=w*w1;
}
}
}
}
int main()
{
int n,m;
scanf("%d%d",&n,&m);
for(int i=0;i<=n;i++) scanf("%lf",&f[i].x);
for(int i=0;i<=m;i++) scanf("%lf",&g[i].x);
int k=1;
while(k<=n+m) k<<=1,b++;
for(int i=1;i<=k;i++)
rev[i]=(rev[i>>1]>>1)+((i&1)<<(b-1));
fft(k,f,1);
fft(k,g,1);
for(int i=0;i<=k;i++)
ans[i]=f[i]*g[i];
fft(k,ans,-1);
for(int i=0;i<=(n+m);i++)
{
printf("%.0f ",ans[i].x/k+0.5);
}
puts("");
return 0;
}
补充一下vector的两个多项式的乘法的板子ntt.
using LL = long long;
#define FOR(i, x, y) for (decay<decltype(y)>::type i = (x), _##i = (y); i < _##i; ++i)
#define FORD(i, x, y) for (decay<decltype(x)>::type i = (x), _##i = (y); i > _##i; --i)
const LL MOD = 998244353;
const int G = 3;
LL bin(LL x, LL n, LL MOD) {
LL ret = MOD != 1;
for (x %= MOD; n; n >>= 1, x = x * x % MOD)
if (n & 1) ret = ret * x % MOD;
return ret;
}
inline LL get_inv(LL x, LL p) { return bin(x, p - 2, p); }
LL wn[(N * 10) << 2], rev[(N * 10) << 2];
int NTT_init(int n_) {
int step = 0; int n = 1;
for ( ; n < n_; n <<= 1) ++step;
FOR (i, 1, n)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (step - 1));
int g = bin(G, (MOD - 1) / n, MOD);
wn[0] = 1;
for (int i = 1; i <= n; ++i)
wn[i] = wn[i - 1] * g % MOD;
return n;
}
void NTT(vector<LL>& a, int n, int f) {
FOR (i, 0, n) if (i < rev[i])
std::swap(a[i], a[rev[i]]);
for (int k = 1; k < n; k <<= 1) {
for (int i = 0; i < n; i += (k << 1)) {
int t = n / (k << 1);
FOR (j, 0, k) {
LL w = f == 1 ? wn[t * j] : wn[n - t * j];
LL x = a[i + j];
LL y = a[i + j + k] * w % MOD;
a[i + j] = (x + y) % MOD;
a[i + j + k] = (x - y + MOD) % MOD;
}
}
}
if (f == -1) {
LL ninv = get_inv(n, MOD);
FOR (i, 0, n)
a[i] = a[i] * ninv % MOD;
}
}
vector<LL> operator+(vector<LL> a, const vector<LL>& b){
a.resize(max(a.size(), b.size()));
for(int i = 0; i < b.size(); ++ i)
a[i] = (a[i] + b[i]) % MOD;
return a;
}
vector<LL> conv(vector<LL> a, vector<LL> b) {
int len = a.size() + b.size() - 1;
int n = NTT_init(len);
a.resize(n);
b.resize(n);
NTT(a, n, 1);
NTT(b, n, 1);
FOR (i, 0, n)
a[i] = a[i] * b[i] % MOD;
NTT(a, n, -1);
a.resize(len);
return a;
}