Skip to content

Commit

Permalink
Convenience functions and fixes for multi-phase (#133)
Browse files Browse the repository at this point in the history
* feat: add `clear` function to circuit builder and managers

* feat: add `BaseConfig::initialize`

* fix: break points for multiphase

* fix: clear should not change phase

* chore: remove dbg
  • Loading branch information
jonathanpwang authored Sep 4, 2023
1 parent f39fef3 commit 608b8f2
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 15 deletions.
17 changes: 15 additions & 2 deletions halo2-base/src/gates/circuit/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,14 +178,18 @@ impl<F: ScalarField> BaseCircuitBuilder<F> {
self.core
.phase_manager
.iter()
.map(|pm| pm.break_points.get().expect("break points not set").clone())
.map(|pm| pm.break_points.borrow().as_ref().expect("break points not set").clone())
.collect()
}

/// Sets the break points of the circuit.
pub fn set_break_points(&mut self, break_points: MultiPhaseThreadBreakPoints) {
if break_points.is_empty() {
return;
}
self.core.touch(break_points.len() - 1);
for (pm, bp) in self.core.phase_manager.iter().zip_eq(break_points) {
pm.break_points.set(bp).unwrap();
*pm.break_points.borrow_mut() = Some(bp);
}
}

Expand All @@ -207,6 +211,15 @@ impl<F: ScalarField> BaseCircuitBuilder<F> {
self
}

/// Clears state and copies, effectively resetting the circuit builder.
pub fn clear(&mut self) {
self.core.clear();
for lm in &mut self.lookup_manager {
lm.cells_to_lookup.lock().unwrap().clear();
lm.copy_manager.lock().unwrap().clear();
}
}

/// Returns a mutable reference to the [Context] of a gate thread. Spawns a new thread for the given phase, if none exists.
/// * `phase`: The challenge phase (as an index) of the gate thread.
pub fn main(&mut self, phase: usize) -> &mut Context<F> {
Expand Down
9 changes: 9 additions & 0 deletions halo2-base/src/gates/circuit/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,15 @@ impl<F: ScalarField> BaseConfig<F> {
MaybeRangeConfig::WithRange(config) => config.gate.max_rows = usable_rows,
}
}

/// Initialization of config at very beginning of `synthesize`.
/// Loads fixed lookup table, if using.
pub fn initialize(&self, layouter: &mut impl Layouter<F>) {
// only load lookup table if we are actually doing lookups
if let MaybeRangeConfig::WithRange(config) = &self.base {
config.load_lookup_table(layouter).expect("load lookup table should not fail");
}
}
}

impl<F: ScalarField> Circuit<F> for BaseCircuitBuilder<F> {
Expand Down
10 changes: 9 additions & 1 deletion halo2-base/src/gates/flex_gate/threads/multi_phase.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@ impl<F: ScalarField> MultiPhaseCoreManager<F> {
self
}

/// Clears all threads in all phases and copy manager.
pub fn clear(&mut self) {
for pm in &mut self.phase_manager {
pm.clear();
}
self.copy_manager.lock().unwrap().clear();
}

/// Returns a mutable reference to the [Context] of a gate thread. Spawns a new thread for the given phase, if none exists.
/// * `phase`: The challenge phase (as an index) of the gate thread.
pub fn main(&mut self, phase: usize) -> &mut Context<F> {
Expand All @@ -88,7 +96,7 @@ impl<F: ScalarField> MultiPhaseCoreManager<F> {
}

/// Populate `self` up to Phase `phase` (inclusive)
fn touch(&mut self, phase: usize) {
pub(crate) fn touch(&mut self, phase: usize) {
while self.phase_manager.len() <= phase {
let _phase = self.phase_manager.len();
let pm = SinglePhaseCoreManager::new(self.witness_gen_only, self.copy_manager.clone())
Expand Down
26 changes: 16 additions & 10 deletions halo2-base/src/gates/flex_gate/threads/single_phase.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{any::TypeId, cell::OnceCell};
use std::{any::TypeId, cell::RefCell};

use getset::CopyGetters;

Expand Down Expand Up @@ -39,7 +39,7 @@ pub struct SinglePhaseCoreManager<F: ScalarField> {
pub(crate) phase: usize,
/// A very simple computation graph for the basic vertical gate. Must be provided as a "pinning"
/// when running the production prover.
pub break_points: OnceCell<ThreadBreakPoints>,
pub break_points: RefCell<Option<ThreadBreakPoints>>,
}

impl<F: ScalarField> SinglePhaseCoreManager<F> {
Expand Down Expand Up @@ -93,6 +93,12 @@ impl<F: ScalarField> SinglePhaseCoreManager<F> {
self
}

/// Clears all threads and copy manager
pub fn clear(&mut self) {
self.threads = vec![];
self.copy_manager.lock().unwrap().clear();
}

/// Returns a mutable reference to the [Context] of a gate thread. Spawns a new thread for the given phase, if none exists.
pub fn main(&mut self) -> &mut Context<F> {
if self.threads.is_empty() {
Expand Down Expand Up @@ -147,7 +153,8 @@ impl<F: ScalarField> VirtualRegionManager<F> for SinglePhaseCoreManager<F> {

fn assign_raw(&self, (config, usable_rows): &Self::Config, region: &mut Region<F>) {
if self.witness_gen_only {
let break_points = self.break_points.get().expect("break points not set");
let binding = self.break_points.borrow();
let break_points = binding.as_ref().expect("break points not set");
assign_witnesses(&self.threads, config, region, break_points);
} else {
let mut copy_manager = self.copy_manager.lock().unwrap();
Expand All @@ -159,13 +166,12 @@ impl<F: ScalarField> VirtualRegionManager<F> for SinglePhaseCoreManager<F> {
*usable_rows,
self.use_unknown,
);
self.break_points.set(break_points).unwrap_or_else(|break_points| {
assert_eq!(
self.break_points.get().unwrap(),
&break_points,
"previously set break points don't match"
);
});
let mut bp = self.break_points.borrow_mut();
if let Some(bp) = bp.as_ref() {
assert_eq!(bp, &break_points, "break points don't match");
} else {
*bp = Some(break_points);
}
}
}
}
Expand Down
5 changes: 5 additions & 0 deletions halo2-base/src/virtual_region/copy_constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ impl<F: Field + Ord> CopyConstraintManager<F> {
self.assigned_advices.insert(context_cell, cell);
context_cell
}

/// Clears state
pub fn clear(&mut self) {
*self = Self::default();
}
}

impl<F: Field + Ord> Drop for CopyConstraintManager<F> {
Expand Down
4 changes: 2 additions & 2 deletions halo2-base/src/virtual_region/tests/lookups/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,13 +198,13 @@ fn test_ram_prover() {
let vk = keygen_vk(&params, &circuit).unwrap();
let pk = keygen_pk(&params, vk, &circuit).unwrap();
let circuit_params = circuit.params();
let break_points = circuit.cpu.break_points.get().unwrap().clone();
let break_points = circuit.cpu.break_points.borrow().clone().unwrap();
drop(circuit);

let memory: Vec<_> = (0..mem_len).map(|_| Fr::random(&mut rng)).collect();
let ptrs = [(); CYCLES].map(|_| rng.gen_range(0..memory.len()));
let mut circuit = RAMCircuit::new(memory, ptrs, circuit_params, true);
circuit.cpu.break_points.set(break_points).unwrap();
*circuit.cpu.break_points.borrow_mut() = Some(break_points);
circuit.compute();

let proof = gen_proof(&params, &pk, circuit);
Expand Down

0 comments on commit 608b8f2

Please sign in to comment.