FFT
Apr 12, 23FFT
$c_i=\sum_{j=0}^i {a_jb_{i−j}}$
위와 같은 convolution을 $O(n log n)$에 계산하기 위해서 쓴다고 함.
https://blog.myungwoo.kr/54에서 베껴옴
#include <bits/stdc++.h>
using namespace std;
#define sz(v) ((int)(v).size())
#define all(v) (v).begin(), (v).end()
using complex_t = complex<double>;
void fft(vector<complex_t>& a) {
int n = sz(a), L = 31 - __builtin_clz(n);
static vector<complex<long double>> R(2, 1);
static vector<complex_t> rt(2, 1); // (^ 10% faster if double)
for (static int k = 2; k < n; k *= 2) {
R.resize(n); rt.resize(n);
auto x = polar(1.0L, acos(-1.0L) / k);
for (int i=k;i<k+k;i++) rt[i] = R[i] = i&1 ? R[i/2] * x : R[i/2];
}
vector<int> rev(n);
for (int i=0;i<n;i++) rev[i] = (rev[i / 2] | (i & 1) << L) / 2;
for (int i=0;i<n;i++) if (i < rev[i]) swap(a[i], a[rev[i]]);
for (int k = 1; k < n; k *= 2){
for (int i = 0; i < n; i += 2 * k) for (int j=0;j<k;j++) {
// complex_t z = rt[j+k] * a[i+j+k]; // (25% faster if hand-rolled) /// include-line
auto x = (double *)&rt[j+k], y = (double *)&a[i+j+k]; /// exclude-line
complex_t z(x[0]*y[0] - x[1]*y[1], x[0]*y[1] + x[1]*y[0]); /// exclude-line
a[i + j + k] = a[i + j] - z;
a[i + j] += z;
}
}
}
template <typename T>
vector<T> conv(const vector<T>& a, const vector<T>& b) {
if (a.empty() || b.empty()) return {};
vector<T> res(sz(a) + sz(b) - 1);
int L = 32 - __builtin_clz(sz(res)), n = 1 << L;
vector<complex_t> in(n), out(n);
copy(all(a), begin(in));
for (int i=0;i<sz(b);i++) in[i].imag(b[i]);
fft(in);
for (complex_t& x: in) x *= x;
for (int i=0;i<n;i++) out[i] = in[-i & (n - 1)] - conj(in[i]);
fft(out);
for (int i=0;i<sz(res);i++){
res[i] = static_cast<T>(imag(out[i]) / (4 * n) + (is_integral_v<T> ? (imag(out[i]) > 0 ? 0.5 : -0.5) : 0));
}
return res;
}
koosaga님 팀노트에 있는거라고함
using namespace std;
typedef complex<double> base;
typedef long long ll;
const double PI = acos(-1);
void fft(vector<base>& a, bool inv = false) {
int n = a.size(), j = 0;
// 분할 정복 기법, 이때 n은 무조건 2^n 승임
vector<base> roots(n / 2);
for (int i = 1; i < n; i++) {
int bit = (n >> 1);
while (j >= bit) {
j -= bit;
bit >>= 1;
}
j += bit;
if (i < j) swap(a[i], a[j]);
}
double ang = 2 * acos(-1) / n * (inv ? -1 : 1);
for (int i = 0; i < n / 2; i++) {
roots[i] = base(cos(ang * i), sin(ang * i));
}
for (int i = 2; i <= n; i <<= 1) {
int step = n / i;
for (int j = 0; j < n; j += i) {
for (int k = 0; k < i / 2; k++) {
base u = a[j + k], v = a[j + k + i / 2] * roots[step * k];
a[j + k] = u + v;
a[j + k + i / 2] = u - v;
}
}
}
if (inv) for (int i = 0; i < n; i++) a[i] /= n;
}
vector<ll> multiply(vector<ll>& v, vector<ll>& w) {
vector<base> fv(v.begin(), v.end()), fw(w.begin(), w.end());
// n이 무조건 2^n 이여야 하기 때문에 변환!
int n = 2; while (n < v.size() + w.size()) n <<= 1;
fv.resize(n); fw.resize(n);
fft(fv, 0); fft(fw, 0);
for (int i = 0; i < n; i++) fv[i] *= fw[i];
fft(fv, 1);
vector<ll> ret(n);
for (int i = 0; i < n; i++) ret[i] = (ll)round(fv[i].real());
return ret;
}