forked from Ethsnarks/libfqfft
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added new recursive FFT implementation
- Loading branch information
Showing
5 changed files
with
650 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
/** @file | ||
***************************************************************************** | ||
Declaration of interfaces for the "recursive" evaluation domain. | ||
Rhe domain has size m = 2^k and consists of the m-th roots of unity. | ||
***************************************************************************** | ||
* @author This file is part of libfqfft, developed by SCIPR Lab | ||
* and contributors (see AUTHORS). | ||
* @copyright MIT license (see LICENSE file) | ||
*****************************************************************************/ | ||
|
||
#ifndef RECURSIVE_DOMAIN_HPP_ | ||
#define RECURSIVE_DOMAIN_HPP_ | ||
|
||
#include <vector> | ||
|
||
#include <libfqfft/evaluation_domain/evaluation_domain.hpp> | ||
#include "recursive_domain_aux.hpp" | ||
#include "prover_config.hpp" | ||
|
||
namespace libfqfft { | ||
|
||
template<typename FieldT> | ||
class recursive_domain : public evaluation_domain<FieldT> { | ||
public: | ||
|
||
FieldT omega; | ||
|
||
recursive_domain(const size_t m, const libsnark::Config& config = libsnark::Config()); | ||
|
||
void FFT(std::vector<FieldT> &a); | ||
void iFFT(std::vector<FieldT> &a); | ||
void cosetFFT(std::vector<FieldT> &a, const FieldT &g); | ||
void icosetFFT(std::vector<FieldT> &a, const FieldT &g); | ||
std::vector<FieldT> evaluate_all_lagrange_polynomials(const FieldT &t); | ||
FieldT get_domain_element(const size_t idx); | ||
FieldT compute_vanishing_polynomial(const FieldT &t); | ||
void add_poly_Z(const FieldT &coeff, std::vector<FieldT> &H); | ||
void divide_by_Z_on_coset(std::vector<FieldT> &P); | ||
|
||
public: | ||
|
||
void iFFT_internal(std::vector<FieldT> &a); | ||
|
||
fft_data<FieldT> data; | ||
}; | ||
|
||
} // libfqfft | ||
|
||
#include <libfqfft/evaluation_domain/domains/recursive_domain.tcc> | ||
|
||
#endif // RECURSIVE_DOMAIN_HPP_ |
187 changes: 187 additions & 0 deletions
187
libfqfft/evaluation_domain/domains/recursive_domain.tcc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
/** @file | ||
***************************************************************************** | ||
Implementation of interfaces for the "recursive" evaluation domain. | ||
See recursive_domain.hpp . | ||
***************************************************************************** | ||
* @author This file is part of libfqfft, developed by SCIPR Lab | ||
* and contributors (see AUTHORS). | ||
* @copyright MIT license (see LICENSE file) | ||
*****************************************************************************/ | ||
|
||
#ifndef RECURSIVE_DOMAIN_TCC_ | ||
#define RECURSIVE_DOMAIN_TCC_ | ||
|
||
#include <libff/algebra/fields/field_utils.hpp> | ||
#include <libff/common/double.hpp> | ||
#include <libff/common/utils.hpp> | ||
|
||
#include <libfqfft/evaluation_domain/domains/recursive_domain_aux.hpp> | ||
|
||
namespace libfqfft { | ||
|
||
template<typename FieldT> | ||
recursive_domain<FieldT>::recursive_domain(const size_t m, const libsnark::Config& config) : evaluation_domain<FieldT>(m) | ||
{ | ||
if (m <= 1) throw InvalidSizeException("recursive(): expected m > 1"); | ||
const size_t logm = libff::log2(m); | ||
|
||
if (!std::is_same<FieldT, libff::Double>::value) | ||
{ | ||
if (logm > (FieldT::s)) throw DomainSizeException("recursive(): expected logm <= FieldT::s"); | ||
} | ||
|
||
try { omega = libff::get_root_of_unity<FieldT>(m); } | ||
catch (const std::invalid_argument& e) { throw DomainSizeException(e.what()); } | ||
|
||
data.m = m; | ||
data.smt = config.smt; | ||
|
||
data.stages = get_stages(m, config.radixes); | ||
|
||
auto ranges = libsnark::get_cpu_ranges(0, m); | ||
|
||
data.scratch.resize(m); | ||
|
||
// Generate stage twiddles | ||
for (unsigned int inv = 0; inv < 2; inv++) | ||
{ | ||
bool inverse = (inv == 0); | ||
const FieldT o = inverse ? omega.inverse() : omega; | ||
std::vector<std::vector<FieldT>>& stageTwiddles = inverse ? data.iTwiddles : data.fTwiddles; | ||
|
||
std::vector<FieldT>& twiddles = data.scratch; | ||
|
||
// Twiddles | ||
{ | ||
#ifdef MULTICORE | ||
#pragma omp parallel for | ||
#endif | ||
for (size_t j = 0; j < ranges.size(); ++j) | ||
{ | ||
const FieldT w_m = o; | ||
FieldT w = (w_m^ranges[j].first); | ||
for (unsigned int i = ranges[j].first; i < ranges[j].second; i++) | ||
{ | ||
twiddles[i] = w; | ||
w *= w_m; | ||
} | ||
} | ||
} | ||
|
||
// Re-order twiddles for cache friendliness | ||
unsigned int n = data.stages.size(); | ||
stageTwiddles.resize(n); | ||
for (unsigned int l = 0; l < n; l++) | ||
{ | ||
const unsigned int radix = data.stages[l].radix; | ||
const unsigned int stage_length = data.stages[l].length; | ||
|
||
unsigned int numTwiddles = stage_length * (radix - 1); | ||
stageTwiddles[l].resize(numTwiddles + 1); | ||
|
||
// Set j | ||
stageTwiddles[l][numTwiddles] = twiddles[(twiddles.size() * 3) / 4]; | ||
|
||
unsigned int stride = m / (stage_length * radix); | ||
std::vector<unsigned int> tws(radix - 1, 0); | ||
for (unsigned int i = 0; i < stage_length; i++) | ||
{ | ||
for(unsigned int j = 0; j < radix-1; j++) | ||
{ | ||
stageTwiddles[l][i*(radix-1) + j] = twiddles[tws[j]]; | ||
tws[j] += (j+1)*stride; | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
template<typename FieldT> | ||
void recursive_domain<FieldT>::FFT(std::vector<FieldT> &a) | ||
{ | ||
_recursive_FFT(data, a, false); | ||
} | ||
|
||
template<typename FieldT> | ||
void recursive_domain<FieldT>::iFFT(std::vector<FieldT> &a) | ||
{ | ||
iFFT_internal(a); | ||
|
||
const FieldT sconst = FieldT(this->m).inverse(); | ||
#ifdef MULTICORE | ||
#pragma omp parallel for | ||
#endif | ||
for (size_t i = 0; i < this->m; ++i) | ||
{ | ||
a[i] *= sconst; | ||
} | ||
} | ||
|
||
template<typename FieldT> | ||
void recursive_domain<FieldT>::iFFT_internal(std::vector<FieldT> &a) | ||
{ | ||
_recursive_FFT(data, a, true); | ||
} | ||
|
||
template<typename FieldT> | ||
void recursive_domain<FieldT>::cosetFFT(std::vector<FieldT> &a, const FieldT &g) | ||
{ | ||
_multiply_by_coset_and_constant(this->m, a, g); | ||
FFT(a); | ||
} | ||
|
||
template<typename FieldT> | ||
void recursive_domain<FieldT>::icosetFFT(std::vector<FieldT> &a, const FieldT &g) | ||
{ | ||
iFFT_internal(a); | ||
const FieldT sconst = FieldT(this->m).inverse(); | ||
_multiply_by_coset_and_constant(this->m, a, g.inverse(), sconst); | ||
} | ||
|
||
template<typename FieldT> | ||
std::vector<FieldT> recursive_domain<FieldT>::evaluate_all_lagrange_polynomials(const FieldT &t) | ||
{ | ||
return _basic_radix2_evaluate_all_lagrange_polynomials(this->m, t); | ||
} | ||
|
||
template<typename FieldT> | ||
FieldT recursive_domain<FieldT>::get_domain_element(const size_t idx) | ||
{ | ||
return omega^idx; | ||
} | ||
|
||
template<typename FieldT> | ||
FieldT recursive_domain<FieldT>::compute_vanishing_polynomial(const FieldT &t) | ||
{ | ||
return (t^this->m) - FieldT::one(); | ||
} | ||
|
||
template<typename FieldT> | ||
void recursive_domain<FieldT>::add_poly_Z(const FieldT &coeff, std::vector<FieldT> &H) | ||
{ | ||
if (H.size() != this->m+1) throw DomainSizeException("recursive: expected H.size() == this->m+1"); | ||
|
||
H[this->m] += coeff; | ||
H[0] -= coeff; | ||
} | ||
|
||
template<typename FieldT> | ||
void recursive_domain<FieldT>::divide_by_Z_on_coset(std::vector<FieldT> &P) | ||
{ | ||
const FieldT coset = FieldT::multiplicative_generator; | ||
const FieldT Z_inverse_at_coset = this->compute_vanishing_polynomial(coset).inverse(); | ||
#ifdef MULTICORE | ||
#pragma omp parallel for | ||
#endif | ||
for (size_t i = 0; i < this->m; ++i) | ||
{ | ||
P[i] *= Z_inverse_at_coset; | ||
} | ||
} | ||
|
||
} // libfqfft | ||
|
||
#endif // RECURSIVE_DOMAIN_TCC_ |
60 changes: 60 additions & 0 deletions
60
libfqfft/evaluation_domain/domains/recursive_domain_aux.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
/** @file | ||
***************************************************************************** | ||
Declaration of interfaces for auxiliary functions for the "basic radix-2" evaluation domain. | ||
These functions compute the radix-2 FFT (in single- or multi-thread mode) and, | ||
also compute Lagrange coefficients. | ||
***************************************************************************** | ||
* @author This file is part of libfqfft, developed by SCIPR Lab | ||
* and contributors (see AUTHORS). | ||
* @copyright MIT license (see LICENSE file) | ||
*****************************************************************************/ | ||
|
||
#ifndef RECURSIVE_DOMAIN_AUX_HPP_ | ||
#define RECURSIVE_DOMAIN_AUX_HPP_ | ||
|
||
#include <vector> | ||
|
||
namespace libfqfft { | ||
|
||
struct fft_stage | ||
{ | ||
fft_stage(unsigned int _radix, unsigned int _length) : radix(_radix), length(_length) {} | ||
unsigned int radix; | ||
unsigned int length; | ||
}; | ||
|
||
template<typename FieldT> | ||
struct fft_data | ||
{ | ||
unsigned int m; | ||
bool smt; | ||
|
||
std::vector<fft_stage> stages; | ||
|
||
std::vector<std::vector<FieldT>> fTwiddles; | ||
std::vector<std::vector<FieldT>> iTwiddles; | ||
|
||
std::vector<FieldT> scratch; | ||
}; | ||
|
||
/** | ||
* Compute the FFT of the vector a over the set S={omega^{0},...,omega^{m-1}}. | ||
*/ | ||
template<typename FieldT> | ||
void _recursive_FFT(fft_data<FieldT>& data, std::vector<FieldT>& in, bool inverse); | ||
|
||
/** | ||
* Translate the vector a to a coset defined by g + extra constant multiplication. | ||
*/ | ||
template<typename FieldT> | ||
void _multiply_by_coset_and_constant(unsigned int m, std::vector<FieldT> &a, const FieldT &g, const FieldT &c = FieldT::one()); | ||
|
||
|
||
} // libfqfft | ||
|
||
#include <libfqfft/evaluation_domain/domains/recursive_domain_aux.tcc> | ||
|
||
#endif // RECURSIVE_DOMAIN_AUX_HPP_ |
Oops, something went wrong.