diff --git a/halo2-base/Cargo.toml b/halo2-base/Cargo.toml index b20e375f..33799495 100644 --- a/halo2-base/Cargo.toml +++ b/halo2-base/Cargo.toml @@ -11,6 +11,7 @@ num-traits = "0.2" rand_chacha = "0.3" rustc-hash = "1.1" ff = "0.12" +rayon = "1.6.1" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" log = "0.4" diff --git a/halo2-base/src/gates/builder.rs b/halo2-base/src/gates/builder.rs index 9cd68db0..22c2ce93 100644 --- a/halo2-base/src/gates/builder.rs +++ b/halo2-base/src/gates/builder.rs @@ -17,6 +17,9 @@ use std::{ env::{set_var, var}, }; +mod parallelize; +pub use parallelize::*; + /// Vector of thread advice column break points pub type ThreadBreakPoints = Vec; /// Vector of vectors tracking the thread break points across different halo2 phases diff --git a/halo2-base/src/gates/builder/parallelize.rs b/halo2-base/src/gates/builder/parallelize.rs new file mode 100644 index 00000000..ab9171d5 --- /dev/null +++ b/halo2-base/src/gates/builder/parallelize.rs @@ -0,0 +1,38 @@ +use itertools::Itertools; +use rayon::prelude::*; + +use crate::{utils::ScalarField, Context}; + +use super::GateThreadBuilder; + +/// Utility function to parallelize an operation involving [`Context`]s in phase `phase`. +pub fn parallelize_in( + phase: usize, + builder: &mut GateThreadBuilder, + input: Vec, + f: FR, +) -> Vec +where + F: ScalarField, + T: Send, + R: Send, + FR: Fn(&mut Context, T) -> R + Send + Sync, +{ + let witness_gen_only = builder.witness_gen_only(); + // to prevent concurrency issues with context id, we generate all the ids first + let ctx_ids = input.iter().map(|_| builder.get_new_thread_id()).collect_vec(); + let (outputs, mut ctxs): (Vec<_>, Vec<_>) = input + .into_par_iter() + .zip(ctx_ids.into_par_iter()) + .map(|(input, ctx_id)| { + // create new context + let mut ctx = Context::new(witness_gen_only, ctx_id); + let output = f(&mut ctx, input); + (output, ctx) + }) + .unzip(); + // we collect the new threads to ensure they are a FIXED order, otherwise later `assign_threads_in` will get confused + builder.threads[phase].append(&mut ctxs); + + outputs +} diff --git a/halo2-ecc/src/ecc/fixed_base.rs b/halo2-ecc/src/ecc/fixed_base.rs index 1d2ff19a..de9e8d86 100644 --- a/halo2-ecc/src/ecc/fixed_base.rs +++ b/halo2-ecc/src/ecc/fixed_base.rs @@ -2,7 +2,7 @@ use super::{ec_add_unequal, ec_select, ec_select_from_bits, EcPoint, EccChip}; use crate::fields::{FieldChip, PrimeField, Selectable}; use group::Curve; -use halo2_base::gates::builder::GateThreadBuilder; +use halo2_base::gates::builder::{parallelize_in, GateThreadBuilder}; use halo2_base::{gates::GateInstructions, utils::CurveAffineExt, AssignedValue, Context}; use itertools::Itertools; use rayon::prelude::*; @@ -107,6 +107,7 @@ where curr_point.unwrap() } +/* To reduce total amount of code, just always use msm_par below. // basically just adding up individual fixed_base::scalar_multiply except that we do all batched normalization of cached points at once to further save inversion time during witness generation // we also use the random accumulator for some extra efficiency (which also works in scalar multiply case but that is TODO) pub fn msm( @@ -212,6 +213,7 @@ where .collect_vec(); chip.sum::(ctx, scalar_mults) } +*/ /// # Assumptions /// * `points.len() = scalars.len()` @@ -269,25 +271,23 @@ where C::Curve::batch_normalize(&cached_points_jacobian, &mut cached_points_affine); let field_chip = chip.field_chip(); - let witness_gen_only = builder.witness_gen_only(); let zero = builder.main(phase).load_zero(); - let thread_ids = (0..scalars.len()).map(|_| builder.get_new_thread_id()).collect::>(); - let (new_threads, scalar_mults): (Vec<_>, Vec<_>) = cached_points_affine - .par_chunks(cached_points_affine.len() / points.len()) - .zip_eq(scalars.into_par_iter()) - .zip(thread_ids.into_par_iter()) - .map(|((cached_points, scalar), thread_id)| { - let mut thread = Context::new(witness_gen_only, thread_id); - let ctx = &mut thread; - + let scalar_mults = parallelize_in( + phase, + builder, + cached_points_affine + .chunks(cached_points_affine.len() / points.len()) + .zip_eq(scalars) + .collect(), + |ctx, (cached_points, scalar)| { let cached_points = cached_points .iter() .map(|point| chip.assign_constant_point(ctx, *point)) .collect_vec(); let cached_point_window_rev = cached_points.chunks(1usize << window_bits).rev(); - debug_assert_eq!(scalar.len(), scalar_len); + assert_eq!(scalar.len(), scalar_len); let bits = scalar .into_iter() .flat_map(|scalar_chunk| { @@ -319,9 +319,8 @@ where field_chip.gate().mul_add(ctx, is_started, is_zero_window, not_zero_window) }; } - (thread, curr_point.unwrap()) - }) - .unzip(); - builder.threads[phase].extend(new_threads); + curr_point.unwrap() + }, + ); chip.sum::(builder.main(phase), scalar_mults) } diff --git a/halo2-ecc/src/ecc/mod.rs b/halo2-ecc/src/ecc/mod.rs index 8b3895f1..eea89a31 100644 --- a/halo2-ecc/src/ecc/mod.rs +++ b/halo2-ecc/src/ecc/mod.rs @@ -971,6 +971,7 @@ impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> { self.field_chip.assert_equal(ctx, P.y, Q.y); } + /// None of elements in `points` can be point at infinity. pub fn sum( &self, ctx: &mut Context, @@ -1153,21 +1154,15 @@ impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> { #[cfg(feature = "display")] println!("computing length {} fixed base msm", points.len()); - // heuristic to decide when to use parallelism - if points.len() < 25 { - let ctx = builder.main(phase); - fixed_base::msm(self, ctx, points, scalars, max_scalar_bits_per_cell, clump_factor) - } else { - fixed_base::msm_par( - self, - builder, - points, - scalars, - max_scalar_bits_per_cell, - clump_factor, - phase, - ) - } + fixed_base::msm_par( + self, + builder, + points, + scalars, + max_scalar_bits_per_cell, + clump_factor, + phase, + ) // Empirically does not seem like pippenger is any better for fixed base msm right now, because of the cost of `select_by_indicator` // Cell usage becomes around comparable when `points.len() > 100`, and `clump_factor` should always be 4 diff --git a/halo2-ecc/src/ecc/pippenger.rs b/halo2-ecc/src/ecc/pippenger.rs index 58e7c739..934a7432 100644 --- a/halo2-ecc/src/ecc/pippenger.rs +++ b/halo2-ecc/src/ecc/pippenger.rs @@ -7,11 +7,13 @@ use crate::{ fields::{FieldChip, PrimeField, Selectable}, }; use halo2_base::{ - gates::{builder::GateThreadBuilder, GateInstructions}, + gates::{ + builder::{parallelize_in, GateThreadBuilder}, + GateInstructions, + }, utils::CurveAffineExt, - AssignedValue, Context, + AssignedValue, }; -use rayon::prelude::*; // Reference: https://jbootle.github.io/Misc/pippenger.pdf @@ -238,7 +240,6 @@ where // get a main thread let ctx = builder.main(phase); - let witness_gen_only = ctx.witness_gen_only(); // single-threaded computation: for scalar in scalars { for (scalar_chunk, bool_chunk) in @@ -250,32 +251,28 @@ where } } } - // see multi-product comments for explanation of below let c = clump_factor; let num_rounds = (points.len() + c - 1) / c; + // to avoid adding two points that are equal or negative of each other, + // we use a trick from halo2wrong where we load a "sufficiently generic" `C` point as witness + // note that while we load a random point, an adversary could load a specifically chosen point, so we must carefully handle edge cases with constraints + // we call it "any point" instead of "random point" to emphasize that "any" sufficiently generic point will do let any_base = load_random_point::(chip, ctx); let mut any_points = Vec::with_capacity(num_rounds); any_points.push(any_base); for _ in 1..num_rounds { any_points.push(ec_double(chip, ctx, any_points.last().unwrap())); } - // we will use a different thread per round - // to prevent concurrency issues with context id, we generate all the ids first - let thread_ids = (0..num_rounds).map(|_| builder.get_new_thread_id()).collect::>(); - // now begins multi-threading + // now begins multi-threading // multi_prods is 2d vector of size `num_rounds` by `scalar_bits` - let (new_threads, multi_prods): (Vec<_>, Vec<_>) = points - .par_chunks(c) - .zip(any_points.par_iter()) - .zip(thread_ids.into_par_iter()) - .enumerate() - .map(|(round, ((points_clump, any_point), thread_id))| { + let multi_prods = parallelize_in( + phase, + builder, + points.chunks(c).into_iter().zip(any_points.iter()).enumerate().collect(), + |ctx, (round, (points_clump, any_point))| { // compute all possible multi-products of elements in points[round * c .. round * (c+1)] - // create new thread - let mut thread = Context::new(witness_gen_only, thread_id); - let ctx = &mut thread; // stores { any_point, any_point + points[0], any_point + points[1], any_point + points[0] + points[1] , ... } let mut bucket = Vec::with_capacity(1 << c); let any_point = into_strict_point(chip, ctx, any_point.clone()); @@ -294,7 +291,7 @@ where bucket.push(new_point); } } - let multi_prods = bool_scalars + bool_scalars .iter() .map(|bits| { strict_ec_select_from_bits( @@ -304,31 +301,19 @@ where &bits[round * c..round * c + points_clump.len()], ) }) - .collect::>(); - - (thread, multi_prods) - }) - .unzip(); - // we collect the new threads to ensure they are a FIXED order, otherwise later `assign_threads_in` will get confused - builder.threads[phase].extend(new_threads); + .collect::>() + }, + ); // agg[j] = sum_{i=0..num_rounds} multi_prods[i][j] for j = 0..scalar_bits - let thread_ids = (0..scalar_bits).map(|_| builder.get_new_thread_id()).collect::>(); - let (new_threads, mut agg): (Vec<_>, Vec<_>) = thread_ids - .into_par_iter() - .enumerate() - .map(|(i, thread_id)| { - let mut thread = Context::new(witness_gen_only, thread_id); - let ctx = &mut thread; - let mut acc = multi_prods[0][i].clone(); - for multi_prod in multi_prods.iter().skip(1) { - let _acc = ec_add_unequal(chip, ctx, &acc, &multi_prod[i], true); - acc = into_strict_point(chip, ctx, _acc); - } - (thread, acc) - }) - .unzip(); - builder.threads[phase].extend(new_threads); + let mut agg = parallelize_in(phase, builder, (0..scalar_bits).collect(), |ctx, i| { + let mut acc = multi_prods[0][i].clone(); + for multi_prod in multi_prods.iter().skip(1) { + let _acc = ec_add_unequal(chip, ctx, &acc, &multi_prod[i], true); + acc = into_strict_point(chip, ctx, _acc); + } + acc + }); // gets the LAST thread for single threaded work let ctx = builder.main(phase);