2023.2.9【模板】快速傅里叶变换(FFT)

好多天没写博客了qwq

题目描述

给定一个 n 次多项式 F(x),和一个 m 次多项式 G(x)。

请求出 F(x) 和 G(x) 的卷积。

朴素(正常)思路

枚举计算的每一位,交叉相乘加起来计算答案,时间复杂度O(\(n^2\))

原地爆炸

这个时候就需要用到NOIP不考的FFT了

前置芝士

有关复数详见lxw (tqqqqqqqqqqqqqqqqqql%%%%%%%%%%%%%)巨佬的博客

复数 - Ricky2007 - 博客园 (cnblogs.com)

*泰勒展开

泰勒中值定理:若\(f(x)\)\(x_0\)处有\(n\)阶导数,那么存在\(x_0\)邻域中的\(x\),有

\(f(x) = f(x_0) + f'(x_0)(x - x_0) + \frac {f''(x_0)}2 {(x - x_0)}^2 + ...+ \frac {f^{(n)}(x_0)}{n!}{(x - x_0)}^n + o({(x - x_0)}^n)\)

泰勒展开经常用于将一个难以计算的函数逼近(或等效)于一个多项式,为了消掉\((x - x_0)\)

,我们经常将函数在0处展开,即\(x_0 = 0\)

*欧拉公式

\(e^{ix} = isinx + cosx\)

证明:将三个柿子泰勒展开可以得到:

\(e^x = 1 + x + \frac1{2!}x^2 + \frac 1{3!}x^3 + ... + \frac1{n!}x^n\)

\(sinx = x - \frac1{3!}x^3 + \frac1{5!}x^5 - \frac1{7!}x^7 + ...\)

\(cosx = 1 - \frac 12x^2 + \frac1{4!}x^4 - \frac1{6!}x^6 + ...\)

我们(神奇地)发现:\(sinx + cosx = 1 + x - \frac 12x^2 - \frac1{3!}x^3 + \frac1{4!}x^4 + \frac1{5!}x^5...\)

这难道不和\(e^x\)长得很像吗?

为了将\(e^x\)中一定项转为负号,我们引入\(i = \sqrt-1\) :

\(e^{ix} = 1 + ix - \frac12x - \frac1{3!}x + ....\)

\(sin\) \(ix + cosx = 1 + ix - \frac12x - \frac1{3!}x^3 ...\)

\(e^{ix} = sin\) \(ix + cosx\)

特殊地,代入\(x = \pi\)得:\(e^{i\pi} = sin\pi * i + cos\pi = -1\)

这不就是那个被很多人当作壁纸的公式吗,我们来思考它的实际作用

在一个笛卡尔坐标系中,一个从源点指向\((x,y)\)的向量,可以使用与\(x\)轴夹角\(\theta\)和长度\(\rho\)来表达(极坐标),即\((\theta,\rho\))

image

我们将Y坐标指定为复数域,X坐标指定为实数域,发现x + yi可以表示一个点

(\(\rho sin i\theta + \rho cos \theta\))

这样以来,再使用欧拉公式转化,就得到了:\(\rho e^{ix \theta}\)

一个单项就可以表达一个坐标,省去许多计算麻烦

*离散傅里叶级数

对于序列\(<c_0,c_1,c_2....c_{n - 1}>\) ,定义其\(k\)次离散傅里叶级数(DF)为:

\(h(\omega^{k}) = c_0 + c_1\omega^k + c_2\omega^{2k} + ... + c_{n - 1}\omega^{(n - 1)k}\)

其中\(\omega = e^{\frac {2~i \pi}n}\)

可以看作在极坐标内,每次走\(k\)单位的角度,一圈是\(2i \pi\)单位角度

image

(B站神犇up@3Blue1Brown的示意图)

*离散傅里叶变换 DFT

对于序列\(<c_0,c_1,c_2,....c_{n - 1}>\),构造一个新的序列:

\(<h(\omega^{0}),h(\omega^{1}),h(\omega^{2}),...,h(\omega^{n - 1})>\)

这个过程叫做离散傅里叶变换,反之,我们称对于一个构造好的序列,求原序列的过程叫做的逆傅里叶变换(IFT)

*定理:一个序列\(<c_0,c_1,...,c_{n - 1}>\)的傅里叶变换后的新序列\(<h(\omega^{0}),h(\omega^{1}),...,h(\omega^{n - 1})>\)\((-k)\)次DF值恰为原序列\(c_k\)\(n\)

证明:设\(h(\omega^{k}) = \Sigma_{x = 0}^{n - 1}\omega^{kx}c_x , g(\omega^{-k}) = \Sigma_{y = 0}^{n - 1}\omega^{-ky}h(\omega^{y})\)

代入\(h(\omega^{y})\)得:\(g(\omega^{-k}) = \Sigma_{y = 0}^{n - 1}\omega^{-ky}\Sigma_{x = 0}^{n - 1}\omega^{xy}c_x\)

提出求和符号:\(g(\omega^{-k}) = \Sigma_{y = 0}^{n - 1}\Sigma_{x = 0}^{n - 1}\omega^{(x - k)y}c_x\)

\(= \Sigma_{x = 0}^{n - 1}c_x\Sigma_{y = 0}^{n - 1}\omega^{(x - k)y}\)

Key1. \(x = k\)时:

\(\Sigma_{y = 0}^{n - 1}\omega^{(x - k)y} = \Sigma_{y=0}^{n - 1}1 = n\)

Key2.\(x \ne k\)时:

\(\Sigma_{y = 0}^{n - 1}[\omega^{(x - k)}]^y = [\omega^{(x - k)}]^0 + [\omega^{(x - k)}]^1 + ... + [\omega^{(x - k)}]^{n - 1} = \frac {1 - [~\omega^{(x - k)}~]^n}{1 - \omega^{(x - k)}}\)

为什么\(x^0 + x^1 + x^2 + ... + x^{n - 1} = \frac{1 - x^n}{1 - x}\)

设其为\(f(x)\),将上式乘\(x\),可以得到

\(x^1 + x^2 + ... + x^n = f(x) - 1 + x^n = xf(x)\)

\((x - 1)f(x) = x^n - 1,f(x) = \frac{x^n - 1}{x - 1}\)

\(\omega = e^{\frac{2~i \pi}n} \to [\omega^{(x - k)}]^n = (\omega^{n})^{x - k} = 1^{x - k} = 1\)

\(\Sigma_{y = 0}^{n - 1}[\omega^{x - k}]^y = \frac{1 - 1}{1 - \omega^{x - k}} = 0\),不计入答案

\(g(\omega^{-k}) = \Sigma_{x = 0}^{n - 1}c_x * [(x = k) ? n:0] = nc_k\)

\(QED\)

通过*****,我们就可以构造出一种计算多项式乘法的新方法:

1.将F(x)和G(x)进行傅里叶变换

2.将两者点乘(\(x \in [0,n - 1]\)上的值分别相乘)

3.将所得序列进行\((-k)\)次方逆傅里叶变换,结果除以\(n\)

\(O(n^2)\)

怎样加快计算过程?

考虑多项式\(h(x) = c_0 + c_1x + ... + c_{n - 1}x^{n - 1}\)

提出奇数\(odd(x) = c_1 + c_3x + c_5x^2 + ... + c_{n - 1}x^{\frac n2 - 1}\)

偶数\(even(x) = c_0 + c_2x + c_4x^2 + ... + c_{n - 2}x^{\frac n2 - 1}\)

可推得合并式\(h(x) = even(x^2) + x * odd(x^2)\)

代入:\(h(\omega^{k}) = even(\omega^{2k}) + \omega^{k} * odd(\omega^{2k})\)

可以发现:\(h(\omega^{k + \frac n2}) = even(\omega^{2k + n}) + \omega^{k + \frac n2} * odd(\omega^{2k + n})\)

\(=even(\omega^{2k} * \omega^{n}) + \omega^{k + \frac n2} * odd(\omega^{2k} * \omega^{n})\)

$\omega = e^{\frac{2~i\pi}n}\to \omega^{n} = 1,\omega^{\frac n2} = -1 $

\(h(\omega^{k + \frac n2}) = even(\omega^{2k}) - \omega^{k} * odd(\omega^{2k})\)

\(h(\omega^{k}) = even(\omega^{2k}) + \omega^{k} * odd(\omega^{2k})\)

其中通过\(\omega^{2k}\),我们将要计算的\(h(\omega^{k})\)\(h(\omega^{k + \frac n2})\)共计n项,转化成\(even\)\(odd\)共计\(n\)项,每次将序列拆成奇项和偶项两部分,共拆\(log_2n\)层,共计时间复杂度\(O(nlog_2n)\)

\(————FFT\)快速傅里叶变换

合并过程图解:例:\(n = 8\)
2023.2.9【模板】快速傅里叶变换-小白菜博客
小技巧:怎样快速将序列拆成奇项和偶项?

考虑到我们要将偶项也拆成其中的奇和偶、再拆为奇和偶...我们可以发现,每次拆项都是按照数字二进制的第K位排序,最后的顺序就是将序列按照最低位为第一关键字,次低位为第二关键字...来排序(或者说是拆分数组),一个数的排名就是它二进制拆分的倒序,于是我们在程序开始预处理每个数字的二进制翻转:

\(rev_i = [(rev_{i >> 1}) >> 1] | (i \& 1)<< d\)

(其中\(d\)为最大数的二进制总位数)

\(FFT\)开始前预处理一遍数组,若\(i < rev[i]\),则\(swap(x[i],x[rev[i]])\)即可,这样可以做到不重不漏地将每个数翻转一次,然后直接按照下标分治即可

Code

#include<bits/stdc++.h>
using namespace std;
const int N = 3e6 + 5;
const double PI = acos(-1.0);
int n,m,rev[N],tt = 1,tw = 0;
struct Complex{
    double r,c;
}a[N],b[N];
Complex operator +(Complex x,Complex y)
{
    Complex z;
    z.r = x.r + y.r;
    z.c = x.c + y.c;
    return z;
}
Complex operator -(Complex x,Complex y)
{
    Complex z;
    z.r = x.r - y.r;
    z.c = x.c - y.c;
    return z;
}
Complex operator *(Complex x,Complex y)
{
    Complex z;
    z.r = x.r * y.r - x.c * y.c;
    z.c = x.c * y.r + x.r * y.c;
    return z;
}
inline int read()
{
    int s = 0,w = 1;
    char k = getchar();
    while(k > '9' || k < '0')
    {
        if(k == '-') w = -w;
        k = getchar();
    }
    while(k <= '9' && k >= '0')
    {
        s = s * 10 + k - '0';
        k = getchar();
    }
    return s * w;
}
inline void FFT(Complex *x,int len,int type)
{
    for(int i = 0;i < tt;i++) if(i < rev[i]) swap(x[i],x[rev[i]]);
    for(int mid = 1;mid < tt;mid <<= 1)
    {
        Complex omega;omega.r = cos(PI / mid);omega.c = type * sin(PI / mid);
        for(int j = 0,R = mid << 1;j < tt;j += R)
        {
            Complex now;now.c = 0;now.r = 1;
            for(int k = j;k < j + mid;k++,now = now * omega)
            {
                Complex X = x[k],Y = x[k + mid];
                x[k] = X + now * Y;
                x[k + mid] = X - now * Y;
            }
        }
    }
}
int main()
{
    n = read();m = read();
    for(int i = 0;i <= n;i++) a[i].r = read();
    for(int i = 0;i <= m;i++) b[i].r = read();
    while(tt <= n + m) tt <<= 1,tw++;
    rev[0] = 0;
    for(int i = 1;i < tt;i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (tw - 1));
    FFT(a,n,1);FFT(b,m,1);
    for(int i = 0;i <= tt;i++) a[i] = a[i] * b[i];
    FFT(a,n,-1);
    for(int i = 0;i <= n + m;i++)
        printf("%d ",(int)(a[i].r / tt + 0.5));
    return 0;
}

买一送一环节

我们知道,FFT中暴力计算\(\omega\)的值,虽然使用了\(double\),但是精度仍然会丢失很多,所以我们就要对其进行魔改 \(———— NTT\) 快速数论变换

考虑到在\(FFT\)中,每一层的单位乘积是\(\omega^k\),考虑用原根替换这个东西

考虑到答案中的每个数不会很大,我们将答案模一个素数,就可以利用它的原根

(在\(int\)范围内,我们一般取998244353,其最小原根是3,原根逆元为332748118,并且

\(998244353 = 119 * 2^{23} + 1\),足够我们完成项数在\(1e6\)以内的运算

于是我们每一步将\(\omega_k\)换作\(g^{\frac {(p - 1)}{k}}\)即可(\(k\)表示当前枚举的部分长度),作为单位乘积每次乘上。

(这就是为什么需要\(P - 1\)有很多个\(2\)因子)

Code

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 3e6 + 5,MOD = 998244353,g = 3,gi = 332748118;
inline ll ksm(ll base,ll pts)
{
    ll ret = 1;
    for(;pts > 0;pts >>= 1,base = base * base % MOD)
        if(pts & 1)
            ret = ret * base % MOD;
    return ret;
}
int rev[N];
ll a[N],b[N],n,m,tt = 1,tw = 0;
inline void NTT(ll *x,ll type)
{
    for(int i = 0;i < tt;i++) if(i < rev[i]) swap(x[i],x[rev[i]]);
    for(int mid = 1,t = 0;mid < tt;mid <<= 1,t++)
    {
        ll dom = ksm((type == 1) ? g : gi,(MOD - 1) >> (t + 1));
        for(int j = 0,R = mid << 1;j < tt;j += R)
        {
            ll w = 1;
            for(int k = j;k < j + mid;k++,w = w * dom % MOD)
            {
                ll X = x[k],Y = x[k + mid] * w % MOD;
                x[k] = (X + Y) % MOD;
                x[k + mid] = (X - Y + MOD) % MOD;
            }
        }
    }
}
int main()
{
    cin>>n>>m;
    for(int i = 0;i <= n;i++) cin>>a[i];
    for(int i = 0;i <= m;i++) cin>>b[i];
    while(tt <= n + m) tt <<= 1,tw++;
    for(int i = 0;i < tt;i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (tw - 1));
    NTT(a,1);NTT(b,1);
    for(int i = 0;i <= tt;i++) a[i] = a[i] * b[i] % MOD;
    NTT(a,-1);
    ll inv = ksm(tt,MOD - 2);
    for(int i = 0;i <= n + m;i++) cout<<a[i] * inv % MOD<<" ";
    return 0;
}