快速傅里叶变换和快速数论变换FFT&NTT
快速傅里叶变换(FFT)
作用:加速多项式乘法
朴素高精度乘法时间O(n^2),但FFT能O(nlog2n)的时间解决
前置知识:
1.点值表示法:
f(x)={( x0,f(x0) ),( x1,f(x1) ) ,( x2, f(x2) ), ( x3, f(x3) ), ( x4, f(x4) ), ... , (xn-1, f(xn-1) )}
g(x)={( x0,g(x0) ),( x1,g(x1) ) ,( x2, g(x2) ), ( x3, g(x3) ), ( x4, g(x4) ), ... , (xn-1, g(xn-1) )}
设它们乘积为h(x),那么
h(x)={( x0,f(x0)g(x0) ),( x1, f(x1)g(x1) ), ( x2, f(x2)g(x2) ), ... , ( xn-1, f(xn-1)g(xn-1) )}
2.复数
(a1,θ1) *(a2,θ2)为(a1a2,θ1+θ2)
快速傅里叶变换的实现:
const double PI = acos(-1); typedef complex <double> cp; cp omega(int n, int k) { return cp(cos(2 * PI * k / n), sin(2 * PI * k / n)); } void fft(cp *a, int n, bool inv) { if(n == 1) return ; static cp buf[N]; int m = n / 2; for(int i = 0; i < m ; i++){ buf[i] = a[2 * i]; buf[i + m] = a[2 * i + 1]; } for(int i = 0; i < n; i++) a[i] = buf[i]; fft(a, m, inv); fft(a + m, m, inv); for(int i = 0; i < m; i++) { cp x = omega(n, i); if(inv) x = conj(x); //conj是一个自带的求共轭复数的函数,精度较高。当复数模为1时,共轭复数等于倒数 buf[i] = a[i] + x * a[i + m]; buf[i + m] = a[i] - x * a[i + m]; } for(int i = 0; i < n; i++) a[i] = buf[i]; }
注1:a中i属于0-m-1是A1(i)的取值,a中i属于m到n-1是A2(i)的取值,那么递归的分析,对现在A1中的第i位置,其是由0-m/2-1的得来的,然后最后再更新a数组,a(i)就是A(omega(n, i))的取值了。
注2:n是偶数,多项式的项数。
注3:inv表示单位根是否要取倒数,FFT的逆变换即点值表示法转化为系数表示法。
做法:把点值表示法作为系数,用取了倒数的单位根代入求个点值表示法,得到Zi再除以n就是i的系数ai(证明参考:https://www.cnblogs.com/RabbitHu/p/FFT.html)以下提到"博客"均指此篇博客。
注4:在逆FFT的时候,应该用floor(a[i].real() / n + 0.5);来得到res[i](精度问题)
优化fft(非递归)
发现每次都往下其实都是先去递归,使得所有元素到达其应该在的地方,然后再不断往上递归对a赋值的。
规律:参考博客,可以发现,一个位置a上的数,最后所在的位置是"a二进制翻转得到的数"
据此写出非递归版本fft:先把每个数放到最后的位置上,然后不断向上,从而求出最终答案。
#include<iostream> #include<cstdio> #include<complex> using namespace std; const double PI = acos(-1); typedef complex <double> cp; cp a[N], b[N], omg[N], inv[N]; void init() { for(int i = 0; i < n; i++) { omg[i] = cp(cos(2 * PI * i / n), sin(2 * PI * i / n)); inv[i] = conj(omg[i]); } } void fft(cp *a, cp *omg) { int lim = 0; while((1 << lim) < n) lim++; for(int i = 0; i < n; i++) { int t = 0; for(int j = 0; j < lim; j++) if((i >> j) & 1) t |= (1 << (lim - j - 1)); if(i < t) swap(a[i], a[t]); //i < t的限制使得每对点只被交换一次(否则交换两次相当于没交换) } static cp buf[N]; for(int l = 2; l <= n; l *= 2) { //区间长度 int m = l / 2; for(int j = 0; j < n; j += l) //区间起点 for(int i = 0; i < m; i++) { buf[j + i] = a[j + i] + omg[n / l * i] * a[j + i + m]; buf[j + i + m] = a[j + i] - omg[n / l * i] * a[j + i + m]; } for(int j = 0; j < n; j++) a[j] = buf[j]; } }
蝴蝶变换:
之前为什么需要buf数组:
如果
a[j + i] = a[j + i] + omg[n / l * i] * a[j + i + m]
a[j + i + m] = a[j + i] - omg[n / l * i] * a[j + i + m]
会对更新a[j + i + m]造成影响。
而通过蝴蝶变换
cp t = omg[n / l * i] * a[j + i + m]
a[j + i + m] = a[j + i] - t
a[j + i] = a[j + i] + t
不就顺序换了一下??
反正就不用buf数组就是了。
FFT最终模板:
const double PI = acos(-1); typedef complex <double> cp; cp a[N], b[N], omg[N], inv[N]; void init() { for(int i = 0; i < n; i++) { omg[i] = cp(cos(2 * PI * i / n), sin(2 * PI * i / n)); inv[i] = conj(omg[i]); } } void fft(cp *a, cp *omg) { int lim = 0; while((1 << lim) < n) lim++; for(int i = 0; i < n; i++) { int t = 0; for(int j = 0; j < lim; j++) if((i >> j) & 1) t |= (1 << (lim - j - 1)); if(i < t) swap(a[i], a[t]); //i < t的限制使得每对点只被交换一次(否则交换两次相当于没交换) } for(int l = 2; l <= n; l *= 2) { //区间长度 int m = l / 2; for(cp *p = a; p != a + n; p += l) for(int i = 0; i < m; i++) { cp t = omg[n / l * i] * p[i + m]; p[i + m] = p[i] - t; p[i] += t; } } }
题1:a*bIII http://www.acmicpc.sdnu.edu.cn/problem/show/1531
注意三点:
1.0与其他数相乘只输出一个0.
2.memset是可以对复数初始化的
3.FFT中的n需要是2的倍数
#include<bits/stdc++.h> using namespace std; #define int long long const int N = 1000005; const double PI = acos(-1); typedef complex <double> cp; char sa[N], sb[N]; int n = 1, lena, lenb, ans[N]; cp a[N], b[N], omg[N], inv[N]; void init(){ for(int i = 0; i < n; i++){ omg[i] = cp(cos(2 * PI * i / n), sin(2 * PI * i / n)); inv[i] = conj(omg[i]); } } void fft(cp *a, cp *omg) { int lim = 0; while((1 << lim) < n) lim++; for(int i = 0; i < n; i++) { int t = 0; for(int j = 0; j < lim; j++) if((i >> j) & 1) t |= (1 << (lim - j - 1)); if(i < t) swap(a[i], a[t]); //i < t的限制使得每对点只被交换一次(否则交换两次相当于没交换) } for(int l = 2; l <= n; l *= 2) { //区间长度 int m = l / 2; for(cp *p = a; p != a + n; p += l) for(int i = 0; i < m; i++) { cp t = omg[n / l * i] * p[i + m]; p[i + m] = p[i] - t; p[i] += t; } } } signed main() { while(~scanf("%s%s",sa,sb)) { memset(ans, 0, sizeof(ans)); memset(a, 0, sizeof(a)); memset(b, 0, sizeof(b)); n = 1; lena = strlen(sa), lenb = strlen(sb); if(lena == 1 && sa[0] == '0' || lenb == 1 && sb[0] == '0') { printf("0\n"); continue; } while(n < lena + lenb) n *= 2; for(int i = 0; i < lena; i++){ a[i].real(sa[lena - 1 - i] - '0'); } for(int i = 0; i < lenb; i++){ b[i].real(sb[lenb - 1 - i] - '0'); } init(); fft(a, omg); fft(b, omg); for(int i = 0; i < n; i++) { a[i] *= b[i]; } fft(a, inv); for(int i = 0; i < n; i++) { ans[i] += floor(a[i].real() / n + 0.5); ans[i + 1] += ans[i] / 10; ans[i] %= 10; } int beg; for(int i = n-1; i >= 0; i--) { if(ans[i] != 0){ beg = i; break; } } for(int i = beg; i >= 0; i--){ printf("%lld",ans[i]); } putchar('\n'); } return 0; }
FFT的缺点是它的复数运算double精度问题导致它实际上是k*nlongn的,会比NTT的常数大很多。
NTT
前置知识:
原根:对于g,p属于Z, 如果g^i mod p ( 1<=i<=p-1)的值互不相同,则称g为p的原根。
或者说对于任意i,j(1<=i<j <= p-1) g^i mod p /= g^j mod p,那么g为p的原根。
常见模数有:998244353,1004535809,469762049,这几个的原根都为3
在NTT中,我们拿原根来代替FFT的单位根
#define g 3 const int mod = 998244353; const int N = 300000; inline get_rev() { int lim = 0; while((1 << lim) < n) lim++; for(int i = 0; i < n; i++) { rev[i] = (rev[i >> 1] >> 1 | ((i & 1) << (lim - 1))); } } inline void ntt(int *a, int inv) { for(int i = 0; i < n; i++) { if(i < rev[i]) swap(a[i], a[rev[i]]); //i < t的限制使得每对点只被交换一次(否则交换两次相当于没交换) } for(int l = 2; l <= n; l *= 2) { //区间长度 int m = l / 2; int tmp = q_pow(g, (mod-1)/l); if(inv == -1) tmp = q_pow(tmp, mod-2); for(int i = 0; i < n; i += l) { int omega = 1; for(int j = 0; j < m; j++, omega = omega*tmp%mod) { int x = a[i + j], y = omega * a[i + j + m] % mod; a[i + j] = (x + y) % mod, a[i + j + m] = (x - y + mod) % mod; } } } if(inv == -1) { int nI = q_pow(n, mod-2); for(int i = 0; i < n; i++) { a[i] = a[i] * nI % mod; } } }
例1:http://www.acmicpc.sdnu.edu.cn/problem/show/1532
#include<bits/stdc++.h> using namespace std; #define int long long #define g 3 const int mod = 998244353; const int N = 300000; inline int read(){ int x=0,f=1;char ch=getchar(); while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();} while(isdigit(ch)){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();} return x*f; } inline void write(int x) { if(x<0)x=-x,putchar('-'); if(x>9)write(x/10); putchar(x%10+'0'); } char sa[N], sb[N]; int n = 1, a[N], b[N], rev[N], lena, lenb; inline int q_pow(int a, int b){ int ans = 1; while(b > 0){ if(b & 1){ ans = ans * a % mod; } a = a * a % mod; b >>= 1; } return ans; } inline get_rev() { int lim = 0; while((1 << lim) < n) lim++; for(int i = 0; i < n; i++) { rev[i] = (rev[i >> 1] >> 1 | ((i & 1) << (lim - 1))); } } inline void ntt(int *a, int inv) { for(int i = 0; i < n; i++) { if(i < rev[i]) swap(a[i], a[rev[i]]); //i < t的限制使得每对点只被交换一次(否则交换两次相当于没交换) } for(int l = 2; l <= n; l *= 2) { //区间长度 int m = l / 2; int tmp = q_pow(g, (mod-1)/l); if(inv == -1) tmp = q_pow(tmp, mod-2); for(int i = 0; i < n; i += l) { int omega = 1; for(int j = 0; j < m; j++, omega = omega*tmp%mod) { int x = a[i + j], y = omega * a[i + j + m] % mod; a[i + j] = (x + y) % mod, a[i + j + m] = (x - y + mod) % mod; } } } if(inv == -1) { int nI = q_pow(n, mod-2); for(int i = 0; i < n; i++) { a[i] = a[i] * nI % mod; } } } signed main() { while(~scanf("%s%s",sa,sb)) { n = 1; memset(a, 0, sizeof(a)); memset(b, 0, sizeof(b)); lena = strlen(sa), lenb = strlen(sb); for(int i = 0; i < lena; i++) { a[i] = sa[lena - 1 - i] - '0'; } for(int i = 0; i < lenb; i++) { b[i] = sb[lenb - 1 - i] - '0'; } while(n < lena + lenb) n <<= 1; get_rev(); ntt(a, 1); ntt(b, 1); for(int i = 0; i < n; i++) { a[i] = a[i] * b[i] % mod; } ntt(a, -1); for(int i = 0; i < n; i++) { a[i + 1] += a[i] / 10; a[i] %= 10; } int cnt = n; while(cnt >= 0 && a[cnt] == 0) cnt--; if(cnt == -1) { printf("0"); } else { for(int i = cnt; i >= 0; i--){ write(a[i]); } } putchar('\n'); } return 0; }