快速傅里叶变换和快速数论变换FFT&NTT

快速傅里叶变换(FFT)

作用:加速多项式乘法

朴素高精度乘法时间O(n^2),但FFT能O(nlog2n)的时间解决

前置知识:

1.点值表示法:

f(x)={( x0,f(x0) ),( x1,f(x1) ) ,( x2, f(x2) ), ( x3, f(x3) ), ( x4, f(x4) ), ... , (xn-1, f(xn-1) )}

g(x)={( x0,g(x0) ),( x1,g(x1) ) ,( x2, g(x2) ), ( x3, g(x3) ), ( x4, g(x4) ), ... , (xn-1, g(xn-1) )}

设它们乘积为h(x),那么

h(x)={( x0,f(x0)g(x0) ),( x1, f(x1)g(x1) ), ( x2, f(x2)g(x2) ), ... , ( xn-1, f(xn-1)g(xn-1) )}

2.复数

(a1,θ1) *(a2,θ2)为(a1a2,θ1+θ2)

快速傅里叶变换的实现:

const double PI = acos(-1);
typedef complex <double> cp;

cp omega(int n, int k) {
    return cp(cos(2 * PI * k / n), sin(2 * PI * k / n));
}

void fft(cp *a, int n, bool inv) {
    if(n == 1) return ;
    static cp buf[N];
    int m = n / 2;
    for(int i = 0; i < m ; i++){        
        buf[i] = a[2 * i];
        buf[i + m] = a[2 * i + 1];
    }
    for(int i = 0; i < n; i++)
        a[i] = buf[i];
    fft(a, m, inv);
    fft(a + m, m, inv);
    for(int i = 0; i < m; i++) {
        cp x = omega(n, i);
        if(inv) x = conj(x);
        //conj是一个自带的求共轭复数的函数,精度较高。当复数模为1时,共轭复数等于倒数
        buf[i] = a[i] + x * a[i + m];
        buf[i + m] = a[i] - x * a[i + m];
    }
    for(int i = 0; i < n; i++)
        a[i] = buf[i];
}

注1:a中i属于0-m-1是A1(i)的取值,a中i属于m到n-1是A2(i)的取值,那么递归的分析,对现在A1中的第i位置,其是由0-m/2-1的得来的,然后最后再更新a数组,a(i)就是A(omega(n, i))的取值了。

注2:n是偶数,多项式的项数。

注3:inv表示单位根是否要取倒数,FFT的逆变换即点值表示法转化为系数表示法。

做法:把点值表示法作为系数,用取了倒数的单位根代入求个点值表示法,得到Zi再除以n就是i的系数ai(证明参考:https://www.cnblogs.com/RabbitHu/p/FFT.html)以下提到"博客"均指此篇博客。

注4:在逆FFT的时候,应该用floor(a[i].real() / n + 0.5);来得到res[i](精度问题)

优化fft(非递归)

发现每次都往下其实都是先去递归,使得所有元素到达其应该在的地方,然后再不断往上递归对a赋值的。

规律:参考博客,可以发现,一个位置a上的数,最后所在的位置是"a二进制翻转得到的数"

据此写出非递归版本fft:先把每个数放到最后的位置上,然后不断向上,从而求出最终答案。

#include<iostream>
#include<cstdio>
#include<complex>
using namespace std;
const double PI = acos(-1);
typedef complex <double> cp;

cp a[N], b[N], omg[N], inv[N];

void init() {
    for(int i = 0; i < n; i++) {
        omg[i] = cp(cos(2 * PI * i / n), sin(2 * PI * i / n));
        inv[i] = conj(omg[i]);
    }
}

void fft(cp *a, cp *omg) {
    int lim = 0;
    while((1 << lim) < n) lim++;
    for(int i = 0; i < n; i++) {
        int t = 0;
        for(int j = 0; j < lim; j++)
            if((i >> j) & 1) t |= (1 << (lim - j - 1));
        if(i < t) swap(a[i], a[t]);    //i < t的限制使得每对点只被交换一次(否则交换两次相当于没交换) 
    }
    static cp buf[N];
    for(int l = 2; l <= n; l *= 2) {    //区间长度 
        int m = l / 2;
        for(int j = 0; j < n; j += l)    //区间起点 
            for(int i = 0; i < m; i++) {
                buf[j + i] = a[j + i] + omg[n / l * i] * a[j + i + m];
                buf[j + i + m] = a[j + i] - omg[n / l * i] * a[j + i + m];
            }
        for(int j = 0; j < n; j++)
            a[j] = buf[j];
    }
}

蝴蝶变换:

之前为什么需要buf数组:

如果

a[j + i] = a[j + i] + omg[n / l * i] * a[j + i + m]

a[j + i + m] = a[j + i] - omg[n / l * i] * a[j + i + m]

会对更新a[j + i + m]造成影响。

而通过蝴蝶变换

cp t = omg[n / l * i] * a[j + i + m]

a[j + i + m] = a[j + i] - t

a[j + i] = a[j + i] + t

不就顺序换了一下??

反正就不用buf数组就是了。

FFT最终模板:

const double PI = acos(-1);
typedef complex <double> cp;

cp a[N], b[N], omg[N], inv[N];

void init() {
    for(int i = 0; i < n; i++) {
        omg[i] = cp(cos(2 * PI * i / n), sin(2 * PI * i / n));
        inv[i] = conj(omg[i]);
    }
}

void fft(cp *a, cp *omg) {
    int lim = 0;
    while((1 << lim) < n) lim++;
    for(int i = 0; i < n; i++) {
        int t = 0;
        for(int j = 0; j < lim; j++)
            if((i >> j) & 1) t |= (1 << (lim - j - 1));
        if(i < t) swap(a[i], a[t]);    //i < t的限制使得每对点只被交换一次(否则交换两次相当于没交换) 
    }
    for(int l = 2; l <= n; l *= 2) {    //区间长度 
        int m = l / 2;
        for(cp *p = a; p != a + n; p += l)
            for(int i = 0; i < m; i++) {
                cp t = omg[n / l * i] * p[i + m];
                p[i + m] = p[i] - t;
                p[i] += t;
            }
    }
}

题1:a*bIII http://www.acmicpc.sdnu.edu.cn/problem/show/1531

注意三点:

1.0与其他数相乘只输出一个0.

2.memset是可以对复数初始化的

3.FFT中的n需要是2的倍数

#include<bits/stdc++.h>
using namespace std;
#define int long long
const int N = 1000005;
const double PI = acos(-1);
typedef complex <double> cp;

char sa[N], sb[N];
int n = 1, lena, lenb, ans[N];
cp a[N], b[N], omg[N], inv[N];

void init(){
    for(int i = 0; i < n; i++){
        omg[i] = cp(cos(2 * PI * i / n), sin(2 * PI * i / n));
        inv[i] = conj(omg[i]);
    }
}

void fft(cp *a, cp *omg) {
    int lim = 0;
    while((1 << lim) < n) lim++;
    for(int i = 0; i < n; i++) {
        int t = 0;
        for(int j = 0; j < lim; j++)
            if((i >> j) & 1) t |= (1 << (lim - j - 1));
        if(i < t) swap(a[i], a[t]);    //i < t的限制使得每对点只被交换一次(否则交换两次相当于没交换) 
    }
    for(int l = 2; l <= n; l *= 2) {    //区间长度 
        int m = l / 2;
        for(cp *p = a; p != a + n; p += l)
            for(int i = 0; i < m; i++) {
                cp t = omg[n / l * i] * p[i + m];
                p[i + m] = p[i] - t;
                p[i] += t;
            }
    }
}


signed main()
{
    while(~scanf("%s%s",sa,sb)) {
        memset(ans, 0, sizeof(ans));
        memset(a, 0, sizeof(a));
        memset(b, 0, sizeof(b));
        n = 1;
        lena = strlen(sa), lenb = strlen(sb);
        if(lena == 1  && sa[0] == '0' || lenb == 1  && sb[0] == '0') {
            printf("0\n");
            continue;
        }
        while(n < lena + lenb) n *= 2;
        for(int i = 0; i < lena; i++){
            a[i].real(sa[lena - 1 - i] - '0');
        }
        for(int i = 0; i < lenb; i++){
            b[i].real(sb[lenb - 1 - i] - '0');
        }
        init();
        fft(a, omg);
        fft(b, omg);
        for(int i = 0; i < n; i++) {
            a[i] *= b[i];
        }
        fft(a, inv);
        for(int i = 0; i < n; i++) {
            ans[i] += floor(a[i].real() / n + 0.5);
            ans[i + 1] += ans[i] / 10;
            ans[i] %= 10; 
        }
        int beg;
        for(int i = n-1; i >= 0; i--) {
            if(ans[i] != 0){
                beg = i;
                break;
            }
        }
        for(int i = beg; i >= 0; i--){
            printf("%lld",ans[i]);
        }
        putchar('\n');
    }
    return 0;
}


FFT的缺点是它的复数运算double精度问题导致它实际上是k*nlongn的,会比NTT的常数大很多。

NTT

前置知识:

原根:对于g,p属于Z, 如果g^i mod p ( 1<=i<=p-1)的值互不相同,则称g为p的原根。

或者说对于任意i,j(1<=i<j <= p-1) g^i mod p /= g^j mod p,那么g为p的原根。

常见模数有:998244353,1004535809,469762049,这几个的原根都为3

在NTT中,我们拿原根来代替FFT的单位根

#define g 3
const int mod = 998244353;
const int N = 300000;

inline get_rev()
{
    int lim = 0;
    while((1 << lim) < n) lim++;
    for(int i = 0; i < n; i++) {
        rev[i] = (rev[i >> 1] >> 1 | ((i & 1) << (lim - 1)));
    }
}

inline void ntt(int *a, int inv) {
    for(int i = 0; i < n; i++) {
        if(i < rev[i]) swap(a[i], a[rev[i]]);    //i < t的限制使得每对点只被交换一次(否则交换两次相当于没交换) 
    }
    for(int l = 2; l <= n; l *= 2) {    //区间长度 
        int m = l / 2;
        int tmp = q_pow(g, (mod-1)/l);
        if(inv == -1) tmp = q_pow(tmp, mod-2);
        for(int i = 0; i < n; i += l) {
            int omega = 1;
            for(int j = 0; j < m; j++, omega = omega*tmp%mod) {
                int x = a[i + j], y = omega * a[i + j + m] % mod;
                a[i + j] = (x + y) % mod, a[i + j + m] = (x - y + mod) % mod;
            }
        }
    }
    if(inv == -1)
    {
        int nI = q_pow(n, mod-2);
        for(int i = 0; i < n; i++) {
            a[i] = a[i] * nI % mod;
        }
    }
} 

例1:http://www.acmicpc.sdnu.edu.cn/problem/show/1532

#include<bits/stdc++.h>

using namespace std;
#define int long long
#define g 3
const int mod = 998244353;
const int N = 300000;

inline int read(){
    int x=0,f=1;char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
    while(isdigit(ch)){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
    return x*f;
}
inline void write(int x)
{
    if(x<0)x=-x,putchar('-');
    if(x>9)write(x/10);
    putchar(x%10+'0');
}

char sa[N], sb[N];
int n = 1, a[N], b[N], rev[N], lena, lenb;

inline int q_pow(int a, int b){
    int ans = 1;
    while(b > 0){
        if(b & 1){
            ans = ans * a % mod;
        }
        a = a * a % mod;
        b >>= 1; 
    } 
    return ans;
}

inline get_rev()
{
    int lim = 0;
    while((1 << lim) < n) lim++;
    for(int i = 0; i < n; i++) {
        rev[i] = (rev[i >> 1] >> 1 | ((i & 1) << (lim - 1)));
    }
}

inline void ntt(int *a, int inv) {
    for(int i = 0; i < n; i++) {
        if(i < rev[i]) swap(a[i], a[rev[i]]);    //i < t的限制使得每对点只被交换一次(否则交换两次相当于没交换) 
    }
    for(int l = 2; l <= n; l *= 2) {    //区间长度 
        int m = l / 2;
        int tmp = q_pow(g, (mod-1)/l);
        if(inv == -1) tmp = q_pow(tmp, mod-2);
        for(int i = 0; i < n; i += l) {
            int omega = 1;
            for(int j = 0; j < m; j++, omega = omega*tmp%mod) {
                int x = a[i + j], y = omega * a[i + j + m] % mod;
                a[i + j] = (x + y) % mod, a[i + j + m] = (x - y + mod) % mod;
            }
        }
    }
    if(inv == -1)
    {
        int nI = q_pow(n, mod-2);
        for(int i = 0; i < n; i++) {
            a[i] = a[i] * nI % mod;
        }
    }
} 



signed main()
{
    while(~scanf("%s%s",sa,sb))
    {
        n = 1;
        memset(a, 0, sizeof(a));
        memset(b, 0, sizeof(b));

        lena = strlen(sa), lenb = strlen(sb);
        for(int i = 0; i < lena; i++) {
            a[i] = sa[lena - 1 - i] - '0';
        }
        for(int i = 0; i < lenb; i++) {
            b[i] = sb[lenb - 1 - i] - '0';
        }
        while(n < lena + lenb) n <<= 1;
        get_rev();

        ntt(a, 1);
        ntt(b, 1);
        for(int i = 0; i < n; i++) {
            a[i] = a[i] * b[i] % mod;
        }
        ntt(a, -1);

        for(int i = 0; i < n; i++) {
            a[i + 1] += a[i] / 10;
            a[i] %= 10;
        }
        int cnt = n;
        while(cnt >= 0 && a[cnt] == 0) cnt--;
        if(cnt == -1) {
            printf("0");
        }
        else {
            for(int i = cnt; i >= 0; i--){
                write(a[i]);
            }
        }
        putchar('\n');
    }
    return 0;
}
全部评论

相关推荐

Java抽象带篮子:难蚌,点进图片上面就是我的大头😆
点赞 评论 收藏
分享
3 收藏 评论
分享
牛客网
牛客企业服务