diff --git a/src/lem/multiframe.rs b/src/lem/multiframe.rs index 9b3aa42d73..86718ad84c 100644 --- a/src/lem/multiframe.rs +++ b/src/lem/multiframe.rs @@ -5,6 +5,7 @@ use bellpepper_core::{num::AllocatedNum, Circuit, ConstraintSystem, SynthesisErr use elsa::sync::FrozenMap; use ff::PrimeField; use nova::{supernova::NonUniformCircuit, traits::Engine}; +use once_cell::sync::OnceCell; use rayon::prelude::*; use std::sync::Arc; @@ -48,7 +49,8 @@ pub struct MultiFrame<'a, F: LurkField, C: Coprocessor> { input: Option>, output: Option>, frames: Option>, - cached_witness: Option>, + /// Cached witness and output for this `MultiFrame` + cached_witness: OnceCell<(WitnessCS, Vec>)>, num_frames: usize, folding_config: Arc>, pc: usize, @@ -403,24 +405,23 @@ impl<'a, F: LurkField, C: Coprocessor + 'a> MultiFrameTrait<'a, F, C> for Mul store.to_scalar_vector(io) } - fn compute_witness(&self, s: &Store) -> WitnessCS { - let mut wcs = WitnessCS::new(); + fn cache_witness(&mut self, s: &Store) -> Result<(), SynthesisError> { + let _ = self.cached_witness.get_or_try_init(|| { + let mut wcs = WitnessCS::new(); - let z_scalar = s.to_scalar_vector(self.input.as_ref().unwrap()); + let z_scalar = s.to_scalar_vector(self.input.as_ref().unwrap()); - let mut bogus_cs = WitnessCS::::new(); - let z: Vec> = z_scalar - .iter() - .map(|x| AllocatedNum::alloc(&mut bogus_cs, || Ok(*x)).unwrap()) - .collect::>(); - - let _ = nova::traits::circuit::StepCircuit::synthesize(self, &mut wcs, z.as_slice()); - - wcs - } + let mut bogus_cs = WitnessCS::::new(); + let z: Vec> = z_scalar + .iter() + .map(|x| AllocatedNum::alloc_infallible(&mut bogus_cs, || *x)) + .collect::>(); - fn cached_witness(&mut self) -> &mut Option> { - &mut self.cached_witness + let output = + nova::traits::circuit::StepCircuit::synthesize(self, &mut wcs, z.as_slice())?; + Ok::<_, SynthesisError>((wcs, output)) + })?; + Ok(()) } fn output(&self) -> &Option<>::FrameIO> { @@ -511,7 +512,7 @@ impl<'a, F: LurkField, C: Coprocessor + 'a> MultiFrameTrait<'a, F, C> for Mul input: None, output: None, frames: None, - cached_witness: None, + cached_witness: OnceCell::new(), num_frames, folding_config, pc, @@ -559,7 +560,7 @@ impl<'a, F: LurkField, C: Coprocessor + 'a> MultiFrameTrait<'a, F, C> for Mul input: Some(chunk[0].input.to_vec()), output: Some(output), frames: Some(inner_frames), - cached_witness: None, + cached_witness: OnceCell::new(), num_frames: reduction_count, folding_config: folding_config.clone(), pc: 0, @@ -644,7 +645,7 @@ impl<'a, F: LurkField, C: Coprocessor + 'a> MultiFrameTrait<'a, F, C> for Mul input: Some(input), output: Some(output), frames: Some(frames_to_add), - cached_witness: None, + cached_witness: OnceCell::new(), num_frames, folding_config: folding_config.clone(), pc, @@ -835,7 +836,19 @@ impl<'a, F: LurkField, C: Coprocessor> nova::traits::circuit::StepCircuit { assert_eq!(self.arity(), z.len()); - let n_ptrs = self.arity() / 2; + if cs.is_witness_generator() { + if let Some((w, output)) = self.cached_witness.get() { + // nothing has been inputized so far + assert_eq!(cs.inputs_slice(), &[F::ONE]); + assert_eq!(w.inputs_slice(), &[F::ONE]); + // output must have the same length as the input + assert_eq!(output.len(), z.len()); + cs.extend_aux(w.aux_slice()); + return Ok(output.clone()); + } + }; + + let n_ptrs = z.len() / 2; let mut input = Vec::with_capacity(n_ptrs); for i in 0..n_ptrs { input.push(AllocatedPtr::from_parts( @@ -863,7 +876,7 @@ impl<'a, F: LurkField, C: Coprocessor> nova::traits::circuit::StepCircuit } }; - let mut output = Vec::with_capacity(self.arity()); + let mut output = Vec::with_capacity(z.len()); for ptr in output_ptrs { output.push(ptr.tag().clone()); output.push(ptr.hash().clone()); diff --git a/src/proof/mod.rs b/src/proof/mod.rs index 6ccc6e8733..ff679ef1a2 100644 --- a/src/proof/mod.rs +++ b/src/proof/mod.rs @@ -16,7 +16,6 @@ pub mod supernova; mod tests; use ::nova::traits::{circuit::StepCircuit, Engine}; -use bellpepper::util_cs::witness_cs::WitnessCS; use bellpepper_core::{test_cs::TestConstraintSystem, Circuit, ConstraintSystem, SynthesisError}; use std::sync::Arc; @@ -120,11 +119,10 @@ pub trait MultiFrameTrait<'a, F: LurkField, C: Coprocessor + 'a>: /// Returns true if the supplied instance directly precedes this one in a sequential computation trace. fn precedes(&self, maybe_next: &Self) -> bool; - /// Populates a WitnessCS with the witness values for the given store. - fn compute_witness(&self, s: &Self::Store) -> WitnessCS; - - /// Returns a reference to the cached witness values - fn cached_witness(&mut self) -> &mut Option>; + /// Cache the witness internally, which can be used later during synthesis. + /// This function can be called in parallel to speed up the witness generation + /// for a series of `MultiFrameTrait` instances. + fn cache_witness(&mut self, s: &Self::Store) -> Result<(), SynthesisError>; /// The output of the last frame fn output(&self) -> &Option<>::FrameIO>; diff --git a/src/proof/nova.rs b/src/proof/nova.rs index 75a6869076..1fb95cf500 100644 --- a/src/proof/nova.rs +++ b/src/proof/nova.rs @@ -281,13 +281,10 @@ where // Skip the very first circuit's witness, so `prove_step` can begin immediately. // That circuit's witness will not be cached and will just be computed on-demand. cc.par_iter().skip(1).for_each(|mf| { - let witness = { - let mf1 = mf.lock().unwrap(); - mf1.compute_witness(store) - }; - let mut mf2 = mf.lock().unwrap(); - - *mf2.cached_witness() = Some(witness); + mf.lock() + .unwrap() + .cache_witness(store) + .expect("witness caching failed"); }); });