diff --git a/halo2_proofs/src/arithmetic.rs b/halo2_proofs/src/arithmetic.rs index bab2a737be..d0864d55c4 100644 --- a/halo2_proofs/src/arithmetic.rs +++ b/halo2_proofs/src/arithmetic.rs @@ -169,18 +169,7 @@ pub fn best_multiexp(coeffs: &[C::Scalar], bases: &[C]) -> C::Cu /// /// This will use multithreading if beneficial. pub fn best_fft(a: &mut [G], omega: G::Scalar, log_n: u32) { - let threads = multicore::current_num_threads(); - let log_threads = log2_floor(threads); - - if log_n <= log_threads { - serial_fft(a, omega, log_n); - } else { - parallel_fft(a, omega, log_n, log_threads); - } -} - -fn serial_fft(a: &mut [G], omega: G::Scalar, log_n: u32) { - fn bitreverse(mut n: u32, l: u32) -> u32 { + fn bitreverse(mut n: usize, l: usize) -> usize { let mut r = 0; for _ in 0..l { r = (r << 1) | (n & 1); @@ -189,79 +178,97 @@ fn serial_fft(a: &mut [G], omega: G::Scalar, log_n: u32) { r } - let n = a.len() as u32; + let threads = multicore::current_num_threads(); + let log_threads = log2_floor(threads); + let n = a.len() as usize; assert_eq!(n, 1 << log_n); for k in 0..n { - let rk = bitreverse(k, log_n); + let rk = bitreverse(k, log_n as usize); if k < rk { - a.swap(rk as usize, k as usize); + a.swap(rk, k); } } - let mut m = 1; - for _ in 0..log_n { - let w_m = omega.pow_vartime(&[u64::from(n / (2 * m)), 0, 0, 0]); - - let mut k = 0; - while k < n { - let mut w = G::Scalar::one(); - for j in 0..m { - let mut t = a[(k + j + m) as usize]; - t.group_scale(&w); - a[(k + j + m) as usize] = a[(k + j) as usize]; - a[(k + j + m) as usize].group_sub(&t); - a[(k + j) as usize].group_add(&t); - w *= &w_m; - } + // precompute twiddle factors + let mut w = G::Scalar::one(); + let mut twiddles = vec![G::Scalar::one(); (n / 2) as usize]; + for tw in twiddles.iter_mut() { + *tw = w; + w.group_scale(&omega); + } - k += 2 * m; + if log_n <= log_threads { + let mut chunk = 2_usize; + let mut twiddle_chunk = (n / 2) as usize; + for _ in 0..log_n { + a.chunks_mut(chunk).for_each(|coeffs| { + let (left, right) = coeffs.split_at_mut(chunk / 2); + + // case when twiddle factor is one + let (a, left) = left.split_at_mut(1); + let (b, right) = right.split_at_mut(1); + let t = b[0]; + b[0] = a[0]; + a[0].group_add(&t); + b[0].group_sub(&t); + + left.iter_mut() + .zip(right.iter_mut()) + .enumerate() + .for_each(|(i, (a, b))| { + let mut t = *b; + t.group_scale(&twiddles[(i + 1) * twiddle_chunk]); + *b = *a; + a.group_add(&t); + b.group_sub(&t); + }); + }); + chunk *= 2; + twiddle_chunk /= 2; } - - m *= 2; + } else { + recursive_butterfly_arithmetic(a, n, 1, &twiddles) } } -fn parallel_fft(a: &mut [G], omega: G::Scalar, log_n: u32, log_threads: u32) { - assert!(log_n >= log_threads); - - let num_threads = 1 << log_threads; - let log_new_n = log_n - log_threads; - let mut tmp = vec![vec![G::group_zero(); 1 << log_new_n]; num_threads]; - let new_omega = omega.pow_vartime(&[num_threads as u64, 0, 0, 0]); - - multicore::scope(|scope| { - let a = &*a; - - for (j, tmp) in tmp.iter_mut().enumerate() { - scope.spawn(move |_| { - // Shuffle into a sub-FFT - let omega_j = omega.pow_vartime(&[j as u64, 0, 0, 0]); - let omega_step = omega.pow_vartime(&[(j as u64) << log_new_n, 0, 0, 0]); - - let mut elt = G::Scalar::one(); - - for (i, tmp) in tmp.iter_mut().enumerate() { - for s in 0..num_threads { - let idx = (i + (s << log_new_n)) % (1 << log_n); - let mut t = a[idx]; - t.group_scale(&elt); - tmp.group_add(&t); - elt *= &omega_step; - } - elt *= &omega_j; - } - - // Perform sub-FFT - serial_fft(tmp, new_omega, log_new_n); +/// This perform recursive butterfly arithmetic +pub fn recursive_butterfly_arithmetic( + a: &mut [G], + n: usize, + twiddle_chunk: usize, + twiddles: &[G::Scalar], +) { + if n == 2 { + let t = a[1]; + a[1] = a[0]; + a[0].group_add(&t); + a[1].group_sub(&t); + } else { + let (left, right) = a.split_at_mut(n / 2); + rayon::join( + || recursive_butterfly_arithmetic(left, n / 2, twiddle_chunk * 2, twiddles), + || recursive_butterfly_arithmetic(right, n / 2, twiddle_chunk * 2, twiddles), + ); + + // case when twiddle factor is one + let (a, left) = left.split_at_mut(1); + let (b, right) = right.split_at_mut(1); + let t = b[0]; + b[0] = a[0]; + a[0].group_add(&t); + b[0].group_sub(&t); + + left.iter_mut() + .zip(right.iter_mut()) + .enumerate() + .for_each(|(i, (a, b))| { + let mut t = *b; + t.group_scale(&twiddles[(i + 1) * twiddle_chunk]); + *b = *a; + a.group_add(&t); + b.group_sub(&t); }); - } - }); - - // Unshuffle - let mask = (1 << log_threads) - 1; - for (idx, a) in a.iter_mut().enumerate() { - *a = tmp[idx & mask][idx >> log_threads]; } }