快速沃尔什变换小记
概述
是用来处理集合卷积的问题。也就是求解类型的问题。其中或运算可以改为。
寻找点值
因为总是看不下去那么长的推导,所以每次都是看到一半。然后就在加上自己的一点理解,简单推导一下吧(背过结论就行)
以或运算为例。为什么说是集合卷积呢。因为或运算等价于求集合并。也就是求 。
那么我们类似于,先将他转化为点值,然后进行乘法运算后,在转换回来。
如何转化为点值呢。或者说他的点值长什么样子呢。
我们求一个。然后我们求一下的乘积。
我们令表示转化后的结果,表示转化或的结果,表示卷积转化后的结果。
也就是
所以就是我们要的点值表达式!
相互转化
有了点值之后我们还需要在点值与多项式之间相互转化,那么应该怎么转化呢。
其实很简单,观察。这个式子,其实就是一个高维前缀和嘛。。
然后转化回去同样的来个高维差分就ok了。
高维前缀和代码如下:
void fwt_or(int *a,int xs) { for(int i = 0;i < n;++i) for(int j = 0;j < (1 << n);++j) if(!((j >> i) & 1)) a[j | (1 << i)] += a[j]; }
对于另外两种运算
对于和。与类似,也有不同之处。因为表示的是集合交,所以他是枚举超集和而不是子集和。
至于,背板子吧我也不会推导啊qwq
板子
板子里面,时表示,即将多项式转化为点值。时表示,即将点值转化回多项式。
或运算
void fwt_or(int *a,int xs) { for(int i = 0;i < n;++i) for(int j = 0;j < (1 << n);++j) if(!((j >> i) & 1)) a[j | (1 << i)] += xs * a[j]; }
and运算
void fwt_and(int *a,int xs) { for(int i = 0;i < n;++i) for(int j = 0;j < (1 << n);++j) if(!((j >> i) & 1)) a[j] += xs * a[j | (1 << i)]; }
异或运算
void fwt_xor(int *a,int xs) { for(int i = 0;i < n;++i) { for(int j = 0;j < (1 << n);++j) { if(!((j >> i) & 1)) { int l = a[j],r = a[j | (1 << i)]; a[j] = l + r;a[j] %= mod; a[j | (1 << i)] = l - r;a[j | (1 << i)] %= mod; } } } if(xs == -1) { int inv = qm(1 << n,mod - 2); for(int i = 0;i < (1 << n);++i) a[i] = 1ll * a[i] * inv % mod; } }
小技巧
在很多题目中,需要进行多次运算,我们不需要每次将数组来回转化,只要先将多项式转化为点值,然后对点值进行快速幂运算,最后在转化回去就行。
模板题
/* * @Author: wxyww * @Date: 2020-04-26 08:03:27 * @Last Modified time: 2020-04-26 08:43:59 */ #include<cstdio> #include<iostream> #include<cstdlib> #include<cstring> #include<algorithm> #include<queue> #include<vector> #include<ctime> using namespace std; typedef long long ll; const int N = 1 << 20,mod = 998244353; ll read() { ll x = 0,f = 1;char c = getchar(); while(c < '0' || c > '9') { if(c == '-') f = -1; c = getchar(); } while(c >= '0' && c <= '9') { x = x * 10 + c - '0'; c = getchar(); } return x * f; } int A[N],B[N],n; void fwt_and(int *a,int xs) { for(int i = 0;i < n;++i) { for(int j = 0;j < (1 << n);++j) { if(!((j >> i) & 1)) { a[j] += xs * a[j | (1 << i)]; a[j] %= mod; } } } } void fwt_or(int *a,int xs) { for(int i = 0;i < n;++i) { for(int j = 0;j < (1 << n);++j) { if(!((j >> i) & 1)) { a[j | (1 << i)] += xs * a[j]; a[j | (1 << i)] %= mod; } } } } ll qm(ll x,ll y) { ll ret = 1; for(;y;y >>= 1,x = x * x % mod) if(y & 1) ret = ret * x % mod; return ret; } void fwt_xor(int *a,int xs) { for(int i = 0;i < n;++i) { for(int j = 0;j < (1 << n);++j) { if(!((j >> i) & 1)) { int l = a[j],r = a[j | (1 << i)]; a[j] = l + r;a[j] %= mod; a[j | (1 << i)] = l - r;a[j | (1 << i)] %= mod; } } } if(xs == -1) { int inv = qm(1 << n,mod - 2); for(int i = 0;i < (1 << n);++i) { a[i] = 1ll * a[i] * inv % mod; } } } int tmp1[N],tmp2[N]; int main() { n = read(); for(int i = 0;i < (1 << n);++i) A[i] = read(); for(int i = 0;i < (1 << n);++i) B[i] = read(); memcpy(tmp1,A,sizeof(tmp1)); memcpy(tmp2,B,sizeof(tmp2)); fwt_or(tmp1,1);fwt_or(tmp2,1); for(int i = 0;i < (1 << n);++i) tmp1[i] = 1ll * tmp1[i] * tmp2[i] % mod; fwt_or(tmp1,-1); for(int i = 0;i < (1 << n);++i) printf("%d ",(tmp1[i] + mod) % mod);puts(""); memcpy(tmp1,A,sizeof(tmp1)); memcpy(tmp2,B,sizeof(tmp2)); fwt_and(tmp1,1);fwt_and(tmp2,1); for(int i = 0;i < (1 << n);++i) tmp1[i] = 1ll * tmp1[i] * tmp2[i] % mod; fwt_and(tmp1,-1); for(int i = 0;i < (1 << n);++i) printf("%d ",(tmp1[i] + mod) % mod);puts(""); memcpy(tmp1,A,sizeof(tmp1)); memcpy(tmp2,B,sizeof(tmp2)); fwt_xor(tmp1,1);fwt_xor(tmp2,1); for(int i = 0;i < (1 << n);++i) tmp1[i] = 1ll * tmp1[i] * tmp2[i] % mod; fwt_xor(tmp1,-1); for(int i = 0;i < (1 << n);++i) printf("%d ",(tmp1[i] + mod) % mod);puts(""); return 0; }