【博客】「算法笔记」快速沃尔什变换 FWT
文章链接:http://orzsiyuan.com/archives/Algorithm/Fast-Walsh-Hadamard-Transform
快速沃尔什变换用于解决多项式位运算卷积。其计算过程和 类似。
概述
首先我们回忆一下多项式卷积:
在[「算法笔记」快速傅里叶变换 FFT](http://orzsiyuan.com/archives/Algorithm/Fast-Fourier-Transform) 中我们将多项式 和 转化为点值表达,然后重新转化为系数表达。
接下来考虑如下卷积形式:
其中 分别表示按位或、按位与、按位异或。
这样就没法使用 解决了,我们需要引进新的方法:快速沃尔什变换。
下文会分类介绍这 类卷积的解决方法。其中定义部分是一种构造出来的变换方法,证明部分将使用数学归纳法证明一些引理或定理,代码部分为该卷积的实现。
符号表示
首先我们认为下文提及的多项式的长度均为 的非负整数次幂。为了方便表述,我们定义如下如下符号及其含义,下文不再赘述。
其中需要尤其注意:这里定义的多项式异或、与、或均为卷积形式!并不是对应系数位运算!
或卷积
定义
证明
两个多项式相加后的 变换等于分别 的和:
当 时
当 时
两个多项式或卷积的 变换等于分别 的积:
当 时
当 时
逆变换
代码
void FWTor(std::vector<int> &a, bool rev) {
int n = a.size();
for (int l = 2, m = 1; l <= n; l <<= 1, m <<= 1) {
for (int j = 0; j < n; j += l) for (int i = 0; i < m; i++) {
if (!rev) add(a[i + j + m], a[i + j]);
else sub(a[i + j + m], a[i + j]);
}
}
}
与卷积
定义
证明
两个多项式相加后的 变换等于分别 的和:
当 时
当 时
两个多项式与卷积的 变换等于分别 的积:
当 时
当 时
逆变换
代码
void FWTand(std::vector<int> &a, bool rev) {
int n = a.size();
for (int l = 2, m = 1; l <= n; l <<= 1, m <<= 1) {
for (int j = 0; j < n; j += l) for (int i = 0; i < m; i++) {
if (!rev) add(a[i + j], a[i + j + m]);
else sub(a[i + j], a[i + j + m]);
}
}
}
异或卷积
定义
证明
两个多项式相加后的 变换等于分别 的和:
当 时
当 时
两个多项式异或卷积的 变换等于分别 的积:
当 时
当 时
逆变换
代码
void FWTxor(std::vector<int> &a, bool rev) {
int n = a.size(), inv2 = (P + 1) >> 1;
for (int l = 2, m = 1; l <= n; l <<= 1, m <<= 1) {
for (int j = 0; j < n; j += l) for (int i = 0; i < m; i++) {
int x = a[i + j], y = a[i + j + m];
if (!rev) {
a[i + j] = (x + y) % P;
a[i + j + m] = (x - y + P) % P;
} else {
a[i + j] = 1LL * (x + y) * inv2 % P;
a[i + j + m] = 1LL * (x - y + P) * inv2 % P;
}
}
}
}
完整代码
#include <cstdio>
#include <algorithm>
#include <vector>
const int P = 998244353;
void add(int &x, int y) {
(x += y) >= P && (x -= P);
}
void sub(int &x, int y) {
(x -= y) < 0 && (x += P);
}
struct FWT {
int extend(int n) {
int N = 1;
for (; N < n; N <<= 1);
return N;
}
void FWTor(std::vector<int> &a, bool rev) {
int n = a.size();
for (int l = 2, m = 1; l <= n; l <<= 1, m <<= 1) {
for (int j = 0; j < n; j += l) for (int i = 0; i < m; i++) {
if (!rev) add(a[i + j + m], a[i + j]);
else sub(a[i + j + m], a[i + j]);
}
}
}
void FWTand(std::vector<int> &a, bool rev) {
int n = a.size();
for (int l = 2, m = 1; l <= n; l <<= 1, m <<= 1) {
for (int j = 0; j < n; j += l) for (int i = 0; i < m; i++) {
if (!rev) add(a[i + j], a[i + j + m]);
else sub(a[i + j], a[i + j + m]);
}
}
}
void FWTxor(std::vector<int> &a, bool rev) {
int n = a.size(), inv2 = (P + 1) >> 1;
for (int l = 2, m = 1; l <= n; l <<= 1, m <<= 1) {
for (int j = 0; j < n; j += l) for (int i = 0; i < m; i++) {
int x = a[i + j], y = a[i + j + m];
if (!rev) {
a[i + j] = (x + y) % P;
a[i + j + m] = (x - y + P) % P;
} else {
a[i + j] = 1LL * (x + y) * inv2 % P;
a[i + j + m] = 1LL * (x - y + P) * inv2 % P;
}
}
}
}
std::vector<int> Or(std::vector<int> a1, std::vector<int> a2) {
int n = std::max(a1.size(), a2.size()), N = extend(n);
a1.resize(N), FWTor(a1, false);
a2.resize(N), FWTor(a2, false);
std::vector<int> A(N);
for (int i = 0; i < N; i++) A[i] = 1LL * a1[i] * a2[i] % P;
FWTor(A, true);
return A;
}
std::vector<int> And(std::vector<int> a1, std::vector<int> a2) {
int n = std::max(a1.size(), a2.size()), N = extend(n);
a1.resize(N), FWTand(a1, false);
a2.resize(N), FWTand(a2, false);
std::vector<int> A(N);
for (int i = 0; i < N; i++) A[i] = 1LL * a1[i] * a2[i] % P;
FWTand(A, true);
return A;
}
std::vector<int> Xor(std::vector<int> a1, std::vector<int> a2) {
int n = std::max(a1.size(), a2.size()), N = extend(n);
a1.resize(N), FWTxor(a1, false);
a2.resize(N), FWTxor(a2, false);
std::vector<int> A(N);
for (int i = 0; i < N; i++) A[i] = 1LL * a1[i] * a2[i] % P;
FWTxor(A, true);
return A;
}
} fwt;
int main() {
int n;
scanf("%d", &n);
std::vector<int> a1(n), a2(n);
for (int i = 0; i < n; i++) scanf("%d", &a1[i]);
for (int i = 0; i < n; i++) scanf("%d", &a2[i]);
std::vector<int> A;
A = fwt.Or(a1, a2);
for (int i = 0; i < n; i++) {
printf("%d%c", A[i], " \n"[i == n - 1]);
}
A = fwt.And(a1, a2);
for (int i = 0; i < n; i++) {
printf("%d%c", A[i], " \n"[i == n - 1]);
}
A = fwt.Xor(a1, a2);
for (int i = 0; i < n; i++) {
printf("%d%c", A[i], " \n"[i == n - 1]);
}
return 0;
}
习题
- 「Luogu 4717」【模板】快速沃尔什变换
- 「BZOJ 4589」Hard Nim
- 「hihoCoder 1230」The Celebration of Rabbits
- 「HDU 5909」Tree Cutting