1 条题解

  • 0
    @ 2023-1-3 22:23:12

    真是道屑题,NTT 还过不了,必须要 FFT。

    这题是一个标准的多项式乘法的板子。题目翻译是我给的。

    前置知识:高中数学

    既然是一篇题解,我就从零开始讲多项式乘法吧。

    首先你要知道在 OI 中,多项式指整式。

    整式

    我们在初中的时候学过整式乘法。我们重复使用乘法分配律,将两个整式乘起来。这个过程我把它叫做卷积

    例如 A(x)=x2+2x1,B(x)=x22x1A(x)=x^2+2x-1,B(x)=x^2-2x-1,我们要求 C(x)=A(x)B(x)C(x)=A(x)\cdot B(x)

    我们可以得到

    $$\begin{aligned} C(x)&=A(x)\cdot B(x)\\ &=(x^2+2x-1)(x^2-2x-1)\\ &=x^2(x^2-2x-1)+2x(x^2-2x-1)-(x^2-2x-1)\\ &=x^4-2x^3-x^2+2x^3-4x^2-2x-x^2+2x+1\\ &=x^4-6x^2+1 \end{aligned} $$

    我们有没有什么方法能加速乘法呢?

    点值表示法

    从最简单的讲起。

    平面几何五大公设(出自《几何原本》):

    1. 两点可以确定唯一一条直线。
    2. 一条线段可以向两端无限延长。
    3. 给定确定的点和长度,一定可以做出唯一的一个圆,以这个点为圆心,以这个长度为半径。
    4. 所有直角都相等。
    5. 同一平面内,两条直线被第三条直线所截,如果一侧的两个内角(同旁内角)之和小于两个直角(即 180180^\circ),那么这两条直线一定在这侧相交。

    其中第五公设最为出名,大家给了它一个名字:平行公设。

    我们要的就是第一公设。

    通过解析几何的学习,我们知道:给定两个点 (x0,y0),(x1,y1)(x_0,y_0),(x_1,y_1)(假设 x0x1x_0\neq x_1),我们一定能确定唯一一条直线

    $$y=\frac{y_1-y_0}{x_1-x_0}x+\frac{x_1y_0-x_0y_1}{x_1-x_0} $$

    其实这个公设可以推广到任意次整式:给定 d+1d+1 个点 (x0,y0),(x1,y1),,(xd,yd)(x_0,y_0),(x_1,y_1),\cdots,(x_d,y_d)ij,xixj\forall i\neq j,x_i\neq x_j),有唯一一个 dd 次函数经过这些点。

    这里给出一种证明:

    我们设这个函数为 f(x)f(x)。以下记 f(x)[i]f(x)\lbrack i\rbrackf(x)f(x) 的第 ii 项系数,即 $f(x)=\sum\limits_{i=0}^{d}(f(x)\lbrack i\rbrack\times x^i)$。

    根据矩阵乘法的意义,我们可以写出这样的式子:

    $$\left\lbrack\begin{matrix} 1&x_0&x_0^2&\cdots&x_0^d\\ 1&x_1&x_1^2&\cdots&x_1^d\\ \vdots&\vdots&\vdots&\ddots&\vdots\\ 1&x_d&x_d^2&\cdots&x_d^d \end{matrix}\right\rbrack \times \left\lbrack\begin{matrix} f(x)\lbrack0\rbrack\\ f(x)\lbrack1\rbrack\\ \vdots\\ f(x)\lbrack d\rbrack \end{matrix}\right\rbrack = \left\lbrack\begin{matrix} f(x_0)\\ f(x_1)\\ \vdots\\ f(x_d) \end{matrix}\right\rbrack $$

    我们可以证明,左边那个矩阵是可逆的。所以 f(x)f(x) 是唯一的。

    整式乘法

    我们发现,点值的乘法只需要将纵坐标相乘就可以了。

    假设 A(x)=y0,B(x)=y1A(x)=y_0,B(x)=y_1,使用定义易得 C(x)=y0y1C(x)=y_0y_1

    于是我们可以设计一套快速的整式乘法的算法:

    1. 将两个整式变成点值表示法
    2. 将这些点值乘起来
    3. 将结果变成系数表示法,即得结果

    然而我们还不会做第一步和第三步。

    傅里叶变换

    这是一个解决第一个和第三个问题的算法。

    我们回想一下奇函数和偶函数的性质。

    例如 y=kx2(kR{0})y=kx^2(k\in\mathbb{R}\setminus\{0\})。如果我们知道了 (x,y)(x,y) 在这个函数的图像上,那么 (x,y)(-x,y) 肯定也在这个函数的图像上。

    同样的,对于函数 y=kx3(kR{0})y=kx^3(k\in\mathbb{R}\setminus\{0\})。如果我们知道了 (x,y)(x,y) 在这个函数的图像上,那么 (x,y)(-x,-y) 肯定也在这个函数的图像上。

    于是我们考虑将一个函数分成若干个奇函数之和和若干个偶函数之和。

    例如 f(x)=x3+x22x+3f(x)=x^3+x^2-2x+3。按照上面的思路,我们可以得到

    $$\begin{aligned} f(x)&=x^3+x^2-2x+3\\ &=(x^3-2x)+(x^2+3)\\ &=x(x^2-2)+(x^2+3)\\ \end{aligned} $$

    我们记 fo(x)=x2,fe(x)=x+3f_{o}(x)=x-2,f_{e}(x)=x+3。容易得到 f(x)=fe(x2)+xfo(x2)f(x)=f_e(x^2)+xf_o(x^2)

    事实上,这个式子对于任意的整式 f(x)f(x) 都成立。


    我们将 x-x 带入上式,得到 f(x)=fe(x2)xfo(x2)f(-x)=f_e(x^2)-xf_o(x^2)

    看来很想一个分治的过程呢。我们看看能不能分治。

    但是这里有一个问题:假设我们要求的点值的横坐标是 {±x0,±x1,,±xd2}\{\pm x_0,\pm x_1,\cdots,\pm x_{\frac{d}{2}}\},这些值是可以正负配对的。

    但是我们丢给 fe(x)f_e(x) 的点值的横坐标就是 {x02,x12,,xd22}\{x_0^2,x_1^2,\cdots,x_{\frac{d}{2}}^2\},这些值不可以正负配对。

    解决这个问题唯一的方法就是将数据扩充到复数域。复数是课内知识,这里不详述。

    仍然是 f(x)=x3+x22x+3f(x)=x^3+x^2-2x+3。由于是一个 33 次整式,我们需要至少 44 个点值。那这些点值的横坐标应该是什么呢?

    由于到最后这些点值需要正负配对,我们不妨设这些点值的横坐标为 x1,x1,x2,x2x_1,-x_1,x_2,-x_2

    这里有个好处就是我们可以任意选择点值。那么我们设 x1=1x_1=1。于是这棵“分治树”就应该长这样:

    $$\begin{matrix} 1&&&&-1&&&&x_2&&&&-x_2\\ &&1&&&&&&&&x_2^2\\ &&&&&&1 \end{matrix} $$

    在第二层,由于 11x22x_2^2 这两个横坐标所代表的点值是正负配对的。于是 x22=1x_2^2=-1

    从这里,我们得到 x2=±ix_2=\pm i。我们假设 x2=ix_2=i,那么“分治树”是这样的:

    $$\begin{matrix} 1&&&&-1&&&&i&&&&-i\\ &&1&&&&&&&&-1\\ &&&&&&1 \end{matrix} $$

    于是需要的点值的横坐标是 1,1,i.i1,-1,i.-i

    我们对于 1616 次整式,“分治树”是这样的:为什么这张图让我想起了树状数组

    $$\begin{matrix} &\omega_{16}^{0}&&\omega_{16}^{8}&&\omega_{16}^{4}&&\omega_{16}^{12}&&\omega_{16}^{2}&&\omega_{16}^{10}&&\omega_{16}^{6}&&\omega_{16}^{14}&&\omega_{16}^{1}&&\omega_{16}^{9}&&\omega_{16}^{5}&&\omega_{16}^{13}&&\omega_{16}^{3}&&\omega_{16}^{11}&&\omega_{16}^{7}&&\omega_{16}^{15}\\ &&\omega_8^0&&&&\omega_8^4&&&&\omega_8^2&&&&\omega_8^6&&&&\omega_8^1&&&&\omega_8^5&&&&\omega_8^3&&&&\omega_8^7\\ &&&&\omega_4^0&&&&&&&&\omega_4^2&&&&&&&&\omega_4^1&&&&&&&&\omega_4^3\\ &&&&&&&&\omega_2^0&&&&&&&&&&&&&&&&\omega_2^1\\ &&&&&&&&&&&&&&&&\omega_1^0 \end{matrix} $$

    对于 2k(kZ+)2^k(k\in\mathbb{Z}^+) 也有对应的“分治树”。

    观察发现 ωki=ω2ki×ω2ki+k\omega_k^i=-\omega_{2k}^i\times\omega_{2k}^{i+k}

    证明:我们可以得到 ω2ki+k=ω2ki\omega_{2k}^{i+k}=-\omega_{2k}^i。于是

    $$\begin{aligned} -\omega_{2k}^i\times\omega_{2k}^{i+k}&=-\omega_{2k}^i\times(-\omega_{2k}^i)\\ &=\omega_{2k}^{2i}\\ &=\omega_{k}^i \end{aligned} $$

    我们将这些值作为傅里叶变换产生的点值的横坐标,就得到了快速傅里叶变换算法。

    下面是 Python 代码:

    def fft(p, n):
        if n == 1:
            return p
        omega = complex(str(math.cos(2 * pi / n)) + '+' + str(math.sin(2 * pi / n)) + 'j') # 由于 Python 默认以 j 作为虚数单位,所以这里使用 j
        p_e = p[::2]
        p_o = p[1::2]
        y_e = fft(p_e, n // 2)
        y_o = fft(p_o, n // 2)
        y = [complex('0')] * n
        k = complex('1')
        for j in range(n // 2):
            y[j] = y_e[j] + k * y_o[j]
            y[j + n // 2] = y_e[j] - k * y_o[j]
            k *= omega
        return y
    

    我们发现递归树上 ω16\omega_{16} 的次数有点意思:二进制反转之后刚好是 {0,1,,15}\{0,1,\cdots,15\}

    于是我们可以 Θ(n)\Theta(n) 将序列二进制下标反转,然后再做迭代版傅里叶变换就可以加速了。

    总时间复杂度:T(n)=2T(n2)+Θ(n)=Θ(nlogn)T(n)=2T(\frac{n}{2})+\Theta(n)=\Theta(n\log n)

    傅里叶逆变换

    通过傅里叶变换,我们将整式从系数表示法转为了点值表示法。那问题来了:我们怎么把点值表示法转回系数表示法?

    这个时候就要请出傅里叶逆变换。

    假设 $\forall i\in\lbrack0,d\rbrack,y_i=f(\omega_{d+1}^i)$。

    设另外有一个向量 (c0,c1,,cd)(c_0,c_1,\cdots,c_{d}) 满足 $\forall i\in\lbrack0,d\rbrack,c_k=\sum\limits_{j=0}^dy_j\omega_{d+1}^{-jk}$,即整式 g(x)=i=0dyig(x)=\sum\limits_{i=0}^dy_idd 个点值 $(g(\omega_{d+1}^0),g(\omega_{d+1}^{-1}),\cdots,g(\omega_{d+1}^{-d}))$。

    那么 (c0,c1,,cd)(c_0,c_1,\cdots,c_d) 满足 k[0,d],\forall k\in\lbrack0,d\rbrack,

    $$\begin{aligned} c_k&=\sum_{i=0}^dy_i\omega_{d+1}^{-ik}\\ &=\sum_{i=0}^{d}(\sum_{j=0}^da_j\omega_{d+1}^{ij})\omega_{d+1}^{-ik}\\ &=\sum_{j=0}^d\sum_{i=0}^da_j\omega_{d+1}^{ij-ik}\\ &=\sum_{j=0}^da_j\sum_{i=0}^d\omega_{d+1}^{(j-k)i}\\ &=\sum_{j=0}^da_j\sum_{i=0}^d(\omega_{d+1}^{j-k})^i\\ \end{aligned} $$

    S(x)=i=0dxiS(x)=\sum\limits_{i=0}^dx^i,那么 $S(\omega_{d+1}^k)=\sum\limits_{i=0}^d\omega_{d+1}^{ik}$

    使用单位根的性质(或错位相减法),得到当 kZ+k\in\mathbb{Z}^+ 的时候 S(ωd+1k)=0S(\omega_{d+1}^k)=0

    咱们继续考虑 ckc_k

    $$c_k=\sum_{j=0}^da_j\sum_{i=0}^d(\omega_{d+1}^{j-k})^i $$

    这东西在 jkj\neq k 时值为 00,当 j=kj=k 时值为 d+1d+1。因此 $\mathcal{F}^{-1}(f(x))\lbrack i\rbrack=\frac{1}{d+1}f(\omega_{d+1}^{i})$。

    使用类似傅里叶变换的分治思路即可达到 Θ(nlogn)\Theta(n\log n)

    上代码:

    //#define __MY_USE_LONG_DOUBLE__ // 这里不需要 long double
    #include <bits/stdc++.h>
    using namespace std;
    #ifdef __MY_USE_LONG_DOUBLE__
    #define double long double
    #define llround llroundl
    const double pi = acosl(-1.0L);
    const double two = 2.0L;
    #else
    const double pi = acos(-1.0);
    const double two = 2.0;
    #endif
    template <typename T>
    struct cp_base {
        T m_real, m_imag;
        inline constexpr cp_base(const T& r, const T& i) : m_real(r), m_imag(i) {}
        inline constexpr cp_base(const T& r = T()) : m_real(r), m_imag(0) {}
        inline constexpr cp_base operator+(const cp_base &o) const { return cp_base(m_real + o.m_real, m_imag + o.m_imag); }
        inline constexpr cp_base operator-(const cp_base &o) const { return cp_base(m_real - o.m_real, m_imag - o.m_imag); }
        inline constexpr cp_base operator*(const cp_base &o) const { return cp_base(m_real * o.m_real - m_imag * o.m_imag, m_real * o.m_imag + m_imag * o.m_real); }
        inline constexpr cp_base& operator*=(const  cp_base &o) { return *this = *this * o; }
        inline constexpr cp_base conj() const { return cp_base(m_real, -m_imag); }
        inline constexpr T real() const { return m_real; }
        inline constexpr T imag() const { return m_imag; }
        inline constexpr T& real() { return m_real; }
        inline constexpr T& imag() { return m_imag; }
    };
    typedef cp_base<double> cp;
    vector<size_t> rev;
    vector<cp> omegas;
    inline void get_rev(size_t len, int x) {
        if (len == rev.size()) return;
        rev.resize(len);
        for (size_t i = 0; i < len; i++) rev[i] = rev[i >> 1ull] >> 1ull | (i & 1ull) << x;
        omegas.resize(len);
        for (size_t i = 0; i < len; i++) omegas[i] = cp(std::cos(two * pi / len * i), std::sin(two * pi / len * i));
    }
    inline void FFT(cp* a, size_t n) {
        for (size_t i = 1ull; i < n; ++i) if (i < rev[i]) std::swap(a[i], a[rev[i]]);
        for (size_t i = 2ull; i <= n; i <<= 1ull) for (size_t j = 0ull, l = (i >> 1ull), ch = n / i; j < n; j += i) for (size_t k = j, now = 0ull; k < j + l; k++) {
            cp x = a[k], y = a[k + l] * omegas[now];
            a[k] = x + y;
            a[k + l] = x - y;
            now += ch;
        }
    }
    inline void IFFT(cp* a, size_t n) {
        for (size_t i = 1ull; i < n; ++i) if (i < rev[i]) std::swap(a[i], a[rev[i]]);
        for (size_t i = 2ull; i <= n; i <<= 1ull) for (size_t j = 0ull, l = (i >> 1ull), ch = n / i; j < n; j += i) for (size_t k = j, now = 0ull; k < j + l; k++) {
            cp x = a[k], y = a[k + l] * omegas[now].conj();
            a[k] = x + y;
            a[k + l] = x - y;
            now += ch;
        }
        for (size_t i = 0; i < n; i++) {
            a[i].real() /= n;
            a[i].imag() /= n;
        }
    }
    void mul(long long* a, long long* b, long long* c, size_t len) {
        cp *P = new cp[len], *Q = new cp[len];
        for (int i = 0; i < len; i++) P[i] = cp(a[i], b[i]);
        FFT(P, len);
        Q[0] = P[0].conj();
        for (int i = 1; i < len; i++) Q[i] = P[len - i].conj();
        cp *A = new cp[len];
        for (int i = 0; i < len; i++) {
            A[i] = (P[i] + Q[i]) * cp(0, 1) * (Q[i] - P[i]);
            A[i].real() /= 4;
            A[i].imag() /= 4;
        }
        IFFT(A, len);
        for (int i = 0; i < len; i++) c[i] = (long long)(A[i].real() + 0.5);
        delete[] P;
        delete[] Q;
        delete[] A;
    }
    void solve() {
        int n, m;
        scanf("%d", &n);
        n++;
        m = n;
        size_t len = 1;
        int x = -1;
        while (len < n + m) len <<= 1, x++;
        long long *a = new long long[len], *b = new long long[len], *c = new long long[len];
        memset(a, 0, len * sizeof(long long));
        memset(b, 0, len * sizeof(long long));
        for (int i = 0; i < n; i++) scanf("%lld", &a[i]);
        for (int i = 0; i < m; i++) scanf("%lld", &b[i]);
        get_rev(len, x);
        mul(a, b, c, len);
        for (int i = 0; i < n + m - 1; i++) printf("%lld%c", c[i], " \n"[i == n + m - 2]);
        delete[] a;
        delete[] b;
        delete[] c;
    }
    int main() {
        int T;
        scanf("%d", &T);
        while (T--) solve();
    }
    

    这份代码里还涉及到了三次变两次优化和预处理单位根这两种科技。可以参考 这篇文章

    练习:

    1. 【模板】多项式乘法(FFT)
    2. 【模板】A*B Problem 升级版(FFT 快速傅里叶变换)
    3. Fast Multiplication
    4. 多项式乘法

    参考:

    1. 题解 P3803 【【模板】多项式乘法(FFT)】
    2. 快速傅里叶变换(FFT)——有史以来最巧妙的算法?
    • 1

    信息

    ID
    2899
    时间
    1000ms
    内存
    1536MiB
    难度
    10
    标签
    递交数
    4
    已通过
    2
    上传者