FFT

Apr 12, 23

FFT

$c_i=\sum_{j=0}^i {a_jb_{i−j}}$

위와 같은 convolution을 $O(n log n)$에 계산하기 위해서 쓴다고 함.

https://blog.myungwoo.kr/54에서 베껴옴

#include <bits/stdc++.h>
using namespace std;
 
#define sz(v) ((int)(v).size())
#define all(v) (v).begin(), (v).end()
 
using complex_t = complex<double>;
 
void fft(vector<complex_t>& a) {
    int n = sz(a), L = 31 - __builtin_clz(n);
    static vector<complex<long double>> R(2, 1);
    static vector<complex_t> rt(2, 1);  // (^ 10% faster if double)
    for (static int k = 2; k < n; k *= 2) {
        R.resize(n); rt.resize(n);
        auto x = polar(1.0L, acos(-1.0L) / k);
        for (int i=k;i<k+k;i++) rt[i] = R[i] = i&1 ? R[i/2] * x : R[i/2];
    }
    vector<int> rev(n);
    for (int i=0;i<n;i++) rev[i] = (rev[i / 2] | (i & 1) << L) / 2;
    for (int i=0;i<n;i++) if (i < rev[i]) swap(a[i], a[rev[i]]);
    for (int k = 1; k < n; k *= 2){
        for (int i = 0; i < n; i += 2 * k) for (int j=0;j<k;j++) {
            // complex_t z = rt[j+k] * a[i+j+k]; // (25% faster if hand-rolled)  /// include-line
            auto x = (double *)&rt[j+k], y = (double *)&a[i+j+k];                /// exclude-line
            complex_t z(x[0]*y[0] - x[1]*y[1], x[0]*y[1] + x[1]*y[0]);           /// exclude-line
            a[i + j + k] = a[i + j] - z;
            a[i + j] += z;
        }
    }
}
 
template <typename T>
vector<T> conv(const vector<T>& a, const vector<T>& b) {
    if (a.empty() || b.empty()) return {};
    vector<T> res(sz(a) + sz(b) - 1);
    int L = 32 - __builtin_clz(sz(res)), n = 1 << L;
    vector<complex_t> in(n), out(n);
    copy(all(a), begin(in));
    for (int i=0;i<sz(b);i++) in[i].imag(b[i]);
    fft(in);
    for (complex_t& x: in) x *= x;
    for (int i=0;i<n;i++) out[i] = in[-i & (n - 1)] - conj(in[i]);
    fft(out);
    for (int i=0;i<sz(res);i++){
        res[i] = static_cast<T>(imag(out[i]) / (4 * n) + (is_integral_v<T> ? (imag(out[i]) > 0 ? 0.5 : -0.5) : 0));
    }
    return res;
}

koosaga님 팀노트에 있는거라고함


using namespace std;
typedef complex<double> base;
typedef long long ll;
const double PI = acos(-1);
 
void fft(vector<base>& a, bool inv = false) {
    int n = a.size(), j = 0;
    // 분할 정복 기법, 이때 n은 무조건 2^n 승임
    vector<base> roots(n / 2);
    for (int i = 1; i < n; i++) {
        int bit = (n >> 1);
        while (j >= bit) {
            j -= bit;
            bit >>= 1;
        }
        j += bit;
        if (i < j) swap(a[i], a[j]);
    }
    double ang = 2 * acos(-1) / n * (inv ? -1 : 1);
    for (int i = 0; i < n / 2; i++) {
        roots[i] = base(cos(ang * i), sin(ang * i));
    }
    for (int i = 2; i <= n; i <<= 1) {
        int step = n / i;
        for (int j = 0; j < n; j += i) {
            for (int k = 0; k < i / 2; k++) {
                base u = a[j + k], v = a[j + k + i / 2] * roots[step * k];
                a[j + k] = u + v;
                a[j + k + i / 2] = u - v;
            }
        }
    }
    if (inv) for (int i = 0; i < n; i++) a[i] /= n;
}
 
vector<ll> multiply(vector<ll>& v, vector<ll>& w) {
    vector<base> fv(v.begin(), v.end()), fw(w.begin(), w.end());
    // n이 무조건 2^n 이여야 하기 때문에 변환!
    int n = 2; while (n < v.size() + w.size()) n <<= 1;
    fv.resize(n); fw.resize(n);
    fft(fv, 0); fft(fw, 0);
    for (int i = 0; i < n; i++) fv[i] *= fw[i];
    fft(fv, 1);
    vector<ll> ret(n);
    for (int i = 0; i < n; i++) ret[i] = (ll)round(fv[i].real());
    return ret;
}