快速傅里叶变换(FFT)

有趣啊,都已经到NOI的难度了,救命

首先,我们先讲述一下前置知识。已经明白的读者请移步后文

虚数

定义\(z = a + bi\),其中 \(a, b \in R\ \ i = \sqrt{-1}\)

运算原则

\[\begin{aligned}
(a+bi) + (c+di) &= (a+c) + (b+d)i \\
(a+bi)(c+di) &= (ac - bd) + (ad + bc)i \\
\cfrac {(a+bi)}{(c+di)} &= \cfrac {ac + bd}{c^2 + d^2} + \cfrac {bc - ad}{c^2 + d^2}
\end{aligned}
\]

重要性质

\[e^{ix} = \cos x + i \sin x
\]

所以说,一个复数也可以写作 \(z = re^{i\theta}\) 的形式。

其中 \(r\) 为它的模,\(\theta\) 为它的辐角

证明

我们通过欧拉公式在 0 处展开:

\[\begin{aligned}
e^x &= 1 + x + \cfrac {x^2}{2!} + \cfrac {x^3}{3!} + \dots \\
\cos x &= 1 - \cfrac {x^2}{2!} + \cfrac {x^4}{4!} - \cfrac {x^6}{6!} + \dots \\
\sin x &= x - \cfrac {x^3}{3!} + \cfrac {x^5}{5!} - \cfrac {x^7}{7!} + \dots \\
\end{aligned}
\]

那么我们考虑如何把三者扯上关系呢?

由于已知 \(i^2 = -1, i^3 = -i\) 那么我们先考虑 \(e^{ix}\)

\[e^{ix} = 1 + ix - \cfrac {x^2}{2!} - \cfrac {ix^3}{3!} + \dots
\]

那么,很明显, \(e^{ix} = \cos x + i \sin x\)

得证。

代码实现

在 C++ 中我们其实可以直接使用 std::complex<double>

文档可以参考 std::complex - cppreference.com

但是毕竟是 stl,其使用细节肯定没有那么顺手,所以建议手写复数模板。

考虑到实际中我们几乎不需要用到除法,所以,我们仅实现三则运算。

struct Complex {
    double real, imag;
    Complex() : real(0), imag(0) {}
    Complex(double re, double im) : real(re), imag(im) {};
    inline Complex operator + (const Complex & b) {
        return Complex(real + b.real, imag + b.imag);
    }
    inline Complex operator - (const Complex & b) {
        return Complex(real - b.real, imag - b.imag);
    }
    inline Complex operator * (const Complex & b) {
        return Complex(real * b.real - imag * b.imag, real * b.imag + imag * b.real);
    }
};

单位根

快速傅里叶变换的核心就是利用的单位根的一些独特的性质来快速实现的

由于我们知道复数可以写作 \(z = re^{i\theta}\),所以说,两个复数 \(z_1, z_2\) 相乘的结果也可以表示为 \(r_1r_2 e^{i(\theta_1+\theta_2)}\)

相乘之后的图类上

扯远了……


单位根的定义:方程 \(z^n = 1\) 在复数范围内的 \(n\) 个根。

那么,不经过证明的给出,每一个根应该为 \(e^{i\frac{2k\pi}{n}}\)

这里我们记 \(\omega_n\) 为主 \(n\) 次单位根, \(\omega_n^k = e^{i\frac{2k\pi}{n}}\)

举个例子,主 \(8\) 次单位根的 \(8\) 个值改写为形如 \((r, \theta)\) 的极坐标后,位置类似于下图:

三个引理

  • 消去定理:\(\omega_{dn}^{dk} = \omega_n^k\)

证明:考虑展开即可:\(\omega_{dn}^{dk} = e^{i\frac{2dk\pi}{dn}} = e^{i\frac{2k\pi}{n}} = w_n^k\)

  • 折半引理:\((\omega_n^{k+\frac n2})^2 = (w_n^k)^2 = \omega_{\frac n2}^k\)

这个引理是快速傅里叶变化的核心

证明:也是考虑展开

\[\begin{aligned}
\omega_n^{k+\frac n2} &= \omega_n^k\omega_n^{\frac n2} = -\omega_n^k \\
(\omega_n^k)^2 &= \omega_n^{2k} = \omega_{\frac n2}^k
\end{aligned}
\]

  • 求和引理:\(\sum_{i=0}^{n-1} (\omega_n^k)^i = 0\)

证明

根据等比数列公式

\[\sum_{i=0}^{n-1} (\omega_n^k)^i
= \cfrac {(w_n^k)^n - 1}{w_n^k - 1}
= \cfrac {(w_n^n)^k - 1}{w_n^k - 1}
= \cfrac {1^k - 1}{w_n^k - 1} = 0
\]

得证


多项式

(OI中)一般形式\(F(x) = a_0 + a_1x + a_2 x^2 + \cdots + a_n x^n\)

上述多项式为一元多项式。

我们可以改写上式:\(\sum_{i=0}^n a_ix^i\)

我们对于多项式运算定义如下:

\[\begin{aligned}
A(x) &= \sum_{i = 0}^n a_i x^i \\
B(x) &= \sum_{i = 0}^n b_i x^i \\
\end{aligned}
\]

  • 加法:
\[A(x) + B(x) = \sum_{i = 0}^n (a_i + b_i) x^i
\]

  • 乘法

一般情况下,我们可以通过补零的方式,将两个次数不同的多项式调整到次数相同。这里我们都补充到 \(n\) 的长度

\[\begin{aligned}
c_i &= \sum_{j = 0}^{i} a_j b_{i - j} \\
A(x)B(x) &= \sum_{i = 0}^{2n} c_i x^i \\
\end{aligned}
\]

我们称这个系数向量 \(c\) 为向量 \(a, b\) 的卷积,记作 \(a \otimes b\)

表示方法

  • 系数表示

    它将一个多项式表示成由其系数构成的向量的形式

    例如 \(A = [a_0, a_1, a_2, \dots, a_n]^T\)

    加法即为 \(A_1+ A_2\),直接相加即可。时间复杂度 \(O(n)\)

    乘法则做向量卷积,为 \(A_1 \otimes A_2\)。一般来说,时间复杂度为 \(O(n^2)\)

    如果给定 \(x\) 求值,则可以使用霍纳法则或者秦九昭算法。时间复杂度为 \(O(n)\)

  • 点值表示

    用至少 \(n\) 个多项式上的点来表示

    一般形式如 \(\{(x_0, A(x_0)), (x_1, A(x_1), \dots, (x_n, A(x_n))\}\)

    进行运算是,一般要保证两个多项式在同一位置取值相同,即 \(x_i\) 相同

    加法运算直接将两点坐标相加即可,时间复杂度为 \(O(n)\)

    乘法运算只需要将两点坐标相乘即可。时间复杂度为 \(O(n)\),太好了!

    如果我们需要 \(A(x)\) ,这个过程叫做插值,可以通过拉格朗日插值公式进行计算,复杂度为 \(O(n^2)\),这里不展开讲述。

离散傅里叶变换(DFT)

DFT(Discrete Fourier Transform) 是快速傅里叶变换(FFT)的基础,也是快速数论变换(NTT)的基础

变换操作是对于一个向量而言

这个变换操作是我们通过单位根的性质构造出来的!请不要求什么证明

不妨设这个向量为 \(C = [c_0, c_1, c_2, \dots, c_{n-1}]^T\)

我们定义一个变换公式

\[h(x) = \sum_{i = 0}^{n-1}c_i x^i
\]

那么变换过后的序列为

这里不用向量是因为我不知道该如何表示

\[< h(\omega^0), h(\omega^1), h(\omega^2), \dots, h(\omega^{n-1})>
\]

其中 \(\omega\) 代表主 \(n\) 次的单位根。

对于上述序列,我们称形如 \(h(\omega^k)\) 的项为 \(k\) 次离散傅里叶级数。

我们将每一项展开,那么可以得到下图:

图片来自网络

这个时候,我们变换后的序列类似于用点值表示的序列。

两个变换后的序列相乘即是一一对应的乘法即可。

例如两个序列 \(<h(\omega_0), h(\omega_1),h(\omega_2)>\)\(<g(\omega_0), g(\omega_1), g(\omega_2)>\)

那么相乘并不需要卷积,结果即是 \(<h(\omega_0)g(\omega_0), h(\omega_1)g(\omega_1), h(\omega_2)g(\omega_2)>\)


重要性质

对于上述序列,其 \(-k\) 次离散傅里叶变换后的的值恰为 \(n \times c_k\)

证明

我们用 \(g(\omega^{-k})\) 表示变换后的结果

\[\begin{aligned}
g(\omega^{-k}) &= h(\omega^0) \omega^{-0k} + h(\omega_1) \omega^{-k} + h(w_2) \omega^{-2k} + \cdots + h(\omega^{n-1}) \omega^{-(n-1)k} \\
&= \sum_{i=0}^{n-1} h(\omega^i)\omega^{-ik} \\
&= \sum_{i=0}^{n-1} \omega^{-ik} \sum_{j = 0}^{n-1} c_j \omega^{ij} \\
&= \sum_{i=0}^{n-1} \sum_{j=0}^{n-1} \omega^{(j-k)i} c_j \\
&= \sum_{i=0}^{n-1} c_j \sum_{j=0}^{n-1} \omega^{(j-k)i}
\end{aligned}
\]

我们分类讨论一下:

  • \(j = k\),那么此时 \(\sum_{j=0}^{n-1} \omega^{(j-k)y} = \sum_{j=0}^{n-1} \omega^0 = n\)。对于 \(c_k\) 做出的贡献为 \(n\)

  • \(j \ne k\),将 \(k\) 看作常数,那么此时 \(\sum_{j=0}^{n-1} \omega^{(j-k)i} = \sum_{j=0}^{n-1} \omega^{j}\)。依据上文中求和引理,其值为 \(0\),也就是对 \(c_j\) 做出的贡献为 \(0\)

综上所述,只有 \(c_k\) 对于答案做出了 \(n\) 次贡献,所以 \(g(\omega^{-k}) = n \times c_k\)


于是,我们可以得出两个多项式卷积的计算方法:

  • 将两个多项式改写为向量的形式,并分别对其做一次离散傅里叶变换 (DFT)

  • 将变换过后的两个序列相乘(点对点相乘)得出一个新的序列

  • 我们再对此序列做一次逆傅里叶变换,也就是将序列变为 \(<g(\omega^{-0}), g(\omega^{-1}),g(\omega^{-2}), \dots, g(\omega^{-(n-1)})>\)

    也就是 \(<n \times c_0, n \times c_1, \dots, n \times c_{n-1}>\)

    最后对于每一项除以 \(n\) 即可。

但是

最朴素的 DFT 变化的时间复杂度为 \(O(n^2)\),三次变换还不如直接暴力计算……所以我们就需要快速傅里叶变换来优化了。

快速傅里叶变换(FFT)

FFT -> Fast-Fast-TLE

我们对于 \(f(x) = c_0 + c_1 x + c_2 x^2 + \dots + c_{n-1}x^{n-1}\),分离其奇数项和偶数项,构造出另外两个向量

\[\begin{aligned}
f_{even}(x) &= c_0 + c_2 x + c_4 x^2 + \dots + c_{n-2} x^{\frac n2 - 1} \\
f_{odd}(x) &= c_1 + c_3 x + c_5 x^2 + \dots + c_{n-1} x^{\frac n2 - 1} \\
\end{aligned}
\]

那么,不难发现:

\[f(x) = f_{even}(x^2) + x f_{odd}(x^2)
\]

也就是说

\[\begin{aligned}
f(\omega_n^k) &= f_{even}(\omega_{n}^{2k}) + \omega^kf_{odd}(\omega_n^{2k}) \\
f(\omega_n^{k+\frac n2}) &= f_{even}(\omega_n^{2k+n}) + \omega^{k+\frac n2}f_{odd}(\omega_n^{2k+n})
\end{aligned}
\]

根据单位根的性质稍微化一下……

补充一个点:

\[\omega_n^{\frac n2} = -1
\]

证明

考虑我们在上面画出的图中,\(\omega_n^{\frac n2}\) 所在的位置。就是 -1 了

\[\begin{aligned}
f(\omega_n^k) &= f_{even}(\omega_{\frac n2}^k) + \omega_n^kf_{odd}(\omega_{\frac n2}^k) \\
f(\omega_n^{k+\frac n2}) &= f_{even}(\omega_{\frac n2}^k) - \omega_n^kf_{odd}(\omega_{\frac n2}^k) \\
\end{aligned}
\]

于是,我们就可以递归分治了,其复杂度为 \(O(nlogn)\)

其实我们还要考虑一点,我们要保证长度为 \(2^k\) 才能保证可以正确的分治。

应为区间合并必须要满足长度相等才行。

所以说,我们要把两个多项式通过补 \(0\) 的方式补齐到 \(2^k\) 项,合并之后就是 \(2^{k+1}\) 项。

对于模板题:【模板】多项式乘法(FFT) - 洛谷

可以写出如下龟速代码:

#include <iostream>
#include <algorithm>
#include <vector>

struct Complex {
    double real, imag;
    Complex() : real(0), imag(0) {}
    Complex(double re, double im) : real(re), imag(im) {};
    inline Complex operator + (const Complex & b) { return Complex(real + b.real, imag + b.imag); }
    inline Complex operator - (const Complex & b) { return Complex(real - b.real, imag - b.imag); }
    inline Complex operator * (const Complex & b) { return Complex(real * b.real - imag * b.imag, real * b.imag + imag * b.real); }
};
typedef std::vector<Complex> Vector;

const double PI = acos(-1);

void FFT(Vector &v, int n, int inv) {
    if (n == 1) return; // 递归边界,只有一个元素,不做变换

    // 奇偶变化为两个向量
    int mid = n >> 1;
    Vector even(mid), odd(mid);
    for (int i(0); i < n; i += 2) {
        even[i >> 1] = v[i], odd[i >> 1] = v[i + 1]; 
    }
    // 递归操作 
    FFT(even, mid, inv), FFT(odd, mid, inv);

    // 进行合并操作
    // 定义基本 omega
    Complex omega(cos(PI * 2 / n), inv * sin(PI * 2 / n));
    // 当前旋转因子
    Complex w(1, 0);
    for (int i(0); i < mid; ++i, w = w * omega) {
        v[i] = even[i] + w * odd[i];
        v[i + mid] = even[i] - w * odd[i];
    }
}

int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(0), std::cout.tie(0);

    int n, m;
    std::cin >> n >> m;

    // 获取最终的长度,必须是 2 的次幂,且比两个向量卷起来要长 
    int O(1);
    while (O <= m + n) O <<= 1;
    // std::cout << PI << " " << O << std::endl;

    Vector A(O), B(O);
    for (int i(0); i <= n; ++i) std::cin >> A[i].real; 
    for (int i(0); i <= m; ++i) std::cin >> B[i].real;

    FFT(A, O, 1);
    FFT(B, O, 1);

    // 我们单点相乘,然后进行逆变换,求出每一项的系数
    for (int i(0); i < O; ++i) A[i] = A[i] * B[i];
    FFT(A, O, -1);

    // 最后进行输出
    // 记得两个东西卷起来之后是 n + m 次的
    for (int i(0); i <= n + m; ++i) {
        // 这里是向上取整? 
        std::cout << (long long)(A[i].real / O + 0.5) << ' ';
    } std::cout << std::flush;
    return 0;
}

我不知道为什么网上这么多写递归版本的都是错的。

例如本题第一篇题解,其实并不是方法有问题,是他的写法错了。

链接:题解 P3803 【【模板】多项式乘法(FFT)】 - attack 的博客

例如知乎上的一篇文章,其在递归时的边界是有问题的。虽然说不影响正确性……

链接:快速傅里叶变换 - 星夜

当然,还是有代码正确,但是代码……

链接:FFT-快速傅里叶变换 - heartbeats

蝶形优化

既然是龟速了……那么我们一定要想办法优化。

我们观察递归时合并如下:

不难发现,每一次合并的单位根的次数取决于合并之后的元素个数。

而有多少个元素又可以由当前层数来确定。

其实上图体现的不明显,我们换一张手稿:

\(n = 8\) 时所求的部分大概如此:

我们将其化成表格:

表格内的元素我们通过了一定的方法对其编了号,可以看作对应数组的下标。

\(k=1\) \(c_0\) \(c_4\) \(c_2\) \(c_6\) \(c_1\) \(c_5\) \(c_3\) \(c_7\)
\(h(\omega^0)\) 0 1 2 3 4 5 6 7
\(k=2\) \(c_0,c_4\) \(c_2,c_6\) \(c_1,c_5\) \(c_3,c_7\)
\(h(\omega^0)\) 0 2 4 6
\(h(\omega^1)\) 1 3 5 7
\(k=4\) \(c_0,c_4,c_2,c_6\) \(c_1,c_5,c_3,c_7\)
\(h(\omega^0)\) 0 4
\(h(\omega^1)\) 1 5
\(h(\omega^2)\) 2 6
\(h(\omega^3)\) 3 7

可以发现,若 \(a\)\(a+k\) 能够凑成一对做出贡献,那么一定是对位置在 \(a\)\(a+k\) 的位置的元素做出贡献。

那么,我们就可以直接替换即可。

这个优化叫做蝶形优化……

还有一个小问题没有解决:初始值的位置怎么处理?

其实我们可以通过观察其二进制得出答案:下标的二进制恰好和目标二进制互为倒叙。

一共只有 \(log(n)\) 位!

也就是说当 \(n = 8\)rev[3] = rev[0b011] = 0b110 = 6

我们考虑可以通过DP在 \(O(n)\) 内实现。

假设我们已经处理完了 \(1 \sim n-1\) 的所有 \(rev\)

考虑 \(n = (abcd)_2\),那么我们已经知道了 \(rev[(0abc)_2] = (cba0)_2\),需要 \(rev[(abcd)_2] = (dcba)_2\)

通过瞪眼法,我们可以轻易的得出

\(rev[(abcd)_2] = (dcba)_2 = (rev[(0abc)_2]>>1)\ or\ ((d\&1)<<3)\)

改写为递推式即是:

dp[x] = (dp[x>>1] >> 1) | ((x&1) << log2(n))


参考代码

#include <complex>
#include <iostream>
#include <algorithm>
#include <vector>

struct Complex {
    double real, imag;
    Complex() : real(0), imag(0) {}
    Complex(double re, double im) : real(re), imag(im) {};
    inline Complex operator + (const Complex & b) { return Complex(real + b.real, imag + b.imag); }
    inline Complex operator - (const Complex & b) { return Complex(real - b.real, imag - b.imag); }
    inline Complex operator * (const Complex & b) { return Complex(real * b.real - imag * b.imag, real * b.imag + imag * b.real); }
};

typedef std::vector<Complex> Vector;

const double PI = acos(-1);
int O(1), logO(0);

void FFT(std::vector<int> &rev, Vector &v, int inv) {
    for (int i(0); i < O; ++i) {
        if (i < rev[i]) std::swap(v[i], v[rev[i]]);
    }

    // 第 log(k) 次合并,一共logO次 
    // 合并之后区间的长度为 k
    for (int k(1); k < O; k <<= 1) {
        Complex omega(cos(PI / k), inv * sin(PI / k));
        for (int i(0); i < O; i += (k<<1)) { // 处理行 
            Complex w(1, 0);
            for (int j = 0; j < k; ++j, w = w * omega) {
                Complex s = v[i + j], t = v[i + j + k] * w;
                v[i + j] = s + t, v[i + j + k] = s - t; 
            }
        }
    }

    if (inv == -1) for (int i(0); i < O; ++i) v[i].real /= O;
}


int main() {
    int n, m;
    std::cin >> n >> m;

    while (O <= n + m) O <<= 1, ++logO;

    Vector A(O), B(O);
    for (int i(0); i <= n; ++i) std::cin >> A[i].real; 
    for (int i(0); i <= m; ++i) std::cin >> B[i].real;

    std::vector<int> rev(O);
    for (int i(0); i < O; ++i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (logO - 1));

    FFT(rev, A, 1), FFT(rev, B, 1);

    for (int i(0); i < O; ++i) A[i] = A[i] * B[i];
    FFT(rev, A, -1);

    for (int i(0); i <= n + m; ++i) {
        std::cout << (long long)(A[i].real + 0.1) << ' '; // 向上取整 
    } std::cout << std::flush;
    return 0;
}

那么恭喜你,大概时明白了 FFT 了吧!