From e8006dec3da20747f74e284a5ce053c804489026 Mon Sep 17 00:00:00 2001 From: Alexander Camuto Date: Thu, 27 Apr 2023 10:07:06 +0100 Subject: [PATCH 1/6] feat: configure with Params object --- Cargo.toml | 2 +- benches/accum_conv.rs | 1 + benches/accum_dot.rs | 1 + benches/accum_matmul.rs | 1 + benches/accum_matmul_relu.rs | 1 + benches/accum_pack.rs | 1 + benches/accum_sum.rs | 1 + benches/accum_sumpool.rs | 1 + benches/pairwise_add.rs | 1 + benches/pairwise_pow.rs | 1 + benches/relu.rs | 1 + examples/conv2d_mnist/main.rs | 2 ++ examples/mlp_4d.rs | 1 + src/circuit/ops/layouts.rs | 12 +++++----- src/circuit/tests.rs | 23 +++++++++++++++++++ src/commands.rs | 2 +- src/execute.rs | 42 +++++++++++++++++++++++++++++------ src/graph/mod.rs | 31 ++++++++++++++++++++++++++ src/graph/model.rs | 5 +++-- src/graph/vars.rs | 5 +++-- src/pfsys/evm/aggregation.rs | 1 + src/pfsys/mod.rs | 18 +++++++++++---- 22 files changed, 131 insertions(+), 23 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index ed5516876..3224b7fa6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -128,4 +128,4 @@ render = ["halo2_proofs/dev-graph", "plotters"] tensorflow = ["dep:tensorflow"] onnx = ["dep:tract-onnx"] python-bindings = ["pyo3", "pyo3-log"] -ezkl = ["onnx", "serde", "serde_json", "log", "colored", "env_logger", "tabled", "colored_json"] +ezkl = ["onnx", "serde", "serde_json", "log", "colored", "env_logger", "tabled", "colored_json", "halo2_proofs/circuit-params"] diff --git a/benches/accum_conv.rs b/benches/accum_conv.rs index ac0885dc4..e00e7ea5f 100644 --- a/benches/accum_conv.rs +++ b/benches/accum_conv.rs @@ -35,6 +35,7 @@ struct MyCircuit { impl Circuit for MyCircuit { type Config = BaseConfig; type FloorPlanner = SimpleFloorPlanner; + type Params = (); fn without_witnesses(&self) -> Self { self.clone() diff --git a/benches/accum_dot.rs b/benches/accum_dot.rs index c48451fa4..5ba066c71 100644 --- a/benches/accum_dot.rs +++ b/benches/accum_dot.rs @@ -28,6 +28,7 @@ struct MyCircuit { impl Circuit for MyCircuit { type Config = BaseConfig; type FloorPlanner = SimpleFloorPlanner; + type Params = (); fn without_witnesses(&self) -> Self { self.clone() diff --git a/benches/accum_matmul.rs b/benches/accum_matmul.rs index 77b31cf45..f034b250c 100644 --- a/benches/accum_matmul.rs +++ b/benches/accum_matmul.rs @@ -28,6 +28,7 @@ struct MyCircuit { impl Circuit for MyCircuit { type Config = BaseConfig; type FloorPlanner = SimpleFloorPlanner; + type Params = (); fn without_witnesses(&self) -> Self { self.clone() diff --git a/benches/accum_matmul_relu.rs b/benches/accum_matmul_relu.rs index d0312acb4..d83e1a88a 100644 --- a/benches/accum_matmul_relu.rs +++ b/benches/accum_matmul_relu.rs @@ -35,6 +35,7 @@ struct MyConfig { impl Circuit for MyCircuit { type Config = MyConfig; type FloorPlanner = SimpleFloorPlanner; + type Params = (); fn without_witnesses(&self) -> Self { self.clone() diff --git a/benches/accum_pack.rs b/benches/accum_pack.rs index 325678991..531c5c3fa 100644 --- a/benches/accum_pack.rs +++ b/benches/accum_pack.rs @@ -28,6 +28,7 @@ struct MyCircuit { impl Circuit for MyCircuit { type Config = BaseConfig; type FloorPlanner = SimpleFloorPlanner; + type Params = (); fn without_witnesses(&self) -> Self { self.clone() diff --git a/benches/accum_sum.rs b/benches/accum_sum.rs index 3e2f83ed9..607616332 100644 --- a/benches/accum_sum.rs +++ b/benches/accum_sum.rs @@ -28,6 +28,7 @@ struct MyCircuit { impl Circuit for MyCircuit { type Config = BaseConfig; type FloorPlanner = SimpleFloorPlanner; + type Params = (); fn without_witnesses(&self) -> Self { self.clone() diff --git a/benches/accum_sumpool.rs b/benches/accum_sumpool.rs index 30e9c150e..a4f71b989 100644 --- a/benches/accum_sumpool.rs +++ b/benches/accum_sumpool.rs @@ -30,6 +30,7 @@ struct MyCircuit { impl Circuit for MyCircuit { type Config = BaseConfig; type FloorPlanner = SimpleFloorPlanner; + type Params = (); fn without_witnesses(&self) -> Self { self.clone() diff --git a/benches/pairwise_add.rs b/benches/pairwise_add.rs index 0fd0d9e30..d0ee88774 100644 --- a/benches/pairwise_add.rs +++ b/benches/pairwise_add.rs @@ -28,6 +28,7 @@ struct MyCircuit { impl Circuit for MyCircuit { type Config = BaseConfig; type FloorPlanner = SimpleFloorPlanner; + type Params = (); fn without_witnesses(&self) -> Self { self.clone() diff --git a/benches/pairwise_pow.rs b/benches/pairwise_pow.rs index 1ae08eab1..8b5fe736c 100644 --- a/benches/pairwise_pow.rs +++ b/benches/pairwise_pow.rs @@ -28,6 +28,7 @@ struct MyCircuit { impl Circuit for MyCircuit { type Config = BaseConfig; type FloorPlanner = SimpleFloorPlanner; + type Params = (); fn without_witnesses(&self) -> Self { self.clone() diff --git a/benches/relu.rs b/benches/relu.rs index 33945b901..dfdbef8f6 100644 --- a/benches/relu.rs +++ b/benches/relu.rs @@ -25,6 +25,7 @@ struct NLCircuit { impl Circuit for NLCircuit { type Config = Config; type FloorPlanner = SimpleFloorPlanner; + type Params = (); fn without_witnesses(&self) -> Self { self.clone() diff --git a/examples/conv2d_mnist/main.rs b/examples/conv2d_mnist/main.rs index 5c2ce7e4b..d541ec488 100644 --- a/examples/conv2d_mnist/main.rs +++ b/examples/conv2d_mnist/main.rs @@ -29,6 +29,7 @@ use halo2curves::pasta::vesta; use halo2curves::pasta::Fp as F; use mnist::*; use rand::rngs::OsRng; +use std::marker::PhantomData; use std::time::Instant; mod params; @@ -131,6 +132,7 @@ where PADDING, >; type FloorPlanner = SimpleFloorPlanner; + type Params = PhantomData; fn without_witnesses(&self) -> Self { self.clone() diff --git a/examples/mlp_4d.rs b/examples/mlp_4d.rs index 25c031c25..63ef1064a 100644 --- a/examples/mlp_4d.rs +++ b/examples/mlp_4d.rs @@ -39,6 +39,7 @@ impl; type FloorPlanner = SimpleFloorPlanner; + type Params = PhantomData; fn without_witnesses(&self) -> Self { self.clone() diff --git a/src/circuit/ops/layouts.rs b/src/circuit/ops/layouts.rs index e42d68ab2..687b10f49 100644 --- a/src/circuit/ops/layouts.rs +++ b/src/circuit/ops/layouts.rs @@ -317,7 +317,7 @@ pub fn sum_axes( } /// Max accumulated layout -pub fn max_axes( +pub fn max_axes( config: &mut BaseConfig, region: &mut Option<&mut Region>, values: &[ValTensor; 1], @@ -369,7 +369,7 @@ pub fn max_axes( } /// Min accumulated layout -pub fn min_axes( +pub fn min_axes( config: &mut BaseConfig, region: &mut Option<&mut Region>, values: &[ValTensor; 1], @@ -622,7 +622,7 @@ pub fn matmul( } /// Iff -pub fn iff( +pub fn iff( config: &mut BaseConfig, region: &mut Option<&mut Region>, values: &[ValTensor; 3], @@ -792,7 +792,7 @@ pub fn sumpool( } /// Convolution accumulated layout -pub fn max_pool2d( +pub fn max_pool2d( config: &mut BaseConfig, region: &mut Option<&mut Region>, values: &[ValTensor; 1], @@ -1294,7 +1294,7 @@ pub fn mean( } /// max layout -pub fn max( +pub fn max( config: &mut BaseConfig, region: &mut Option<&mut Region>, values: &[ValTensor; 1], @@ -1408,7 +1408,7 @@ pub fn max( } /// min layout -pub fn min( +pub fn min( config: &mut BaseConfig, region: &mut Option<&mut Region>, values: &[ValTensor; 1], diff --git a/src/circuit/tests.rs b/src/circuit/tests.rs index 05ca54e85..c80063f7a 100644 --- a/src/circuit/tests.rs +++ b/src/circuit/tests.rs @@ -11,6 +11,9 @@ use halo2curves::pasta::Fp as F; use rand::rngs::OsRng; use std::marker::PhantomData; +#[derive(Default)] +struct TestParams; + #[cfg(test)] mod matmul { @@ -28,6 +31,7 @@ mod matmul { impl Circuit for MatmulCircuit { type Config = BaseConfig; type FloorPlanner = SimpleFloorPlanner; + type Params = TestParams; fn without_witnesses(&self) -> Self { self.clone() @@ -101,6 +105,7 @@ mod matmul_col_overflow { impl Circuit for MatmulCircuit { type Config = BaseConfig; type FloorPlanner = SimpleFloorPlanner; + type Params = TestParams; fn without_witnesses(&self) -> Self { self.clone() @@ -174,6 +179,7 @@ mod dot { impl Circuit for MyCircuit { type Config = BaseConfig; type FloorPlanner = SimpleFloorPlanner; + type Params = TestParams; fn without_witnesses(&self) -> Self { self.clone() @@ -244,6 +250,7 @@ mod dot_col_overflow { impl Circuit for MyCircuit { type Config = BaseConfig; type FloorPlanner = SimpleFloorPlanner; + type Params = TestParams; fn without_witnesses(&self) -> Self { self.clone() @@ -314,6 +321,7 @@ mod sum { impl Circuit for MyCircuit { type Config = BaseConfig; type FloorPlanner = SimpleFloorPlanner; + type Params = TestParams; fn without_witnesses(&self) -> Self { self.clone() @@ -382,6 +390,7 @@ mod sum_col_overflow { impl Circuit for MyCircuit { type Config = BaseConfig; type FloorPlanner = SimpleFloorPlanner; + type Params = TestParams; fn without_witnesses(&self) -> Self { self.clone() @@ -450,6 +459,7 @@ mod composition { impl Circuit for MyCircuit { type Config = BaseConfig; type FloorPlanner = SimpleFloorPlanner; + type Params = TestParams; fn without_witnesses(&self) -> Self { self.clone() @@ -539,6 +549,7 @@ mod conv { impl Circuit for ConvCircuit { type Config = BaseConfig; type FloorPlanner = SimpleFloorPlanner; + type Params = TestParams; fn without_witnesses(&self) -> Self { self.clone() @@ -665,6 +676,7 @@ mod sumpool { impl Circuit for ConvCircuit { type Config = BaseConfig; type FloorPlanner = SimpleFloorPlanner; + type Params = TestParams; fn without_witnesses(&self) -> Self { self.clone() @@ -743,6 +755,7 @@ mod add_w_shape_casting { impl Circuit for MyCircuit { type Config = BaseConfig; type FloorPlanner = SimpleFloorPlanner; + type Params = TestParams; fn without_witnesses(&self) -> Self { self.clone() @@ -813,6 +826,7 @@ mod add { impl Circuit for MyCircuit { type Config = BaseConfig; type FloorPlanner = SimpleFloorPlanner; + type Params = TestParams; fn without_witnesses(&self) -> Self { self.clone() @@ -883,6 +897,7 @@ mod add_with_overflow { impl Circuit for MyCircuit { type Config = BaseConfig; type FloorPlanner = SimpleFloorPlanner; + type Params = TestParams; fn without_witnesses(&self) -> Self { self.clone() @@ -953,6 +968,7 @@ mod sub { impl Circuit for MyCircuit { type Config = BaseConfig; type FloorPlanner = SimpleFloorPlanner; + type Params = TestParams; fn without_witnesses(&self) -> Self { self.clone() @@ -1023,6 +1039,7 @@ mod mult { impl Circuit for MyCircuit { type Config = BaseConfig; type FloorPlanner = SimpleFloorPlanner; + type Params = TestParams; fn without_witnesses(&self) -> Self { self.clone() @@ -1093,6 +1110,7 @@ mod pow { impl Circuit for MyCircuit { type Config = BaseConfig; type FloorPlanner = SimpleFloorPlanner; + type Params = TestParams; fn without_witnesses(&self) -> Self { self.clone() @@ -1161,6 +1179,7 @@ mod pack { impl Circuit for MyCircuit { type Config = BaseConfig; type FloorPlanner = SimpleFloorPlanner; + type Params = TestParams; fn without_witnesses(&self) -> Self { self.clone() @@ -1229,6 +1248,7 @@ mod rescaled { impl Circuit for MyCircuit { type Config = BaseConfig; type FloorPlanner = SimpleFloorPlanner; + type Params = TestParams; fn without_witnesses(&self) -> Self { self.clone() @@ -1308,6 +1328,7 @@ mod matmul_relu { impl Circuit for MyCircuit { type Config = MyConfig; type FloorPlanner = SimpleFloorPlanner; + type Params = TestParams; fn without_witnesses(&self) -> Self { self.clone() @@ -1411,6 +1432,7 @@ mod rangecheck { impl Circuit for MyCircuit { type Config = BaseConfig; type FloorPlanner = SimpleFloorPlanner; + type Params = TestParams; fn without_witnesses(&self) -> Self { self.clone() @@ -1503,6 +1525,7 @@ mod relu { impl Circuit for ReLUCircuit { type Config = BaseConfig; type FloorPlanner = SimpleFloorPlanner; + type Params = TestParams; fn without_witnesses(&self) -> Self { self.clone() diff --git a/src/commands.rs b/src/commands.rs index 8e6a8f632..5051f00e5 100644 --- a/src/commands.rs +++ b/src/commands.rs @@ -44,7 +44,7 @@ impl std::fmt::Display for StrategyType { } /// Parameters specific to a proving run -#[derive(Debug, Args, Deserialize, Serialize, Clone)] +#[derive(Debug, Args, Deserialize, Serialize, Clone, Default)] pub struct RunArgs { /// The tolerance for error on model outputs #[arg(short = 'T', long, default_value = "0")] diff --git a/src/execute.rs b/src/execute.rs index c4d6c01fe..cdf5e3873 100644 --- a/src/execute.rs +++ b/src/execute.rs @@ -5,7 +5,7 @@ use crate::eth::{ deploy_verifier, fix_verifier_sol, get_ledger_signing_provider, get_provider, get_wallet_signing_provider, send_proof, verify_proof_via_solidity, }; -use crate::graph::{vector_to_quantized, Model, ModelCircuit}; +use crate::graph::{vector_to_quantized, Model, ModelCircuit, ModelParams}; use crate::pfsys::evm::aggregation::{AggregationCircuit, PoseidonTranscript}; #[cfg(not(target_arch = "wasm32"))] use crate::pfsys::evm::{aggregation::gen_aggregation_evm_verifier, single::gen_evm_verifier}; @@ -424,8 +424,10 @@ fn create_evm_verifier( let public_inputs = circuit.prepare_public_inputs(&data)?; let num_instance = public_inputs.iter().map(|x| x.len()).collect(); let params = load_params_cmd(params_path, logrows)?; + let model_circuit_params = load_model_circuit_params(); - let vk = load_vk::, Fr, ModelCircuit>(vk_path)?; + let vk = + load_vk::, Fr, ModelCircuit>(vk_path, model_circuit_params)?; trace!("params computed"); let (deployment_code, yul_code) = gen_evm_verifier(¶ms, &vk, num_instance)?; @@ -473,7 +475,7 @@ fn create_evm_aggregate_verifier( ) -> Result<(), Box> { let params: ParamsKZG = load_params::>(params_path)?; - let agg_vk = load_vk::, Fr, AggregationCircuit>(vk_path)?; + let agg_vk = load_vk::, Fr, AggregationCircuit>(vk_path, ())?; let deployment_code = gen_aggregation_evm_verifier( ¶ms, @@ -561,9 +563,14 @@ fn aggregate( // the K used when generating the application snark proof. we assume K is homogenous across snarks to aggregate let params_app = load_params_cmd(params_path, app_logrows)?; + let model_circuit_params = load_model_circuit_params(); + for (proof_path, vk_path) in aggregation_snarks.iter().zip(aggregation_vk_paths) { - let vk = - load_vk::, Fr, ModelCircuit>(vk_path.to_path_buf())?; + let vk = load_vk::, Fr, ModelCircuit>( + vk_path.to_path_buf(), + // safe to clone as the inner model is wrapped in an Arc + model_circuit_params.clone(), + )?; snarks.push(Snark::load::>( proof_path, Some(¶ms_app), @@ -606,9 +613,11 @@ fn verify( let params = load_params_cmd(params_path, logrows)?; let proof = Snark::load::>(&proof_path, None, None)?; + let model_circuit_params = load_model_circuit_params(); let strategy = KZGSingleStrategy::new(params.verifier_params()); - let vk = load_vk::, Fr, ModelCircuit>(vk_path)?; + let vk = + load_vk::, Fr, ModelCircuit>(vk_path, model_circuit_params)?; let result = verify_proof_circuit_kzg(params.verifier_params(), proof, &vk, transcript, strategy); info!("verified: {}", result.is_ok()); @@ -627,7 +636,7 @@ fn verify_aggr( let proof = Snark::load::>(&proof_path, None, None)?; let strategy = AccumulatorStrategy::new(params.verifier_params()); - let vk = load_vk::, Fr, AggregationCircuit>(vk_path)?; + let vk = load_vk::, Fr, AggregationCircuit>(vk_path, ())?; let result = verify_proof_circuit_kzg(¶ms, proof, &vk, transcript, strategy); info!("verified: {}", result.is_ok()); Ok(()) @@ -641,3 +650,22 @@ fn load_params_cmd(params_path: PathBuf, logrows: u32) -> Result ModelParams { + let model: Arc> = Arc::new(Model::from_arg().expect("model should load")); + + let instance_shapes = model.instance_shapes(); + // this is the total number of variables we will need to allocate + // for the circuit + let num_constraints = if let Some(num_constraints) = model.run_args.allocated_constraints { + num_constraints + } else { + model.dummy_layout(&model.input_shapes()).unwrap() + }; + + ModelParams { + model, + instance_shapes, + num_constraints, + } +} diff --git a/src/graph/mod.rs b/src/graph/mod.rs index 5c0cef8c2..95192b02a 100644 --- a/src/graph/mod.rs +++ b/src/graph/mod.rs @@ -78,6 +78,17 @@ pub enum GraphError { PackingExponent, } +/// model parameters +#[derive(Clone, Debug, Default)] +pub struct ModelParams { + /// An onnx model quantized and configured for zkSNARKs + pub model: Arc>, + /// the potential number of constraints in the circuit + pub num_constraints: usize, + /// the shape of public inputs to the circuit (in order of appearance) + pub instance_shapes: Vec>, +} + /// Defines the circuit for a computational graph / model loaded from a `.onnx` file. #[derive(Clone, Debug)] pub struct ModelCircuit { @@ -174,11 +185,31 @@ impl ModelCircuit { impl Circuit for ModelCircuit { type Config = ModelConfig; type FloorPlanner = SimpleFloorPlanner; + type Params = ModelParams; fn without_witnesses(&self) -> Self { self.clone() } + fn configure_with_params(cs: &mut ConstraintSystem, params: Self::Params) -> Self::Config { + let mut vars = ModelVars::new( + cs, + params.model.run_args.logrows as usize, + params.num_constraints, + params.instance_shapes.clone(), + params.model.visibility.clone(), + params.model.run_args.scale, + ); + + let base = params.model.configure(cs, &mut vars).unwrap(); + + ModelConfig { + base, + vars, + model: params.model.clone(), + } + } + fn configure(cs: &mut ConstraintSystem) -> Self::Config { let model: Arc> = Arc::new(Model::from_arg().expect("model should load")); diff --git a/src/graph/model.rs b/src/graph/model.rs index e28dacee2..55db9661f 100644 --- a/src/graph/model.rs +++ b/src/graph/model.rs @@ -38,9 +38,10 @@ use tabled::Table; use tract_onnx; use tract_onnx::prelude::Framework; /// Mode we're using the model in. -#[derive(Clone, Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize, Default)] pub enum Mode { /// Initialize the model and display the operations table / graph + #[default] Table, /// Initialize the model and generate a mock proof Mock, @@ -64,7 +65,7 @@ pub struct ModelConfig { } /// A struct for loading from an Onnx file and converting a computational graph to a circuit. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Default)] pub struct Model { /// input indices pub inputs: Vec, diff --git a/src/graph/vars.rs b/src/graph/vars.rs index d8d589327..3ca94cc10 100644 --- a/src/graph/vars.rs +++ b/src/graph/vars.rs @@ -11,9 +11,10 @@ use serde::{Deserialize, Serialize}; use super::*; /// Label Enum to track whether model input, model parameters, and model output are public or private -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)] +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Default)] pub enum Visibility { /// Mark an item as private to the prover (not in the proof submitted for verification) + #[default] Private, /// Mark an item as public (sent in the proof submitted for verification) Public, @@ -34,7 +35,7 @@ impl std::fmt::Display for Visibility { } /// Whether the model input, model parameters, and model output are Public or Private to the prover. -#[derive(Clone, Debug, Deserialize, Serialize)] +#[derive(Clone, Debug, Deserialize, Serialize, Default)] pub struct VarVisibility { /// Input to the model or computational graph pub input: Visibility, diff --git a/src/pfsys/evm/aggregation.rs b/src/pfsys/evm/aggregation.rs index 5ba63329c..8fe6fa8e3 100644 --- a/src/pfsys/evm/aggregation.rs +++ b/src/pfsys/evm/aggregation.rs @@ -247,6 +247,7 @@ impl AggregationCircuit { impl Circuit for AggregationCircuit { type Config = AggregationConfig; type FloorPlanner = SimpleFloorPlanner; + type Params = (); fn without_witnesses(&self) -> Self { Self { diff --git a/src/pfsys/mod.rs b/src/pfsys/mod.rs index 7a5a6952b..1cb1be6cb 100644 --- a/src/pfsys/mod.rs +++ b/src/pfsys/mod.rs @@ -385,6 +385,7 @@ where /// Loads a [VerifyingKey] at `path`. pub fn load_vk>( path: PathBuf, + params: >::Params, ) -> Result, Box> where C: Circuit, @@ -394,13 +395,18 @@ where info!("loading verification key from {:?}", path); let f = File::open(path).map_err(Box::::from)?; let mut reader = BufReader::new(f); - VerifyingKey::::read::<_, C>(&mut reader, halo2_proofs::SerdeFormat::RawBytes) - .map_err(Box::::from) + VerifyingKey::::read::<_, C>( + &mut reader, + halo2_proofs::SerdeFormat::RawBytes, + params, + ) + .map_err(Box::::from) } /// Loads a [ProvingKey] at `path`. pub fn load_pk>( path: PathBuf, + params: >::Params, ) -> Result, Box> where C: Circuit, @@ -410,8 +416,12 @@ where info!("loading proving key from {:?}", path); let f = File::open(path).map_err(Box::::from)?; let mut reader = BufReader::new(f); - ProvingKey::::read::<_, C>(&mut reader, halo2_proofs::SerdeFormat::RawBytes) - .map_err(Box::::from) + ProvingKey::::read::<_, C>( + &mut reader, + halo2_proofs::SerdeFormat::RawBytes, + params, + ) + .map_err(Box::::from) } /// Loads the [CommitmentScheme::ParamsVerifier] at `path`. From b7cc8acf81a146aeb437616277212f46a8d0131a Mon Sep 17 00:00:00 2001 From: Alexander Camuto Date: Thu, 27 Apr 2023 10:41:43 +0100 Subject: [PATCH 2/6] instantiate params with `new()` --- src/execute.rs | 10 ++++---- src/graph/mod.rs | 59 ++++++++++++++++++++++++++++++++---------------- 2 files changed, 43 insertions(+), 26 deletions(-) diff --git a/src/execute.rs b/src/execute.rs index cdf5e3873..70f5954bd 100644 --- a/src/execute.rs +++ b/src/execute.rs @@ -363,7 +363,7 @@ fn forward( fn mock(data: String, logrows: u32) -> Result<(), Box> { let data = prepare_data(data)?; - let model = Model::from_arg()?; + let model = Arc::new(Model::from_arg()?); let circuit = ModelCircuit::::new(&data, model)?; let public_inputs = circuit.prepare_public_inputs(&data)?; @@ -390,7 +390,7 @@ fn print_proof_hex(proof_path: PathBuf) -> Result<(), Box> { #[cfg(feature = "render")] fn render(data: String, output: String, logrows: u32) -> Result<(), Box> { let data = prepare_data(data.to_string())?; - let model = Model::from_arg()?; + let model = Arc::new(Model::from_arg()?); let circuit = ModelCircuit::::new(&data, model)?; info!("Rendering circuit"); @@ -418,8 +418,7 @@ fn create_evm_verifier( logrows: u32, ) -> Result<(), Box> { let data = prepare_data(data)?; - - let model = Model::from_arg()?; + let model = Arc::new(Model::from_arg()?); let circuit = ModelCircuit::::new(&data, model)?; let public_inputs = circuit.prepare_public_inputs(&data)?; let num_instance = public_inputs.iter().map(|x| x.len()).collect(); @@ -498,8 +497,7 @@ fn prove( check_mode: CheckMode, ) -> Result<(), Box> { let data = prepare_data(data)?; - - let model = Model::from_arg()?; + let model = Arc::new(Model::from_arg()?); let circuit = ModelCircuit::::new(&data, model)?; let public_inputs = circuit.prepare_public_inputs(&data)?; diff --git a/src/graph/mod.rs b/src/graph/mod.rs index 95192b02a..70e53e043 100644 --- a/src/graph/mod.rs +++ b/src/graph/mod.rs @@ -25,7 +25,6 @@ pub use model::*; pub use node::*; // use std::fs::File; // use std::io::{BufReader, BufWriter, Read, Write}; -use std::marker::PhantomData; use std::sync::Arc; // use std::path::PathBuf; use thiserror::Error; @@ -95,16 +94,14 @@ pub struct ModelCircuit { /// Vector of input tensors to the model / graph of computations. pub inputs: Vec>, /// - pub model: Model, - /// Represents the PrimeField we are using. - pub _marker: PhantomData, + pub params: ModelParams, } impl ModelCircuit { /// pub fn new( data: &ModelInput, - model: Model, + model: Arc>, ) -> Result, Box> { // quantize the supplied data using the provided scale. let mut inputs: Vec> = vec![]; @@ -113,17 +110,28 @@ impl ModelCircuit { inputs.push(t); } - Ok(ModelCircuit:: { - inputs, + let instance_shapes = model.instance_shapes(); + // this is the total number of variables we will need to allocate + // for the circuit + let num_constraints = if let Some(num_constraints) = model.run_args.allocated_constraints { + num_constraints + } else { + model.dummy_layout(&model.input_shapes()).unwrap() + }; + + let params = ModelParams { model, - _marker: PhantomData, - }) + instance_shapes, + num_constraints, + }; + + Ok(ModelCircuit:: { inputs, params }) } /// pub fn from_arg(data: &ModelInput) -> Result> { let cli = Cli::create()?; - let model = Model::from_ezkl_conf(cli)?; + let model = Arc::new(Model::from_ezkl_conf(cli)?); Self::new(data, model) } @@ -132,33 +140,39 @@ impl ModelCircuit { &self, data: &ModelInput, ) -> Result>, Box> { - let out_scales = self.model.get_output_scales(); + let out_scales = self.params.model.get_output_scales(); // quantize the supplied data using the provided scale. // the ordering here is important, we want the inputs to come before the outputs // as they are configured in that order as Column let mut public_inputs = vec![]; - if self.model.visibility.input.is_public() { + if self.params.model.visibility.input.is_public() { for v in data.input_data.iter() { - let t = - vector_to_quantized(v, &Vec::from([v.len()]), 0.0, self.model.run_args.scale)?; + let t = vector_to_quantized( + v, + &Vec::from([v.len()]), + 0.0, + self.params.model.run_args.scale, + )?; public_inputs.push(t); } } - if self.model.visibility.output.is_public() { + if self.params.model.visibility.output.is_public() { for (idx, v) in data.output_data.iter().enumerate() { let mut t = vector_to_quantized(v, &Vec::from([v.len()]), 0.0, out_scales[idx])?; let len = t.len(); - if self.model.run_args.pack_base > 1 { + if self.params.model.run_args.pack_base > 1 { let max_exponent = - (((len - 1) as u32) * (self.model.run_args.scale + 1)) as f64; - if max_exponent > (i128::MAX as f64).log(self.model.run_args.pack_base as f64) { + (((len - 1) as u32) * (self.params.model.run_args.scale + 1)) as f64; + if max_exponent + > (i128::MAX as f64).log(self.params.model.run_args.pack_base as f64) + { return Err(Box::new(GraphError::PackingExponent)); } t = pack( &t, - self.model.run_args.pack_base as i128, - self.model.run_args.scale, + self.params.model.run_args.pack_base as i128, + self.params.model.run_args.scale, )?; } public_inputs.push(t); @@ -191,6 +205,11 @@ impl Circuit for ModelCircuit { self.clone() } + fn params(&self) -> Self::Params { + // safe to clone because the model is Arc'd + self.params.clone() + } + fn configure_with_params(cs: &mut ConstraintSystem, params: Self::Params) -> Self::Config { let mut vars = ModelVars::new( cs, From c5b19ec0434ca846cf4b66fa0e68b5887aa8422f Mon Sep 17 00:00:00 2001 From: Alexander Camuto Date: Thu, 27 Apr 2023 10:47:10 +0100 Subject: [PATCH 3/6] remove nested model --- src/graph/mod.rs | 14 +++----------- src/graph/model.rs | 3 --- 2 files changed, 3 insertions(+), 14 deletions(-) diff --git a/src/graph/mod.rs b/src/graph/mod.rs index 70e53e043..76dfaa877 100644 --- a/src/graph/mod.rs +++ b/src/graph/mod.rs @@ -222,11 +222,7 @@ impl Circuit for ModelCircuit { let base = params.model.configure(cs, &mut vars).unwrap(); - ModelConfig { - base, - vars, - model: params.model.clone(), - } + ModelConfig { base, vars } } fn configure(cs: &mut ConstraintSystem) -> Self::Config { @@ -266,11 +262,7 @@ impl Circuit for ModelCircuit { info!("configured model"); - ModelConfig { - model, - base: config, - vars, - } + ModelConfig { base: config, vars } } fn synthesize( @@ -285,7 +277,7 @@ impl Circuit for ModelCircuit { .map(|i| ValTensor::from( as Into>>>::into(i.clone()))) .collect::>>(); trace!("Laying out model"); - config + self.params .model .layout(config.clone(), &mut layouter, &inputs, &config.vars) .unwrap(); diff --git a/src/graph/model.rs b/src/graph/model.rs index 55db9661f..24f214e8a 100644 --- a/src/graph/model.rs +++ b/src/graph/model.rs @@ -33,7 +33,6 @@ use log::{debug, info, trace}; use std::collections::BTreeMap; use std::error::Error; use std::path::Path; -use std::sync::Arc; use tabled::Table; use tract_onnx; use tract_onnx::prelude::Framework; @@ -58,8 +57,6 @@ pub enum Mode { pub struct ModelConfig { /// The base configuration for the circuit pub base: PolyConfig, - /// The model struct - pub model: Arc>, /// A wrapper for holding all columns that will be assigned to by the model pub vars: ModelVars, } From 4ae344098c62df05c7d20a0161d249237765e06c Mon Sep 17 00:00:00 2001 From: Alexander Camuto Date: Thu, 27 Apr 2023 10:50:34 +0100 Subject: [PATCH 4/6] cleanup execute --- src/execute.rs | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/execute.rs b/src/execute.rs index 70f5954bd..b7b4dfd49 100644 --- a/src/execute.rs +++ b/src/execute.rs @@ -363,8 +363,7 @@ fn forward( fn mock(data: String, logrows: u32) -> Result<(), Box> { let data = prepare_data(data)?; - let model = Arc::new(Model::from_arg()?); - let circuit = ModelCircuit::::new(&data, model)?; + let circuit = ModelCircuit::::from_arg(&data)?; let public_inputs = circuit.prepare_public_inputs(&data)?; info!("Mock proof"); @@ -390,8 +389,7 @@ fn print_proof_hex(proof_path: PathBuf) -> Result<(), Box> { #[cfg(feature = "render")] fn render(data: String, output: String, logrows: u32) -> Result<(), Box> { let data = prepare_data(data.to_string())?; - let model = Arc::new(Model::from_arg()?); - let circuit = ModelCircuit::::new(&data, model)?; + let circuit = ModelCircuit::::from_arg(&data)?; info!("Rendering circuit"); // Create the area we want to draw on. @@ -418,8 +416,7 @@ fn create_evm_verifier( logrows: u32, ) -> Result<(), Box> { let data = prepare_data(data)?; - let model = Arc::new(Model::from_arg()?); - let circuit = ModelCircuit::::new(&data, model)?; + let circuit = ModelCircuit::::from_arg(&data)?; let public_inputs = circuit.prepare_public_inputs(&data)?; let num_instance = public_inputs.iter().map(|x| x.len()).collect(); let params = load_params_cmd(params_path, logrows)?; @@ -497,8 +494,7 @@ fn prove( check_mode: CheckMode, ) -> Result<(), Box> { let data = prepare_data(data)?; - let model = Arc::new(Model::from_arg()?); - let circuit = ModelCircuit::::new(&data, model)?; + let circuit = ModelCircuit::::from_arg(&data)?; let public_inputs = circuit.prepare_public_inputs(&data)?; let params = load_params_cmd(params_path, logrows)?; From 470ac08c453b1c4f11d08784a4d208bb63d51494 Mon Sep 17 00:00:00 2001 From: Alexander Camuto Date: Thu, 27 Apr 2023 11:08:02 +0100 Subject: [PATCH 5/6] Update execute.rs --- src/execute.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/execute.rs b/src/execute.rs index b7b4dfd49..868de09d1 100644 --- a/src/execute.rs +++ b/src/execute.rs @@ -42,7 +42,6 @@ use std::fs::File; #[cfg(not(target_arch = "wasm32"))] use std::io::Write; use std::path::PathBuf; -#[cfg(not(target_arch = "wasm32"))] use std::sync::Arc; use std::time::Instant; use tabled::Table; From 53fc67b695c3298c7db76b9d56dfc789934b426a Mon Sep 17 00:00:00 2001 From: Alexander Camuto Date: Thu, 27 Apr 2023 11:16:14 +0100 Subject: [PATCH 6/6] rm regular configure for safety --- src/graph/mod.rs | 40 ++-------------------------------------- 1 file changed, 2 insertions(+), 38 deletions(-) diff --git a/src/graph/mod.rs b/src/graph/mod.rs index 76dfaa877..c0dbd0901 100644 --- a/src/graph/mod.rs +++ b/src/graph/mod.rs @@ -225,44 +225,8 @@ impl Circuit for ModelCircuit { ModelConfig { base, vars } } - fn configure(cs: &mut ConstraintSystem) -> Self::Config { - let model: Arc> = Arc::new(Model::from_arg().expect("model should load")); - - let instance_shapes = model.instance_shapes(); - // this is the total number of variables we will need to allocate - // for the circuit - let num_constraints = if let Some(num_constraints) = model.run_args.allocated_constraints { - num_constraints - } else { - model.dummy_layout(&model.input_shapes()).unwrap() - }; - - info!("total num constraints: {:?}", num_constraints); - info!("instance_shapes: {:?}", instance_shapes); - - let mut vars = ModelVars::new( - cs, - model.run_args.logrows as usize, - num_constraints, - instance_shapes.clone(), - model.visibility.clone(), - model.run_args.scale, - ); - info!( - "number of advices used: {:?}", - vars.advices.iter().map(|a| a.num_cols()).sum::() - ); - info!( - "number of fixed used: {:?}", - vars.fixed.iter().map(|a| a.num_cols()).sum::() - ); - info!("number of instances used: {:?}", instance_shapes.len()); - - let config = model.configure(cs, &mut vars).unwrap(); - - info!("configured model"); - - ModelConfig { base: config, vars } + fn configure(_: &mut ConstraintSystem) -> Self::Config { + unimplemented!("you should call configure_with_params instead") } fn synthesize(