Skip to content

Commit

Permalink
Added new recursive FFT implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Brechtpd committed Feb 26, 2020
1 parent e0183b2 commit 1d8db42
Show file tree
Hide file tree
Showing 5 changed files with 650 additions and 2 deletions.
54 changes: 54 additions & 0 deletions libfqfft/evaluation_domain/domains/recursive_domain.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/** @file
*****************************************************************************
Declaration of interfaces for the "recursive" evaluation domain.
Rhe domain has size m = 2^k and consists of the m-th roots of unity.
*****************************************************************************
* @author This file is part of libfqfft, developed by SCIPR Lab
* and contributors (see AUTHORS).
* @copyright MIT license (see LICENSE file)
*****************************************************************************/

#ifndef RECURSIVE_DOMAIN_HPP_
#define RECURSIVE_DOMAIN_HPP_

#include <vector>

#include <libfqfft/evaluation_domain/evaluation_domain.hpp>
#include "recursive_domain_aux.hpp"
#include "prover_config.hpp"

namespace libfqfft {

template<typename FieldT>
class recursive_domain : public evaluation_domain<FieldT> {
public:

FieldT omega;

recursive_domain(const size_t m, const libsnark::Config& config = libsnark::Config());

void FFT(std::vector<FieldT> &a);
void iFFT(std::vector<FieldT> &a);
void cosetFFT(std::vector<FieldT> &a, const FieldT &g);
void icosetFFT(std::vector<FieldT> &a, const FieldT &g);
std::vector<FieldT> evaluate_all_lagrange_polynomials(const FieldT &t);
FieldT get_domain_element(const size_t idx);
FieldT compute_vanishing_polynomial(const FieldT &t);
void add_poly_Z(const FieldT &coeff, std::vector<FieldT> &H);
void divide_by_Z_on_coset(std::vector<FieldT> &P);

public:

void iFFT_internal(std::vector<FieldT> &a);

fft_data<FieldT> data;
};

} // libfqfft

#include <libfqfft/evaluation_domain/domains/recursive_domain.tcc>

#endif // RECURSIVE_DOMAIN_HPP_
187 changes: 187 additions & 0 deletions libfqfft/evaluation_domain/domains/recursive_domain.tcc
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
/** @file
*****************************************************************************
Implementation of interfaces for the "recursive" evaluation domain.
See recursive_domain.hpp .
*****************************************************************************
* @author This file is part of libfqfft, developed by SCIPR Lab
* and contributors (see AUTHORS).
* @copyright MIT license (see LICENSE file)
*****************************************************************************/

#ifndef RECURSIVE_DOMAIN_TCC_
#define RECURSIVE_DOMAIN_TCC_

#include <libff/algebra/fields/field_utils.hpp>
#include <libff/common/double.hpp>
#include <libff/common/utils.hpp>

#include <libfqfft/evaluation_domain/domains/recursive_domain_aux.hpp>

namespace libfqfft {

template<typename FieldT>
recursive_domain<FieldT>::recursive_domain(const size_t m, const libsnark::Config& config) : evaluation_domain<FieldT>(m)
{
if (m <= 1) throw InvalidSizeException("recursive(): expected m > 1");
const size_t logm = libff::log2(m);

if (!std::is_same<FieldT, libff::Double>::value)
{
if (logm > (FieldT::s)) throw DomainSizeException("recursive(): expected logm <= FieldT::s");
}

try { omega = libff::get_root_of_unity<FieldT>(m); }
catch (const std::invalid_argument& e) { throw DomainSizeException(e.what()); }

data.m = m;
data.smt = config.smt;

data.stages = get_stages(m, config.radixes);

auto ranges = libsnark::get_cpu_ranges(0, m);

data.scratch.resize(m);

// Generate stage twiddles
for (unsigned int inv = 0; inv < 2; inv++)
{
bool inverse = (inv == 0);
const FieldT o = inverse ? omega.inverse() : omega;
std::vector<std::vector<FieldT>>& stageTwiddles = inverse ? data.iTwiddles : data.fTwiddles;

std::vector<FieldT>& twiddles = data.scratch;

// Twiddles
{
#ifdef MULTICORE
#pragma omp parallel for
#endif
for (size_t j = 0; j < ranges.size(); ++j)
{
const FieldT w_m = o;
FieldT w = (w_m^ranges[j].first);
for (unsigned int i = ranges[j].first; i < ranges[j].second; i++)
{
twiddles[i] = w;
w *= w_m;
}
}
}

// Re-order twiddles for cache friendliness
unsigned int n = data.stages.size();
stageTwiddles.resize(n);
for (unsigned int l = 0; l < n; l++)
{
const unsigned int radix = data.stages[l].radix;
const unsigned int stage_length = data.stages[l].length;

unsigned int numTwiddles = stage_length * (radix - 1);
stageTwiddles[l].resize(numTwiddles + 1);

// Set j
stageTwiddles[l][numTwiddles] = twiddles[(twiddles.size() * 3) / 4];

unsigned int stride = m / (stage_length * radix);
std::vector<unsigned int> tws(radix - 1, 0);
for (unsigned int i = 0; i < stage_length; i++)
{
for(unsigned int j = 0; j < radix-1; j++)
{
stageTwiddles[l][i*(radix-1) + j] = twiddles[tws[j]];
tws[j] += (j+1)*stride;
}
}
}
}
}

template<typename FieldT>
void recursive_domain<FieldT>::FFT(std::vector<FieldT> &a)
{
_recursive_FFT(data, a, false);
}

template<typename FieldT>
void recursive_domain<FieldT>::iFFT(std::vector<FieldT> &a)
{
iFFT_internal(a);

const FieldT sconst = FieldT(this->m).inverse();
#ifdef MULTICORE
#pragma omp parallel for
#endif
for (size_t i = 0; i < this->m; ++i)
{
a[i] *= sconst;
}
}

template<typename FieldT>
void recursive_domain<FieldT>::iFFT_internal(std::vector<FieldT> &a)
{
_recursive_FFT(data, a, true);
}

template<typename FieldT>
void recursive_domain<FieldT>::cosetFFT(std::vector<FieldT> &a, const FieldT &g)
{
_multiply_by_coset_and_constant(this->m, a, g);
FFT(a);
}

template<typename FieldT>
void recursive_domain<FieldT>::icosetFFT(std::vector<FieldT> &a, const FieldT &g)
{
iFFT_internal(a);
const FieldT sconst = FieldT(this->m).inverse();
_multiply_by_coset_and_constant(this->m, a, g.inverse(), sconst);
}

template<typename FieldT>
std::vector<FieldT> recursive_domain<FieldT>::evaluate_all_lagrange_polynomials(const FieldT &t)
{
return _basic_radix2_evaluate_all_lagrange_polynomials(this->m, t);
}

template<typename FieldT>
FieldT recursive_domain<FieldT>::get_domain_element(const size_t idx)
{
return omega^idx;
}

template<typename FieldT>
FieldT recursive_domain<FieldT>::compute_vanishing_polynomial(const FieldT &t)
{
return (t^this->m) - FieldT::one();
}

template<typename FieldT>
void recursive_domain<FieldT>::add_poly_Z(const FieldT &coeff, std::vector<FieldT> &H)
{
if (H.size() != this->m+1) throw DomainSizeException("recursive: expected H.size() == this->m+1");

H[this->m] += coeff;
H[0] -= coeff;
}

template<typename FieldT>
void recursive_domain<FieldT>::divide_by_Z_on_coset(std::vector<FieldT> &P)
{
const FieldT coset = FieldT::multiplicative_generator;
const FieldT Z_inverse_at_coset = this->compute_vanishing_polynomial(coset).inverse();
#ifdef MULTICORE
#pragma omp parallel for
#endif
for (size_t i = 0; i < this->m; ++i)
{
P[i] *= Z_inverse_at_coset;
}
}

} // libfqfft

#endif // RECURSIVE_DOMAIN_TCC_
60 changes: 60 additions & 0 deletions libfqfft/evaluation_domain/domains/recursive_domain_aux.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/** @file
*****************************************************************************
Declaration of interfaces for auxiliary functions for the "basic radix-2" evaluation domain.
These functions compute the radix-2 FFT (in single- or multi-thread mode) and,
also compute Lagrange coefficients.
*****************************************************************************
* @author This file is part of libfqfft, developed by SCIPR Lab
* and contributors (see AUTHORS).
* @copyright MIT license (see LICENSE file)
*****************************************************************************/

#ifndef RECURSIVE_DOMAIN_AUX_HPP_
#define RECURSIVE_DOMAIN_AUX_HPP_

#include <vector>

namespace libfqfft {

struct fft_stage
{
fft_stage(unsigned int _radix, unsigned int _length) : radix(_radix), length(_length) {}
unsigned int radix;
unsigned int length;
};

template<typename FieldT>
struct fft_data
{
unsigned int m;
bool smt;

std::vector<fft_stage> stages;

std::vector<std::vector<FieldT>> fTwiddles;
std::vector<std::vector<FieldT>> iTwiddles;

std::vector<FieldT> scratch;
};

/**
* Compute the FFT of the vector a over the set S={omega^{0},...,omega^{m-1}}.
*/
template<typename FieldT>
void _recursive_FFT(fft_data<FieldT>& data, std::vector<FieldT>& in, bool inverse);

/**
* Translate the vector a to a coset defined by g + extra constant multiplication.
*/
template<typename FieldT>
void _multiply_by_coset_and_constant(unsigned int m, std::vector<FieldT> &a, const FieldT &g, const FieldT &c = FieldT::one());


} // libfqfft

#include <libfqfft/evaluation_domain/domains/recursive_domain_aux.tcc>

#endif // RECURSIVE_DOMAIN_AUX_HPP_
Loading

0 comments on commit 1d8db42

Please sign in to comment.