Быстрое преобразование Фурье

April 14, 2017

Мотивировка

Основная задача, которая решается при помощи быстрого преобразования Фурье (Fast Fourier Transform, FFT) — это умножение многочленов за время $O(n \log n)$.

Тривиально (по определению) многочлены степеней $m$ и $n$ умножаются за время $O(nm)$. Долгое время считалось, что быстрее это сделать невозможно, А.Н. Колмогоров даже принимал попытки доказать нижнюю границу, пока в 1960 году А. Карацуба не придумал способ умножать многочлены степени $n$ за время $O(n^{\log_{2}3})$.

Как известно, многочлен степени строго меньше $n$ однозначно определяется своими значениями в $n$ (вообще говоря комплексных) точках. Действительно, если есть два различных многочлена с одинаковыми значениями в $n$ точках, то их разность имеет $n$ комплексных корней, причём она является ненулевым многочленом степени строго меньше $n$, что противоречит основной теореме алгебры. С другой стороны, интерполяционный многочлен Лагранжа в явном виде предъявляет многочлен степени строго меньше $n$ по значениям в $n$ точках.

Таким образом, многочлены можно хранить не в виде вектора коэффициентов, а в виде набора значений в некоторых точках. Над многочленами в таком виде очень удобно производить арифметические операции, в том числе умножать их за время $O(n)$ (нужно просто перемножить значения в соответствующих точках). С другой стороны, непонятно, как считать значения в других точках, да и знать сами коэффициенты бывает полезно. А воостанавливать коэффициенты по значениям в некотором наборе точек сложно, тот же интерполяционный многочлен Лагранжа вычисляется за время $O(n^2)$. Да и многочлены нам обычно задаются в форме вектора коэффициентов, получить значеня в $n$ произвольных точках вряд ли можно быстрее, чем за $O(n^2)$.

Хитрость FFT в том, что точки, в которых считаются значения многочлена, выбираются отнюдь не произвольным образом.

Описание

Итак, FFT преобразует вектор $\langle a_{0}, a_{1}, \ldots, a_{n-1} \rangle$ в вектор $\langle b_{0}, b_{1}, \ldots, b_{n-1} \rangle$, где $b_{j} = \sum_{k=0}^{n-1} a_{k} e^{2 \pi i \frac{j}{n}}$, $n=2^{m}$, иначе говоря, преобразует вектор коэффициентов многочлена степени $n-1$ в набор его значений в точках $\omega_{j} = e^{2 \pi i \frac{j}{n}}$.

Тут сразу возникает два вопроса:

  1. Что делать с многочленами других степеней?
  2. Почему именно эти точки?

С первым вопросом всё просто: нужно дополнять коэффициенты нулями до ближайшей степени двойки.

Со вторым же вопросом дело чуть хитрее. $\omega_{k}$ - это комплексные корни из $1$ $n$-й степени, то есть $\omega_{i}^n = 1$. У них есть замечательные свойства: $\omega_{k} = \omega_{1}^{k}$, $\omega_{j} \omega_{k} = \omega_{j+k}$.

Разделяй и властвуй

Обозначим $A(x) = \sum_{k=0}^{n-1} a_{k} x^{k}$, $A_{0}(x) = \sum_{k=0}^{\frac{n}{2} - 1} a_{2k} x^{k}$, $A_{1}(x) = \sum_{k=0}^{\frac{n}{2} - 1} a_{2k + 1} x^{k}$. Легко проверить, что $A(x) = A_{0}(x^{2}) + x A_{1}(x^{2})$.

$A_{0}$ и $A_{1}$ - многочлены степени $\frac{n}{2} = 2^{m-1}$, к ним можно применить FFT и получить набор значений в точках $e^{2 \pi i \frac{k}{n / 2}} = e^{2 \pi i \frac{2k}{n}} = \omega_{2k}$.

$A(\omega_{k}) = A_{0}(\omega_{k}^{2}) + \omega_{k} A_{1}(\omega_{k}^{2}) = A_{0}(\omega_{2k}) + \omega_{k} A_{1}(\omega_{2k})$. Таким образом, зная FFT для $A_{0}$ и $A_{1}$, можно вычислить FFT для $A$.

Осталось определить базу рекурсии — $n=1$. Для этого нужно посчитать значение константного многочлена $A=a_{0}$ в точке $\omega_{0} = 1$. Это значение, очевидно, равно $a_{0}$.

Таким образом, алгоритм работает за время $T(n) = 2 T(\frac{n}{2}) + O(n)$. По мастер-теореме $T(n) = O(n \log n)$.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
const double PI = 4 * atan(1.);
typedef complex<double> cd;

vector<cd> FFT(vector<cd> A)
{
	int n = (int)A.size();
	if (n == 1) return A;
	vector<cd> A0, A1;
	for (int i = 0; i < n; i++)
	{
		if (i % 2 == 0)
			A0.push_back(A[i]);
		else
			A1.push_back(A[i]);
	}
	A0 = FFT(A0);
	A1 = FFT(A1);
	for (int i = 0; i < n; i++)
	{
		cd w = cd(cos(2 * PI * i / n), sin(2 * PI * i / n));
		A[i] = A0[i % (n / 2)] + w * A1[i % (n / 2)];
	}
	return A;
}

Обратное преобразование

Хорошо, мы научились по многочлену вычислять его значение в $n$ особых точках за $O(n \log n)$, потом мы можем перемножить значения и получить FFT от произведения многочленов. Но мы пока не умеем восстанавливать многочлен по его FFT.

Можно рассмотреть FFT как линейное преобразование:

Возведём матрицу преобразования в квадрат: $R_{ij} = \sum_{k=0}^{n-1} \omega_{i}^{k} \omega_{k}^{j} = \sum_{k=0}^{n-1} \omega_{k(i+j)} = \sum_{k=0}^{n-1} \omega_{i+j}^{k}$

Если $\omega_{i+j} = 1$, то $R_{ij} = n$. В противном случае можно посчитать сумму геометрической прогрессии и получить $0$.

Таким образом, обратное преобразование выглядит следующим образом:

  1. Применить прямое преобразование
  2. Разделить все элементы на $n$
  3. Развернуть массив без первого элемента
1
2
3
4
5
6
7
8
9
vector<cd> inverseFFT(vector<cd> A)
{
	A = FFT(A);
	int n = (int)A.size();
	for (int i = 0; i < n; i++)
		A[i] /= n;
	reverse(A.begin() + 1, A.end());
	return A;
}

Оптимизации

В общем-то, это всё; однако данная реализация работает не очень быстро и потребляет $O(n \log n)$ памяти.

Что мы глобально делаем:

  1. Переставляем коэффициенты
  2. Делаем рекурсивные запуски
  3. Склеиваем результаты

Если один раз в самом начале применить все перестановки индексов, то можно будет сразу двигаться от более глубоких уровней рекурсии в начало.

Как переставляются индексы: в левую часть идут чётные индексы, в правую — нечётные, потом к каждой из половин рекурсивно примерняется то же правило. Если рассмотреть битовую запись числа, то мы сначала сортируем по младшему биту, потом “отрезаем” его и делаем дальше то же самое. Нетрудно догадаться, что мы на самом деле сортируем индексы по реверснутой битовой записи; строгое доказательство проводится мат.индукцией по номеру уровня.

Реверснутую битовую запись всех чисел от $0$ до $2^{m}-1$ можно предподсчитать с помощью ДП, воспользовавшись следующей идеей: $rev(mask) = rev(mask \oplus 2^{k}) \oplus 2^{m - 1 - k}$.

Также можно заранее предподсчитать комплексные корни из $1$. Тут есть тонкий момент: синусы и косинусы весьма медленные, в то время как если считать $\omega_{k} = \omega{k-1} \omega_{1}$, то может набежать большая погрешность. Можно использовать промежуточные варианты, например, честно посчитать первые $2^{\frac{m}{2}}$ значений, а также значения по индексам, кратным $2^{\frac{m}{2}}$, а все остальные корни разложить в произведение двух уже посчитанных.

Наконец, можно уменьшить количество умножений комплексных чисел. Вспомним, что

Их можно считать одновременно, это уменьшит количество умножений вдвое.

Код

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
const double PI = 4 * atan(1.);
typedef complex<double> cd;

const int LOG = 18;
const int N = 1 << LOG;
cd w[N + 5];
int rev[N + 5];

void initFFT()
{
	for (int i = 0; i < N; i++)
		w[i] = cd(cos(2 * PI * i / N), sin(2 * PI * i / N));

	int k = 0;
	rev[0] = 0;
	for (int mask = 1; mask < N; mask++)
	{
		if (mask >> (k + 1)) k++; // k - the most significant bit of mask
		rev[mask] = rev[mask ^ (1 << k)] ^ (1 << (LOG - 1 - k));
	}
}

cd F[2][N]; // maintain two layers
void FFT(cd *A, int k) // n = (1 << k)
{
	int L = 1 << k;
	// rearrange coefficients
	for (int mask = 0; mask < L; mask++)
		F[0][rev[mask] >> (LOG - k)] = A[mask];
	int t = 0, nt = 1; // t - current, nt - new
	for (int lvl = 0; lvl < k; lvl++)
	{
		int len = 1 << lvl;
		for (int st = 0; st < L; st += (len << 1))
			for (int i = 0; i < len; i++)
			{
				cd summand = F[t][st + len + i] * w[i << (LOG - 1 - lvl)];
				F[nt][st + i] = F[t][st + i] + summand;
				F[nt][st + len + i] = F[t][st + i] - summand;
			}
		swap(t, nt); // change layers
	}
	for (int i = 0; i < L; i++)
		A[i] = F[t][i];
}

Умножение многочленов

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
vector<cd> multiply(vector<cd> A, vector<cd> B)
{
	int sz1 = (int)A.size(), sz2 = (int)B.size();
	int k = 0;
	// deg(A) = sz1 - 1, deg(B) = sz2 - 1, deg(AB) = sz1 + sz2 - 2
	while((1 << k) < (sz1 + sz2 - 1)) k++;
	int L = 1 << k;
	cd C[L], D[L];
	for (int i = 0; i < L; i++)
		C[i] = D[i] = 0;
	for (int i = 0; i < sz1; i++)
		C[i] = A[i];
	for (int i = 0; i < sz2; i++)
		D[i] = B[i];
	FFT(C, k);
	FFT(D, k);
	for (int i = 0; i < L; i++)
		C[i] *= D[i];
	FFT(C, k);
	reverse(C + 1, C + L);
	vector<cd> res;
	res.resize(sz1 + sz2 - 1);
	for (int i = 0; i < sz1 + sz2 - 1; i++)
		res.push_back(C[i] / L);
	return res;
}

Размер применяемого FFT должен быть строго больше, чем степень произведения многочленов.

В задачах часто многочлены имеют целочисленные коэффициенты, причём неотрицательные. Понятно, что при таких условиях коэффициенты у произведения будут тоже целыми и неотрицательными.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
typedef long long ll;

vector<ll> multiply(vector<ll> A, vector<ll> B)
{
	int sz1 = (int)A.size(), sz2 = (int)B.size();
	int k = 0;
	// deg(A) = sz1 - 1, deg(B) = sz2 - 1, deg(AB) = sz1 + sz2 - 2
	while((1 << k) < (sz1 + sz2 - 1)) k++;
	int L = 1 << k;
	cd C[L], D[L];
	for (int i = 0; i < L; i++)
		C[i] = D[i] = 0;
	for (int i = 0; i < sz1; i++)
		C[i] = A[i];
	for (int i = 0; i < sz2; i++)
		D[i] = B[i];
	FFT(C, k);
	FFT(D, k);
	for (int i = 0; i < L; i++)
		C[i] *= D[i];
	FFT(C, k);
	reverse(C + 1, C + L);
	vector<ll> res;
	res.resize(sz1 + sz2 - 1);
	for (int i = 0; i < sz1 + sz2 - 1; i++)
		res.push_back((ll)(C[i].real)() / L + 0.5));
	return res;
}

С таким способом округления следует быть осторожным, он верен только для неотрицательных чисел. Если коэффициенты многочлена могут быть отрицательными, следует округлять аккуратно:

1
2
3
4
5
6
7
typedef long long ll;

ll myRound(double x)
{
	if (x > 0) return (ll)(x + 0.5);
	return (ll)(x - 0.5);
}

Также нужно помнить, что double имеет точность около $15$ знаков, поэтому расчитывать на точное умножение многочленов можно только если коэффициенты произведения не превосходят $10^{14}$, а лучше и ещё меньше.

Разные мелочи

При реализации я пользовался встроенным классом complex. Чтобы его использовать, нужно подключить заголовочный файл complex.

Разумеется, можно использовать не только complex<double>, но и complex<float> или complex<long double>. Первый вариант действительно может помочь уменьшить время работы и объём потребляемой памяти.

Удивительно, но написание своего класса Complex также может уменьшить время работы.

В общем случае нам нужно 3 вызова FFT, чтобы перемножить два многочлена. Существуют методы, позволяющие проводить 2 FFT одновременно. Но если у нас есть два набора из $n$ и $m$ многочленов, и мы хотим посчитать попарные произведения, то для этого достаточно $n + m + nm$ вызовов FFT.

Операции по модулю

Часто в задачах просят посчитать что-то по модулю некоторого числа. Это же может относиться и к произведению многочленов. Конечно, в таком случае у нас все коэффициенты целые неотрицательные. Однако если модуль порядка $10^{9}$, то коэффициенты произведения могут получиться порядка $10^{23}$, что нельзя не то что точно сохранить в double, а даже сохранить в long long, чтобы потом взять остаток по модулю. Что же делать?

“Хороший” модуль

Почему мы вообще использовали комплексные числа, если все коэффициенты исходных многочленов были действительными? Дело в том, что нам нужно было $n$ разных корней из $1$, чего действительные числа предоставить не могут. Если бы только был какой-то другой объект с такими свойствами…

Как известно из теории чисел, ненулевые остатки по модулю простого числа $P$ образуют циклическую группу порядка $P-1$ по умножению. Пусть $P-1$ делится на достаточно большую степень двойки, то есть $P-1 = nz$ для некоторого целого $z$. Тогда в этой группе тоже есть $n$ разных корней из $1$ степени $n$, и они тоже обладают всеми необходимыми нам свойствами.

Таким образом, мы можем выполнять все операции по модулю $P$, а комплексные корни из $1$ нужно заменить на корни из $1$ по модулю.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
typedef long long ll;
const ll MOD = 998244353; // most popular "good" prime number
ll add(ll x, ll y)
{
	x += y;
	if (x >= MOD) return x - MOD;
	return x;
}
ll sub(ll x, ll y)
{
	x -= y;
	if (x < 0) return x + MOD;
	return x;
}
ll mult(ll x, ll y)
{
	return (x * y) % MOD;
}
ll binPow(ll x, ll p)
{
	if (p == 0) return 1;
	if (p == 2 || (p & 1)) return mult(x, binPow(x, p - 1));
	return binPow(binPow(x, p / 2), 2);
}
ll modRev(ll x)
{
	return binPow(x, MOD - 2);
}

const int LOG = 18;
const int N = 1 << LOG;
ll w[N + 5];
int rev[N + 5];

void initFFT()
{
	// finding root
	ll W = -1;
	for (ll x = 2;; x++)
	{
		ll y = x;
		for (int i = 1; i < LOG; i++)
			y = mult(y, y);
		// y = x ** (n / 2)
		// so, y != 1
		// but y ** 2 == 1
		if (y == MOD - 1)
		{
			W = x;
			break;
		}
	}
	if (W == -1) throw;
	w[0] = 1;
	for (int i = 1; i < N; i++)
		w[i] = mult(w[i - 1], W); // no precision errors now

	int k = 0;
	rev[0] = 0;
	for (int mask = 1; mask < N; mask++)
	{
		if (mask >> (k + 1)) k++; // k - the most significant bit of mask
		rev[mask] = rev[mask ^ (1 << k)] ^ (1 << (LOG - 1 - k));
	}
}

ll F[2][N]; // maintain two layers
void FFT(ll *A, int k) // n = (1 << k)
{
	int L = 1 << k;
	// rearrange coefficients
	for (int mask = 0; mask < L; mask++)
		F[0][rev[mask] >> (LOG - k)] = A[mask];
	int t = 0, nt = 1; // t - current, nt - new
	for (int lvl = 0; lvl < k; lvl++)
	{
		int len = 1 << lvl;
		for (int st = 0; st < L; st += (len << 1))
			for (int i = 0; i < len; i++)
			{
				ll summand = mult(F[t][st + len + i], w[i << (LOG - 1 - lvl)]);
				F[nt][st + i] = add(F[t][st + i], summand);
				F[nt][st + len + i] = sub(F[t][st + i], summand);
			}
		swap(t, nt); // change layers
	}
	for (int i = 0; i < L; i++)
		A[i] = F[t][i];
}

vector<ll> multiply(vector<ll> A, vector<ll> B)
{
	int sz1 = (int)A.size(), sz2 = (int)B.size();
	int k = 0;
	// deg(A) = sz1 - 1, deg(B) = sz2 - 1, deg(AB) = sz1 + sz2 - 2
	while((1 << k) < (sz1 + sz2 - 1)) k++;
	int L = 1 << k;
	ll C[L], D[L];
	for (int i = 0; i < L; i++)
		C[i] = D[i] = 0;
	for (int i = 0; i < sz1; i++)
		C[i] = A[i];
	for (int i = 0; i < sz2; i++)
		D[i] = B[i];
	FFT(C, k);
	FFT(D, k);
	for (int i = 0; i < L; i++)
		C[i] = mult(C[i], D[i]);
	FFT(C, k);
	reverse(C + 1, C + L);
	vector<cd> res;
	res.resize(sz1 + sz2 - 1);
	// important change here
	// we should divide by L modulo MOD
	ll rL = modRev(L);
	for (int i = 0; i < sz1 + sz2 - 1; i++)
		res.push_back(mult(C[i], rL));
	return res;
}

Таким образом, если нам повезло, то всё хорошо. Но нужно понимать, что в рамках соревнований “повезло” — это значит авторы задачи сделали так, чтобы нам повезло. Поэтому если вы видите в задаче “необычный модуль”, это может быть сильным намёком на то, что в задаче требуется FFT. С другой стороны, давать такую огромную подсказку — это нежелательно для авторов. Так или иначе, коротенький список часто встречающихся “необычных модулей”, подходящих для написания FFT по модулю: $7340033 = 7 \cdot 2^{20} + 1$, $998244353 = 119 \cdot 2^{23} + 1$.

Любой модуль

Тут есть два принципиально разных подхода:

Несколько хороших модулей + Китайская теорема об остатках

Название говорит само за себя. Можно незавсимо выполнить умножение по 3 разным хорошим модулям, а потом узнать искомые коэффициенты при помощи КТО.

Разбить на многочлены с меньшими коэффициентами

Выберем $Q \approx 1000$. Запишем

Тогда

Коэффициенты многочленов $A_{i}(x)B_{j}(x)$ будут порядка $nQ^2 \le 10^{12}$.

Применение

Понятно, что FFT используется для умножения многочленов, так что это будет скорее описание применений умножения многочленов.

Умножение длинных чисел

Длинные числа умножаются как многочлены (коэффициенты — это цифры), только потом нужно сделать переносы. Однако лучше разбивать цифры на группы по три, то есть рассматривать число в системе счисления по основанию $1000$, — таким образом степени многочленов уменьшаются в три раза.

Вычисление скалярных произведений

Есть массив $A$ длины $n$ и массив $B$ длины $m$ ($m \le n$). Нужно посчитать скалярные произведения $B$ со всеми подотрезками $A$ длины $m$.

Перевернём $B$ и умножим с $A$ как многочлены. Несложно понять, что после переворота скалярное произведение превращается в свёртку. Коэффициенты с $(m-1)$-го по $(n-1)$-й — это ответы.

Нечёткий поиск

Есть строчки $S$ и $P$ над алфавитом $\Sigma$. Для каждой подстроки $S$ длины $ | P | $ найти расстояние Хэмминга (количество позиций, в которых строки отличаются).

Будем считать не расстояние Хэмминга, а $ | P | $ - расстояние, то есть количество совпадающих позиций. Переберём символ из $\Sigma$. В каждой строчке заменим все вхождения данного символа на $1$, а всех остальных — на $0$. Мы свели задачу к предыдущей.

Сложность — $O( | \Sigma | n \log n)$.

Используя эту задачу, можно решать задачи вида “Найдите все вхождения с не более чем $k$ ошибками”.

Все суммы

Есть два множества чисел $A$ и $B$, все числа целые от $0$ до $n$. Выдать все числа из $\lbrace x = a + b \mid a \in A, b \in B \rbrace$.

Сопоставим каждому набору многочлен $ \sum_{a \in A} x^{a} $ и перемножим их. Коэффициент при $x^{k}$ будет равен количеству способов представить $k$ в виде суммы элементов $A$ и $B$.

Динамическое программирование

Иногда в ДП переход выглядит как свёртка, и его можно ускорить при помощи FFT.

Задачи