diff --git a/libfqfft/evaluation_domain/domains/recursive_domain.hpp b/libfqfft/evaluation_domain/domains/recursive_domain.hpp new file mode 100755 index 0000000..ac6343d --- /dev/null +++ b/libfqfft/evaluation_domain/domains/recursive_domain.hpp @@ -0,0 +1,54 @@ +/** @file + ***************************************************************************** + + Declaration of interfaces for the "recursive" evaluation domain. + + Rhe domain has size m = 2^k and consists of the m-th roots of unity. + + ***************************************************************************** + * @author This file is part of libfqfft, developed by SCIPR Lab + * and contributors (see AUTHORS). + * @copyright MIT license (see LICENSE file) + *****************************************************************************/ + +#ifndef RECURSIVE_DOMAIN_HPP_ +#define RECURSIVE_DOMAIN_HPP_ + +#include + +#include +#include "recursive_domain_aux.hpp" +#include "prover_config.hpp" + +namespace libfqfft { + +template +class recursive_domain : public evaluation_domain { +public: + + FieldT omega; + + recursive_domain(const size_t m, const libsnark::Config& config = libsnark::Config()); + + void FFT(std::vector &a); + void iFFT(std::vector &a); + void cosetFFT(std::vector &a, const FieldT &g); + void icosetFFT(std::vector &a, const FieldT &g); + std::vector evaluate_all_lagrange_polynomials(const FieldT &t); + FieldT get_domain_element(const size_t idx); + FieldT compute_vanishing_polynomial(const FieldT &t); + void add_poly_Z(const FieldT &coeff, std::vector &H); + void divide_by_Z_on_coset(std::vector &P); + +public: + + void iFFT_internal(std::vector &a); + + fft_data data; +}; + +} // libfqfft + +#include + +#endif // RECURSIVE_DOMAIN_HPP_ diff --git a/libfqfft/evaluation_domain/domains/recursive_domain.tcc b/libfqfft/evaluation_domain/domains/recursive_domain.tcc new file mode 100755 index 0000000..2cb3261 --- /dev/null +++ b/libfqfft/evaluation_domain/domains/recursive_domain.tcc @@ -0,0 +1,187 @@ +/** @file + ***************************************************************************** + + Implementation of interfaces for the "recursive" evaluation domain. + + See recursive_domain.hpp . + + ***************************************************************************** + * @author This file is part of libfqfft, developed by SCIPR Lab + * and contributors (see AUTHORS). + * @copyright MIT license (see LICENSE file) + *****************************************************************************/ + +#ifndef RECURSIVE_DOMAIN_TCC_ +#define RECURSIVE_DOMAIN_TCC_ + +#include +#include +#include + +#include + +namespace libfqfft { + +template +recursive_domain::recursive_domain(const size_t m, const libsnark::Config& config) : evaluation_domain(m) +{ + if (m <= 1) throw InvalidSizeException("recursive(): expected m > 1"); + const size_t logm = libff::log2(m); + + if (!std::is_same::value) + { + if (logm > (FieldT::s)) throw DomainSizeException("recursive(): expected logm <= FieldT::s"); + } + + try { omega = libff::get_root_of_unity(m); } + catch (const std::invalid_argument& e) { throw DomainSizeException(e.what()); } + + data.m = m; + data.smt = config.smt; + + data.stages = get_stages(m, config.radixes); + + auto ranges = libsnark::get_cpu_ranges(0, m); + + data.scratch.resize(m); + + // Generate stage twiddles + for (unsigned int inv = 0; inv < 2; inv++) + { + bool inverse = (inv == 0); + const FieldT o = inverse ? omega.inverse() : omega; + std::vector>& stageTwiddles = inverse ? data.iTwiddles : data.fTwiddles; + + std::vector& twiddles = data.scratch; + + // Twiddles + { +#ifdef MULTICORE + #pragma omp parallel for +#endif + for (size_t j = 0; j < ranges.size(); ++j) + { + const FieldT w_m = o; + FieldT w = (w_m^ranges[j].first); + for (unsigned int i = ranges[j].first; i < ranges[j].second; i++) + { + twiddles[i] = w; + w *= w_m; + } + } + } + + // Re-order twiddles for cache friendliness + unsigned int n = data.stages.size(); + stageTwiddles.resize(n); + for (unsigned int l = 0; l < n; l++) + { + const unsigned int radix = data.stages[l].radix; + const unsigned int stage_length = data.stages[l].length; + + unsigned int numTwiddles = stage_length * (radix - 1); + stageTwiddles[l].resize(numTwiddles + 1); + + // Set j + stageTwiddles[l][numTwiddles] = twiddles[(twiddles.size() * 3) / 4]; + + unsigned int stride = m / (stage_length * radix); + std::vector tws(radix - 1, 0); + for (unsigned int i = 0; i < stage_length; i++) + { + for(unsigned int j = 0; j < radix-1; j++) + { + stageTwiddles[l][i*(radix-1) + j] = twiddles[tws[j]]; + tws[j] += (j+1)*stride; + } + } + } + } +} + +template +void recursive_domain::FFT(std::vector &a) +{ + _recursive_FFT(data, a, false); +} + +template +void recursive_domain::iFFT(std::vector &a) +{ + iFFT_internal(a); + + const FieldT sconst = FieldT(this->m).inverse(); +#ifdef MULTICORE + #pragma omp parallel for +#endif + for (size_t i = 0; i < this->m; ++i) + { + a[i] *= sconst; + } +} + +template +void recursive_domain::iFFT_internal(std::vector &a) +{ + _recursive_FFT(data, a, true); +} + +template +void recursive_domain::cosetFFT(std::vector &a, const FieldT &g) +{ + _multiply_by_coset_and_constant(this->m, a, g); + FFT(a); +} + +template +void recursive_domain::icosetFFT(std::vector &a, const FieldT &g) +{ + iFFT_internal(a); + const FieldT sconst = FieldT(this->m).inverse(); + _multiply_by_coset_and_constant(this->m, a, g.inverse(), sconst); +} + +template +std::vector recursive_domain::evaluate_all_lagrange_polynomials(const FieldT &t) +{ + return _basic_radix2_evaluate_all_lagrange_polynomials(this->m, t); +} + +template +FieldT recursive_domain::get_domain_element(const size_t idx) +{ + return omega^idx; +} + +template +FieldT recursive_domain::compute_vanishing_polynomial(const FieldT &t) +{ + return (t^this->m) - FieldT::one(); +} + +template +void recursive_domain::add_poly_Z(const FieldT &coeff, std::vector &H) +{ + if (H.size() != this->m+1) throw DomainSizeException("recursive: expected H.size() == this->m+1"); + + H[this->m] += coeff; + H[0] -= coeff; +} + +template +void recursive_domain::divide_by_Z_on_coset(std::vector &P) +{ + const FieldT coset = FieldT::multiplicative_generator; + const FieldT Z_inverse_at_coset = this->compute_vanishing_polynomial(coset).inverse(); +#ifdef MULTICORE + #pragma omp parallel for +#endif + for (size_t i = 0; i < this->m; ++i) + { + P[i] *= Z_inverse_at_coset; + } +} + +} // libfqfft + +#endif // RECURSIVE_DOMAIN_TCC_ diff --git a/libfqfft/evaluation_domain/domains/recursive_domain_aux.hpp b/libfqfft/evaluation_domain/domains/recursive_domain_aux.hpp new file mode 100755 index 0000000..eea8121 --- /dev/null +++ b/libfqfft/evaluation_domain/domains/recursive_domain_aux.hpp @@ -0,0 +1,60 @@ +/** @file + ***************************************************************************** + + Declaration of interfaces for auxiliary functions for the "basic radix-2" evaluation domain. + + These functions compute the radix-2 FFT (in single- or multi-thread mode) and, + also compute Lagrange coefficients. + + ***************************************************************************** + * @author This file is part of libfqfft, developed by SCIPR Lab + * and contributors (see AUTHORS). + * @copyright MIT license (see LICENSE file) + *****************************************************************************/ + +#ifndef RECURSIVE_DOMAIN_AUX_HPP_ +#define RECURSIVE_DOMAIN_AUX_HPP_ + +#include + +namespace libfqfft { + +struct fft_stage +{ + fft_stage(unsigned int _radix, unsigned int _length) : radix(_radix), length(_length) {} + unsigned int radix; + unsigned int length; +}; + +template +struct fft_data +{ + unsigned int m; + bool smt; + + std::vector stages; + + std::vector> fTwiddles; + std::vector> iTwiddles; + + std::vector scratch; +}; + +/** + * Compute the FFT of the vector a over the set S={omega^{0},...,omega^{m-1}}. + */ +template +void _recursive_FFT(fft_data& data, std::vector& in, bool inverse); + +/** + * Translate the vector a to a coset defined by g + extra constant multiplication. + */ +template +void _multiply_by_coset_and_constant(unsigned int m, std::vector &a, const FieldT &g, const FieldT &c = FieldT::one()); + + +} // libfqfft + +#include + +#endif // RECURSIVE_DOMAIN_AUX_HPP_ diff --git a/libfqfft/evaluation_domain/domains/recursive_domain_aux.tcc b/libfqfft/evaluation_domain/domains/recursive_domain_aux.tcc new file mode 100755 index 0000000..411f655 --- /dev/null +++ b/libfqfft/evaluation_domain/domains/recursive_domain_aux.tcc @@ -0,0 +1,345 @@ +/** @file + ***************************************************************************** + + Implementation of interfaces for auxiliary functions for the "recursive" evaluation domain. + + See recursive_domain_aux.hpp . + + ***************************************************************************** + * @author This file is part of libfqfft, developed by SCIPR Lab + * and contributors (see AUTHORS). + * @copyright MIT license (see LICENSE file) + *****************************************************************************/ + +#ifndef RECURSIVE_DOMAIN_AUX_TCC_ +#define RECURSIVE_DOMAIN_AUX_TCC_ + +#include +#include +#include "prover_config.hpp" + +#ifdef MULTICORE +#include +#endif + +#include + +#include + +#ifdef DEBUG +#include +#endif + +namespace libfqfft { + +static inline std::vector get_stages(unsigned int n, const std::vector& radixes) +{ + std::vector stages; + + // Use the specified radices + for (unsigned int i = 0; i < radixes.size(); i++) + { + n /= radixes[i]; + stages.push_back(fft_stage(radixes[i], n)); + } + + // Fill in the rest of the tree if needed + unsigned int p = 4; + while (n > 1) + { + while (n % p) + { + switch (p) + { + case 4: p = 2; break; + } + } + n /= p; + stages.push_back(fft_stage(p, n)); + }; + + for (unsigned int i = 0; i < stages.size(); i++) + { + std::cout << "Stage " << i << ": " << stages[i].radix << ", " << stages[i].length << std::endl; + } + + return stages; +} + +template +static void butterfly_2(std::vector& out, const std::vector& twiddles, unsigned int stride, unsigned int stage_length, unsigned int out_offset) +{ + unsigned int out_offset2 = out_offset + stage_length; + + FieldT t = out[out_offset2]; + out[out_offset2] = out[out_offset] - t; + out[out_offset] += t; + out_offset2++; + out_offset++; + + for (unsigned int k = 1; k < stage_length; k++) + { + FieldT t = twiddles[k] * out[out_offset2]; + out[out_offset2] = out[out_offset] - t; + out[out_offset] += t; + out_offset2++; + out_offset++; + } +} + +template +static void butterfly_2_parallel(std::vector& out, const std::vector& twiddles, unsigned int stride, unsigned int stage_length, unsigned int out_offset, unsigned int num_threads) +{ + unsigned int out_offset2 = out_offset + stage_length; + + auto ranges = libsnark::get_cpu_ranges(0, stage_length, num_threads); + +#ifdef MULTICORE + #pragma omp parallel for num_threads(num_threads) +#endif + for (unsigned int c = 0; c < ranges.size(); c++) + { + unsigned int offset1 = out_offset + ranges[c].first; + unsigned int offset2 = out_offset2 + ranges[c].first; + for (unsigned int k = ranges[c].first; k < ranges[c].second; k++) + { + FieldT t = twiddles[k] * out[offset2]; + out[offset2] = out[offset1] - t; + out[offset1] += t; + offset2++; + offset1++; + } + } +} + +template +static void butterfly_4(std::vector& out, const std::vector& twiddles, unsigned int stride, unsigned int stage_length, unsigned int out_offset) +{ + const FieldT j = twiddles[twiddles.size() - 1]; + unsigned int tw = 0; + + /* Case twiddle == one */ + { + const unsigned i0 = out_offset; + const unsigned i1 = out_offset + stage_length; + const unsigned i2 = out_offset + stage_length*2; + const unsigned i3 = out_offset + stage_length*3; + + const FieldT z0 = out[i0]; + const FieldT z1 = out[i1]; + const FieldT z2 = out[i2]; + const FieldT z3 = out[i3]; + + const FieldT t1 = z0 + z2; + const FieldT t2 = z1 + z3; + const FieldT t3 = z0 - z2; + const FieldT t4j = j * (z1 - z3); + + out[i0] = t1 + t2; + out[i1] = t3 - t4j; + out[i2] = t1 - t2; + out[i3] = t3 + t4j; + + out_offset++; + tw += 3; + } + + for (unsigned int k = 1; k < stage_length; k++) + { + const unsigned i0 = out_offset; + const unsigned i1 = out_offset + stage_length; + const unsigned i2 = out_offset + stage_length*2; + const unsigned i3 = out_offset + stage_length*3; + + const FieldT z0 = out[i0]; + const FieldT z1 = out[i1] * twiddles[tw]; + const FieldT z2 = out[i2] * twiddles[tw+1]; + const FieldT z3 = out[i3] * twiddles[tw+2]; + + const FieldT t1 = z0 + z2; + const FieldT t2 = z1 + z3; + const FieldT t3 = z0 - z2; + const FieldT t4j = j * (z1 - z3); + + out[i0] = t1 + t2; + out[i1] = t3 - t4j; + out[i2] = t1 - t2; + out[i3] = t3 + t4j; + + out_offset++; + tw += 3; + } +} + +template +static void butterfly_4_parallel(std::vector& out, const std::vector& twiddles, unsigned int stride, unsigned int stage_length, unsigned int out_offset, unsigned int num_threads) +{ + const FieldT j = twiddles[twiddles.size() - 1]; + const auto ranges = libsnark::get_cpu_ranges(0, stage_length, num_threads); + +#ifdef MULTICORE + #pragma omp parallel for +#endif + for (unsigned int c = 0; c < ranges.size(); c++) + { + unsigned int offset = out_offset + ranges[c].first; + unsigned int tw = 3 * ranges[c].first; + for (unsigned int k = ranges[c].first; k < ranges[c].second; k++) + { + const unsigned i0 = offset; + const unsigned i1 = offset + stage_length; + const unsigned i2 = offset + stage_length*2; + const unsigned i3 = offset + stage_length*3; + + const FieldT z0 = out[i0]; + const FieldT z1 = out[i1] * twiddles[tw]; + const FieldT z2 = out[i2] * twiddles[tw+1]; + const FieldT z3 = out[i3] * twiddles[tw+2]; + + const FieldT t1 = z0 + z2; + const FieldT t2 = z1 + z3; + const FieldT t3 = z0 - z2; + const FieldT t4j = j * (z1 - z3); + + out[i0] = t1 + t2; + out[i1] = t3 - t4j; + out[i2] = t1 - t2; + out[i3] = t3 + t4j; + + offset++; + tw += 3; + } + } +} + +template +void _recursive_FFT_inner( + std::vector& in, + std::vector& out, + const std::vector>& twiddles, + const std::vector& stages, + unsigned int in_offset, + unsigned int out_offset, + unsigned int stride, + unsigned int level, + unsigned int num_threads) +{ + const unsigned int radix = stages[level].radix; + const unsigned int stage_length = stages[level].length; + + if (num_threads > 1) + { + if (stage_length == 1) + { + for (unsigned int i = 0; i < radix; i++) + { + out[out_offset + i] = in[in_offset + i * stride]; + } + } + else + { + unsigned int num_threads_recursive = (num_threads >= radix) ? radix : num_threads; + #pragma omp parallel for num_threads(num_threads_recursive) + for (unsigned int i = 0; i < radix; i++) + { + unsigned int num_threads_in_recursion = (num_threads < radix) ? 1 : (num_threads + i) / radix; + if (smt) + { + omp_set_num_threads(num_threads_in_recursion * 2); + } + // std::cout << "Start thread on " << level << ": " << num_threads_in_recursion << std::endl; + _recursive_FFT_inner(in, out, twiddles, stages, in_offset + i*stride, out_offset + i*stage_length, stride*radix, level+1, num_threads_in_recursion); + } + } + + switch (radix) + { + case 2: butterfly_2_parallel(out, twiddles[level], stride, stage_length, out_offset, num_threads); break; + case 4: butterfly_4_parallel(out, twiddles[level], stride, stage_length, out_offset, num_threads); break; + default: std::cout << "error" << std::endl; assert(false); + } + } + else + { + if (stage_length == 1) + { + for (unsigned int i = 0; i < radix; i++) + { + out[out_offset + i] = in[in_offset + i * stride]; + } + } + else + { + for (unsigned int i = 0; i < radix; i++) + { + _recursive_FFT_inner(in, out, twiddles, stages, in_offset + i*stride, out_offset + i*stage_length, stride*radix, level+1, num_threads); + } + } + + /*if (smt) + { + switch (radix) + { + case 2: butterfly_2_parallel(out, twiddles[level], stride, stage_length, out_offset, 2); break; + case 4: butterfly_4_parallel(out, twiddles[level], stride, stage_length, out_offset, 2); break; + default: std::cout << "error" << std::endl; assert(false); + } + } + else*/ + { + switch (radix) + { + case 2: butterfly_2(out, twiddles[level], stride, stage_length, out_offset); break; + case 4: butterfly_4(out, twiddles[level], stride, stage_length, out_offset); break; + default: std::cout << "error" << std::endl; assert(false); + } + } + } +} + +template +void _recursive_FFT(fft_data& data, std::vector& in, bool inverse) +{ +#ifdef MULTICORE + size_t num_threads = omp_get_max_threads(); + if (data.smt) + { + num_threads /= 2; + } +#else + size_t num_threads = 1; +#endif + if (data.smt) + { + _recursive_FFT_inner(in, data.scratch, inverse? data.iTwiddles : data.fTwiddles, data.stages, 0, 0, 1, 0, num_threads); + } + else + { + _recursive_FFT_inner(in, data.scratch, inverse? data.iTwiddles : data.fTwiddles, data.stages, 0, 0, 1, 0, num_threads); + } + std::swap(in, data.scratch); +} + +template +void _multiply_by_coset_and_constant(unsigned int m, std::vector &a, const FieldT &g, const FieldT &c) +{ + auto ranges = libsnark::get_cpu_ranges(1, m); + + a[0] *= c; +#ifdef MULTICORE + #pragma omp parallel for +#endif + for (size_t j = 0; j < ranges.size(); ++j) + { + FieldT u = c * (g^ranges[j].first); + for (unsigned int i = ranges[j].first; i < ranges[j].second; i++) + { + a[i] *= u; + u *= g; + } + } +} + +} // libfqfft + +#endif // RECURSIVE_DOMAIN_AUX_TCC_ diff --git a/libfqfft/evaluation_domain/get_evaluation_domain.tcc b/libfqfft/evaluation_domain/get_evaluation_domain.tcc index 299537c..fc42536 100755 --- a/libfqfft/evaluation_domain/get_evaluation_domain.tcc +++ b/libfqfft/evaluation_domain/get_evaluation_domain.tcc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include @@ -38,7 +39,8 @@ std::shared_ptr > get_evaluation_domain(const size_t m const size_t small = min_size - big; const size_t rounded_small = (1ul<(min_size)); } + try { result.reset(new recursive_domain(min_size)); } + catch(...) { try { result.reset(new basic_radix2_domain(min_size)); } catch(...) { try { result.reset(new extended_radix2_domain(min_size)); } catch(...) { try { result.reset(new step_radix2_domain(min_size)); } catch(...) { try { result.reset(new basic_radix2_domain(big + rounded_small)); } @@ -46,7 +48,7 @@ std::shared_ptr > get_evaluation_domain(const size_t m catch(...) { try { result.reset(new step_radix2_domain(big + rounded_small)); } catch(...) { try { result.reset(new geometric_sequence_domain(min_size)); } catch(...) { try { result.reset(new arithmetic_sequence_domain(min_size)); } - catch(...) { throw DomainSizeException("get_evaluation_domain: no matching domain"); }}}}}}}} + catch(...) { throw DomainSizeException("get_evaluation_domain: no matching domain"); }}}}}}}}} return result; }