任意模数多项式乘法
任意模数多项式乘法
NTT
因为任意模数多项式乘法的系数可能超过了 所以一般的 就做不了。所以对于 的做法是找到 个模数,最后使用 合并一下,但是这里的合并中的答案可能超过 的值域,所以要用快速乘实现。
FFT
任意模数多项式乘法用 实现的问题主要是因为精度问题,这里使用了 次变化的 实现,采用了 作为底,那么一个多项式,我们根据 。那么 。就拆分成了系数比较小的四个多项式。最后合并一下就好了。
代码
这里只给出了 的代码实现。
#include<bits/stdc++.h> using namespace std; const int N = 4e5 + 100,base = 1 << 15; #define db long double #define ll long long const db pi = acos(-1); int read() { int x = 0,f = 0;char ch = getchar(); while(!isdigit(ch)) {if(ch=='-')f=1;ch=getchar();} while(isdigit(ch)) {x=x*10+ch-'0';ch=getchar();} return f?-x:x; } struct comp{ db x,y; comp(db a,db b) {x=a;y=b;} comp() {x=y=0;} comp operator + (const comp a) const {return comp(x+a.x,y+a.y);} comp operator - (const comp a) const {return comp(x-a.x,y-a.y);} comp operator * (const comp a) const {return comp(a.x*x-a.y*y,y*a.x+a.y*x);} }A1[N],A2[N],B1[N],B2[N]; int limit = 1,L,r[N],n,m,p; void fft(comp *a,int type) { for(int i = 0;i < limit;i++) if(r[i] < i) swap(a[r[i]],a[i]); for(int mid = 1;mid < limit;mid <<= 1) { comp wn = comp(cos(1.0 * pi / mid),type * sin(1.0 * pi / mid)); for(int i = 0;i < limit;i += (mid << 1)) { comp w = comp(1,0); for(int j = 0;j < mid;j++,w = w * wn) { comp x = a[i + j],y = w * a[i + j + mid]; a[i + j] = x + y;a[i + j + mid] = x - y; } } } if(type == -1) for(int i = 0;i < limit;i++) a[i].x = a[i].x / limit; } void Merge(comp *a,comp *b,int B,int *f) { static comp g[N]; for(int i = 0;i < limit;i++) g[i] = a[i] * b[i];fft(g,-1); for(int i = 0;i < limit;i++) f[i] = (f[i] + 1ll * B * ((ll)floor(g[i].x + 0.5) % p) % p) % p; } void MTT(comp *a,comp *b,comp *c,comp *d,int *f) { fft(a,1);fft(b,1);fft(c,1);fft(d,1); Merge(a,c,base * 1ll * base % p,f);Merge(a,d,base % p,f); Merge(b,c,base % p,f);Merge(b,d,1,f); } int main() { n = read();m = read();p = read(); for(int i = 0,x;i <= n;i++) {x = read();A1[i].x = x / base;B1[i].x = x % base;} for(int i = 0,x;i <= m;i++) {x = read();A2[i].x = x / base;B2[i].x = x % base;} while(limit <= n + m) limit <<= 1,L++; for(int i = 0;i < limit;i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << L - 1); static int Ans[N];memset(Ans,0,sizeof(Ans)); MTT(A1,B1,A2,B2,Ans); for(int i = 0;i <= n + m;i++) printf("%d ",Ans[i]); return 0 & printf("\n"); }
数学 文章被收录于专栏
关于多项式