Skip to content

Commit

Permalink
Actually use precomputed witness (#932)
Browse files Browse the repository at this point in the history
* actually use precomputed witness

* simplify multiframe witness cache plumbing

* use OnceCell instead of Option for caching MultiFrame witness
  • Loading branch information
arthurpaulino authored Dec 1, 2023
1 parent 86b5844 commit 212f295
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 34 deletions.
55 changes: 34 additions & 21 deletions src/lem/multiframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use bellpepper_core::{num::AllocatedNum, Circuit, ConstraintSystem, SynthesisErr
use elsa::sync::FrozenMap;
use ff::PrimeField;
use nova::{supernova::NonUniformCircuit, traits::Engine};
use once_cell::sync::OnceCell;
use rayon::prelude::*;
use std::sync::Arc;

Expand Down Expand Up @@ -48,7 +49,8 @@ pub struct MultiFrame<'a, F: LurkField, C: Coprocessor<F>> {
input: Option<Vec<Ptr>>,
output: Option<Vec<Ptr>>,
frames: Option<Vec<Frame>>,
cached_witness: Option<WitnessCS<F>>,
/// Cached witness and output for this `MultiFrame`
cached_witness: OnceCell<(WitnessCS<F>, Vec<AllocatedNum<F>>)>,
num_frames: usize,
folding_config: Arc<FoldingConfig<F, C>>,
pc: usize,
Expand Down Expand Up @@ -403,24 +405,23 @@ impl<'a, F: LurkField, C: Coprocessor<F> + 'a> MultiFrameTrait<'a, F, C> for Mul
store.to_scalar_vector(io)
}

fn compute_witness(&self, s: &Store<F>) -> WitnessCS<F> {
let mut wcs = WitnessCS::new();
fn cache_witness(&mut self, s: &Store<F>) -> Result<(), SynthesisError> {
let _ = self.cached_witness.get_or_try_init(|| {
let mut wcs = WitnessCS::new();

let z_scalar = s.to_scalar_vector(self.input.as_ref().unwrap());
let z_scalar = s.to_scalar_vector(self.input.as_ref().unwrap());

let mut bogus_cs = WitnessCS::<F>::new();
let z: Vec<AllocatedNum<F>> = z_scalar
.iter()
.map(|x| AllocatedNum::alloc(&mut bogus_cs, || Ok(*x)).unwrap())
.collect::<Vec<_>>();

let _ = nova::traits::circuit::StepCircuit::synthesize(self, &mut wcs, z.as_slice());

wcs
}
let mut bogus_cs = WitnessCS::<F>::new();
let z: Vec<AllocatedNum<F>> = z_scalar
.iter()
.map(|x| AllocatedNum::alloc_infallible(&mut bogus_cs, || *x))
.collect::<Vec<_>>();

fn cached_witness(&mut self) -> &mut Option<WitnessCS<F>> {
&mut self.cached_witness
let output =
nova::traits::circuit::StepCircuit::synthesize(self, &mut wcs, z.as_slice())?;
Ok::<_, SynthesisError>((wcs, output))
})?;
Ok(())
}

fn output(&self) -> &Option<<Self::EvalFrame as FrameLike<Ptr, Ptr>>::FrameIO> {
Expand Down Expand Up @@ -511,7 +512,7 @@ impl<'a, F: LurkField, C: Coprocessor<F> + 'a> MultiFrameTrait<'a, F, C> for Mul
input: None,
output: None,
frames: None,
cached_witness: None,
cached_witness: OnceCell::new(),
num_frames,
folding_config,
pc,
Expand Down Expand Up @@ -559,7 +560,7 @@ impl<'a, F: LurkField, C: Coprocessor<F> + 'a> MultiFrameTrait<'a, F, C> for Mul
input: Some(chunk[0].input.to_vec()),
output: Some(output),
frames: Some(inner_frames),
cached_witness: None,
cached_witness: OnceCell::new(),
num_frames: reduction_count,
folding_config: folding_config.clone(),
pc: 0,
Expand Down Expand Up @@ -644,7 +645,7 @@ impl<'a, F: LurkField, C: Coprocessor<F> + 'a> MultiFrameTrait<'a, F, C> for Mul
input: Some(input),
output: Some(output),
frames: Some(frames_to_add),
cached_witness: None,
cached_witness: OnceCell::new(),
num_frames,
folding_config: folding_config.clone(),
pc,
Expand Down Expand Up @@ -835,7 +836,19 @@ impl<'a, F: LurkField, C: Coprocessor<F>> nova::traits::circuit::StepCircuit<F>
{
assert_eq!(self.arity(), z.len());

let n_ptrs = self.arity() / 2;
if cs.is_witness_generator() {
if let Some((w, output)) = self.cached_witness.get() {
// nothing has been inputized so far
assert_eq!(cs.inputs_slice(), &[F::ONE]);
assert_eq!(w.inputs_slice(), &[F::ONE]);
// output must have the same length as the input
assert_eq!(output.len(), z.len());
cs.extend_aux(w.aux_slice());
return Ok(output.clone());
}
};

let n_ptrs = z.len() / 2;
let mut input = Vec::with_capacity(n_ptrs);
for i in 0..n_ptrs {
input.push(AllocatedPtr::from_parts(
Expand Down Expand Up @@ -863,7 +876,7 @@ impl<'a, F: LurkField, C: Coprocessor<F>> nova::traits::circuit::StepCircuit<F>
}
};

let mut output = Vec::with_capacity(self.arity());
let mut output = Vec::with_capacity(z.len());
for ptr in output_ptrs {
output.push(ptr.tag().clone());
output.push(ptr.hash().clone());
Expand Down
10 changes: 4 additions & 6 deletions src/proof/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ pub mod supernova;
mod tests;

use ::nova::traits::{circuit::StepCircuit, Engine};
use bellpepper::util_cs::witness_cs::WitnessCS;
use bellpepper_core::{test_cs::TestConstraintSystem, Circuit, ConstraintSystem, SynthesisError};
use std::sync::Arc;

Expand Down Expand Up @@ -120,11 +119,10 @@ pub trait MultiFrameTrait<'a, F: LurkField, C: Coprocessor<F> + 'a>:
/// Returns true if the supplied instance directly precedes this one in a sequential computation trace.
fn precedes(&self, maybe_next: &Self) -> bool;

/// Populates a WitnessCS with the witness values for the given store.
fn compute_witness(&self, s: &Self::Store) -> WitnessCS<F>;

/// Returns a reference to the cached witness values
fn cached_witness(&mut self) -> &mut Option<WitnessCS<F>>;
/// Cache the witness internally, which can be used later during synthesis.
/// This function can be called in parallel to speed up the witness generation
/// for a series of `MultiFrameTrait` instances.
fn cache_witness(&mut self, s: &Self::Store) -> Result<(), SynthesisError>;

/// The output of the last frame
fn output(&self) -> &Option<<Self::EvalFrame as FrameLike<Self::Ptr, Self::ContPtr>>::FrameIO>;
Expand Down
11 changes: 4 additions & 7 deletions src/proof/nova.rs
Original file line number Diff line number Diff line change
Expand Up @@ -281,13 +281,10 @@ where
// Skip the very first circuit's witness, so `prove_step` can begin immediately.
// That circuit's witness will not be cached and will just be computed on-demand.
cc.par_iter().skip(1).for_each(|mf| {
let witness = {
let mf1 = mf.lock().unwrap();
mf1.compute_witness(store)
};
let mut mf2 = mf.lock().unwrap();

*mf2.cached_witness() = Some(witness);
mf.lock()
.unwrap()
.cache_witness(store)
.expect("witness caching failed");
});
});

Expand Down

1 comment on commit 212f295

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmarks

Table of Contents

Overview

This benchmark report shows the Fibonacci GPU benchmark.
NVIDIA L4
Intel(R) Xeon(R) CPU @ 2.20GHz
125.78 GB RAM
Workflow run: https://github.com/lurk-lab/lurk-rs/actions/runs/7066296206

Benchmark Results

LEM Fibonacci Prove - rc = 100

fib-ref=86b5844e7136799afe245f46638e049eb0a4b66b fib-ref=212f295f846e31d752f2ca2e5d81200b651a514c
num-100 4.03 s (✅ 1.00x) 3.86 s (✅ 1.04x faster)
num-200 8.14 s (✅ 1.00x) 7.77 s (✅ 1.05x faster)

LEM Fibonacci Prove - rc = 600

fib-ref=86b5844e7136799afe245f46638e049eb0a4b66b fib-ref=212f295f846e31d752f2ca2e5d81200b651a514c
num-100 3.43 s (✅ 1.00x) 3.33 s (✅ 1.03x faster)
num-200 7.58 s (✅ 1.00x) 7.31 s (✅ 1.04x faster)

Made with criterion-table

Please sign in to comment.