FFT

spring·2024년 11월 19일
0

Convolution

Ci=j=0iaibijC_{i} = \sum_{j=0}^{i}a_{i}\cdot b_{i-j}

  • conv가 다항식의 곱셈과 같다.

예를 들어 a=[1,2], b=[3,4] 라고 가정한다면, c는 [3,10,8] 이다.

이는 다항식의 곱셈으로 변경할 수 있는데,

(1x1+2)(3x1+4)=3x2+10x+8(1\cdot x^{1}+2)(3\cdot x^{1}+4) = 3x^{2}+10x+8 과 같이 계수가 위의 결과와 동일하다.

vector<int> a = { 1,2 };
vector<int> b = { 3,4 };
const int N = a.size();
vector<int> c(2 * N - 1);
for (int i = 0; i < N; i++) {
    for (int j = 0; j < N; j++) {
        c[i + j] += a[i] * b[j];
    }
}

라그랑주 다항식

https://ko.wikipedia.org/wiki/%EB%9D%BC%EA%B7%B8%EB%9E%91%EC%A3%BC_%EB%8B%A4%ED%95%AD%EC%8B%9D

#include<iostream>
#include<vector>
#define _USE_MATH_DEFINES
#include <math.h>
#include <complex>
#include <vector>
#include <algorithm>
using namespace std;

#define sz(v) ((int)(v).size())
#define all(v) (v).begin(), (v).end()
typedef complex<double> base;

void fft(vector<base>& a, bool invert)
{
    int n = sz(a);
    for (int i = 1, j = 0; i < n; i++) {
        int bit = n >> 1;
        for (; j >= bit; bit >>= 1) j -= bit;
        j += bit;
        if (i < j) swap(a[i], a[j]);
    }
    for (int len = 2; len <= n; len <<= 1) {
        double ang = 2 * M_PI / len * (invert ? -1 : 1);
        base wlen(cos(ang), sin(ang));
        for (int i = 0; i < n; i += len) {
            base w(1);
            for (int j = 0; j < len / 2; j++) {
                base u = a[i + j], v = a[i + j + len / 2] * w;
                a[i + j] = u + v;
                a[i + j + len / 2] = u - v;
                w *= wlen;
            }
        }
    }
    if (invert) {
        for (int i = 0; i < n; i++) a[i] /= n;
    }
}

template <typename T>
vector<T> multiply(const vector<T>& a, const vector<T>& b)
{
    vector<base> fa(all(a)), fb(all(b));
    int n = 1;
    int m = sz(a) + sz(b) - 1;
    while (n < m)
        n <<= 1;
    fa.resize(n);
    fb.resize(n);
    fft(fa, false);
    fft(fb, false);
    for (int i = 0; i < n; i++)
        fa[i] *= fb[i];
    fft(fa, true);
    vector<T> ret(m);
    for (int i = 0; i < m; i++)
        ret[i] = static_cast<T>(fa[i].real() + (fa[i].real() > 0 ? 0.5 : -0.5));
    return ret;
}
using namespace std;
int main() {
    vector<int> a = { 1,2,3,4 };
    vector<int> b = { 5,6,7,8 };
    const int N = 4;
    vector<int> c(2 * N - 1);
    for (int i = 0; i < 4; i++) {
        for (int j = 0; j < 4; j++) {
            c[i + j] += a[i] * b[j];
        }
    }
    for (auto& e : c) {
        cout << e << ", ";
    }
    cout << endl;

    vector<int> d = multiply(a, b);
    for (auto& e : d) {
        cout << e << ", ";
    }
    //5, 16, 34, 60,
    //5, 16, 34, 60, 61, 52, 32,
    return 0;
}
profile
Researcher & Developer @ NAVER Corp | Designer @ HONGIK Univ.

0개의 댓글