Skip to content

Commit

Permalink
optimize fft
Browse files Browse the repository at this point in the history
  • Loading branch information
ashWhiteHat committed Apr 28, 2022
1 parent 1e6bb51 commit 9a9873a
Showing 1 changed file with 79 additions and 72 deletions.
151 changes: 79 additions & 72 deletions halo2_proofs/src/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,18 +169,7 @@ pub fn best_multiexp<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Cu
///
/// This will use multithreading if beneficial.
pub fn best_fft<G: Group>(a: &mut [G], omega: G::Scalar, log_n: u32) {
let threads = multicore::current_num_threads();
let log_threads = log2_floor(threads);

if log_n <= log_threads {
serial_fft(a, omega, log_n);
} else {
parallel_fft(a, omega, log_n, log_threads);
}
}

fn serial_fft<G: Group>(a: &mut [G], omega: G::Scalar, log_n: u32) {
fn bitreverse(mut n: u32, l: u32) -> u32 {
fn bitreverse(mut n: usize, l: usize) -> usize {
let mut r = 0;
for _ in 0..l {
r = (r << 1) | (n & 1);
Expand All @@ -189,79 +178,97 @@ fn serial_fft<G: Group>(a: &mut [G], omega: G::Scalar, log_n: u32) {
r
}

let n = a.len() as u32;
let threads = multicore::current_num_threads();
let log_threads = log2_floor(threads);
let n = a.len() as usize;
assert_eq!(n, 1 << log_n);

for k in 0..n {
let rk = bitreverse(k, log_n);
let rk = bitreverse(k, log_n as usize);
if k < rk {
a.swap(rk as usize, k as usize);
a.swap(rk, k);
}
}

let mut m = 1;
for _ in 0..log_n {
let w_m = omega.pow_vartime(&[u64::from(n / (2 * m)), 0, 0, 0]);

let mut k = 0;
while k < n {
let mut w = G::Scalar::one();
for j in 0..m {
let mut t = a[(k + j + m) as usize];
t.group_scale(&w);
a[(k + j + m) as usize] = a[(k + j) as usize];
a[(k + j + m) as usize].group_sub(&t);
a[(k + j) as usize].group_add(&t);
w *= &w_m;
}
// precompute twiddle factors
let mut w = G::Scalar::one();
let mut twiddles = vec![G::Scalar::one(); (n / 2) as usize];
for tw in twiddles.iter_mut() {
*tw = w;
w.group_scale(&omega);
}

k += 2 * m;
if log_n <= log_threads {
let mut chunk = 2_usize;
let mut twiddle_chunk = (n / 2) as usize;
for _ in 0..log_n {
a.chunks_mut(chunk).for_each(|coeffs| {
let (left, right) = coeffs.split_at_mut(chunk / 2);

// case when twiddle factor is one
let (a, left) = left.split_at_mut(1);
let (b, right) = right.split_at_mut(1);
let t = b[0];
b[0] = a[0];
a[0].group_add(&t);
b[0].group_sub(&t);

left.iter_mut()
.zip(right.iter_mut())
.enumerate()
.for_each(|(i, (a, b))| {
let mut t = *b;
t.group_scale(&twiddles[(i + 1) * twiddle_chunk]);
*b = *a;
a.group_add(&t);
b.group_sub(&t);
});
});
chunk *= 2;
twiddle_chunk /= 2;
}

m *= 2;
} else {
recursive_butterfly_arithmetic(a, n, 1, &twiddles)
}
}

fn parallel_fft<G: Group>(a: &mut [G], omega: G::Scalar, log_n: u32, log_threads: u32) {
assert!(log_n >= log_threads);

let num_threads = 1 << log_threads;
let log_new_n = log_n - log_threads;
let mut tmp = vec![vec![G::group_zero(); 1 << log_new_n]; num_threads];
let new_omega = omega.pow_vartime(&[num_threads as u64, 0, 0, 0]);

multicore::scope(|scope| {
let a = &*a;

for (j, tmp) in tmp.iter_mut().enumerate() {
scope.spawn(move |_| {
// Shuffle into a sub-FFT
let omega_j = omega.pow_vartime(&[j as u64, 0, 0, 0]);
let omega_step = omega.pow_vartime(&[(j as u64) << log_new_n, 0, 0, 0]);

let mut elt = G::Scalar::one();

for (i, tmp) in tmp.iter_mut().enumerate() {
for s in 0..num_threads {
let idx = (i + (s << log_new_n)) % (1 << log_n);
let mut t = a[idx];
t.group_scale(&elt);
tmp.group_add(&t);
elt *= &omega_step;
}
elt *= &omega_j;
}

// Perform sub-FFT
serial_fft(tmp, new_omega, log_new_n);
/// This perform recursive butterfly arithmetic
pub fn recursive_butterfly_arithmetic<G: Group>(
a: &mut [G],
n: usize,
twiddle_chunk: usize,
twiddles: &[G::Scalar],
) {
if n == 2 {
let t = a[1];
a[1] = a[0];
a[0].group_add(&t);
a[1].group_sub(&t);
} else {
let (left, right) = a.split_at_mut(n / 2);
rayon::join(
|| recursive_butterfly_arithmetic(left, n / 2, twiddle_chunk * 2, twiddles),
|| recursive_butterfly_arithmetic(right, n / 2, twiddle_chunk * 2, twiddles),
);

// case when twiddle factor is one
let (a, left) = left.split_at_mut(1);
let (b, right) = right.split_at_mut(1);
let t = b[0];
b[0] = a[0];
a[0].group_add(&t);
b[0].group_sub(&t);

left.iter_mut()
.zip(right.iter_mut())
.enumerate()
.for_each(|(i, (a, b))| {
let mut t = *b;
t.group_scale(&twiddles[(i + 1) * twiddle_chunk]);
*b = *a;
a.group_add(&t);
b.group_sub(&t);
});
}
});

// Unshuffle
let mask = (1 << log_threads) - 1;
for (idx, a) in a.iter_mut().enumerate() {
*a = tmp[idx & mask][idx >> log_threads];
}
}

Expand Down

0 comments on commit 9a9873a

Please sign in to comment.