diff --git a/Cargo.toml b/Cargo.toml index 450ebced5..00593592b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,6 +39,7 @@ tokio = { version = "1.26.0", features = ["macros", "rt"] } puruspe = "0.2.0" bincode = "*" + # python binding related deps pyo3 = { version = "0.18.2", features = ["extension-module", "abi3-py37"], optional = true } pyo3-log = { version = "0.8.1", optional = true } diff --git a/benches/accum_affine.rs b/benches/accum_affine.rs index b5511e8c4..a21a4857f 100644 --- a/benches/accum_affine.rs +++ b/benches/accum_affine.rs @@ -1,4 +1,5 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; +use ezkl_lib::circuit::poly::PolyOp; use ezkl_lib::circuit::*; use ezkl_lib::commands::TranscriptType; use ezkl_lib::execute::create_proof_circuit_kzg; @@ -53,7 +54,12 @@ impl Circuit for MyCircuit { || "", |mut region| { config - .layout(&mut region, &self.inputs, &mut 0, Op::Affine.into()) + .layout( + Some(&mut region), + &self.inputs, + &mut 0, + Box::new(PolyOp::Affine), + ) .unwrap(); Ok(()) }, diff --git a/benches/accum_conv.rs b/benches/accum_conv.rs index fc36d4708..e49f56ac4 100644 --- a/benches/accum_conv.rs +++ b/benches/accum_conv.rs @@ -1,4 +1,5 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; +use ezkl_lib::circuit::poly::PolyOp; use ezkl_lib::circuit::*; use ezkl_lib::commands::TranscriptType; use ezkl_lib::execute::create_proof_circuit_kzg; @@ -61,14 +62,13 @@ impl Circuit for MyCircuit { |mut region| { config .layout( - &mut region, + Some(&mut region), &[self.image.clone(), self.kernel.clone(), self.bias.clone()], &mut 0, - Op::Conv { + Box::new(PolyOp::Conv { padding: (0, 0), stride: (1, 1), - } - .into(), + }), ) .unwrap(); Ok(()) diff --git a/benches/accum_dot.rs b/benches/accum_dot.rs index da4a9bf0a..09f030b0c 100644 --- a/benches/accum_dot.rs +++ b/benches/accum_dot.rs @@ -1,4 +1,5 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; +use ezkl_lib::circuit::poly::PolyOp; use ezkl_lib::circuit::*; use ezkl_lib::commands::TranscriptType; use ezkl_lib::execute::create_proof_circuit_kzg; @@ -51,7 +52,12 @@ impl Circuit for MyCircuit { || "", |mut region| { config - .layout(&mut region, &self.inputs, &mut 0, Op::Dot.into()) + .layout( + Some(&mut region), + &self.inputs, + &mut 0, + Box::new(PolyOp::Dot), + ) .unwrap(); Ok(()) }, diff --git a/benches/accum_matmul.rs b/benches/accum_matmul.rs index 70c447d01..c8191c544 100644 --- a/benches/accum_matmul.rs +++ b/benches/accum_matmul.rs @@ -1,4 +1,5 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; +use ezkl_lib::circuit::poly::PolyOp; use ezkl_lib::circuit::*; use ezkl_lib::commands::TranscriptType; use ezkl_lib::execute::create_proof_circuit_kzg; @@ -53,7 +54,12 @@ impl Circuit for MyCircuit { || "", |mut region| { config - .layout(&mut region, &self.inputs, &mut 0, Op::Matmul.into()) + .layout( + Some(&mut region), + &self.inputs, + &mut 0, + Box::new(PolyOp::Matmul), + ) .unwrap(); Ok(()) }, diff --git a/benches/accum_matmul_relu.rs b/benches/accum_matmul_relu.rs index 5b3a68c9f..f9219820d 100644 --- a/benches/accum_matmul_relu.rs +++ b/benches/accum_matmul_relu.rs @@ -1,6 +1,8 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; use ezkl_lib::circuit::*; +use ezkl_lib::circuit::lookup::LookupOp; +use ezkl_lib::circuit::poly::PolyOp; use ezkl_lib::commands::TranscriptType; use ezkl_lib::execute::create_proof_circuit_kzg; use ezkl_lib::pfsys::{create_keys, gen_srs}; @@ -65,19 +67,19 @@ impl Circuit for MyCircuit { layouter.assign_region( || "", |mut region| { - let op = Op::Matmul; + let op = PolyOp::Matmul; let mut offset = 0; let output = config .base_config - .layout(&mut region, &self.inputs, &mut offset, op.into()) + .layout(Some(&mut region), &self.inputs, &mut offset, Box::new(op)) .unwrap(); let _output = config .base_config .layout( - &mut region, + Some(&mut region), &[output.unwrap()], &mut offset, - LookupOp::ReLU { scale: 1 }.into(), + Box::new(LookupOp::ReLU { scale: 1 }), ) .unwrap(); Ok(()) diff --git a/benches/accum_pack.rs b/benches/accum_pack.rs index 9afa27b73..a91d55bfd 100644 --- a/benches/accum_pack.rs +++ b/benches/accum_pack.rs @@ -1,4 +1,5 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; +use ezkl_lib::circuit::poly::PolyOp; use ezkl_lib::circuit::*; use ezkl_lib::commands::TranscriptType; use ezkl_lib::execute::create_proof_circuit_kzg; @@ -51,7 +52,12 @@ impl Circuit for MyCircuit { || "", |mut region| { config - .layout(&mut region, &self.inputs, &mut 0, Op::Pack(2, 1).into()) + .layout( + Some(&mut region), + &self.inputs, + &mut 0, + Box::new(PolyOp::Pack(2, 1)), + ) .unwrap(); Ok(()) }, diff --git a/benches/accum_sum.rs b/benches/accum_sum.rs index 0cd6b5171..d89a85d12 100644 --- a/benches/accum_sum.rs +++ b/benches/accum_sum.rs @@ -1,4 +1,5 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; +use ezkl_lib::circuit::poly::PolyOp; use ezkl_lib::circuit::*; use ezkl_lib::commands::TranscriptType; use ezkl_lib::execute::create_proof_circuit_kzg; @@ -51,7 +52,12 @@ impl Circuit for MyCircuit { || "", |mut region| { config - .layout(&mut region, &self.inputs, &mut 0, Op::Sum.into()) + .layout( + Some(&mut region), + &self.inputs, + &mut 0, + Box::new(PolyOp::Sum), + ) .unwrap(); Ok(()) }, diff --git a/benches/accum_sumpool.rs b/benches/accum_sumpool.rs index ff9ad1f4b..164a4e7ee 100644 --- a/benches/accum_sumpool.rs +++ b/benches/accum_sumpool.rs @@ -1,4 +1,5 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; +use ezkl_lib::circuit::poly::PolyOp; use ezkl_lib::circuit::*; use ezkl_lib::commands::TranscriptType; use ezkl_lib::execute::create_proof_circuit_kzg; @@ -56,15 +57,14 @@ impl Circuit for MyCircuit { |mut region| { config .layout( - &mut region, + Some(&mut region), &[self.image.clone()], &mut 0, - Op::SumPool { + Box::new(PolyOp::SumPool { padding: (0, 0), stride: (1, 1), kernel_shape: (2, 2), - } - .into(), + }), ) .unwrap(); Ok(()) diff --git a/benches/pairwise_add.rs b/benches/pairwise_add.rs index f2aa4e9db..0533bf63d 100644 --- a/benches/pairwise_add.rs +++ b/benches/pairwise_add.rs @@ -1,4 +1,5 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; +use ezkl_lib::circuit::poly::PolyOp; use ezkl_lib::circuit::*; use ezkl_lib::commands::TranscriptType; use ezkl_lib::execute::create_proof_circuit_kzg; @@ -51,7 +52,12 @@ impl Circuit for MyCircuit { || "", |mut region| { config - .layout(&mut region, &self.inputs, &mut 0, Op::Add.into()) + .layout( + Some(&mut region), + &self.inputs, + &mut 0, + Box::new(PolyOp::Add), + ) .unwrap(); Ok(()) }, diff --git a/benches/pairwise_pow.rs b/benches/pairwise_pow.rs index 70189be43..99ff221cd 100644 --- a/benches/pairwise_pow.rs +++ b/benches/pairwise_pow.rs @@ -1,4 +1,5 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; +use ezkl_lib::circuit::poly::PolyOp; use ezkl_lib::circuit::*; use ezkl_lib::commands::TranscriptType; use ezkl_lib::execute::create_proof_circuit_kzg; @@ -51,7 +52,12 @@ impl Circuit for MyCircuit { || "", |mut region| { config - .layout(&mut region, &self.inputs, &mut 0, Op::Pow(4).into()) + .layout( + Some(&mut region), + &self.inputs, + &mut 0, + Box::new(PolyOp::Pow(4)), + ) .unwrap(); Ok(()) }, diff --git a/benches/relu.rs b/benches/relu.rs index 61ee77aa1..f40e7e128 100644 --- a/benches/relu.rs +++ b/benches/relu.rs @@ -1,5 +1,5 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; -use ezkl_lib::circuit::{BaseConfig as Config, CheckMode, LookupOp}; +use ezkl_lib::circuit::{ops::lookup::LookupOp, BaseConfig as Config, CheckMode}; use ezkl_lib::commands::TranscriptType; use ezkl_lib::execute::create_proof_circuit_kzg; use ezkl_lib::pfsys::{create_keys, gen_srs}; @@ -59,10 +59,10 @@ impl Circuit for NLCircuit { |mut region| { config .layout( - &mut region, + Some(&mut region), &[self.input.clone()], &mut 0, - LookupOp::ReLU { scale: 128 }.into(), + Box::new(LookupOp::ReLU { scale: 128 }), ) .unwrap(); Ok(()) diff --git a/examples/conv2d_mnist/main.rs b/examples/conv2d_mnist/main.rs index edb4075a5..7c95eb689 100644 --- a/examples/conv2d_mnist/main.rs +++ b/examples/conv2d_mnist/main.rs @@ -1,4 +1,6 @@ -use ezkl_lib::circuit::{BaseConfig as PolyConfig, CheckMode, LookupOp, Op as PolyOp}; +use ezkl_lib::circuit::{ + ops::lookup::LookupOp, ops::poly::PolyOp, BaseConfig as PolyConfig, CheckMode, +}; use ezkl_lib::fieldutils; use ezkl_lib::fieldutils::i32_to_felt; use ezkl_lib::tensor::*; @@ -178,24 +180,24 @@ where let x = config .layer_config .layout( - &mut region, + Some(&mut region), &[ self.input.clone(), self.l0_params[0].clone(), self.l0_params[1].clone(), ], &mut offset, - op.into(), + Box::new(op), ) .unwrap(); let mut x = config .layer_config .layout( - &mut region, + Some(&mut region), &[x.unwrap()], &mut offset, - LookupOp::ReLU { scale: 32 }.into(), + Box::new(LookupOp::ReLU { scale: 32 }), ) .unwrap() .unwrap(); @@ -203,10 +205,10 @@ where let l2out = config .layer_config .layout( - &mut region, + Some(&mut region), &[x, self.l2_params[0].clone(), self.l2_params[1].clone()], &mut offset, - PolyOp::Affine.into(), + Box::new(PolyOp::Affine), ) .unwrap(); Ok(l2out) diff --git a/examples/mlp_4d.rs b/examples/mlp_4d.rs index e8d733be3..9b69c9ed2 100644 --- a/examples/mlp_4d.rs +++ b/examples/mlp_4d.rs @@ -1,4 +1,6 @@ -use ezkl_lib::circuit::{BaseConfig as PolyConfig, CheckMode, LookupOp, Op as PolyOp}; +use ezkl_lib::circuit::{ + ops::lookup::LookupOp, ops::poly::PolyOp, BaseConfig as PolyConfig, CheckMode, +}; use ezkl_lib::fieldutils::i32_to_felt; use ezkl_lib::tensor::*; use halo2_proofs::dev::MockProver; @@ -95,14 +97,14 @@ impl Circuit let x = config .layer_config .layout( - &mut region, + Some(&mut region), &[ self.input.clone(), self.l0_params[0].clone(), self.l0_params[1].clone(), ], &mut offset, - PolyOp::Affine.into(), + Box::new(PolyOp::Affine), ) .unwrap() .unwrap(); @@ -112,10 +114,10 @@ impl Circuit let mut x = config .layer_config .layout( - &mut region, + Some(&mut region), &[x], &mut offset, - LookupOp::ReLU { scale: 1 }.into(), + Box::new(LookupOp::ReLU { scale: 1 }), ) .unwrap() .unwrap(); @@ -125,10 +127,10 @@ impl Circuit let x = config .layer_config .layout( - &mut region, + Some(&mut region), &[x, self.l2_params[0].clone(), self.l2_params[1].clone()], &mut offset, - PolyOp::Affine.into(), + Box::new(PolyOp::Affine), ) .unwrap() .unwrap(); @@ -137,23 +139,22 @@ impl Circuit let x = config .layer_config .layout( - &mut region, + Some(&mut region), &[x], &mut offset, - LookupOp::ReLU { scale: 1 }.into(), + Box::new(LookupOp::ReLU { scale: 1 }), ) .unwrap(); println!("offset: {}", offset); Ok(config .layer_config .layout( - &mut region, + Some(&mut region), &[x.unwrap()], &mut offset, - LookupOp::Div { + Box::new(LookupOp::Div { denom: ezkl_lib::circuit::utils::F32::from(128.), - } - .into(), + }), ) .unwrap()) }, diff --git a/examples/onnx/1l_instance_norm/gen.py b/examples/onnx/1l_instance_norm/gen.py new file mode 100644 index 000000000..7f8c86ce5 --- /dev/null +++ b/examples/onnx/1l_instance_norm/gen.py @@ -0,0 +1,18 @@ +from torch import nn +from ezkl import export +import torch + +class MyModel(nn.Module): + def __init__(self): + super(MyModel, self).__init__() + + self.layer = nn.InstanceNorm2d(3).eval() + + def forward(self, x): + return [self.layer(x)] + +circuit = MyModel() +export(circuit, input_shape = [3,2,2]) + + + diff --git a/examples/onnx/1l_instance_norm/input.json b/examples/onnx/1l_instance_norm/input.json new file mode 100644 index 000000000..956e4f0c8 --- /dev/null +++ b/examples/onnx/1l_instance_norm/input.json @@ -0,0 +1 @@ +{"input_shapes": [[3, 2, 2]], "input_data": [[0.012968450784683228, 0.07738310843706131, 0.026791423559188843, 0.030387181788682938, 0.0008226811769418418, 0.0008844911935739219, 0.050431013107299805, 0.0018652736907824874, 0.02243981324136257, 0.05826736241579056, 0.03202100470662117, 0.01562164444476366]], "output_data": [[-0.977062463760376, 1.6547390222549438, -0.41229474544525146, -0.26538217067718506, -0.5880738496780396, -0.5852068066596985, 1.7129942178726196, -0.5397135019302368, -0.5845468044281006, 1.586229920387268, -0.004026293754577637, -0.9976569414138794]]} \ No newline at end of file diff --git a/examples/onnx/1l_instance_norm/network.onnx b/examples/onnx/1l_instance_norm/network.onnx new file mode 100644 index 000000000..78c939c3c Binary files /dev/null and b/examples/onnx/1l_instance_norm/network.onnx differ diff --git a/src/circuit/mod.rs b/src/circuit/mod.rs index 82d3b1434..61729c8e1 100644 --- a/src/circuit/mod.rs +++ b/src/circuit/mod.rs @@ -1,12 +1,13 @@ -/// Layouts for specific functions (composed of base ops) -pub mod layouts; - /// pub mod table; /// pub mod utils; +/// +pub mod ops; +pub use ops::*; + /// Tests #[cfg(test)] mod tests; @@ -20,25 +21,17 @@ use halo2_proofs::{ poly::Rotation, }; use halo2curves::FieldExt; -use itertools::Itertools; use log::{trace, warn}; use serde::{Deserialize, Serialize}; use crate::{ - fieldutils::{i128_to_felt, i32_to_felt}, - tensor::{self, Tensor, TensorError, TensorType, ValTensor, VarTensor}, -}; -use std::{ - cell::RefCell, - collections::BTreeMap, - error::Error, - fmt, - marker::PhantomData, - ops::{Add, Mul, Neg, Sub}, - rc::Rc, + circuit::ops::base::BaseOp, + fieldutils::i32_to_felt, + tensor::{Tensor, TensorType, ValTensor, VarTensor}, }; +use std::{cell::RefCell, collections::BTreeMap, error::Error, marker::PhantomData, rc::Rc}; -use self::table::Table; +use self::{ops::lookup::LookupOp, table::Table}; /// circuit related errors. #[derive(Debug, Error)] @@ -57,22 +50,6 @@ pub enum CircuitError { UnsupportedOp, } -#[allow(missing_docs)] -/// An enum representing the operations that can be used to express more complex operations via accumulation -#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] -pub enum BaseOp { - Dot, - Identity, - Add, - Mult, - Sub, - Sum, - Neg, - Range { tol: i32 }, - IsZero, - IsBoolean, -} - #[allow(missing_docs)] /// An enum representing activating the sanity checks we can perform on the accumulated arguments #[derive( @@ -94,592 +71,6 @@ impl From for CheckMode { } } -/// Matches a [BaseOp] to an operation over inputs -impl BaseOp { - /// forward func - pub fn f< - T: TensorType + Add + Sub + Mul + Neg, - >( - &self, - inputs: (T, T, T), - ) -> T { - let (a, b, m) = inputs; - match &self { - BaseOp::Dot => a * b + m, - BaseOp::Add => a + b, - BaseOp::Identity => b, - BaseOp::Sum => b + m, - BaseOp::Neg => -b, - BaseOp::Sub => a - b, - BaseOp::Mult => a * b, - BaseOp::Range { .. } => b, - BaseOp::IsZero => b, - BaseOp::IsBoolean => b, - } - } - - fn as_str(&self) -> &'static str { - match self { - BaseOp::Identity => "IDENTITY", - BaseOp::Dot => "DOT", - BaseOp::Add => "ADD", - BaseOp::Neg => "NEG", - BaseOp::Sub => "SUB", - BaseOp::Mult => "MULT", - BaseOp::Sum => "SUM", - BaseOp::Range { .. } => "RANGE", - BaseOp::IsZero => "ISZERO", - BaseOp::IsBoolean => "ISBOOLEAN", - } - } - fn query_offset_rng(&self) -> (i32, usize) { - match self { - BaseOp::Identity => (0, 1), - BaseOp::Neg => (0, 1), - BaseOp::Dot => (-1, 2), - BaseOp::Add => (0, 1), - BaseOp::Sub => (0, 1), - BaseOp::Mult => (0, 1), - BaseOp::Sum => (-1, 2), - BaseOp::Range { .. } => (0, 1), - BaseOp::IsZero => (0, 1), - BaseOp::IsBoolean => (0, 1), - } - } - fn num_inputs(&self) -> usize { - match self { - BaseOp::Identity => 1, - BaseOp::Neg => 1, - BaseOp::Dot => 2, - BaseOp::Add => 2, - BaseOp::Sub => 2, - BaseOp::Mult => 2, - BaseOp::Sum => 1, - BaseOp::Range { .. } => 1, - BaseOp::IsZero => 1, - BaseOp::IsBoolean => 1, - } - } - fn constraint_idx(&self) -> usize { - match self { - BaseOp::Identity => 0, - BaseOp::Neg => 0, - BaseOp::Dot => 1, - BaseOp::Add => 0, - BaseOp::Sub => 0, - BaseOp::Mult => 0, - BaseOp::Range { .. } => 0, - BaseOp::Sum => 1, - BaseOp::IsZero => 0, - BaseOp::IsBoolean => 0, - } - } -} - -impl fmt::Display for BaseOp { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self.as_str()) - } -} - -#[allow(missing_docs)] -/// An enum representing the operations that can be used to express more complex operations via accumulation -#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Deserialize, Serialize)] -pub enum LookupOp { - Div { - denom: utils::F32, - }, - ReLU { - scale: usize, - }, - Sqrt { - scales: (usize, usize), - }, - LeakyReLU { - scale: usize, - slope: utils::F32, - }, - PReLU { - scale: usize, - slopes: Vec, - }, - Sigmoid { - scales: (usize, usize), - }, - Tanh { - scales: (usize, usize), - }, - Erf { - scales: (usize, usize), - }, - Mean { - scale: usize, - }, - Max, - MaxPool2d { - padding: (usize, usize), - stride: (usize, usize), - pool_dims: (usize, usize), - }, - Min, -} - -impl LookupOp { - /// Matches a [Op] to an operation in the `tensor::ops` module. - pub fn f(&self, x: Tensor) -> Result, TensorError> { - match &self { - LookupOp::Div { denom } => Ok(tensor::ops::nonlinearities::const_div( - &x, - f32::from(*denom), - )), - LookupOp::ReLU { scale } => { - Ok(tensor::ops::nonlinearities::leakyrelu(&x, *scale, 0_f32)) - } - LookupOp::LeakyReLU { scale, slope } => { - Ok(tensor::ops::nonlinearities::leakyrelu(&x, *scale, slope.0)) - } - LookupOp::PReLU { scale, slopes } => Ok(tensor::ops::nonlinearities::prelu( - &x, - *scale, - &slopes.iter().map(|e| e.0).collect_vec(), - )), - LookupOp::Sigmoid { scales } => { - Ok(tensor::ops::nonlinearities::sigmoid(&x, scales.0, scales.1)) - } - LookupOp::Sqrt { scales } => { - Ok(tensor::ops::nonlinearities::sqrt(&x, scales.0, scales.1)) - } - LookupOp::Tanh { scales } => { - Ok(tensor::ops::nonlinearities::tanh(&x, scales.0, scales.1)) - } - LookupOp::Erf { scales } => { - Ok(tensor::ops::nonlinearities::erffunc(&x, scales.0, scales.1)) - } - LookupOp::Max { .. } => Tensor::new(Some(&[*x.iter().max().unwrap()]), &[1]), - LookupOp::Min { .. } => Tensor::new(Some(&[*x.iter().min().unwrap()]), &[1]), - LookupOp::MaxPool2d { - padding, - stride, - pool_dims, - } => tensor::ops::max_pool2d(&x, padding, stride, pool_dims), - - LookupOp::Mean { scale } => Ok(tensor::ops::nonlinearities::mean(&x, *scale)), - } - } - - fn as_str(&self) -> &'static str { - match self { - LookupOp::Min { .. } => "MIN", - LookupOp::Max { .. } => "MAX", - LookupOp::MaxPool2d { .. } => "MAXPOOL", - LookupOp::Div { .. } => "DIV", - LookupOp::ReLU { .. } => "RELU", - LookupOp::LeakyReLU { .. } => "LEAKY_RELU", - LookupOp::PReLU { .. } => "PRELU", - LookupOp::Sigmoid { .. } => "SIGMOID", - LookupOp::Sqrt { .. } => "SQRT", - LookupOp::Tanh { .. } => "TANH", - LookupOp::Erf { .. } => "ERF", - LookupOp::Mean { .. } => "MEAN", - } - } - - /// a value which is always in the table - pub fn default_pair(&self) -> (F, F) { - let x = vec![0_i128].into_iter().into(); - (F::zero(), i128_to_felt(self.f(x).unwrap()[0])) - } -} - -#[allow(missing_docs)] -/// An enum representing the operations that can be used to express more complex operations via accumulation -#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Deserialize, Serialize)] -pub enum Op { - Dot, - Matmul, - Affine, - Conv { - padding: (usize, usize), - stride: (usize, usize), - }, - SumPool { - padding: (usize, usize), - stride: (usize, usize), - kernel_shape: (usize, usize), - }, - Add, - Sub, - Mult, - Identity, - Reshape(Vec), - Flatten(Vec), - BatchNorm, - ScaleAndShift, - Pad(usize, usize), - Sum, - Pow(u32), - Pack(u32, u32), - GlobalSumPool, - Rescaled { - inner: Box, - scale: Vec<(usize, usize)>, - }, - RangeCheck(i32), -} - -impl Op { - /// circuit shape - pub fn circuit_shapes(&self, input_shapes: Vec>) -> Vec { - let mut shapes = match &self { - Op::Identity => vec![0, input_shapes[0].iter().product()], - Op::Reshape(_) => vec![0; 2], - Op::Flatten(_) => vec![0; 2], - Op::Pad(_, _) => vec![0; 2], - Op::Add => input_shapes.iter().map(|x| x.iter().product()).collect(), - Op::Mult => input_shapes.iter().map(|x| x.iter().product()).collect(), - Op::Sub => input_shapes.iter().map(|x| x.iter().product()).collect(), - Op::Sum => vec![0, input_shapes[0].iter().product()], - Op::Dot => input_shapes.iter().map(|x| x.iter().product()).collect(), - Op::Pow(_) => input_shapes.iter().map(|x| x.iter().product()).collect(), - Op::Pack(_, _) => input_shapes.iter().map(|x| x.iter().product()).collect(), - Op::GlobalSumPool => unreachable!("should be handled by sumpool"), - Op::ScaleAndShift => input_shapes.iter().map(|x| x.iter().product()).collect(), - Op::BatchNorm => input_shapes.iter().map(|x| x.iter().product()).collect(), - Op::Conv { padding, stride } => { - let image_dims = &input_shapes[0]; - let kernel_dims = &input_shapes[1]; - - let (output_channels, _input_channels, kernel_height, kernel_width) = ( - kernel_dims[0], - kernel_dims[1], - kernel_dims[2], - kernel_dims[3], - ); - - let (image_height, image_width) = (image_dims[1], image_dims[2]); - - let padded_height = image_height + 2 * padding.0; - let padded_width = image_width + 2 * padding.1; - - let vert_slides = (padded_height - kernel_height) / stride.0 + 1; - let horz_slides = (padded_width - kernel_width) / stride.1 + 1; - - let input_shapes = vec![ - vec![ - output_channels * vert_slides * horz_slides, - (padded_height * padded_width * image_dims[0] + 1), - ], - vec![(padded_height * padded_width * image_dims[0] + 1), 1], - ]; - let op = Op::Matmul; - let output_len = op.circuit_shapes(input_shapes); - - vec![*output_len.last().unwrap(); 2] - } - Op::SumPool { - padding, - stride, - kernel_shape, - } => { - let image_dims = &input_shapes[0]; - - let (image_height, image_width) = (image_dims[1], image_dims[2]); - - let padded_height = image_height + 2 * padding.0; - let padded_width = image_width + 2 * padding.1; - - let vert_slides = (padded_height - kernel_shape.0) / stride.0 + 1; - let horz_slides = (padded_width - kernel_shape.1) / stride.1 + 1; - - let input_shapes = vec![ - vec![ - image_dims[0] * vert_slides * horz_slides, - (padded_height * padded_width * image_dims[0] + 1), - ], - vec![(padded_height * padded_width * image_dims[0] + 1), 1], - ]; - let op = Op::Matmul; - let output_len = op.circuit_shapes(input_shapes); - - vec![*output_len.last().unwrap(); 2] - } - Op::Affine => { - let s = input_shapes; - // add 1 cause of bias - let output_len = s[1][0] * (s[1][1] + 1); - vec![output_len; 2] - } - Op::Matmul => { - let output_len = input_shapes[0].iter().product::() * input_shapes[1][1]; - - vec![output_len; 2] - } - Op::Rescaled { inner, .. } => inner.circuit_shapes(input_shapes), - Op::RangeCheck(..) => input_shapes.iter().map(|x| x.iter().product()).collect(), - }; - match shapes.last() { - // add output - Some(s) => shapes.push(*s), - _ => {} - }; - shapes - } - - /// Matches a [Op] to an operation in the `tensor::ops` module. - pub fn f(&self, mut inputs: Vec>) -> Result, TensorError> { - match &self { - Op::Identity => Ok(inputs[0].clone()), - Op::Reshape(new_dims) => { - let mut t = inputs[0].clone(); - t.reshape(new_dims); - Ok(t) - } - Op::Flatten(new_dims) => { - let mut t = inputs[0].clone(); - t.reshape(new_dims); - Ok(t) - } - Op::Pad(dim1, dim2) => { - if 1 != inputs.len() { - return Err(TensorError::DimMismatch("pad inputs".to_string())); - } - tensor::ops::pad(&inputs[0], (*dim1, *dim2)) - } - Op::Add => tensor::ops::add(&inputs), - Op::Sub => tensor::ops::sub(&inputs), - Op::Mult => tensor::ops::mult(&inputs), - Op::Affine => tensor::ops::affine(&inputs), - Op::BatchNorm => tensor::ops::scale_and_shift(&inputs), - Op::ScaleAndShift => tensor::ops::scale_and_shift(&inputs), - Op::Matmul => tensor::ops::matmul(&inputs), - Op::Dot => tensor::ops::dot(&inputs.iter().collect()), - Op::Conv { padding, stride } => tensor::ops::convolution(&inputs, *padding, *stride), - Op::SumPool { - padding, - stride, - kernel_shape, - } => tensor::ops::sumpool(&inputs[0], *padding, *stride, *kernel_shape), - Op::Pack(base, scale) => { - if 1 != inputs.len() { - return Err(TensorError::DimMismatch("pack inputs".to_string())); - } - - tensor::ops::pack(&inputs[0], *base as i128, *scale) - } - Op::Pow(u) => { - if 1 != inputs.len() { - return Err(TensorError::DimMismatch("pow inputs".to_string())); - } - inputs[0].pow(*u) - } - Op::Sum => { - if 1 != inputs.len() { - return Err(TensorError::DimMismatch("sum inputs".to_string())); - } - tensor::ops::sum(&inputs[0]) - } - Op::Rescaled { inner, scale } => { - if scale.len() != inputs.len() { - return Err(TensorError::DimMismatch("rescaled inputs".to_string())); - } - - let mut rescaled_inputs = vec![]; - for (i, ri) in inputs.iter_mut().enumerate() { - rescaled_inputs.push(tensor::ops::rescale(ri, scale[i].1)?); - } - Ok(inner.f(rescaled_inputs)?) - } - Op::GlobalSumPool => unreachable!(), - Op::RangeCheck(..) => Ok(inputs[0].clone()), - } - } -} - -impl fmt::Display for Op { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - Op::Identity => write!(f, "identity"), - Op::Reshape(new_dims) => write!(f, "reshape to {:?}", new_dims), - Op::Flatten(new_dims) => write!(f, "flatten to {:?}", new_dims), - Op::Pad(dim1, dim2) => write!(f, "padding: ({:?}, {:?})", dim1, dim2), - Op::Add => write!(f, "add"), - Op::Sub => write!(f, "sub"), - Op::Sum => write!(f, "sum"), - Op::Mult => write!(f, "mult"), - Op::Matmul => write!(f, "matmul"), - Op::Dot => write!(f, "dot"), - Op::Pack(base, _) => write!(f, "pack with base {:?}", base), - Op::Affine => write!(f, "affine"), - Op::BatchNorm => write!(f, "batchnorm"), - Op::ScaleAndShift => write!(f, "scale & shift"), - Op::Conv { padding, stride } => { - write!(f, "conv w/ padding: {:?}, stride: {:?}", padding, stride) - } - Op::SumPool { - padding, - stride, - kernel_shape, - } => { - write!( - f, - "avg pl w/ padding: {:?}, stride: {:?}, kernel shape: {:?}", - padding, stride, kernel_shape, - ) - } - Op::GlobalSumPool => write!(f, "globalsumpool"), - Op::Pow(s) => write!(f, "pow {}", s), - Op::Rescaled { inner, scale } => { - write!( - f, - "rescaled {} w/ scalings: {:?}", - **inner, - scale.iter().map(|e| e.1).collect_vec() - ) - } - Op::RangeCheck(tol) => write!(f, "range check w/ tol {}", tol), - } - } -} - -// Initially, some of these OpKinds will be folded into others (for example, Const nodes that -// contain parameters will be handled at the consuming self. -// Eventually, though, we probably want to keep them and treat them directly (layouting and configuring -// at each type of node) -/// Enum of the different kinds of operations `ezkl` can support. -#[derive(Clone, Debug, Default, PartialEq, Eq, Ord, PartialOrd, Deserialize, Serialize)] -pub enum OpKind { - /// A nonlinearity - Lookup(LookupOp), - /// A fused op, combining affine layers or other arithmetic - Poly(Op), - /// Constant - Const, - /// Input node - Input, - /// Unable to parse the node type - Unknown(String), - #[allow(missing_docs)] - #[default] - None, -} - -impl From for OpKind { - fn from(op: Op) -> Self { - OpKind::Poly(op) - } -} - -impl From for OpKind { - fn from(op: LookupOp) -> Self { - OpKind::Lookup(op) - } -} - -impl OpKind { - /// Produce an OpKind from a `&str` onnx name - pub fn new(name: &str) -> Self { - match name { - "Reduce" => OpKind::Lookup(LookupOp::Min), - "Reduce" => OpKind::Lookup(LookupOp::Max), - "Clip" => OpKind::Lookup(LookupOp::ReLU { scale: 1 }), - "Prelu" => OpKind::Lookup(LookupOp::PReLU { - scale: 1, - slopes: vec![], - }), - "LeakyRelu" => OpKind::Lookup(LookupOp::LeakyReLU { - scale: 1, - slope: utils::F32(0.0), - }), - "Sigmoid" => OpKind::Lookup(LookupOp::Sigmoid { scales: (1, 1) }), - "Sqrt" => OpKind::Lookup(LookupOp::Sqrt { scales: (1, 1) }), - "Tanh" => OpKind::Lookup(LookupOp::Tanh { scales: (1, 1) }), - "onnx.Erf" => OpKind::Lookup(LookupOp::Erf { scales: (1, 1) }), - "Div" => OpKind::Lookup(LookupOp::Div { - denom: utils::F32(1.0), - }), - "Const" => OpKind::Const, - "Source" => OpKind::Input, - "Add" => OpKind::Poly(Op::Add), - "Sub" => OpKind::Poly(Op::Sub), - "Mul" => OpKind::Poly(Op::Mult), - "Gemm" => OpKind::Poly(Op::Affine), - "MatMulInference" => OpKind::Poly(Op::Matmul), - "MaxPool" => OpKind::Lookup(LookupOp::MaxPool2d { - padding: (1, 1), - stride: (1, 1), - pool_dims: (1, 1), - }), - "Dot" => OpKind::Poly(Op::Dot), - "Reduce" => OpKind::Poly(Op::Sum), - "Reduce" => OpKind::Lookup(LookupOp::Mean { scale: 1 }), - "Pow" => OpKind::Poly(Op::Pow(1)), - "Conv" | "ConvHir" => OpKind::Poly(Op::Conv { - padding: (1, 1), - stride: (1, 1), - }), - "SumPool" => OpKind::Poly(Op::SumPool { - padding: (1, 1), - stride: (1, 1), - kernel_shape: (1, 1), - }), - "GlobalAvgPool" => OpKind::Poly(Op::GlobalSumPool), - "Pad" => OpKind::Poly(Op::Pad(0, 0)), - "Reshape" => OpKind::Poly(Op::Reshape(Vec::new())), - "Flatten" => OpKind::Poly(Op::Flatten(Vec::new())), - "BatchNorm" => OpKind::Poly(Op::BatchNorm), - c => { - warn!("{:?} is not currently supported", c); - OpKind::Unknown(c.to_string()) - } - } - } - /// is ploy type constrant - pub fn is_poly(&self) -> bool { - matches!(self, OpKind::Poly(_)) - } - - /// is lookup based op - pub fn is_lookup(&self) -> bool { - matches!(self, OpKind::Lookup(_)) - } - - /// is lookup based op - pub fn is_parameterized(&self) -> bool { - match self { - OpKind::Poly(Op::Affine) | OpKind::Poly(Op::Conv { .. }) => true, - _ => false, - } - } - - /// is rescaled op - pub fn is_rescaled(&self) -> bool { - matches!(self, OpKind::Poly(Op::Rescaled { .. })) - } - - /// is input - pub fn is_input(&self) -> bool { - matches!(self, OpKind::Input) - } - - /// is const - pub fn is_const(&self) -> bool { - matches!(self, OpKind::Const) - } -} - -impl fmt::Display for OpKind { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - OpKind::Const => write!(f, "const"), - OpKind::Input => write!(f, "input"), - OpKind::Lookup(s) => write!(f, "{:#?}", s), - OpKind::Poly(s) => write!(f, "{}", s), - OpKind::Unknown(c) => write!(f, "? {}", c), - OpKind::None => write!(f, "n/a",), - } - } -} - /// Configuration for an accumulated arg. #[derive(Clone, Debug, Default)] pub struct BaseConfig { @@ -705,6 +96,21 @@ pub struct BaseConfig { } impl BaseConfig { + /// Returns a new [BaseConfig] with no inputs, no selectors, and no tables. + pub fn dummy(col_size: usize) -> Self { + Self { + inputs: vec![VarTensor::dummy(col_size), VarTensor::dummy(col_size)], + lookup_input: VarTensor::dummy(col_size), + output: VarTensor::dummy(col_size), + lookup_output: VarTensor::dummy(col_size), + selectors: BTreeMap::new(), + lookup_selectors: BTreeMap::new(), + tables: BTreeMap::new(), + check_mode: CheckMode::SAFE, + _marker: PhantomData, + } + } + /// Configures [BaseOp]s for a given [ConstraintSystem]. /// # Arguments /// * `inputs` - The explicit inputs to the operations. @@ -728,10 +134,11 @@ impl BaseConfig { selectors.insert((BaseOp::Sub, i), meta.selector()); selectors.insert((BaseOp::Dot, i), meta.selector()); selectors.insert((BaseOp::Sum, i), meta.selector()); + selectors.insert((BaseOp::Neg, i), meta.selector()); selectors.insert((BaseOp::Mult, i), meta.selector()); - selectors.insert((BaseOp::Identity, i), meta.selector()); selectors.insert((BaseOp::Range { tol }, i), meta.selector()); selectors.insert((BaseOp::IsZero, i), meta.selector()); + selectors.insert((BaseOp::Identity, i), meta.selector()); selectors.insert((BaseOp::IsBoolean, i), meta.selector()); } @@ -801,8 +208,8 @@ impl BaseConfig { selectors, lookup_selectors: BTreeMap::new(), inputs: inputs.to_vec(), - lookup_input: VarTensor::None, - lookup_output: VarTensor::None, + lookup_input: VarTensor::Empty, + lookup_output: VarTensor::Empty, tables: BTreeMap::new(), output: output.clone(), check_mode, @@ -831,7 +238,7 @@ impl BaseConfig { for x in 0..input.num_cols() { let qlookup = cs.complex_selector(); selectors.insert((nl.clone(), x), qlookup); - let _ = cs.lookup(nl.as_str(), |cs| { + let _ = cs.lookup(Op::::as_str(nl), |cs| { let qlookup = cs.query_selector(qlookup); let not_qlookup = Expression::Constant(::one()) - qlookup.clone(); let (default_x, default_y): (F, F) = nl.default_pair(); @@ -869,11 +276,12 @@ impl BaseConfig { } self.lookup_selectors.extend(selectors); // if we haven't previously initialized the input/output, do so now - if let VarTensor::None = self.lookup_input { + if let VarTensor::Empty = self.lookup_input { warn!("assiging lookup input"); self.lookup_input = input.clone(); } - if let VarTensor::None = self.lookup_output { + if let VarTensor::Empty = self.lookup_output { + warn!("assiging lookup output"); self.lookup_output = output.clone(); } Ok(()) @@ -881,7 +289,7 @@ impl BaseConfig { /// layout_tables must be called before layout. pub fn layout_tables(&mut self, layouter: &mut impl Layouter) -> Result<(), Box> { - for (_, table) in &self.tables { + for table in self.tables.values() { if !table.borrow().is_assigned { table.borrow_mut().layout(layouter)?; } @@ -897,153 +305,27 @@ impl BaseConfig { /// * `op` - The operation being represented. pub fn layout( &mut self, - region: &mut Region, + mut region: Option<&mut Region>, values: &[ValTensor], offset: &mut usize, - op: OpKind, + op: Box>, ) -> Result>, Box> { let mut cp_values = vec![]; for v in values.iter() { if let ValTensor::Instance { .. } = v { - cp_values.push(layouts::identity(self, region, &[v.clone()], offset)?); + cp_values.push(layouts::identity( + self, + region.as_deref_mut(), + &[v.clone()], + offset, + )?); } else { cp_values.push(v.clone()); } } trace!("laying out {:?}", op); - let res = match op { - OpKind::Poly(op) => Some(match op { - Op::Dot => layouts::dot(self, region, cp_values[..].try_into()?, offset)?, - Op::Sum => layouts::sum(self, region, cp_values[..].try_into()?, offset)?, - Op::Matmul => layouts::matmul(self, region, cp_values[..].try_into()?, offset)?, - Op::Affine => layouts::affine(self, region, cp_values[..].try_into()?, offset)?, - Op::Conv { padding, stride } => layouts::conv( - self, - region, - cp_values[..].try_into()?, - padding, - stride, - offset, - )?, - Op::SumPool { - padding, - stride, - kernel_shape, - } => layouts::sumpool( - self, - region, - cp_values[..].try_into()?, - padding, - stride, - kernel_shape, - offset, - )?, - Op::Add => { - layouts::pairwise(self, region, cp_values[..].try_into()?, offset, BaseOp::Add)? - } - Op::Sub => { - layouts::pairwise(self, region, cp_values[..].try_into()?, offset, BaseOp::Sub)? - } - Op::Mult => layouts::pairwise( - self, - region, - cp_values[..].try_into()?, - offset, - BaseOp::Mult, - )?, - Op::Identity => layouts::identity(self, region, cp_values[..].try_into()?, offset)?, - Op::Reshape(d) | Op::Flatten(d) => layouts::reshape(cp_values[..].try_into()?, &d)?, - Op::BatchNorm => { - layouts::scale_and_shift(self, region, cp_values[..].try_into()?, offset)? - } - Op::ScaleAndShift => { - layouts::scale_and_shift(self, region, cp_values[..].try_into()?, offset)? - } - Op::Pad(p1, p2) => { - if values.len() != 1 { - return Err(Box::new(TensorError::DimError)); - } - let mut input = cp_values[0].clone(); - input.pad((p1, p2))?; - input - } - Op::Pow(exp) => layouts::pow(self, region, cp_values[..].try_into()?, exp, offset)?, - Op::Pack(base, scale) => { - layouts::pack(self, region, cp_values[..].try_into()?, base, scale, offset)? - } - Op::Rescaled { inner, scale } => { - if scale.len() != values.len() { - return Err(Box::new(TensorError::DimMismatch( - "rescaled inputs".to_string(), - ))); - } - - let res = - &layouts::rescale(self, region, cp_values[..].try_into()?, &scale, offset)? - [..]; - self.layout(region, res, offset, OpKind::Poly(*inner))? - .unwrap() - } - Op::RangeCheck(tol) => { - layouts::range_check(self, region, cp_values[..].try_into()?, offset, tol)? - } - Op::GlobalSumPool => unreachable!(), - }), - OpKind::Lookup(nl) => match nl { - LookupOp::PReLU { scale, .. } => Some(layouts::prelu( - self, - region, - cp_values[..].try_into()?, - scale, - offset, - )?), - LookupOp::Mean { scale, .. } => Some(layouts::mean( - self, - region, - cp_values[..].try_into()?, - scale, - offset, - )?), - LookupOp::MaxPool2d { - padding, - stride, - pool_dims, - } => Some(layouts::max_pool2d( - self, - region, - cp_values[..].try_into()?, - padding, - stride, - pool_dims, - offset, - )?), - LookupOp::Max => Some(layouts::max( - self, - region, - cp_values[..].try_into()?, - offset, - )?), - LookupOp::Min => Some(layouts::min( - self, - region, - cp_values[..].try_into()?, - offset, - )?), - _ => Some(layouts::nonlinearity( - self, - region, - cp_values[..].try_into()?, - nl, - offset, - )?), - }, - OpKind::Const => None, - OpKind::Input => None, - _ => { - return Err(Box::new(CircuitError::UnsupportedOp)); - } - }; - Ok(res) + let res = op.layout(self, region, &cp_values, offset); + res } } diff --git a/src/circuit/ops/base.rs b/src/circuit/ops/base.rs new file mode 100644 index 000000000..5cce69876 --- /dev/null +++ b/src/circuit/ops/base.rs @@ -0,0 +1,116 @@ +use crate::tensor::TensorType; +use std::{ + fmt, + ops::{Add, Mul, Neg, Sub}, +}; + +#[allow(missing_docs)] +/// An enum representing the operations that can be used to express more complex operations via accumulation +#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub enum BaseOp { + Dot, + Identity, + Add, + Mult, + Sub, + Sum, + Neg, + Range { tol: i32 }, + IsZero, + IsBoolean, +} + +/// Matches a [BaseOp] to an operation over inputs +impl BaseOp { + /// forward func + pub fn f< + T: TensorType + Add + Sub + Mul + Neg, + >( + &self, + inputs: (T, T, T), + ) -> T { + let (a, b, m) = inputs; + match &self { + BaseOp::Dot => a * b + m, + BaseOp::Add => a + b, + BaseOp::Identity => b, + BaseOp::Sum => b + m, + BaseOp::Neg => -b, + BaseOp::Sub => a - b, + BaseOp::Mult => a * b, + BaseOp::Range { .. } => b, + BaseOp::IsZero => b, + BaseOp::IsBoolean => b, + } + } + + /// + pub fn as_str(&self) -> &'static str { + match self { + BaseOp::Identity => "IDENTITY", + BaseOp::Dot => "DOT", + BaseOp::Add => "ADD", + BaseOp::Neg => "NEG", + BaseOp::Sub => "SUB", + BaseOp::Mult => "MULT", + BaseOp::Sum => "SUM", + BaseOp::Range { .. } => "RANGE", + BaseOp::IsZero => "ISZERO", + BaseOp::IsBoolean => "ISBOOLEAN", + } + } + + /// + pub fn query_offset_rng(&self) -> (i32, usize) { + match self { + BaseOp::Identity => (0, 1), + BaseOp::Neg => (0, 1), + BaseOp::Dot => (-1, 2), + BaseOp::Add => (0, 1), + BaseOp::Sub => (0, 1), + BaseOp::Mult => (0, 1), + BaseOp::Sum => (-1, 2), + BaseOp::Range { .. } => (0, 1), + BaseOp::IsZero => (0, 1), + BaseOp::IsBoolean => (0, 1), + } + } + + /// + pub fn num_inputs(&self) -> usize { + match self { + BaseOp::Identity => 1, + BaseOp::Neg => 1, + BaseOp::Dot => 2, + BaseOp::Add => 2, + BaseOp::Sub => 2, + BaseOp::Mult => 2, + BaseOp::Sum => 1, + BaseOp::Range { .. } => 1, + BaseOp::IsZero => 1, + BaseOp::IsBoolean => 1, + } + } + + /// + pub fn constraint_idx(&self) -> usize { + match self { + BaseOp::Identity => 0, + BaseOp::Neg => 0, + BaseOp::Dot => 1, + BaseOp::Add => 0, + BaseOp::Sub => 0, + BaseOp::Mult => 0, + BaseOp::Range { .. } => 0, + BaseOp::Sum => 1, + BaseOp::IsZero => 0, + BaseOp::IsBoolean => 0, + } + } +} + +impl fmt::Display for BaseOp { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.as_str()) + } +} diff --git a/src/circuit/ops/hybrid.rs b/src/circuit/ops/hybrid.rs new file mode 100644 index 000000000..8818ce000 --- /dev/null +++ b/src/circuit/ops/hybrid.rs @@ -0,0 +1,219 @@ +use halo2curves::FieldExt; +use itertools::Itertools; +use serde::{Deserialize, Serialize}; + +use crate::{ + circuit::{layouts, utils}, + graph::scale_to_multiplier, + tensor::{self, Tensor, TensorError, TensorType}, +}; + +use super::{lookup::LookupOp, Op}; + +#[allow(missing_docs)] +/// An enum representing the operations that can be used to express more complex operations via accumulation +#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Deserialize, Serialize)] +pub enum HybridOp { + Mean { + scale: usize, + num_inputs: usize, + }, + Max, + MaxPool2d { + padding: (usize, usize), + stride: (usize, usize), + pool_dims: (usize, usize), + }, + InstanceNorm2d { + epsilon: crate::circuit::utils::F32, + }, + Min, + PReLU { + scale: usize, + slopes: Vec, + }, +} + +impl Op for HybridOp { + /// Matches a [Op] to an operation in the `tensor::ops` module. + fn f(&self, inputs: &[Tensor]) -> Result, TensorError> { + match &self { + HybridOp::Mean { scale, .. } => { + Ok(tensor::ops::nonlinearities::mean(&inputs[0], *scale)) + } + HybridOp::Max => Ok(Tensor::new( + Some(&[inputs[0].clone().into_iter().max().unwrap()]), + &[1], + )?), + HybridOp::MaxPool2d { + padding, + stride, + pool_dims, + } => tensor::ops::max_pool2d(&inputs[0], padding, stride, pool_dims), + HybridOp::InstanceNorm2d { epsilon } => Ok(tensor::ops::nonlinearities::instance_norm( + inputs.to_vec().try_into().unwrap(), + epsilon.0, + )), + HybridOp::Min => Ok(Tensor::new( + Some(&[inputs[0].clone().into_iter().min().unwrap()]), + &[1], + )?), + HybridOp::PReLU { scale, slopes } => Ok(tensor::ops::nonlinearities::prelu( + &inputs[0], + *scale, + &slopes.iter().map(|e| e.0).collect_vec(), + )), + } + } + + fn as_str(&self) -> &'static str { + match &self { + HybridOp::Mean { .. } => "MEAN", + HybridOp::Max => "MAX", + HybridOp::MaxPool2d { .. } => "MAXPOOL2D", + HybridOp::InstanceNorm2d { .. } => "INSTANCENORM", + HybridOp::Min => "MIN", + HybridOp::PReLU { .. } => "PRELU", + } + } + + fn layout( + &self, + config: &mut crate::circuit::BaseConfig, + region: Option<&mut halo2_proofs::circuit::Region>, + values: &[tensor::ValTensor], + offset: &mut usize, + ) -> Result>, Box> { + Ok(match self { + HybridOp::PReLU { scale, .. } => Some(layouts::prelu( + config, + region, + values[..].try_into()?, + *scale, + offset, + )?), + HybridOp::Mean { scale, .. } => Some(layouts::mean( + config, + region, + values[..].try_into()?, + *scale, + offset, + )?), + HybridOp::MaxPool2d { + padding, + stride, + pool_dims, + } => Some(layouts::max_pool2d( + config, + region, + values[..].try_into()?, + *padding, + *stride, + *pool_dims, + offset, + )?), + HybridOp::Max => Some(layouts::max( + config, + region, + values[..].try_into()?, + offset, + )?), + HybridOp::Min => Some(layouts::min( + config, + region, + values[..].try_into()?, + offset, + )?), + HybridOp::InstanceNorm2d { epsilon } => Some(layouts::instance_norm( + config, + region, + values[..].try_into()?, + 1, + epsilon.0 as u64, + offset, + )?), + }) + } + + fn out_scale(&self, in_scales: Vec, _: u32) -> u32 { + in_scales[0] + } + + fn out_dims(&self, in_dims: Vec>) -> Vec { + match self { + HybridOp::Mean { .. } => vec![1], + HybridOp::Max => vec![1], + HybridOp::MaxPool2d { + padding, + stride, + pool_dims, + } => { + let (out_channels, kernel_height, kernel_width) = + (in_dims[0][0], pool_dims.0, pool_dims.1); + + let (padding_h, padding_w, stride_h, stride_w) = + (padding.0, padding.1, stride.0, stride.1); + + let input_height = in_dims[0][1]; + let input_width = in_dims[0][2]; + + let out_height = (input_height + 2 * padding_h - kernel_height) / stride_h + 1; + let out_width = (input_width + 2 * padding_w - kernel_width) / stride_w + 1; + + vec![out_channels, out_height, out_width] + } + HybridOp::InstanceNorm2d { .. } => in_dims[0].clone(), + HybridOp::Min => vec![1], + HybridOp::PReLU { .. } => in_dims[0].clone(), + } + } + + fn has_3d_input(&self) -> bool { + matches!( + self, + HybridOp::MaxPool2d { .. } | HybridOp::InstanceNorm2d { .. } + ) + } + + fn rescale(&self, inputs_scale: Vec, global_scale: u32) -> Box> { + let mult = scale_to_multiplier(inputs_scale[0] - global_scale); + match self { + HybridOp::PReLU { scale: _, slopes } => Box::new(HybridOp::PReLU { + scale: mult as usize, + slopes: slopes.to_vec(), + }), + HybridOp::Mean { + scale: _, + num_inputs, + } => Box::new(HybridOp::Mean { + scale: mult as usize, + num_inputs: *num_inputs, + }), + _ => Box::new(self.clone()), + } + } + + fn required_lookup(&self) -> Option { + match self { + HybridOp::PReLU { scale, .. } => Some(LookupOp::ReLU { scale: *scale }), + HybridOp::Max | HybridOp::Min | HybridOp::MaxPool2d { .. } => { + Some(LookupOp::ReLU { scale: 1 }) + } + HybridOp::Mean { scale, num_inputs } => Some(LookupOp::Div { + denom: utils::F32((*scale * *num_inputs) as f32), + }), + HybridOp::InstanceNorm2d { .. } => Some(LookupOp::Sqrt { scales: (1, 1) }), + } + } + + fn bias_variable(&self) -> Option { + match self { + HybridOp::InstanceNorm2d { .. } => Some(2), + _ => None, + } + } + + fn clone_dyn(&self) -> Box> { + Box::new(self.clone()) // Forward to the derive(Clone) impl + } +} diff --git a/src/circuit/layouts.rs b/src/circuit/ops/layouts.rs similarity index 69% rename from src/circuit/layouts.rs rename to src/circuit/ops/layouts.rs index c8d2ab3a8..41af56781 100644 --- a/src/circuit/layouts.rs +++ b/src/circuit/ops/layouts.rs @@ -2,29 +2,32 @@ use core::panic; use std::error::Error; use halo2_proofs::circuit::{Region, Value}; -use log::error; +use itertools::Itertools; +use log::{error, trace}; use crate::{ - circuit::{utils, CircuitError}, + circuit::{ops::base::BaseOp, utils, BaseConfig, CheckMode, CircuitError}, fieldutils::i128_to_felt, tensor::{ ops::{ accumulated, add, affine as non_accum_affine, convolution as non_accum_conv, dot as non_accum_dot, matmul as non_accum_matmul, max_pool2d as non_accum_max_pool2d, - mult, nonlinearities::prelu as ref_prelu, pack as non_accum_pack, - rescale as ref_rescaled, scale_and_shift as ref_scale_and_shift, sub, - sum as non_accum_sum, sumpool as non_accum_sumpool, + mult, nonlinearities::instance_norm as ref_instance_norm, + nonlinearities::prelu as ref_prelu, pack as non_accum_pack, rescale as ref_rescaled, + scale_and_shift as ref_scale_and_shift, sub, sum as non_accum_sum, + sumpool as non_accum_sumpool, }, Tensor, TensorError, ValType, }, }; use super::*; +use crate::circuit::ops::lookup::LookupOp; /// Dot product accumulated layout pub fn dot( config: &mut BaseConfig, - region: &mut Region, + mut region: Option<&mut Region>, values: &[ValTensor; 2], offset: &mut usize, ) -> Result, Box> { @@ -33,13 +36,13 @@ pub fn dot( for (i, input) in values.iter().enumerate() { let inp = { let (res, len) = config.inputs[i].assign_with_duplication( - region, + region.as_deref_mut(), *offset, input, &config.check_mode, )?; assigned_len = len; - res.map(|e| e.value_field().evaluate()) + res.get_inner()? }; inputs.push(inp); } @@ -48,7 +51,7 @@ pub fn dot( let accumulated_dot = accumulated::dot(&[inputs[0].clone(), inputs[1].clone()]) .expect("accum poly: dot op failed"); let (output, output_assigned_len) = config.output.assign_with_duplication( - region, + region.as_deref_mut(), *offset, &accumulated_dot.into(), &config.check_mode, @@ -56,24 +59,26 @@ pub fn dot( assert_eq!(assigned_len, output_assigned_len); - for i in 0..assigned_len { - let (x, y) = config.output.cartesian_coord(*offset + i); - // hop over duplicates at start of column - if y == 0 && i > 0 { - continue; - } - if i == 0 { - config - .selectors - .get(&(BaseOp::Mult, x)) - .unwrap() - .enable(region, y)?; - } else { - config - .selectors - .get(&(BaseOp::Dot, x)) - .unwrap() - .enable(region, y)?; + if let Some(region) = region { + for i in 0..assigned_len { + let (x, y) = config.output.cartesian_coord(*offset + i); + // hop over duplicates at start of column + if y == 0 && i > 0 { + continue; + } + if i == 0 { + config + .selectors + .get(&(BaseOp::Mult, x)) + .unwrap() + .enable(region, y)?; + } else { + config + .selectors + .get(&(BaseOp::Dot, x)) + .unwrap() + .enable(region, y)?; + } } } @@ -88,39 +93,39 @@ pub fn dot( })?; assert_eq!( - Into::>::into(last_elem.clone()), + Into::>::into(last_elem.get_inner()?), Into::>::into(safe_dot), ); } *offset += assigned_len; // last element is the result - Ok(ValTensor::from(last_elem)) + Ok(last_elem) } /// Sum accumulated layout pub fn sum( config: &mut BaseConfig, - region: &mut Region, + mut region: Option<&mut Region>, values: &[ValTensor; 1], offset: &mut usize, ) -> Result, Box> { let assigned_len: usize; let input = { let (res, len) = config.inputs[1].assign_with_duplication( - region, + region.as_deref_mut(), *offset, &values[0], &config.check_mode, )?; assigned_len = len; - res.map(|e| e.value_field().evaluate()) + res.get_inner()? }; // Now we can assign the dot product let accumulated_sum = accumulated::sum(&input).expect("accum poly: sum op failed"); let (output, output_assigned_len) = config.output.assign_with_duplication( - region, + region.as_deref_mut(), *offset, &accumulated_sum.into(), &config.check_mode, @@ -128,24 +133,26 @@ pub fn sum( assert_eq!(assigned_len, output_assigned_len); - for i in 0..assigned_len { - let (x, y) = config.output.cartesian_coord(*offset + i); - // skip over duplicates at start of column - if y == 0 && i > 0 { - continue; - } - if i == 0 { - config - .selectors - .get(&(BaseOp::Identity, x)) - .unwrap() - .enable(region, y)?; - } else { - config - .selectors - .get(&(BaseOp::Sum, x)) - .unwrap() - .enable(region, y)?; + if let Some(region) = region { + for i in 0..assigned_len { + let (x, y) = config.output.cartesian_coord(*offset + i); + // skip over duplicates at start of column + if y == 0 && i > 0 { + continue; + } + if i == 0 { + config + .selectors + .get(&(BaseOp::Identity, x)) + .unwrap() + .enable(region, y)?; + } else { + config + .selectors + .get(&(BaseOp::Sum, x)) + .unwrap() + .enable(region, y)?; + } } } @@ -160,20 +167,20 @@ pub fn sum( })?; assert_eq!( - Into::>::into(last_elem.clone()), + Into::>::into(last_elem.get_inner()?), Into::>::into(safe_dot), ) } *offset += assigned_len; // last element is the result - Ok(ValTensor::from(last_elem)) + Ok(last_elem) } /// Pairwise (elementwise) op layout pub fn pairwise( config: &mut BaseConfig, - region: &mut Region, + mut region: Option<&mut Region>, values: &[ValTensor; 2], offset: &mut usize, op: BaseOp, @@ -202,8 +209,8 @@ pub fn pairwise( for (i, input) in [lhs.clone(), rhs.clone()].iter().enumerate() { let inp = { - let res = config.inputs[i].assign(region, *offset, input)?; - res.map(|e| e.value_field().evaluate()) + let res = config.inputs[i].assign(region.as_deref_mut(), *offset, input)?; + res.get_inner()? }; inputs.push(inp); } @@ -220,28 +227,32 @@ pub fn pairwise( halo2_proofs::plonk::Error::Synthesis })?; - let mut output = config.output.assign(region, *offset, &op_result.into())?; + let mut output = config + .output + .assign(region.as_deref_mut(), *offset, &op_result.into())?; - for i in 0..inputs[0].len() { - let (x, y) = config.inputs[0].cartesian_coord(*offset + i); - config - .selectors - .get(&(op.clone(), x)) - .unwrap() - .enable(region, y)?; + if let Some(region) = region { + for i in 0..inputs[0].len() { + let (x, y) = config.inputs[0].cartesian_coord(*offset + i); + config + .selectors + .get(&(op.clone(), x)) + .unwrap() + .enable(region, y)?; + } } *offset += output.len(); - output.reshape(lhs.dims().clone()); + output.reshape(lhs.dims())?; - Ok(ValTensor::from(output)) + Ok(output) } /// Matrix multiplication accumulated layout pub fn matmul( config: &mut BaseConfig, - region: &mut Region, + mut region: Option<&mut Region>, values: &[ValTensor; 2], offset: &mut usize, ) -> Result, Box> { @@ -309,13 +320,13 @@ pub fn matmul( for (i, elem) in vec![a.clone(), b.clone()].iter().enumerate() { let inp = { let (res, len) = config.inputs[i].assign_with_duplication( - region, + region.as_deref_mut(), *offset, elem, &config.check_mode, )?; assigned_len = len; - res.map(|e| e.value_field().evaluate()) + res.get_inner()? }; inputs.push(inp); } @@ -346,7 +357,7 @@ pub fn matmul( .expect("accum poly: matmul op failed"); let (output, output_assigned_len) = config.output.assign_with_duplication( - region, + region.as_deref_mut(), *offset, &accumulated_matmul.into(), &config.check_mode, @@ -354,27 +365,29 @@ pub fn matmul( assert_eq!(assigned_len, output_assigned_len); - let mut idx_wo_duplicates = 0; - for i in 0..assigned_len { - let (x, y) = config.output.cartesian_coord(*offset + i); - // skip over duplicates at start of column - if y == 0 && i > 0 { - continue; - } - if idx_wo_duplicates % b_row_len > 0 { - config - .selectors - .get(&(BaseOp::Dot, x)) - .unwrap() - .enable(region, y)?; - } else { - config - .selectors - .get(&(BaseOp::Mult, x)) - .unwrap() - .enable(region, y)?; + if let Some(region) = region.as_deref_mut() { + let mut idx_wo_duplicates = 0; + for i in 0..assigned_len { + let (x, y) = config.output.cartesian_coord(*offset + i); + // skip over duplicates at start of column + if y == 0 && i > 0 { + continue; + } + if idx_wo_duplicates % b_row_len > 0 { + config + .selectors + .get(&(BaseOp::Dot, x)) + .unwrap() + .enable(region, y)?; + } else { + config + .selectors + .get(&(BaseOp::Mult, x)) + .unwrap() + .enable(region, y)?; + } + idx_wo_duplicates += 1; } - idx_wo_duplicates += 1; } let dims = output.dims(); @@ -390,7 +403,7 @@ pub fn matmul( .get_slice(&last_dims) .expect("accum poly: failed to fetch last elem"); - last_elem.reshape(&[original_a_dims[0], original_b_dims[1]]); + last_elem.reshape(&[original_a_dims[0], original_b_dims[1]])?; if matches!(&config.check_mode, CheckMode::SAFE) { let safe_mm = non_accum_matmul(&inputs).map_err(|e| { @@ -399,14 +412,14 @@ pub fn matmul( })?; assert_eq!( - Into::>::into(last_elem.clone()), + Into::>::into(last_elem.get_inner()?), Into::>::into(safe_mm), ) } *offset += assigned_len; - res.push(last_elem.clone()); + res.push(last_elem.get_inner_tensor()?); } let mut res = Tensor::new(Some(&res), &[res.len()])?.combine()?; @@ -423,11 +436,12 @@ pub fn matmul( /// Affine operation accumulated layout pub fn affine( config: &mut BaseConfig, - region: &mut Region, + region: Option<&mut Region>, values: &[ValTensor; 3], offset: &mut usize, ) -> Result, Box> { let (mut input, kernel, mut bias) = (values[0].clone(), values[1].clone(), values[2].clone()); + if input.dims().len() == 1 { input.reshape(&[input.len(), 1])?; } @@ -470,28 +484,41 @@ pub fn affine( /// Negation operation accumulated layout pub fn neg( config: &mut BaseConfig, - region: &mut Region, + mut region: Option<&mut Region>, values: &[ValTensor; 1], offset: &mut usize, ) -> Result, Box> { let input = { - let res = config.inputs[1].assign(region, *offset, &values[0])?; - res.map(|e| e.value_field().evaluate()) + let res = config.inputs[1].assign(region.as_deref_mut(), *offset, &values[0])?; + res.get_inner()? }; let neg = input.map(|e| -e); - let output = config.output.assign(region, *offset, &neg.into())?; + let output = config + .output + .assign(region.as_deref_mut(), *offset, &neg.into())?; + + if let Some(region) = region { + for i in 0..values[0].len() { + let (x, y) = config.inputs[1].cartesian_coord(*offset + i); + config + .selectors + .get(&(BaseOp::Neg, x)) + .unwrap() + .enable(region, y)?; + } + } *offset += output.len(); - Ok(output.into()) + Ok(output) } /// Sumpool accumulated layout pub fn sumpool( config: &mut BaseConfig, - region: &mut Region, + mut region: Option<&mut Region>, values: &[ValTensor], padding: (usize, usize), stride: (usize, usize), @@ -500,7 +527,15 @@ pub fn sumpool( ) -> Result, Box> { let image_channels = values[0].dims()[0]; - let unit = config.inputs[1].assign_constant(region, *offset, F::from(1))?; + let unit: ValType = if let Some(region) = region.as_deref_mut() { + config.inputs[1] + .assign_constant(region, *offset, F::from(1))? + .into() + } else { + // for dummy run throughs + Value::known(F::from(1)).into() + }; + *offset += 1; let mut kernel = Tensor::from(0..kernel_shape.0 * kernel_shape.1).map(|_| unit.clone()); @@ -510,7 +545,7 @@ pub fn sumpool( for i in 0..image_channels { res.push(conv( config, - region, + region.as_deref_mut(), &[values[0].get_slice(&[i..i + 1])?, kernel.clone().into()], padding, stride, @@ -550,7 +585,7 @@ pub fn sumpool( /// Convolution accumulated layout pub fn max_pool2d( config: &mut BaseConfig, - region: &mut Region, + mut region: Option<&mut Region>, values: &[ValTensor; 1], padding: (usize, usize), stride: (usize, usize), @@ -586,7 +621,7 @@ pub fn max_pool2d( rs..(rs + pool_dims.0), cs..(cs + pool_dims.1), ])?; - let max_w = max(config, region, &[slice], offset)?; + let max_w = max(config, region.as_deref_mut(), &[slice], offset)?; let max_w = &max_w.get_inner_tensor()?[0]; output.set(&[i, j, k], max_w.clone()); } @@ -598,7 +633,7 @@ pub fn max_pool2d( if matches!(&config.check_mode, CheckMode::SAFE) { // during key generation this will be 0 so we use this as a flag to check // TODO: this isn't very safe and would be better to get the phase directly - let is_assigned = !Into::>::into(res.clone().get_inner()?) + let is_assigned = !Into::>::into(res.get_inner()?) .iter() .all(|&x| x == 0); if is_assigned { @@ -618,7 +653,7 @@ pub fn max_pool2d( /// Convolution accumulated layout pub fn conv( config: &mut BaseConfig, - region: &mut Region, + region: Option<&mut Region>, values: &[ValTensor], padding: (usize, usize), stride: (usize, usize), @@ -720,7 +755,7 @@ pub fn conv( /// Power accumulated layout pub fn pow( config: &mut BaseConfig, - region: &mut Region, + mut region: Option<&mut Region>, values: &[ValTensor; 1], exponent: u32, offset: &mut usize, @@ -730,7 +765,7 @@ pub fn pow( for _ in 1..exponent { t = pairwise( config, - region, + region.as_deref_mut(), &[t, values[0].clone()], offset, BaseOp::Mult, @@ -762,7 +797,7 @@ pub fn pow( /// Rescaled op accumulated layout pub fn rescale( config: &mut BaseConfig, - region: &mut Region, + mut region: Option<&mut Region>, values: &[ValTensor], scales: &[(usize, usize)], offset: &mut usize, @@ -774,7 +809,7 @@ pub fn rescale( let mult_tensor = Tensor::new(Some(&vec![mult; num_elems]), ri.dims())?; let scaled_input = pairwise( config, - region, + region.as_deref_mut(), &[ri.clone(), mult_tensor.into()], offset, BaseOp::Mult, @@ -807,7 +842,7 @@ pub fn rescale( /// Pack accumulated layout pub fn pack( config: &mut BaseConfig, - region: &mut Region, + mut region: Option<&mut Region>, values: &[ValTensor; 1], base: u32, scale: u32, @@ -831,7 +866,7 @@ pub fn pack( let base_tensor = Tensor::new(Some(&accum_base), &[accum_base.len()])?; let base_prod = pairwise( config, - region, + region.as_deref_mut(), &[t.clone(), base_tensor.into()], offset, BaseOp::Mult, @@ -875,7 +910,7 @@ pub fn reshape( /// Identity constraint. Usually used to constrain an instance column to an advice so the returned cells / values can be operated upon. pub fn identity( config: &mut BaseConfig, - region: &mut Region, + region: Option<&mut Region>, values: &[ValTensor; 1], offset: &mut usize, ) -> Result, Box> { @@ -883,18 +918,24 @@ pub fn identity( *offset += output.len(); - Ok(output.into()) + Ok(output) } /// Scale and shift accumulated layout pub fn scale_and_shift( config: &mut BaseConfig, - region: &mut Region, + mut region: Option<&mut Region>, values: &[ValTensor; 3], offset: &mut usize, ) -> Result, Box> { let (input, kernel, bias) = (values[0].clone(), values[1].clone(), values[2].clone()); - let prod = pairwise(config, region, &[input, kernel], offset, BaseOp::Mult)?; + let prod = pairwise( + config, + region.as_deref_mut(), + &[input, kernel], + offset, + BaseOp::Mult, + )?; let res = pairwise(config, region, &[prod, bias], offset, BaseOp::Add)?; if matches!(&config.check_mode, CheckMode::SAFE) { @@ -927,36 +968,40 @@ pub fn scale_and_shift( /// Layout for range check. pub fn range_check( config: &mut BaseConfig, - region: &mut Region, + mut region: Option<&mut Region>, values: &[ValTensor; 2], offset: &mut usize, tol: i32, ) -> Result, Box> { // assigns the instance to the advice. - config.inputs[1].assign(region, *offset, &values[0])?; + config.inputs[1].assign(region.as_deref_mut(), *offset, &values[0])?; - let output = config.output.assign(region, *offset, &values[1])?; + let output = config + .output + .assign(region.as_deref_mut(), *offset, &values[1])?; - for i in 0..values[0].len() { - let (x, y) = config.inputs[1].cartesian_coord(*offset + i); - config - .selectors - .get(&(BaseOp::Range { tol }, x)) - .unwrap() - .enable(region, y)?; + if let Some(region) = region { + for i in 0..values[0].len() { + let (x, y) = config.inputs[1].cartesian_coord(*offset + i); + config + .selectors + .get(&(BaseOp::Range { tol }, x)) + .unwrap() + .enable(region, y)?; + } } *offset += output.len(); - Ok(output.into()) + Ok(output) } /// pub fn nonlinearity( config: &mut BaseConfig, - region: &mut Region, + mut region: Option<&mut Region>, values: &[ValTensor; 1], - nl: LookupOp, + nl: &LookupOp, offset: &mut usize, ) -> Result, Box> { let region_name = format!("Lookup for {:#?}", nl); @@ -965,7 +1010,9 @@ pub fn nonlinearity( trace!("laying out {}", region_name); - let w = ValTensor::from(config.lookup_input.assign(region, *offset, x)?); + let w = config + .lookup_input + .assign(region.as_deref_mut(), *offset, x)?; // extract integer_valuations let integer_evals: Tensor = w .get_int_evals() @@ -982,39 +1029,38 @@ pub fn nonlinearity( 0 => Tensor::from((0..x.dims().iter().product::()).map(|_| Value::unknown())), // if not empty apply the nonlinearity ! _ => { - let x = nl.f(integer_evals)?; + let x = Op::::f(nl, &[integer_evals])?; x.map(|elem| Value::known(i128_to_felt(elem))) } }; let mut output = config .lookup_output - .assign(region, *offset, &output.into())?; - - println!("lookup_selectors {:?}", config.lookup_selectors); - println!("nk {:?}", nl); + .assign(region.as_deref_mut(), *offset, &output.into())?; - for i in 0..x.len() { - let (x, y) = config.lookup_input.cartesian_coord(*offset + i); - config - .lookup_selectors - .get(&(nl.clone(), x)) - .unwrap() - .enable(region, y)?; + if let Some(region) = region { + for i in 0..x.len() { + let (x, y) = config.lookup_input.cartesian_coord(*offset + i); + config + .lookup_selectors + .get(&(nl.clone(), x)) + .unwrap() + .enable(region, y)?; + } } - output.reshape(x.dims()); + output.reshape(x.dims())?; *offset += x.len(); // constrain the calculated output to a column - Ok(ValTensor::from(output)) + Ok(output) } /// PrElu layout pub fn prelu( config: &mut BaseConfig, - region: &mut Region, + mut region: Option<&mut Region>, values: &[ValTensor; 2], scale: usize, offset: &mut usize, @@ -1032,17 +1078,29 @@ pub fn prelu( let relu = nonlinearity( config, - region, + region.as_deref_mut(), &[values[0].clone()], - LookupOp::ReLU { scale }, + &LookupOp::ReLU { scale }, offset, )?; // -x - let neg_x = neg(config, region, &[values[0].clone()], offset)?; + let neg_x = neg(config, region.as_deref_mut(), &[values[0].clone()], offset)?; // relu(-x) - let relu_neg_x = nonlinearity(config, region, &[neg_x], LookupOp::ReLU { scale }, offset)?; + let relu_neg_x = nonlinearity( + config, + region.as_deref_mut(), + &[neg_x], + &LookupOp::ReLU { scale }, + offset, + )?; // relu(-x) * slope - let scaled_relu_neg_x = pairwise(config, region, &[relu_neg_x, slopes], offset, BaseOp::Mult)?; + let scaled_relu_neg_x = pairwise( + config, + region.as_deref_mut(), + &[relu_neg_x, slopes], + offset, + BaseOp::Mult, + )?; let prelu = pairwise( config, @@ -1085,47 +1143,93 @@ pub fn prelu( /// mean function layout pub fn mean( config: &mut BaseConfig, - region: &mut Region, + mut region: Option<&mut Region>, + values: &[ValTensor; 1], + scale: usize, + offset: &mut usize, +) -> Result, Box> { + let x = &values[0]; + + let sum_x = sum(config, region.as_deref_mut(), &[x.clone()], offset)?; + let nl = LookupOp::Div { + denom: utils::F32((scale * x.len()) as f32), + }; + nonlinearity(config, region, &[sum_x], &nl, offset) +} + +/// variance function layout +pub fn variance( + config: &mut BaseConfig, + mut region: Option<&mut Region>, values: &[ValTensor; 1], scale: usize, offset: &mut usize, ) -> Result, Box> { let x = &values[0]; - let sum_x = sum(config, region, &[x.clone()], offset)?; + let mean = mean(config, region.as_deref_mut(), &[x.clone()], scale, offset)?; + + let sub = pairwise( + config, + region.as_deref_mut(), + &[x.clone(), mean], + offset, + BaseOp::Sub, + )?; + + let square = pairwise( + config, + region.as_deref_mut(), + &[sub.clone(), sub], + offset, + BaseOp::Mult, + )?; + + let sum_square = sum(config, region.as_deref_mut(), &[square], offset)?; + + // biased estimator let nl = LookupOp::Div { denom: utils::F32((scale * x.len()) as f32), }; - nonlinearity(config, region, &[sum_x], nl, offset) + + let variance = nonlinearity(config, region, &[sum_square], &nl, offset)?; + + Ok(variance) } /// max layout pub fn max( config: &mut BaseConfig, - region: &mut Region, + mut region: Option<&mut Region>, values: &[ValTensor; 1], offset: &mut usize, ) -> Result, Box> { // this is safe because we later constrain it let max_int = values[0].get_int_evals()?.into_iter().max(); let max_val: ValTensor = match max_int { - None => Tensor::new(Some(&vec![Value::::unknown()]), &[1])?.into(), - Some(i) => Tensor::new(Some(&vec![Value::known(i128_to_felt::(i))]), &[1])?.into(), + None => Tensor::new(Some(&[Value::::unknown()]), &[1])?.into(), + Some(i) => Tensor::new(Some(&[Value::known(i128_to_felt::(i))]), &[1])?.into(), }; - let assigned_max_val: ValTensor = config.inputs[1].assign(region, *offset, &max_val)?.into(); + let assigned_max_val: ValTensor = + config.inputs[1].assign(region.as_deref_mut(), *offset, &max_val)?; *offset += 1; - let unit: ValTensor = Tensor::from( - vec![config.inputs[1].assign_constant(region, *offset, F::from(1))?].into_iter(), - ) - .into(); + let unit: ValTensor = if let Some(region) = region.as_deref_mut() { + Tensor::from( + vec![config.inputs[1].assign_constant(region, *offset, F::from(1))?].into_iter(), + ) + .into() + } else { + // for dummy run throughs + Tensor::from(vec![Value::known(F::from(1))].into_iter()).into() + }; *offset += 1; // max(x - 1) let max_minus_1 = pairwise( config, - region, + region.as_deref_mut(), &[assigned_max_val.clone(), unit.clone()], offset, BaseOp::Sub, @@ -1134,56 +1238,66 @@ pub fn max( // x - max(x - 1) let diff = pairwise( config, - region, - &[values[0].clone(), max_minus_1.clone()], + region.as_deref_mut(), + &[values[0].clone(), max_minus_1], offset, BaseOp::Sub, )?; // relu(x - max(x - 1)) - let relu = nonlinearity(config, region, &[diff], LookupOp::ReLU { scale: 1 }, offset)?; + let relu = nonlinearity( + config, + region.as_deref_mut(), + &[diff], + &LookupOp::ReLU { scale: 1 }, + offset, + )?; let len = relu.dims().iter().product(); // y_i*(1 - y_i) =0 // assert the values are either 0 or 1 - config.inputs[1].assign(region, *offset, &relu)?; - for i in 0..len { - let (x, y) = config.output.cartesian_coord(*offset + i); - config - .selectors - .get(&(BaseOp::IsBoolean, x)) - .unwrap() - .enable(region, y)?; + config.inputs[1].assign(region.as_deref_mut(), *offset, &relu)?; + if let Some(region) = region.as_deref_mut() { + for i in 0..len { + let (x, y) = config.output.cartesian_coord(*offset + i); + config + .selectors + .get(&(BaseOp::IsBoolean, x)) + .unwrap() + .enable(region, y)?; + } } *offset += len; // sum(relu(x - max(x - 1))) - let sum_relu = sum(config, region, &[relu], offset)?; + let sum_relu = sum(config, region.as_deref_mut(), &[relu], offset)?; // 1 - sum(relu(x - max(x - 1))) let one_minus_sum_relu = pairwise( config, - region, - &[unit.clone(), sum_relu.clone()], + region.as_deref_mut(), + &[unit, sum_relu], offset, BaseOp::Sub, )?; // relu(1 - sum(relu(x - max(x - 1)))) let relu_one_minus_sum_relu = nonlinearity( config, - region, + region.as_deref_mut(), &[one_minus_sum_relu], - LookupOp::ReLU { scale: 1 }, + &LookupOp::ReLU { scale: 1 }, offset, )?; // constraining relu(sum(relu(x - max(x - 1)) - len(x))) = 0 - config.inputs[1].assign(region, *offset, &relu_one_minus_sum_relu)?; - - let (x, y) = config.output.cartesian_coord(*offset); - config - .selectors - .get(&(BaseOp::IsZero, x)) - .unwrap() - .enable(region, y)?; + config.inputs[1].assign(region.as_deref_mut(), *offset, &relu_one_minus_sum_relu)?; + + if let Some(region) = region { + let (x, y) = config.output.cartesian_coord(*offset); + config + .selectors + .get(&(BaseOp::IsZero, x)) + .unwrap() + .enable(region, y)?; + } *offset += relu_one_minus_sum_relu.len(); if matches!(&config.check_mode, CheckMode::SAFE) { @@ -1207,7 +1321,7 @@ pub fn max( /// min layout pub fn min( config: &mut BaseConfig, - region: &mut Region, + mut region: Option<&mut Region>, values: &[ValTensor; 1], offset: &mut usize, ) -> Result, Box> { @@ -1215,23 +1329,29 @@ pub fn min( let min_int = values[0].get_int_evals()?.into_iter().min(); let min_val: ValTensor = match min_int { - None => Tensor::new(Some(&vec![Value::::unknown()]), &[1])?.into(), - Some(i) => Tensor::new(Some(&vec![Value::known(i128_to_felt::(i))]), &[1])?.into(), + None => Tensor::new(Some(&[Value::::unknown()]), &[1])?.into(), + Some(i) => Tensor::new(Some(&[Value::known(i128_to_felt::(i))]), &[1])?.into(), }; - let assigned_min_val: ValTensor = config.inputs[1].assign(region, *offset, &min_val)?.into(); + let assigned_min_val: ValTensor = + config.inputs[1].assign(region.as_deref_mut(), *offset, &min_val)?; *offset += 1; - let unit: ValTensor = Tensor::from( - vec![config.inputs[1].assign_constant(region, *offset, F::from(1))?].into_iter(), - ) - .into(); + let unit: ValTensor = if let Some(region) = region.as_deref_mut() { + Tensor::from( + vec![config.inputs[1].assign_constant(region, *offset, F::from(1))?].into_iter(), + ) + .into() + } else { + // for dummy run throughs + Tensor::from(vec![Value::known(F::from(1))].into_iter()).into() + }; *offset += 1; // min(x + 1) let min_plus_1 = pairwise( config, - region, + region.as_deref_mut(), &[assigned_min_val.clone(), unit.clone()], offset, BaseOp::Add, @@ -1240,58 +1360,68 @@ pub fn min( // min(x + 1) - x let diff = pairwise( config, - region, - &[min_plus_1.clone(), values[0].clone()], + region.as_deref_mut(), + &[min_plus_1, values[0].clone()], offset, BaseOp::Sub, )?; // relu(min(x + 1) - x) - let relu = nonlinearity(config, region, &[diff], LookupOp::ReLU { scale: 1 }, offset)?; + let relu = nonlinearity( + config, + region.as_deref_mut(), + &[diff], + &LookupOp::ReLU { scale: 1 }, + offset, + )?; let len = relu.dims().iter().product(); // y_i*(1 - y_i) =0 // assert the values are either 0 or 1 - config.inputs[1].assign(region, *offset, &relu)?; - for i in 0..len { - let (x, y) = config.output.cartesian_coord(*offset + i); - config - .selectors - .get(&(BaseOp::IsBoolean, x)) - .unwrap() - .enable(region, y)?; + config.inputs[1].assign(region.as_deref_mut(), *offset, &relu)?; + if let Some(region) = region.as_deref_mut() { + for i in 0..len { + let (x, y) = config.output.cartesian_coord(*offset + i); + config + .selectors + .get(&(BaseOp::IsBoolean, x)) + .unwrap() + .enable(region, y)?; + } } *offset += len; // sum(relu(min(x + 1) - x)) - let sum_relu = sum(config, region, &[relu], offset)?; + let sum_relu = sum(config, region.as_deref_mut(), &[relu], offset)?; // 1 - sum(relu(min(x + 1) - x)) let one_minus_sum_relu = pairwise( config, - region, - &[unit.into(), sum_relu.clone()], + region.as_deref_mut(), + &[unit, sum_relu], offset, BaseOp::Sub, )?; // relu(1 - sum(relu(min(x + 1) - x))) let relu_one_minus_sum_relu = nonlinearity( config, - region, + region.as_deref_mut(), &[one_minus_sum_relu], - LookupOp::ReLU { scale: 1 }, + &LookupOp::ReLU { scale: 1 }, offset, )?; // constraining product to 0 - config.inputs[1].assign(region, *offset, &relu_one_minus_sum_relu)?; - - let (x, y) = config.output.cartesian_coord(*offset); - config - .selectors - .get(&(BaseOp::IsZero, x)) - .unwrap() - .enable(region, y)?; + config.inputs[1].assign(region.as_deref_mut(), *offset, &relu_one_minus_sum_relu)?; + + if let Some(region) = region { + let (x, y) = config.output.cartesian_coord(*offset); + config + .selectors + .get(&(BaseOp::IsZero, x)) + .unwrap() + .enable(region, y)?; + } *offset += relu_one_minus_sum_relu.len(); if matches!(&config.check_mode, CheckMode::SAFE) { @@ -1309,5 +1439,143 @@ pub fn min( assert_eq!(Into::>::into(min_val.get_inner()?), ref_min,) } }; - Ok(assigned_min_val.into()) + Ok(assigned_min_val) +} + +/// +pub fn instance_norm( + config: &mut BaseConfig, + mut region: Option<&mut Region>, + values: &[ValTensor; 3], + scale: usize, + epsilon: u64, + offset: &mut usize, +) -> Result, Box> { + let gamma = values[1].clone(); + let beta = values[2].clone(); + + if gamma.len() != values[0].dims()[0] { + return Err("gamma and x channels must have the same length".into()); + }; + if beta.len() != values[0].dims()[0] { + return Err("beta and x channels must have the same length".into()); + }; + + let mut channel_norms = vec![]; + + // iterate over inner channel + for i in 0..values[0].dims()[0] { + let x = values[0].get_slice(&[i..i + 1])?; + let mean = mean(config, region.as_deref_mut(), &[x.clone()], scale, offset)?; + let variance = variance(config, region.as_deref_mut(), &[x.clone()], scale, offset)?; + + let numerator = pairwise( + config, + region.as_deref_mut(), + &[x.clone(), mean.clone()], + offset, + BaseOp::Sub, + )?; + + let denominator = pairwise( + config, + region.as_deref_mut(), + &[ + variance.clone(), + // TODO: should we make this a constant ? doesn't matter I think + Tensor::from(vec![F::from(epsilon)].into_iter()).into(), + ], + offset, + BaseOp::Add, + )?; + + let denominator = nonlinearity( + config, + region.as_deref_mut(), + &[denominator], + &LookupOp::Sqrt { + scales: (scale, scale), + }, + offset, + )?; + + let product = pairwise( + config, + region.as_deref_mut(), + &[numerator.clone(), denominator.clone()], + offset, + BaseOp::Mult, + )?; + + let numerator_evals = numerator.get_int_evals()?; + let result = numerator_evals + .iter() + .zip(denominator.get_int_evals()?) + .map(|(x, y)| F::from((x / y) as u64)); + + let result = Tensor::from(result).into(); + + // constraining product to 0 + let result = config.inputs[1].assign(region.as_deref_mut(), *offset, &result)?; + config + .output + .assign(region.as_deref_mut(), *offset, &product)?; + + if let Some(region) = region.as_deref_mut() { + let (x, y) = config.output.cartesian_coord(*offset); + config + .selectors + .get(&(BaseOp::Identity, x)) + .unwrap() + .enable(region, y)?; + } + *offset += result.len(); + + let scaled_fraction = pairwise( + config, + region.as_deref_mut(), + &[result, gamma.clone()], + offset, + BaseOp::Mult, + )?; + + let instance_norm = pairwise( + config, + region.as_deref_mut(), + &[scaled_fraction, beta.clone()], + offset, + BaseOp::Add, + )?; + + channel_norms.push(instance_norm.get_inner_tensor()?); + } + + let mut instance_norm = Tensor::from(channel_norms.into_iter()).combine()?; + instance_norm.reshape(values[0].dims()); + let instance_norm: ValTensor = instance_norm.into(); + + if matches!(&config.check_mode, CheckMode::SAFE) { + // during key generation this will be 0 so we use this as a flag to check + // TODO: this isn't very safe and would be better to get the phase directly + let is_assigned = !Into::>::into(instance_norm.get_inner()?) + .iter() + .all(|&x| x == 0); + if is_assigned { + let ref_instance_norm: Tensor = ref_instance_norm( + [ + Tensor::from(values[0].get_int_evals()?.into_iter()), + Tensor::from(gamma.get_int_evals()?.into_iter()), + Tensor::from(beta.get_int_evals()?.into_iter()), + ], + epsilon as f32, + ) + .map(|x| x as i32); + assert_eq!( + Into::>::into(instance_norm.get_inner()?), + ref_instance_norm, + ) + } + }; + + Ok(instance_norm) } diff --git a/src/circuit/ops/lookup.rs b/src/circuit/ops/lookup.rs new file mode 100644 index 000000000..3e960fc52 --- /dev/null +++ b/src/circuit/ops/lookup.rs @@ -0,0 +1,146 @@ +use super::*; +use halo2_proofs::circuit::Region; +use halo2curves::FieldExt; +use serde::{Deserialize, Serialize}; +use std::error::Error; + +use crate::{ + circuit::{layouts, utils}, + fieldutils::i128_to_felt, + graph::scale_to_multiplier, + tensor::{self, Tensor, TensorError, TensorType}, +}; + +use super::Op; + +#[allow(missing_docs)] +/// An enum representing the operations that can be used to express more complex operations via accumulation +#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Deserialize, Serialize)] +pub enum LookupOp { + Div { denom: utils::F32 }, + ReLU { scale: usize }, + Sqrt { scales: (usize, usize) }, + LeakyReLU { scale: usize, slope: utils::F32 }, + Sigmoid { scales: (usize, usize) }, + Tanh { scales: (usize, usize) }, + Erf { scales: (usize, usize) }, +} + +impl LookupOp { + /// a value which is always in the table + pub fn default_pair(&self) -> (F, F) { + let x = vec![0_i128].into_iter().into(); + ( + ::zero().unwrap(), + i128_to_felt(Op::::f(self, &[x]).unwrap()[0]), + ) + } +} + +impl Op for LookupOp { + /// Matches a [Op] to an operation in the `tensor::ops` module. + fn f(&self, x: &[Tensor]) -> Result, TensorError> { + match &self { + LookupOp::Div { denom } => Ok(tensor::ops::nonlinearities::const_div( + &x[0], + f32::from(*denom), + )), + LookupOp::ReLU { scale } => { + Ok(tensor::ops::nonlinearities::leakyrelu(&x[0], *scale, 0_f32)) + } + LookupOp::LeakyReLU { scale, slope } => Ok(tensor::ops::nonlinearities::leakyrelu( + &x[0], *scale, slope.0, + )), + LookupOp::Sigmoid { scales } => Ok(tensor::ops::nonlinearities::sigmoid( + &x[0], scales.0, scales.1, + )), + LookupOp::Sqrt { scales } => { + Ok(tensor::ops::nonlinearities::sqrt(&x[0], scales.0, scales.1)) + } + LookupOp::Tanh { scales } => { + Ok(tensor::ops::nonlinearities::tanh(&x[0], scales.0, scales.1)) + } + LookupOp::Erf { scales } => Ok(tensor::ops::nonlinearities::erffunc( + &x[0], scales.0, scales.1, + )), + } + } + + /// Returns the name of the operation + fn as_str(&self) -> &'static str { + match self { + LookupOp::Div { .. } => "DIV", + LookupOp::ReLU { .. } => "RELU", + LookupOp::LeakyReLU { .. } => "LEAKY_RELU", + LookupOp::Sigmoid { .. } => "SIGMOID", + LookupOp::Sqrt { .. } => "SQRT", + LookupOp::Tanh { .. } => "TANH", + LookupOp::Erf { .. } => "ERF", + } + } + + fn layout( + &self, + config: &mut crate::circuit::BaseConfig, + region: Option<&mut Region>, + values: &[ValTensor], + offset: &mut usize, + ) -> Result>, Box> { + Ok(Some(layouts::nonlinearity( + config, + region, + values[..].try_into()?, + self, + offset, + )?)) + } + + fn rescale(&self, inputs_scale: Vec, global_scale: u32) -> Box> { + match self { + LookupOp::Div { denom } => Box::new(LookupOp::Div { + denom: crate::circuit::utils::F32( + denom.0 * scale_to_multiplier(inputs_scale[0] - global_scale), + ), + }), + LookupOp::ReLU { .. } => Box::new(LookupOp::ReLU { + scale: scale_to_multiplier(inputs_scale[0] - global_scale) as usize, + }), + LookupOp::LeakyReLU { slope, .. } => Box::new(LookupOp::LeakyReLU { + scale: scale_to_multiplier(inputs_scale[0] - global_scale) as usize, + slope: *slope, + }), + LookupOp::Sigmoid { .. } => Box::new(LookupOp::Sigmoid { + scales: ( + scale_to_multiplier(inputs_scale[0]) as usize, + scale_to_multiplier(global_scale) as usize, + ), + }), + LookupOp::Sqrt { .. } => Box::new(LookupOp::Sqrt { + scales: ( + scale_to_multiplier(inputs_scale[0]) as usize, + scale_to_multiplier(global_scale) as usize, + ), + }), + LookupOp::Tanh { .. } => Box::new(LookupOp::Tanh { + scales: ( + scale_to_multiplier(inputs_scale[0]) as usize, + scale_to_multiplier(global_scale) as usize, + ), + }), + LookupOp::Erf { .. } => Box::new(LookupOp::Erf { + scales: ( + scale_to_multiplier(inputs_scale[0]) as usize, + scale_to_multiplier(global_scale) as usize, + ), + }), + } + } + + fn required_lookup(&self) -> Option { + Some(self.clone()) + } + + fn clone_dyn(&self) -> Box> { + Box::new(self.clone()) // Forward to the derive(Clone) impl + } +} diff --git a/src/circuit/ops/mod.rs b/src/circuit/ops/mod.rs new file mode 100644 index 000000000..d2c77223f --- /dev/null +++ b/src/circuit/ops/mod.rs @@ -0,0 +1,216 @@ +use std::error::Error; + +use halo2_proofs::circuit::Region; +use halo2curves::FieldExt; +use serde::Serialize; + +use crate::tensor::{Tensor, TensorError, TensorType, ValTensor}; + +use self::{lookup::LookupOp, poly::PolyOp}; + +/// +pub mod base; +/// +pub mod hybrid; +/// Layouts for specific functions (composed of base ops) +pub mod layouts; +/// +pub mod lookup; +/// +pub mod poly; + +/// +pub trait Op: std::fmt::Debug + Send + Sync { + /// + fn f(&self, x: &[Tensor]) -> Result, TensorError>; + /// + fn as_str(&self) -> &'static str; + + /// + fn layout( + &self, + config: &mut crate::circuit::BaseConfig, + region: Option<&mut Region>, + values: &[ValTensor], + offset: &mut usize, + ) -> Result>, Box>; + + /// + fn out_scale(&self, _: Vec, global_scale: u32) -> u32 { + global_scale + } + + /// + fn out_dims(&self, in_dims: Vec>) -> Vec { + in_dims[0].clone() + } + + /// + fn has_3d_input(&self) -> bool { + false + } + + /// + fn requires_homogenous_input_scales(&self) -> bool { + false + } + + /// + fn required_poly(&self) -> Option { + None + } + + /// + fn required_lookup(&self) -> Option { + None + } + + /// + fn rescale(&self, inputs_scale: Vec, global_scale: u32) -> Box>; + + /// + fn is_input(&self) -> bool { + false + } + + /// + fn is_const(&self) -> bool { + false + } + + /// + fn const_value(&self) -> Option> { + None + } + + /// + fn raw_const_value(&self) -> Option> { + None + } + + /// bias variable index (if any) + fn bias_variable(&self) -> Option { + None + } + + /// + fn clone_dyn(&self) -> Box>; +} + +impl Clone for Box> { + fn clone(&self) -> Self { + self.clone_dyn() + } +} + +/// +#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize)] +pub struct Input; + +impl Op for Input { + fn f(&self, x: &[Tensor]) -> Result, TensorError> { + Ok(x[0].clone()) + } + + fn as_str(&self) -> &'static str { + "Input" + } + fn layout( + &self, + _: &mut crate::circuit::BaseConfig, + _: Option<&mut Region>, + _: &[ValTensor], + _: &mut usize, + ) -> Result>, Box> { + Ok(None) + } + + fn rescale(&self, _: Vec, _: u32) -> Box> { + Box::new(self.clone()) + } + + fn is_input(&self) -> bool { + true + } + + fn clone_dyn(&self) -> Box> { + Box::new(self.clone()) // Forward to the derive(Clone) impl + } +} + +/// +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize)] +pub struct Const { + /// The quantized constants potentially associated with this self. + pub const_value: Tensor, + /// The un-quantized constants potentially associated with this self. + pub raw_const_value: Option>, +} + +impl Op for Const { + fn f(&self, _: &[Tensor]) -> Result, TensorError> { + Ok(self.const_value.clone()) + } + + fn as_str(&self) -> &'static str { + "Const" + } + fn layout( + &self, + _: &mut crate::circuit::BaseConfig, + _: Option<&mut Region>, + _: &[ValTensor], + _: &mut usize, + ) -> Result>, Box> { + Ok(None) + } + fn rescale(&self, _: Vec, _: u32) -> Box> { + Box::new(self.clone()) + } + + fn is_const(&self) -> bool { + true + } + + fn const_value(&self) -> Option> { + Some(self.const_value.clone()) + } + + fn raw_const_value(&self) -> Option> { + self.raw_const_value.clone() + } + + fn clone_dyn(&self) -> Box> { + Box::new(self.clone()) // Forward to the derive(Clone) impl + } +} + +/// +#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize)] +pub struct Unknown; + +impl Op for Unknown { + fn f(&self, _: &[Tensor]) -> Result, TensorError> { + Err(TensorError::WrongMethod) + } + + fn as_str(&self) -> &'static str { + "Unknown" + } + fn layout( + &self, + _: &mut crate::circuit::BaseConfig, + _: Option<&mut Region>, + _: &[ValTensor], + _: &mut usize, + ) -> Result>, Box> { + Err(Box::new(super::CircuitError::UnsupportedOp)) + } + fn rescale(&self, _: Vec, _: u32) -> Box> { + Box::new(self.clone()) + } + + fn clone_dyn(&self) -> Box> { + Box::new(self.clone()) // Forward to the derive(Clone) impl + } +} diff --git a/src/circuit/ops/poly.rs b/src/circuit/ops/poly.rs new file mode 100644 index 000000000..32d0209c5 --- /dev/null +++ b/src/circuit/ops/poly.rs @@ -0,0 +1,416 @@ +use std::fmt; + +use itertools::Itertools; +use serde::{Deserialize, Serialize}; + +use crate::{ + circuit::layouts, + tensor::{self, Tensor, TensorError}, +}; + +use super::{base::BaseOp, *}; + +#[allow(missing_docs)] +/// An enum representing the operations that can be used to express more complex operations via accumulation +#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Deserialize, Serialize)] +pub enum PolyOp { + Dot, + Matmul, + Affine, + Conv { + padding: (usize, usize), + stride: (usize, usize), + }, + SumPool { + padding: (usize, usize), + stride: (usize, usize), + kernel_shape: (usize, usize), + }, + Add, + Sub, + Mult, + Identity, + Reshape(Vec), + Flatten(Vec), + BatchNorm, + ScaleAndShift, + Pad(usize, usize), + Sum, + Pow(u32), + Pack(u32, u32), + GlobalSumPool, + Rescaled { + inner: Box, + scale: Vec<(usize, usize)>, + }, + RangeCheck(i32), +} + +impl Op for PolyOp { + fn as_str(&self) -> &'static str { + match &self { + PolyOp::Identity => "IDENTITY", + PolyOp::Reshape(_) => "RESHAPE", + PolyOp::Flatten(_) => "FLATTEN", + PolyOp::Pad(_, _) => "PAD", + PolyOp::Add => "ADD", + PolyOp::Mult => "MULT", + PolyOp::Sub => "SUB", + PolyOp::Sum => "SUM", + PolyOp::Dot => "DOT", + PolyOp::Pow(_) => "POW", + PolyOp::Pack(_, _) => "PACK", + PolyOp::GlobalSumPool => "GLOBALSUMPOOL", + PolyOp::ScaleAndShift => "SCALESHIFT", + PolyOp::BatchNorm => "BATCHNORM", + PolyOp::Conv { .. } => "CONV", + PolyOp::SumPool { .. } => "SUMPOOL", + PolyOp::Affine => "AFFINE", + PolyOp::Matmul => "MATMUL", + PolyOp::Rescaled { inner, .. } => Op::::as_str(&**inner), + PolyOp::RangeCheck(..) => "RANGECHECK", + } + } + + /// Matches a [Op] to an operation in the `tensor::ops` module. + fn f(&self, inputs: &[Tensor]) -> Result, TensorError> { + match &self { + PolyOp::Identity => Ok(inputs[0].clone()), + PolyOp::Reshape(new_dims) => { + let mut t = inputs[0].clone(); + t.reshape(new_dims); + Ok(t) + } + PolyOp::Flatten(new_dims) => { + let mut t = inputs[0].clone(); + t.reshape(new_dims); + Ok(t) + } + PolyOp::Pad(dim1, dim2) => { + if 1 != inputs.len() { + return Err(TensorError::DimMismatch("pad inputs".to_string())); + } + tensor::ops::pad(&inputs[0], (*dim1, *dim2)) + } + PolyOp::Add => tensor::ops::add(inputs), + PolyOp::Sub => tensor::ops::sub(inputs), + PolyOp::Mult => tensor::ops::mult(inputs), + PolyOp::Affine => tensor::ops::affine(inputs), + PolyOp::BatchNorm => tensor::ops::scale_and_shift(inputs), + PolyOp::ScaleAndShift => tensor::ops::scale_and_shift(inputs), + PolyOp::Matmul => tensor::ops::matmul(inputs), + PolyOp::Dot => tensor::ops::dot(&inputs.iter().collect()), + PolyOp::Conv { padding, stride } => tensor::ops::convolution(inputs, *padding, *stride), + PolyOp::SumPool { + padding, + stride, + kernel_shape, + } => tensor::ops::sumpool(&inputs[0], *padding, *stride, *kernel_shape), + PolyOp::Pack(base, scale) => { + if 1 != inputs.len() { + return Err(TensorError::DimMismatch("pack inputs".to_string())); + } + + tensor::ops::pack(&inputs[0], *base as i128, *scale) + } + PolyOp::Pow(u) => { + if 1 != inputs.len() { + return Err(TensorError::DimMismatch("pow inputs".to_string())); + } + inputs[0].pow(*u) + } + PolyOp::Sum => { + if 1 != inputs.len() { + return Err(TensorError::DimMismatch("sum inputs".to_string())); + } + tensor::ops::sum(&inputs[0]) + } + PolyOp::Rescaled { inner, scale } => { + if scale.len() != inputs.len() { + return Err(TensorError::DimMismatch("rescaled inputs".to_string())); + } + + let mut rescaled_inputs = vec![]; + let inputs = &mut inputs.to_vec(); + for (i, ri) in inputs.iter_mut().enumerate() { + rescaled_inputs.push(tensor::ops::rescale(ri, scale[i].1)?); + } + Ok(Op::::f(&**inner, &rescaled_inputs)?) + } + PolyOp::GlobalSumPool => unreachable!(), + PolyOp::RangeCheck(..) => Ok(inputs[0].clone()), + } + } + + fn layout( + &self, + config: &mut crate::circuit::BaseConfig, + mut region: Option<&mut Region>, + values: &[ValTensor], + offset: &mut usize, + ) -> Result>, Box> { + Ok(Some(match self { + PolyOp::Dot => layouts::dot(config, region, values[..].try_into()?, offset)?, + PolyOp::Sum => layouts::sum(config, region, values[..].try_into()?, offset)?, + PolyOp::Matmul => layouts::matmul(config, region, values[..].try_into()?, offset)?, + PolyOp::Affine => layouts::affine(config, region, values[..].try_into()?, offset)?, + PolyOp::Conv { padding, stride } => layouts::conv( + config, + region, + values[..].try_into()?, + *padding, + *stride, + offset, + )?, + PolyOp::SumPool { + padding, + stride, + kernel_shape, + } => layouts::sumpool( + config, + region, + values[..].try_into()?, + *padding, + *stride, + *kernel_shape, + offset, + )?, + PolyOp::Add => { + layouts::pairwise(config, region, values[..].try_into()?, offset, BaseOp::Add)? + } + PolyOp::Sub => { + layouts::pairwise(config, region, values[..].try_into()?, offset, BaseOp::Sub)? + } + PolyOp::Mult => { + layouts::pairwise(config, region, values[..].try_into()?, offset, BaseOp::Mult)? + } + PolyOp::Identity => layouts::identity(config, region, values[..].try_into()?, offset)?, + PolyOp::Reshape(d) | PolyOp::Flatten(d) => layouts::reshape(values[..].try_into()?, d)?, + PolyOp::BatchNorm => { + layouts::scale_and_shift(config, region, values[..].try_into()?, offset)? + } + PolyOp::ScaleAndShift => { + layouts::scale_and_shift(config, region, values[..].try_into()?, offset)? + } + PolyOp::Pad(p1, p2) => { + if values.len() != 1 { + return Err(Box::new(TensorError::DimError)); + } + let mut input = values[0].clone(); + input.pad((*p1, *p2))?; + input + } + PolyOp::Pow(exp) => layouts::pow(config, region, values[..].try_into()?, *exp, offset)?, + PolyOp::Pack(base, scale) => layouts::pack( + config, + region, + values[..].try_into()?, + *base, + *scale, + offset, + )?, + PolyOp::Rescaled { inner, scale } => { + if scale.len() != values.len() { + return Err(Box::new(TensorError::DimMismatch( + "rescaled inputs".to_string(), + ))); + } + + let res = &layouts::rescale( + config, + region.as_deref_mut(), + values[..].try_into()?, + scale, + offset, + )?[..]; + inner.layout(config, region, res, offset)?.unwrap() + } + PolyOp::RangeCheck(tol) => { + layouts::range_check(config, region, values[..].try_into()?, offset, *tol)? + } + PolyOp::GlobalSumPool => unreachable!(), + })) + } + + fn out_scale(&self, in_scales: Vec, _g: u32) -> u32 { + match self { + PolyOp::Dot => in_scales[0] + in_scales[1], + PolyOp::Sum => in_scales[0], + PolyOp::Matmul => in_scales[0] + in_scales[1], + PolyOp::Affine => in_scales[0] + in_scales[1], + PolyOp::Conv { .. } => in_scales[0] + in_scales[1], + PolyOp::SumPool { .. } => in_scales[0], + PolyOp::Add => in_scales[0], + PolyOp::Sub => in_scales[0], + PolyOp::Mult => in_scales[0] + in_scales[1], + PolyOp::Identity => in_scales[0], + PolyOp::Reshape(_) | PolyOp::Flatten(_) => in_scales[0], + PolyOp::BatchNorm => 2 * in_scales[0], + PolyOp::ScaleAndShift => 2 * in_scales[0], + PolyOp::Pad(_, _) => in_scales[0], + PolyOp::Pow(pow) => in_scales[0] * (*pow), + PolyOp::Pack(_, _) => in_scales[0], + PolyOp::Rescaled { inner, .. } => Op::::out_scale(&**inner, in_scales, _g), + PolyOp::RangeCheck(_) => in_scales[0], + PolyOp::GlobalSumPool => in_scales[0], + } + } + + fn out_dims(&self, in_dims: Vec>) -> Vec { + match self { + PolyOp::Dot => vec![1], + PolyOp::Sum => vec![1], + PolyOp::Matmul => { + let a_dims = in_dims[0].clone(); + let b_dims = in_dims[1].clone(); + let mut dims = Vec::from(&a_dims[0..a_dims.len() - 2]); + dims.push(a_dims[a_dims.len() - 2]); + dims.push(b_dims[a_dims.len() - 1]); + dims + } + PolyOp::Affine => { + let weight_node = &in_dims[1]; + let out_dim = weight_node.clone()[0]; + vec![out_dim] + } + PolyOp::Conv { padding, stride } => { + let oihw = in_dims[1].clone(); + let (out_channels, _, kernel_height, kernel_width) = + (oihw[0], oihw[1], oihw[2], oihw[3]); + + let (padding_h, padding_w, stride_h, stride_w) = + (padding.0, padding.1, stride.0, stride.1); + + println!("in_dims: {:?}", in_dims); + + let input_height = in_dims[0][1]; + let input_width = in_dims[0][2]; + + let out_height = (input_height + 2 * padding_h - kernel_height) / stride_h + 1; + let out_width = (input_width + 2 * padding_w - kernel_width) / stride_w + 1; + + vec![out_channels, out_height, out_width] + } + PolyOp::SumPool { + padding, + stride, + kernel_shape, + } => { + let (input_channels, kernel_height, kernel_width) = + (in_dims[0][0], kernel_shape.0, kernel_shape.1); + + let (padding_h, padding_w, stride_h, stride_w) = + (padding.0, padding.1, stride.0, stride.1); + + let input_height = in_dims[0][1]; + let input_width = in_dims[0][2]; + + let out_height = (input_height + 2 * padding_h - kernel_height) / stride_h + 1; + let out_width = (input_width + 2 * padding_w - kernel_width) / stride_w + 1; + + vec![input_channels, out_height, out_width] + } + PolyOp::Add => in_dims[0].clone(), + PolyOp::Sub => in_dims[0].clone(), + PolyOp::Mult => in_dims[0].clone(), + PolyOp::Identity => in_dims[0].clone(), + PolyOp::Reshape(d) | PolyOp::Flatten(d) => d.clone(), + PolyOp::BatchNorm => in_dims[0].clone(), + PolyOp::ScaleAndShift => in_dims[0].clone(), + PolyOp::Pad(padding_h, padding_w) => { + let input_channels = in_dims[0][0]; + + let out_height = in_dims[0][1] + 2 * padding_h; + let out_width = in_dims[0][2] + 2 * padding_w; + vec![input_channels, out_height, out_width] + } + PolyOp::Pow(_) => in_dims[0].clone(), + PolyOp::Pack(_, _) => vec![1], + PolyOp::Rescaled { inner, .. } => Op::::out_dims(&**inner, in_dims), + PolyOp::RangeCheck(_) => in_dims[0].clone(), + PolyOp::GlobalSumPool => { + let input_channels = in_dims[0][0]; + vec![input_channels, 1, 1] + } + } + } + + fn has_3d_input(&self) -> bool { + matches!( + self, + PolyOp::Conv { .. } + | PolyOp::SumPool { .. } + | PolyOp::GlobalSumPool + | PolyOp::Pad { .. } + ) + } + + fn rescale(&self, _: Vec, _: u32) -> Box> { + Box::new(self.clone()) + } + + fn requires_homogenous_input_scales(&self) -> bool { + matches!(self, PolyOp::Add | PolyOp::Sub) + } + + fn bias_variable(&self) -> Option { + match self { + PolyOp::Affine | PolyOp::ScaleAndShift | PolyOp::Conv { .. } => Some(2), + _ => None, + } + } + + fn required_poly(&self) -> Option { + Some(self.clone()) + } + + fn clone_dyn(&self) -> Box> { + Box::new(self.clone()) // Forward to the derive(Clone) impl + } +} + +impl fmt::Display for PolyOp { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + PolyOp::Identity => write!(f, "identity"), + PolyOp::Reshape(new_dims) => write!(f, "reshape to {:?}", new_dims), + PolyOp::Flatten(new_dims) => write!(f, "flatten to {:?}", new_dims), + PolyOp::Pad(dim1, dim2) => write!(f, "padding: ({:?}, {:?})", dim1, dim2), + PolyOp::Add => write!(f, "add"), + PolyOp::Sub => write!(f, "sub"), + PolyOp::Sum => write!(f, "sum"), + PolyOp::Mult => write!(f, "mult"), + PolyOp::Matmul => write!(f, "matmul"), + PolyOp::Dot => write!(f, "dot"), + PolyOp::Pack(base, _) => write!(f, "pack with base {:?}", base), + PolyOp::Affine => write!(f, "affine"), + PolyOp::BatchNorm => write!(f, "batchnorm"), + PolyOp::ScaleAndShift => write!(f, "scale & shift"), + PolyOp::Conv { padding, stride } => { + write!(f, "conv w/ padding: {:?}, stride: {:?}", padding, stride) + } + PolyOp::SumPool { + padding, + stride, + kernel_shape, + } => { + write!( + f, + "avg pl w/ padding: {:?}, stride: {:?}, kernel shape: {:?}", + padding, stride, kernel_shape, + ) + } + PolyOp::GlobalSumPool => write!(f, "globalsumpool"), + PolyOp::Pow(s) => write!(f, "pow {}", s), + PolyOp::Rescaled { inner, scale } => { + write!( + f, + "rescaled {} w/ scalings: {:?}", + **inner, + scale.iter().map(|e| e.1).collect_vec() + ) + } + PolyOp::RangeCheck(tol) => write!(f, "range check w/ tol {}", tol), + } + } +} diff --git a/src/circuit/table.rs b/src/circuit/table.rs index e659101a4..d857f77c7 100644 --- a/src/circuit/table.rs +++ b/src/circuit/table.rs @@ -6,9 +6,15 @@ use halo2_proofs::{ }; use halo2curves::FieldExt; -use crate::{circuit::CircuitError, fieldutils::i128_to_felt, tensor::Tensor}; +use crate::{ + circuit::CircuitError, + fieldutils::i128_to_felt, + tensor::{Tensor, TensorType}, +}; + +use crate::circuit::lookup::LookupOp; -use super::LookupOp; +use super::Op; /// Halo2 lookup table for element wise non-linearities. // Table that should be reused across all lookups (so no Clone) @@ -27,7 +33,7 @@ pub struct Table { _marker: PhantomData, } -impl Table { +impl Table { /// Configures the table. pub fn configure( cs: &mut ConstraintSystem, @@ -52,12 +58,9 @@ impl Table { let base = 2i128; let smallest = -base.pow(self.bits as u32 - 1); let largest = base.pow(self.bits as u32 - 1); - // let smallest = -base.pow(3); - // let largest = base.pow(3); + let inputs = Tensor::from(smallest..largest); - // println!("Are we here Tuesday input {:?}", inputs); - let evals = self.nonlinearity.f(inputs.clone())?; - // println!("Tuesday If we are here then evals {:?}", evals); + let evals = Op::::f(&self.nonlinearity, &[inputs.clone()])?; self.is_assigned = true; layouter @@ -81,7 +84,6 @@ impl Table { row_offset, || Value::known(i128_to_felt::(evals[row_offset])), )?; - // println!("All good here inside assign table, Tuesday"); Ok(()) }) .collect::, halo2_proofs::plonk::Error>>()?; diff --git a/src/circuit/tests.rs b/src/circuit/tests.rs index b0897cf75..5ee0614dd 100644 --- a/src/circuit/tests.rs +++ b/src/circuit/tests.rs @@ -1,3 +1,4 @@ +use crate::circuit::ops::poly::PolyOp; use crate::circuit::*; use halo2_proofs::{ arithmetic::FieldExt, @@ -48,7 +49,12 @@ mod matmul { || "", |mut region| { config - .layout(&mut region, &self.inputs.clone(), &mut 0, Op::Matmul.into()) + .layout( + Some(&mut region), + &self.inputs.clone(), + &mut 0, + Box::new(PolyOp::Matmul), + ) .map_err(|_| Error::Synthesis) }, ) @@ -116,7 +122,12 @@ mod matmul_col_overflow { || "", |mut region| { config - .layout(&mut region, &self.inputs.clone(), &mut 0, Op::Matmul.into()) + .layout( + Some(&mut region), + &self.inputs.clone(), + &mut 0, + Box::new(PolyOp::Matmul), + ) .map_err(|_| Error::Synthesis) }, ) @@ -146,6 +157,8 @@ mod matmul_col_overflow { #[cfg(test)] mod dot { + use ops::poly::PolyOp; + use super::*; const K: usize = 4; @@ -183,7 +196,12 @@ mod dot { || "", |mut region| { config - .layout(&mut region, &self.inputs.clone(), &mut 0, Op::Dot.into()) + .layout( + Some(&mut region), + &self.inputs.clone(), + &mut 0, + Box::new(PolyOp::Dot), + ) .map_err(|_| Error::Synthesis) }, ) @@ -248,7 +266,12 @@ mod dot_col_overflow { || "", |mut region| { config - .layout(&mut region, &self.inputs.clone(), &mut 0, Op::Dot.into()) + .layout( + Some(&mut region), + &self.inputs.clone(), + &mut 0, + Box::new(PolyOp::Dot), + ) .map_err(|_| Error::Synthesis) }, ) @@ -313,7 +336,12 @@ mod sum { || "", |mut region| { config - .layout(&mut region, &self.inputs.clone(), &mut 0, Op::Sum.into()) + .layout( + Some(&mut region), + &self.inputs.clone(), + &mut 0, + Box::new(PolyOp::Sum), + ) .map_err(|_| Error::Synthesis) }, ) @@ -376,7 +404,12 @@ mod sum_col_overflow { || "", |mut region| { config - .layout(&mut region, &self.inputs.clone(), &mut 0, Op::Sum.into()) + .layout( + Some(&mut region), + &self.inputs.clone(), + &mut 0, + Box::new(PolyOp::Sum), + ) .map_err(|_| Error::Synthesis) }, ) @@ -440,10 +473,10 @@ mod batchnorm { |mut region| { config .layout( - &mut region, + Some(&mut region), &self.inputs.clone(), &mut 0, - Op::BatchNorm.into(), + Box::new(PolyOp::BatchNorm), ) .map_err(|_| Error::Synthesis) }, @@ -515,7 +548,12 @@ mod affine { || "", |mut region| { config - .layout(&mut region, &self.inputs.clone(), &mut 0, Op::Affine.into()) + .layout( + Some(&mut region), + &self.inputs.clone(), + &mut 0, + Box::new(PolyOp::Affine), + ) .map_err(|_| Error::Synthesis) }, ) @@ -586,7 +624,12 @@ mod affine_col_overflow { || "", |mut region| { config - .layout(&mut region, &self.inputs.clone(), &mut 0, Op::Affine.into()) + .layout( + Some(&mut region), + &self.inputs.clone(), + &mut 0, + Box::new(PolyOp::Affine), + ) .map_err(|_| Error::Synthesis) }, ) @@ -659,26 +702,26 @@ mod composition { let mut offset = 0; let _ = config .layout( - &mut region, + Some(&mut region), &self.inputs.clone(), &mut offset, - Op::Dot.into(), + Box::new(PolyOp::Dot), ) .unwrap(); let _ = config .layout( - &mut region, + Some(&mut region), &self.inputs.clone(), &mut offset, - Op::Dot.into(), + Box::new(PolyOp::Dot), ) .unwrap(); config .layout( - &mut region, + Some(&mut region), &self.inputs.clone(), &mut offset, - Op::Dot.into(), + Box::new(PolyOp::Dot), ) .map_err(|_| Error::Synthesis) }, @@ -746,14 +789,13 @@ mod conv { |mut region| { config .layout( - &mut region, + Some(&mut region), &self.inputs.clone(), &mut 0, - Op::Conv { + Box::new(PolyOp::Conv { padding: (1, 1), stride: (2, 2), - } - .into(), + }), ) .map_err(|_| Error::Synthesis) }, @@ -872,15 +914,14 @@ mod sumpool { |mut region| { config .layout( - &mut region, + Some(&mut region), &self.inputs.clone(), &mut 0, - Op::SumPool { + Box::new(PolyOp::SumPool { padding: (0, 0), stride: (1, 1), kernel_shape: (3, 3), - } - .into(), + }), ) .map_err(|_| Error::Synthesis) }, @@ -951,7 +992,12 @@ mod add_w_shape_casting { || "", |mut region| { config - .layout(&mut region, &self.inputs.clone(), &mut 0, Op::Add.into()) + .layout( + Some(&mut region), + &self.inputs.clone(), + &mut 0, + Box::new(PolyOp::Add), + ) .map_err(|_| Error::Synthesis) }, ) @@ -1016,7 +1062,12 @@ mod add { || "", |mut region| { config - .layout(&mut region, &self.inputs.clone(), &mut 0, Op::Add.into()) + .layout( + Some(&mut region), + &self.inputs.clone(), + &mut 0, + Box::new(PolyOp::Add), + ) .map_err(|_| Error::Synthesis) }, ) @@ -1081,7 +1132,12 @@ mod add_with_overflow { || "", |mut region| { config - .layout(&mut region, &self.inputs.clone(), &mut 0, Op::Add.into()) + .layout( + Some(&mut region), + &self.inputs.clone(), + &mut 0, + Box::new(PolyOp::Add), + ) .map_err(|_| Error::Synthesis) }, ) @@ -1146,7 +1202,12 @@ mod sub { || "", |mut region| { config - .layout(&mut region, &self.inputs.clone(), &mut 0, Op::Sub.into()) + .layout( + Some(&mut region), + &self.inputs.clone(), + &mut 0, + Box::new(PolyOp::Sub), + ) .map_err(|_| Error::Synthesis) }, ) @@ -1211,7 +1272,12 @@ mod mult { || "", |mut region| { config - .layout(&mut region, &self.inputs.clone(), &mut 0, Op::Mult.into()) + .layout( + Some(&mut region), + &self.inputs.clone(), + &mut 0, + Box::new(PolyOp::Mult), + ) .map_err(|_| Error::Synthesis) }, ) @@ -1276,7 +1342,12 @@ mod pow { || "", |mut region| { config - .layout(&mut region, &self.inputs.clone(), &mut 0, Op::Pow(5).into()) + .layout( + Some(&mut region), + &self.inputs.clone(), + &mut 0, + Box::new(PolyOp::Pow(5)), + ) .map_err(|_| Error::Synthesis) }, ) @@ -1340,10 +1411,10 @@ mod pack { |mut region| { config .layout( - &mut region, + Some(&mut region), &self.inputs.clone(), &mut 0, - Op::Pack(2, 1).into(), + Box::new(PolyOp::Pack(2, 1)), ) .map_err(|_| Error::Synthesis) }, @@ -1408,14 +1479,13 @@ mod rescaled { |mut region| { config .layout( - &mut region, + Some(&mut region), &self.inputs.clone(), &mut 0, - Op::Rescaled { - inner: Box::new(Op::Sum), + Box::new(PolyOp::Rescaled { + inner: Box::new(PolyOp::Sum), scale: vec![(0, 5)], - } - .into(), + }), ) .map_err(|_| Error::Synthesis) }, @@ -1447,7 +1517,7 @@ mod matmul_relu { const K: usize = 18; const LEN: usize = 32; - use crate::circuit::LookupOp; + use crate::circuit::lookup::LookupOp; #[derive(Clone)] struct MyCircuit { @@ -1493,19 +1563,19 @@ mod matmul_relu { layouter.assign_region( || "", |mut region| { - let op = Op::Matmul; + let op = PolyOp::Matmul; let mut offset = 0; let output = config .base_config - .layout(&mut region, &self.inputs, &mut offset, op.into()) + .layout(Some(&mut region), &self.inputs, &mut offset, Box::new(op)) .unwrap(); let _output = config .base_config .layout( - &mut region, + Some(&mut region), &[output.unwrap()], &mut offset, - LookupOp::ReLU { scale: 1 }.into(), + Box::new(LookupOp::ReLU { scale: 1 }), ) .unwrap(); Ok(()) @@ -1587,10 +1657,10 @@ mod rangecheck { |mut region| { config .layout( - &mut region, + Some(&mut region), &[self.input.clone(), self.output.clone()], &mut 0, - Op::RangeCheck(RANGE as i32).into(), + Box::new(PolyOp::RangeCheck(RANGE as i32)), ) .map_err(|_| Error::Synthesis) }, @@ -1688,10 +1758,10 @@ mod relu { |mut region| { config .layout( - &mut region, + Some(&mut region), &[self.input.clone()], &mut 0, - LookupOp::ReLU { scale: 1 }.into(), + Box::new(LookupOp::ReLU { scale: 1 }), ) .map_err(|_| Error::Synthesis) }, diff --git a/src/eth.rs b/src/eth.rs index e1e73d574..7e740512c 100644 --- a/src/eth.rs +++ b/src/eth.rs @@ -410,7 +410,7 @@ pub fn fix_verifier_sol(input_file: PathBuf) -> Result> { if let Some(m) = m { let mstore = m.get(1).unwrap().as_str(); let addr = m.get(2).unwrap().as_str(); - let addr_as_num = u32::from_str_radix(addr, 10)?; + let addr_as_num = addr.parse::()?; let transcript_addr = format!("{:#x}", addr_as_num); transcript_addrs.push(addr_as_num); line = line.replace( @@ -423,7 +423,7 @@ pub fn fix_verifier_sol(input_file: PathBuf) -> Result> { if let Some(m) = m { let mstore = m.get(1).unwrap().as_str(); let addr = m.get(2).unwrap().as_str(); - let addr_as_num = u32::from_str_radix(addr, 10)?; + let addr_as_num = addr.parse::()?; let transcript_addr = format!("{:#x}", addr_as_num); transcript_addrs.push(addr_as_num); line = line.replace( diff --git a/src/execute.rs b/src/execute.rs index 4cf52c402..c4d6c01fe 100644 --- a/src/execute.rs +++ b/src/execute.rs @@ -331,7 +331,7 @@ fn gen_srs_cmd(params_path: PathBuf, logrows: u32) -> Result<(), Box> } fn table(cli: Cli) -> Result<(), Box> { - let om = Model::from_ezkl_conf(cli)?; + let om = Model::::from_ezkl_conf(cli)?; info!("{}", Table::new(om.nodes.iter())); Ok(()) } @@ -342,7 +342,7 @@ fn forward( output: String, args: RunArgs, ) -> Result<(), Box> { - let mut data = prepare_data(data.to_string())?; + let mut data = prepare_data(data)?; // quantize the supplied data using the provided scale. let mut model_inputs = vec![]; @@ -351,7 +351,7 @@ fn forward( model_inputs.push(t); } - let res = Model::forward(model, &model_inputs, args)?; + let res = Model::::forward(model, &model_inputs, args)?; let float_res: Vec> = res.iter().map(|t| t.to_vec()).collect(); trace!("forward pass output: {:?}", float_res); @@ -362,7 +362,7 @@ fn forward( } fn mock(data: String, logrows: u32) -> Result<(), Box> { - let data = prepare_data(data.to_string())?; + let data = prepare_data(data)?; let model = Model::from_arg()?; let circuit = ModelCircuit::::new(&data, model)?; let public_inputs = circuit.prepare_public_inputs(&data)?; @@ -417,7 +417,7 @@ fn create_evm_verifier( sol_code_path: Option, logrows: u32, ) -> Result<(), Box> { - let data = prepare_data(data.to_string())?; + let data = prepare_data(data)?; let model = Model::from_arg()?; let circuit = ModelCircuit::::new(&data, model)?; @@ -425,7 +425,7 @@ fn create_evm_verifier( let num_instance = public_inputs.iter().map(|x| x.len()).collect(); let params = load_params_cmd(params_path, logrows)?; - let vk = load_vk::, Fr, ModelCircuit>(vk_path.to_path_buf())?; + let vk = load_vk::, Fr, ModelCircuit>(vk_path)?; trace!("params computed"); let (deployment_code, yul_code) = gen_evm_verifier(¶ms, &vk, num_instance)?; @@ -495,7 +495,7 @@ fn prove( logrows: u32, check_mode: CheckMode, ) -> Result<(), Box> { - let data = prepare_data(data.to_string())?; + let data = prepare_data(data)?; let model = Model::from_arg()?; let circuit = ModelCircuit::::new(&data, model)?; diff --git a/src/graph/mod.rs b/src/graph/mod.rs index 0610a8af3..061134746 100644 --- a/src/graph/mod.rs +++ b/src/graph/mod.rs @@ -1,6 +1,5 @@ /// Helper functions pub mod utilities; -use serde::{Deserialize, Serialize}; pub use utilities::*; /// Crate for defining a computational graph and building a ZK-circuit from it. pub mod model; @@ -9,7 +8,6 @@ pub mod node; /// Representations of a computational graph's variables. pub mod vars; -use crate::circuit::OpKind; use crate::commands::Cli; use crate::fieldutils::i128_to_felt; use crate::pfsys::ModelInput; @@ -25,10 +23,10 @@ use halo2_proofs::{ use log::{info, trace}; pub use model::*; pub use node::*; -use std::fs::File; -use std::io::{BufReader, BufWriter, Read, Write}; +// use std::fs::File; +// use std::io::{BufReader, BufWriter, Read, Write}; use std::marker::PhantomData; -use std::path::PathBuf; +// use std::path::PathBuf; use thiserror::Error; pub use vars::*; @@ -40,16 +38,16 @@ pub enum GraphError { InvalidLookupInputs, /// Shape mismatch in circuit construction #[error("invalid dimensions used for node {0} ({1})")] - InvalidDims(usize, OpKind), + InvalidDims(usize, String), /// Wrong method was called to configure an op #[error("wrong method was called to configure node {0} ({1})")] - WrongMethod(usize, OpKind), + WrongMethod(usize, String), /// A requested node is missing in the graph #[error("a requested node is missing in the graph: {0}")] MissingNode(usize), /// The wrong method was called on an operation #[error("an unsupported method was called on node {0} ({1})")] - OpMismatch(usize, OpKind), + OpMismatch(usize, String), /// This operation is unsupported #[error("unsupported operation in graph")] UnsupportedOp, @@ -70,7 +68,7 @@ pub enum GraphError { NonConstantPower, /// Error when attempting to rescale an operation #[error("failed to rescale inputs for {0}")] - RescalingError(OpKind), + RescalingError(String), /// Error when attempting to load a model #[error("failed to load model")] ModelLoad, @@ -80,12 +78,12 @@ pub enum GraphError { } /// Defines the circuit for a computational graph / model loaded from a `.onnx` file. -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct ModelCircuit { +#[derive(Clone, Debug)] +pub struct ModelCircuit { /// Vector of input tensors to the model / graph of computations. pub inputs: Vec>, /// - pub model: Model, + pub model: Model, /// Represents the Field we are using. pub _marker: PhantomData, } @@ -94,7 +92,7 @@ impl ModelCircuit { /// pub fn new( data: &ModelInput, - model: Model, + model: Model, ) -> Result, Box> { // quantize the supplied data using the provided scale. let mut inputs: Vec> = vec![]; @@ -110,39 +108,6 @@ impl ModelCircuit { }) } - /// - pub fn write( - &self, - mut writer: BufWriter, - ) -> Result<(), Box> { - let circuit_bytes = bincode::serialize(&self)?; - writer.write(&circuit_bytes)?; - writer.flush()?; - Ok(()) - } - - /// - pub fn write_to_file(&self, path: PathBuf) -> Result<(), Box> { - let fs = File::create(path)?; - let buffer = BufWriter::new(fs); - self.write(buffer) - } - - /// - pub fn read(mut reader: BufReader) -> Result> { - let buffer: &mut Vec = &mut vec![]; - reader.read_to_end(buffer)?; - - let circuit = bincode::deserialize(&buffer)?; - Ok(circuit) - } - /// - pub fn read_from_file(path: PathBuf) -> Result> { - let f = File::open(path)?; - let reader = BufReader::new(f); - Self::read(reader) - } - /// pub fn from_arg(data: &ModelInput) -> Result> { let cli = Cli::create()?; @@ -214,11 +179,11 @@ impl Circuit for ModelCircuit { } fn configure(cs: &mut ConstraintSystem) -> Self::Config { - let model = Model::from_arg().expect("model should load"); + let model: Model = Model::from_arg().expect("model should load"); // for now the number of instances corresponds to the number of graph / model outputs let instance_shapes = model.instance_shapes(); - let var_len = model.total_var_len(); + let var_len = model.dummy_layout(&model.input_shapes()).unwrap(); info!("total var len: {:?}", var_len); info!("instance_shapes: {:?}", instance_shapes); diff --git a/src/graph/model.rs b/src/graph/model.rs index d10155ca8..f4e7317c2 100644 --- a/src/graph/model.rs +++ b/src/graph/model.rs @@ -1,22 +1,20 @@ use super::node::*; use super::vars::*; use super::GraphError; +use crate::circuit::ops::poly::PolyOp; use crate::circuit::BaseConfig as PolyConfig; -use crate::circuit::LookupOp; -use crate::circuit::Op as PolyOp; -use crate::circuit::OpKind; +use crate::circuit::Op; + use crate::commands::RunArgs; use crate::commands::{Cli, Commands}; use crate::fieldutils::i128_to_felt; use crate::graph::scale_to_multiplier; use crate::tensor::TensorType; use crate::tensor::{Tensor, ValTensor}; -use anyhow::Context; use serde::Deserialize; use serde::Serialize; //use clap::Parser; use core::panic; -use halo2_proofs::circuit::Region; use halo2_proofs::{ arithmetic::FieldExt, circuit::{Layouter, Value}, @@ -25,15 +23,9 @@ use halo2_proofs::{ use itertools::Itertools; use log::error; use log::{debug, info, trace}; -use std::cell::RefCell; -use std::cmp::max; use std::collections::BTreeMap; use std::error::Error; -use std::fs::File; -use std::io::{BufReader, BufWriter, Read, Write}; use std::path::Path; -use std::path::PathBuf; -use std::rc::Rc; use tabled::Table; use tract_onnx; use tract_onnx::prelude::Framework; @@ -55,26 +47,22 @@ pub enum Mode { /// A circuit configuration for the entirety of a model loaded from an Onnx file. #[derive(Clone, Debug)] pub struct ModelConfig { - configs: BTreeMap>, + base: PolyConfig, /// The model struct - pub model: Model, - /// (optional) range checked outputs of the model graph - pub range_checks: Vec>>>, - /// (optional) packed outputs of the model graph - pub packed_outputs: Vec>>>, + pub model: Model, /// A wrapper for holding all columns that will be assigned to by the model pub vars: ModelVars, } /// A struct for loading from an Onnx file and converting a computational graph to a circuit. -#[derive(Clone, Debug, Deserialize, Serialize)] -pub struct Model { +#[derive(Clone, Debug)] +pub struct Model { /// input indices pub inputs: Vec, /// output indices pub outputs: Vec, /// Graph of nodes we are loading from Onnx. - pub nodes: NodeGraph, // Wrapped nodes with additional methods and data (e.g. inferred shape, quantization) + pub nodes: NodeGraph, // Wrapped nodes with additional methods and data (e.g. inferred shape, quantization) /// The [RunArgs] being used pub run_args: RunArgs, /// The [Mode] we're using the model in. @@ -83,7 +71,7 @@ pub struct Model { pub visibility: VarVisibility, } -impl Model { +impl Model { /// Creates an `Model` from a specified path to an Onnx file. /// # Arguments /// @@ -102,9 +90,9 @@ impl Model { .map_err(|_| GraphError::ModelLoad)?; info!("visibility: {}", visibility); - let mut nodes = BTreeMap::::new(); + let mut nodes = BTreeMap::>::new(); for (i, n) in model.nodes.iter().enumerate() { - let n = Node::new(n.clone(), &mut nodes, run_args.scale, i)?; + let n = Node::::new(n.clone(), &mut nodes, run_args.scale, i)?; nodes.insert(i, n); } let om = Model { @@ -116,7 +104,7 @@ impl Model { visibility, }; - debug!("{}", Table::new(om.nodes.iter()).to_string()); + debug!("\n {}", Table::new(om.nodes.iter()).to_string()); Ok(om) } @@ -136,7 +124,7 @@ impl Model { .map_err(|_| GraphError::ModelLoad)?; info!("running forward pass"); - let mut nodes = BTreeMap::::new(); + let mut nodes = BTreeMap::>::new(); for (i, n) in model.nodes.iter().enumerate() { let n = Node::new(n.clone(), &mut nodes, run_args.scale, i)?; nodes.insert(i, n); @@ -147,32 +135,19 @@ impl Model { let mut results: BTreeMap<&usize, Tensor> = BTreeMap::new(); for (i, n) in nodes.iter() { let mut inputs = vec![]; - for i in n.inputs.iter() { - match results.get(&i) { - Some(value) => inputs.push(value.clone()), - None => return Err(Box::new(GraphError::MissingNode(*i))), - } - } - match &n.opkind { - OpKind::Lookup(op) => { - // assert_eq!(inputs.len(), 1); - results.insert(i, op.f(inputs[0].clone())?); - } - OpKind::Poly(op) => { - results.insert(i, op.f(inputs)?); - } - OpKind::Input => { - let mut t = model_inputs[*i].clone(); - t.reshape(&n.out_dims); - results.insert(i, t); - } - OpKind::Const => { - results.insert(i, n.const_value.as_ref().unwrap().clone()); - } - _ => { - panic!("unsupported op") + if n.opkind.is_input() { + let mut t = model_inputs[*i].clone(); + t.reshape(&n.out_dims); + inputs.push(t); + } else { + for i in n.inputs.iter() { + match results.get(&i) { + Some(value) => inputs.push(value.clone()), + None => return Err(Box::new(GraphError::MissingNode(*i))), + } } - } + }; + results.insert(i, Op::::f(&*n.opkind, &inputs)?); } let output_nodes = model.outputs.iter(); @@ -219,36 +194,6 @@ impl Model { } } - /// - pub fn write(&self, mut writer: BufWriter) -> Result<(), Box> { - let circuit_bytes = bincode::serialize(&self)?; - writer.write(&circuit_bytes)?; - writer.flush()?; - Ok(()) - } - - /// - pub fn write_to_file(&self, path: PathBuf) -> Result<(), Box> { - let fs = File::create(path)?; - let buffer = BufWriter::new(fs); - self.write(buffer) - } - - /// - pub fn read(mut reader: BufReader) -> Result> { - let buffer: &mut Vec = &mut vec![]; - reader.read_to_end(buffer)?; - - let circuit = bincode::deserialize(&buffer)?; - Ok(circuit) - } - /// - pub fn read_from_file(path: PathBuf) -> Result> { - let f = File::open(path)?; - let reader = BufReader::new(f); - Self::read(reader) - } - /// Creates a `Model` based on CLI arguments pub fn from_arg() -> Result> { let conf = Cli::create()?; @@ -262,150 +207,38 @@ impl Model { /// /// * `meta` - Halo2 ConstraintSystem. /// * `advices` - A `VarTensor` holding columns of advices. Must be sufficiently large to configure all the nodes loaded in `self.nodes`. - pub fn configure( + pub fn configure( &self, meta: &mut ConstraintSystem, vars: &mut ModelVars, ) -> Result, Box> { info!("configuring model"); - let mut results = BTreeMap::new(); - let mut base_gate = Rc::new(RefCell::new(PolyConfig::configure( + let mut base_gate = PolyConfig::configure( meta, vars.advices[0..2].try_into()?, &vars.advices[2], self.run_args.check_mode, self.run_args.tolerance as i32, - ))); - - let non_op_nodes: BTreeMap<&usize, &Node> = self - .nodes - .iter() - .filter(|(_, n)| n.opkind.is_const() || n.opkind.is_input()) - .collect(); - if !non_op_nodes.is_empty() { - for (i, node) in non_op_nodes { - let config = self.conf_non_op_node(node)?; - results.insert(*i, config); - } - } - - // preserves ordering - let poly_ops: BTreeMap<&usize, &Node> = self - .nodes - .iter() - .filter(|(_, n)| n.opkind.is_poly()) - .collect(); - // preserves ordering - if !poly_ops.is_empty() { - for (i, node) in poly_ops { - let config = self.conf_poly_ops(node, &mut base_gate)?; - results.insert(*i, config); - - let mut display: String = "Poly nodes: ".to_string(); - display.push_str(&format!("| {} ({:?}) | ", i, node.opkind)); - - trace!("{}", display); - } - } + ); - let lookup_ops: BTreeMap<&usize, &Node> = self + let lookup_ops: BTreeMap<&usize, &Node> = self .nodes .iter() - .filter(|(_, n)| n.opkind.is_lookup()) + .filter(|(_, n)| n.opkind.required_lookup().is_some()) .collect(); - if !lookup_ops.is_empty() { - for (i, node) in lookup_ops { - let config = self.conf_lookup(base_gate.clone(), node, meta, vars)?; - results.insert(*i, config); - } + for node in lookup_ops.values() { + self.conf_lookup(&mut base_gate, node, meta, vars)?; } - let mut range_checks = vec![]; - let mut packed_outputs = vec![]; - if self.run_args.pack_base > 1 { - info!("packing outputs..."); - packed_outputs = self.output_ops(&mut base_gate); - } - if self.visibility.output.is_public() { - range_checks = self.output_ops(&mut base_gate); - }; - Ok(ModelConfig { - configs: results, + base: base_gate, model: self.clone(), - range_checks, - packed_outputs, vars: vars.clone(), }) } - fn output_ops( - &self, - base_gate: &mut Rc>>, - ) -> Vec>>> { - let mut configs = vec![]; - - for _ in self.output_shapes() { - configs.push(base_gate.clone()); - } - - configs - } - - /// Configures non op related nodes (eg. representing an input or const value) - pub fn conf_non_op_node( - &self, - node: &Node, - ) -> Result, Box> { - match &node.opkind { - OpKind::Const => { - // Typically parameters for one or more layers. - // Currently this is handled in the consuming node(s), but will be moved here. - Ok(NodeConfig::Const) - } - OpKind::Input => { - // This is the input to the model (e.g. the image). - // Currently this is handled in the consuming node(s), but will be moved here. - Ok(NodeConfig::Input) - } - OpKind::Unknown(_c) => { - unimplemented!() - } - c => Err(Box::new(GraphError::WrongMethod(node.idx, c.clone()))), - } - } - - /// Configures a [BTreeMap] of operations that can be constrained using polynomials. These correspond to operations that are represented in - /// the `circuit::polynomial` module. A single configuration is output, representing the amalgamation of these operations into - /// a single Halo2 gate. - /// # Arguments - /// - /// * `nodes` - A [BTreeMap] of (node index, [Node] pairs). The [Node] must represent a polynomial op. - /// * `meta` - Halo2 ConstraintSystem. - /// * `vars` - [ModelVars] for the model. - fn conf_poly_ops( - &self, - node: &Node, - base_gate: &mut Rc>>, - ) -> Result, Box> { - let input_nodes = node - .inputs - .iter() - .map(|i| self.nodes.get(&i).unwrap()) - .collect_vec(); - - let input_idx = input_nodes.iter().map(|f| f.idx).collect_vec(); - - let config = NodeConfig::Op { - config: base_gate.clone(), - inputs: input_idx, - op: node.opkind.clone(), - }; - Ok(config) - } - /// Configures a lookup table based operation. These correspond to operations that are represented in /// the `circuit::eltwise` module. /// # Arguments @@ -413,59 +246,29 @@ impl Model { /// * `node` - The [Node] must represent a lookup based op. /// * `meta` - Halo2 ConstraintSystem. /// * `vars` - [ModelVars] for the model. - fn conf_lookup( + fn conf_lookup( &self, - config: Rc>>, - node: &Node, + config: &mut PolyConfig, + node: &Node, meta: &mut ConstraintSystem, vars: &mut ModelVars, - ) -> Result, Box> { + ) -> Result<(), Box> { let input = &vars.advices[0]; let output = &vars.advices[1]; - let input_nodes = node - .inputs - .iter() - .map(|i| self.nodes.get(&i).unwrap()) - .collect_vec(); - - let input_idx = input_nodes.iter().map(|f| f.idx).collect_vec(); - let mut op = match &node.opkind { - OpKind::Lookup(l) => l.clone(), - c => { - return Err(Box::new(GraphError::WrongMethod(node.idx, c.clone()))); + let op = match &node.opkind.required_lookup() { + Some(nl) => nl.clone(), + None => { + return Err(Box::new(GraphError::WrongMethod( + node.idx, + node.opkind.as_str().to_string(), + ))); } }; - match op { - LookupOp::PReLU { scale, .. } => { - op = LookupOp::ReLU { scale }; - } - LookupOp::Max | LookupOp::Min | LookupOp::MaxPool2d { .. } => { - op = LookupOp::ReLU { scale: 1 }; - } - LookupOp::Mean { scale } => { - assert_eq!(input_nodes.len(), 1); - op = LookupOp::Div { - denom: crate::circuit::utils::F32( - // we need to scale the denom by the number of elements in the input tensor and the calculated scale diff - (scale * input_nodes[0].out_dims.iter().product::()) as f32, - ), - }; - } - _ => {} - } - - config - .borrow_mut() - .configure_lookup(meta, input, output, self.run_args.bits, &op)?; + config.configure_lookup(meta, input, output, self.run_args.bits, &op)?; - let config = NodeConfig::Op { - config, - inputs: input_idx, - op: node.opkind.clone(), - }; - Ok(config) + Ok(()) } /// Assigns values to the regions created when calling `configure`. @@ -474,7 +277,7 @@ impl Model { /// * `config` - [ModelConfig] holding all node configs. /// * `layouter` - Halo2 Layouter. /// * `inputs` - The values to feed into the circuit. - pub fn layout( + pub fn layout( &self, mut config: ModelConfig, layouter: &mut impl Layouter, @@ -491,30 +294,53 @@ impl Model { } } - // layout any lookup tables - let _: Vec<()> = config - .configs - .values() - .map(|c| match c { - // only lays out tables if they exist so this can be called safely - NodeConfig::Op { config, .. } => config.borrow_mut().layout_tables(layouter), - _ => Ok(()), - }) - .collect::, _>>()?; + config.base.layout_tables(layouter)?; layouter.assign_region( || "model", |mut region| { let mut offset: usize = 0; - for (idx, config) in config.configs.iter() { + for (idx, node) in self.nodes.iter() { + let values: Vec> = node + .inputs + .iter() + .map(|i| match &self.nodes.get(i).unwrap().opkind.const_value() { + Some(const_value) => { + if self.visibility.params.is_public() { + const_value + .map(|x| { + crate::tensor::ValType::Constant(i128_to_felt::(x)) + }) + .into() + } else { + const_value + .map(|x| { + crate::tensor::ValType::Value(Value::known( + i128_to_felt::(x), + )) + }) + .into() + } + } + _ => results.get(i).unwrap().clone(), + }) + .collect_vec(); + trace!("laying out offset {}", offset); - if let Some(vt) = self - .layout_config(&mut region, &mut results, config, &mut offset) + let res = config + .base + .layout( + Some(&mut region), + &values, + &mut offset, + node.opkind.clone_dyn(), + ) .map_err(|e| { error!("{}", e); halo2_proofs::plonk::Error::Synthesis - })? - { + })?; + + if let Some(vt) = res { // we get the max as for fused nodes this corresponds to the node output results.insert(*idx, vt); //only use with mock prover @@ -534,49 +360,55 @@ impl Model { output_nodes.clone().collect_vec() ); let mut outputs = output_nodes - .map(|o| results.get(&o).unwrap().clone()) + .map(|o| results.get(o).unwrap().clone()) .collect_vec(); // pack outputs if need be - for (i, packed_output) in config.packed_outputs.iter_mut().enumerate() { - info!("packing outputs..."); - outputs[i] = packed_output - .borrow_mut() - .layout( - &mut region, - &outputs[i..i + 1], - &mut offset, - PolyOp::Pack(self.run_args.pack_base, self.run_args.scale).into(), - ) - .map_err(|e| { - error!("{}", e); - halo2_proofs::plonk::Error::Synthesis - })? - .unwrap(); - // only use with mock prover - if matches!(self.mode, Mode::Mock) { - trace!("------------ packed output {:?}", outputs[i].show()); + if self.run_args.pack_base > 1 { + for i in 0..outputs.len() { + info!("packing outputs..."); + outputs[i] = config + .base + .layout( + Some(&mut region), + &outputs[i..i + 1], + &mut offset, + Box::new(PolyOp::Pack( + self.run_args.pack_base, + self.run_args.scale, + )), + ) + .map_err(|e| { + error!("{}", e); + halo2_proofs::plonk::Error::Synthesis + })? + .unwrap(); + // only use with mock prover + if matches!(self.mode, Mode::Mock) { + trace!("------------ packed output {:?}", outputs[i].show()); + } } } - let _ = config - .range_checks - .iter() - .zip(outputs) - .enumerate() - .map(|(i, (range_check, output))| { - let mut instance_offset = 0; - if self.visibility.input.is_public() { - instance_offset += inputs.len(); - }; - range_check.borrow_mut().layout( - &mut region, - &[output, vars.instances[instance_offset + i].clone()], - &mut offset, - PolyOp::RangeCheck(self.run_args.tolerance as i32).into(), - ) - }) - .collect_vec(); + if self.run_args.public_outputs { + let _ = outputs + .into_iter() + .enumerate() + .map(|(i, output)| { + let mut instance_offset = 0; + if self.visibility.input.is_public() { + instance_offset += inputs.len(); + }; + config.base.layout( + Some(&mut region), + &[output, vars.instances[instance_offset + i].clone()], + &mut offset, + Box::new(PolyOp::RangeCheck(self.run_args.tolerance as i32)), + ) + }) + .collect_vec(); + } + Ok(()) }, )?; @@ -584,67 +416,99 @@ impl Model { Ok(()) } - /// Assigns values to a single region, represented as a [NodeConfig]. + /// Assigns values to the regions created when calling `configure`. /// # Arguments /// - /// * `config` - [NodeConfig] the single region we will layout. + /// * `config` - [ModelConfig] holding all node configs. /// * `layouter` - Halo2 Layouter. - /// * `inputs` - [BTreeMap] of values to feed into the [NodeConfig], can also include previous intermediate results, i.e the output of other nodes. - fn layout_config( - &self, - region: &mut Region, - inputs: &mut BTreeMap>, - config: &NodeConfig, - offset: &mut usize, - ) -> Result>, Box> { - // The node kind and the config should be the same. - let res = match config.clone() { - NodeConfig::Op { - config, - inputs: idx, - op, - } => { - let values: Vec> = idx - .iter() - .map(|i| { - let node = &self.nodes.get(i).unwrap(); - match node.opkind { - OpKind::Const => { - let val = node - .const_value - .clone() - .context("Tensor should already be loaded") - .unwrap(); - if self.visibility.params.is_public() { - val.map(|x| { - crate::tensor::ValType::Constant(i128_to_felt::(x)) - }) - .into() - } else { - val.map(|x| { - crate::tensor::ValType::Value(Value::known( - i128_to_felt::(x), - )) - }) - .into() - } - } - _ => inputs.get(i).unwrap().clone(), - } - }) - .collect_vec(); + /// * `inputs` - The values to feed into the circuit. + pub fn dummy_layout(&self, input_shapes: &[Vec]) -> Result> { + info!("model layout"); + let mut results = BTreeMap::>::new(); + + let inputs: Vec> = input_shapes + .iter() + .map(|shape| { + let t: Tensor> = Tensor::new(None, shape).unwrap(); + t.into() + }) + .collect_vec(); - let res = config.borrow_mut().layout(region, &values, offset, op)?; + for (i, input_value) in inputs.iter().enumerate() { + results.insert(i, input_value.clone()); + } - res + let mut dummy_config = PolyConfig::dummy(self.run_args.logrows as usize); + + let mut offset: usize = 0; + for (idx, node) in self.nodes.iter() { + let values: Vec> = node + .inputs + .iter() + .map(|i| match &self.nodes.get(i).unwrap().opkind.const_value() { + Some(const_value) => const_value + .map(|x| crate::tensor::ValType::Value(Value::known(i128_to_felt::(x)))) + .into(), + _ => results.get(i).unwrap().clone(), + }) + .collect_vec(); + + let res = dummy_config + .layout(None, &values, &mut offset, node.opkind.clone_dyn()) + .map_err(|e| { + error!("{}", e); + halo2_proofs::plonk::Error::Synthesis + })?; + + if let Some(vt) = res { + // we get the max as for fused nodes this corresponds to the node output + results.insert(*idx, vt); } - NodeConfig::Input => None, - NodeConfig::Const => None, - _ => { - return Err(Box::new(GraphError::UnsupportedOp)); + } + + let output_nodes = self.outputs.iter(); + info!( + "model outputs are nodes: {:?}", + output_nodes.clone().collect_vec() + ); + let mut outputs = output_nodes + .map(|o| results.get(o).unwrap().clone()) + .collect_vec(); + + // pack outputs if need be + if self.run_args.pack_base > 1 { + for i in 0..outputs.len() { + info!("packing outputs..."); + outputs[i] = dummy_config + .layout( + None, + &outputs[i..i + 1], + &mut offset, + Box::new(PolyOp::Pack(self.run_args.pack_base, self.run_args.scale)), + ) + .map_err(|e| { + error!("{}", e); + halo2_proofs::plonk::Error::Synthesis + })? + .unwrap(); } - }; - Ok(res) + } + + if self.run_args.public_outputs { + let _ = outputs + .into_iter() + .map(|output| { + dummy_config.layout( + None, + &[output.clone(), output], + &mut offset, + Box::new(PolyOp::RangeCheck(self.run_args.tolerance as i32)), + ) + }) + .collect_vec(); + } + + Ok(offset) } /// Returns the number of the computational graph's inputs @@ -657,7 +521,7 @@ impl Model { pub fn input_shapes(&self) -> Vec> { self.inputs .iter() - .map(|o| self.nodes.get(&o).unwrap().out_dims.clone()) + .map(|o| self.nodes.get(o).unwrap().out_dims.clone()) .collect_vec() } @@ -671,7 +535,7 @@ impl Model { pub fn output_shapes(&self) -> Vec> { self.outputs .iter() - .map(|o| self.nodes.get(&o).unwrap().out_dims.clone()) + .map(|o| self.nodes.get(o).unwrap().out_dims.clone()) .collect_vec() } @@ -679,106 +543,10 @@ impl Model { pub fn get_output_scales(&self) -> Vec { let output_nodes = self.outputs.iter(); output_nodes - .map(|o| self.nodes.get(&o).unwrap().out_scale) + .map(|o| self.nodes.get(o).unwrap().out_scale) .collect_vec() } - /// Total number of variables in lookup layers - pub fn num_vars_lookup_op(&self, lookup_op: &LookupOp) -> Vec { - let mut count = BTreeMap::::new(); - for (_, n) in self.nodes.iter() { - if n.opkind == OpKind::Lookup(lookup_op.clone()) { - match &n.opkind { - OpKind::Lookup(op) => { - let elem = count.get_mut(op); - // handle output variables - let output_size: usize = n.out_dims.iter().product(); - let input_size = output_size; - match elem { - None => { - count.insert(op.clone(), (input_size, output_size)); - } - Some(m) => { - m.0 += input_size; - m.1 += output_size; - } - } - } - // should never reach here - _ => panic!(), - } - } - } - // now get the max across all ops - let (mut num_inputs, mut num_outputs) = (0, 0); - for (_, v) in count.iter() { - num_inputs = max(num_inputs, v.0); - num_outputs = max(num_outputs, v.1); - } - vec![num_inputs, num_outputs] - } - - /// Maximum number of input variables - pub fn total_var_len(&self) -> usize { - let mut maximum_var_len = 0; - - let poly_ops: BTreeMap<&usize, &Node> = self - .nodes - .iter() - .filter(|(_, n)| n.opkind.is_poly()) - .collect(); - - let _: Vec<_> = poly_ops - .values() - .map(|n| match &n.opkind { - OpKind::Poly(p) => { - let in_dims = n - .inputs - .iter() - .map(|i| self.nodes.get(&i).unwrap().out_dims.clone()); - let layout_shape = p.circuit_shapes(in_dims.collect_vec()); - maximum_var_len += layout_shape.last().unwrap(); - } - _ => panic!(), - }) - .collect(); - - let lookup_ops: BTreeMap<&usize, &Node> = self - .nodes - .iter() - .filter(|(_, n)| n.opkind.is_lookup()) - .collect(); - - for op in lookup_ops { - let len = (*op.1.out_dims).iter().product::(); - maximum_var_len += len; - } - - let output_lens: usize = self - .output_shapes() - .iter() - .map(|s| s.iter().product::()) - .sum::(); - - let input_lens: usize = self - .input_shapes() - .iter() - .map(|s| s.iter().product::()) - .sum::(); - - if self.run_args.pack_base > 1 { - maximum_var_len += output_lens; - } - if matches!(self.visibility.output, Visibility::Public) { - maximum_var_len += output_lens; - } - if matches!(self.visibility.output, Visibility::Public) { - maximum_var_len += input_lens; - } - - maximum_var_len - } - /// Number of instances used by the circuit pub fn instance_shapes(&self) -> Vec> { // for now the number of instances corresponds to the number of graph / model outputs diff --git a/src/graph/node.rs b/src/graph/node.rs index c7f552818..d35589cff 100644 --- a/src/graph/node.rs +++ b/src/graph/node.rs @@ -1,77 +1,34 @@ use super::utilities::{node_output_shapes, scale_to_multiplier, vector_to_quantized}; -use crate::circuit::BaseConfig; -use crate::circuit::LookupOp; -use crate::circuit::Op as PolyOp; -use crate::circuit::OpKind; +use crate::circuit::ops::poly::PolyOp; +use crate::circuit::Op; +use crate::graph::new_op_from_onnx; use crate::graph::GraphError; -use crate::tensor::Tensor; use crate::tensor::TensorType; use anyhow::Result; use halo2_proofs::arithmetic::FieldExt; use itertools::Itertools; -use log::{info, trace, warn}; -use serde::{Deserialize, Serialize}; -use std::cell::RefCell; +use log::{info, trace}; use std::collections::BTreeMap; use std::error::Error; use std::fmt; -use std::ops::Deref; -use std::rc::Rc; use tabled::Tabled; use tract_onnx; -use tract_onnx::prelude::{DatumType, InferenceFact, Node as OnnxNode}; -use tract_onnx::tract_hir::{ - infer::Factoid, - internal::InferenceOp, - ops::activations::LeakyRelu, - ops::array::{Pad, PadMode}, - ops::cnn::{Conv, MaxPool, PoolSpec, SumPool}, - ops::expandable::Expansion, - ops::nn::DataFormat, - tract_core::ops::{ - cnn::{conv::KernelFormat, PaddingSpec}, - konst::Const, - }, -}; - -/// Enum of the different kinds of node configurations `ezkl` can support. -#[allow(missing_docs)] -#[derive(Clone, Default, Debug)] -pub enum NodeConfig { - Op { - config: Rc>>, - inputs: Vec, - op: OpKind, - }, - Const, - Input, - #[default] - NotConfigured, -} +use tract_onnx::prelude::{InferenceFact, Node as OnnxNode}; +use tract_onnx::tract_hir::{infer::Factoid, internal::InferenceOp}; /// Representation of an execution graph divided into execution 'buckets'. -pub type NodeGraph = BTreeMap; +pub type NodeGraph = BTreeMap>; fn display_vector(v: &Vec) -> String { - if v.len() > 0 { + if !v.is_empty() { format!("{:?}", v) } else { - format!("") + String::new() } } -fn display_tensor(o: &Option>) -> String { - match o { - Some(s) => format!("[{:#?}...]", s[0]), - None => String::new(), - } -} - -fn display_tensorf32(o: &Option>) -> String { - match o { - Some(s) => format!("[{:#?}...]", s[0]), - None => String::new(), - } +fn display_opkind(v: &Box>) -> String { + v.as_str().to_string() } /// A single operation in a Model. @@ -84,36 +41,26 @@ fn display_tensorf32(o: &Option>) -> String { /// * `const_value` - The constants potentially associated with this self. /// * `idx` - The node's unique identifier. /// * `bucket` - The execution bucket this node has been assigned to. -#[derive(Clone, Debug, Default, Tabled, Serialize, Deserialize)] -pub struct Node { +#[derive(Clone, Debug, Tabled)] +pub struct Node { /// [OpKind] enum, i.e what operation this node represents. - pub opkind: OpKind, - /// The denominator in the fixed point representation for the node's input. Tensors of differing scales should not be combined. - pub in_scale: u32, + #[tabled(display_with = "display_opkind")] + pub opkind: Box>, /// The denominator in the fixed point representation for the node's output. Tensors of differing scales should not be combined. pub out_scale: u32, - #[tabled(display_with = "display_tensor")] - /// The quantized constants potentially associated with this self. - pub const_value: Option>, - #[tabled(display_with = "display_tensorf32")] - /// The un-quantized constants potentially associated with this self. - pub raw_const_value: Option>, // Usually there is a simple in and out shape of the node as an operator. For example, an Affine node has three input_shapes (one for the input, weight, and bias), // but in_dim is [in], out_dim is [out] #[tabled(display_with = "display_vector")] /// The indices of the node's inputs. pub inputs: Vec, #[tabled(display_with = "display_vector")] - /// Dimensions of input. - pub in_dims: Vec>, - #[tabled(display_with = "display_vector")] /// Dimensions of output. pub out_dims: Vec, /// The node's unique identifier. pub idx: usize, } -impl Node { +impl Node { /// Converts a tract [OnnxNode] into an ezkl [Node]. /// # Arguments: /// * `node` - [OnnxNode] @@ -122,17 +69,14 @@ impl Node { /// * `idx` - The node's unique identifier. pub fn new( mut node: OnnxNode>, - other_nodes: &mut BTreeMap, + other_nodes: &mut BTreeMap>, scale: u32, idx: usize, ) -> Result> { trace!("Create {:?}", node); trace!("Create op {:?}", node.op); - let output_shapes = match node_output_shapes(&node) { - Ok(s) => Some(s), - _ => None, - }; + // load the node inputs let mut inputs = vec![]; for i in node.inputs.iter_mut() { match other_nodes.get(&i.node) { @@ -141,994 +85,103 @@ impl Node { } } - let mut opkind = OpKind::new(node.op().name().as_ref()); // parses the op name - - let mn = match opkind { - OpKind::Lookup(ref s) => { - match s { - LookupOp::Min { .. } => { - let input_node = &inputs[0]; + let mut opkind = new_op_from_onnx(idx, scale, node.clone(), &mut inputs)?; // parses the op name - Node { - idx, - opkind, - inputs: node.inputs.iter().map(|i| i.node).collect(), - in_dims: vec![input_node.out_dims.clone()], - out_dims: vec![1], - in_scale: input_node.out_scale, - out_scale: input_node.out_scale, - const_value: None, - raw_const_value: None, - } - } - LookupOp::Max { .. } => { - let input_node = &inputs[0]; - - Node { - idx, - opkind, - inputs: node.inputs.iter().map(|i| i.node).collect(), - in_dims: vec![input_node.out_dims.clone()], - out_dims: vec![1], - in_scale: input_node.out_scale, - out_scale: input_node.out_scale, - const_value: None, - raw_const_value: None, - } - } - LookupOp::Sigmoid { .. } => { - let input_node = &inputs[0]; - let scale_diff = input_node.out_scale; - if scale_diff > 0 { - let mult = scale_to_multiplier(scale_diff); - opkind = OpKind::Lookup(LookupOp::Sigmoid { - scales: (mult as usize, scale_to_multiplier(scale) as usize), - }); - } else { - opkind = OpKind::Lookup(LookupOp::Sigmoid { - scales: (1, scale_to_multiplier(scale) as usize), - }); - } - - Node { - idx, - opkind, - inputs: node.inputs.iter().map(|i| i.node).collect(), - in_dims: vec![input_node.out_dims.clone()], - out_dims: input_node.out_dims.clone(), - in_scale: input_node.out_scale, - out_scale: scale, - ..Default::default() - } - } - LookupOp::MaxPool2d { .. } => { - // input_nodes come in all shapes and sizes we gotta homogenize, especially for 2D (single channel images) - let input_node = other_nodes.get_mut(&node.inputs[0].node).unwrap(); - inputs[0] = Self::format_3d_inputs(input_node)?.clone(); - - let input_node = &inputs[0]; - - // Extract the padding and stride layer hyperparams - let op = Box::new(node.op()); - let sumpool_node: &MaxPool = match op.downcast_ref() { - Some(b) => b, - None => { - return Err(Box::new(GraphError::OpMismatch(idx, opkind))); - } - }; + // if the op requires 3d inputs, we need to make sure the input shape is consistent with that + if opkind.has_3d_input() { + let input_node = other_nodes.get_mut(&node.inputs[0].node).unwrap(); + Self::format_3d_inputs(input_node)?; + inputs[0] = input_node.clone(); + }; - let pool_spec: &PoolSpec = &sumpool_node.pool_spec; + // creates a rescaled op if the inputs are not homogenous + if opkind.requires_homogenous_input_scales() { + opkind = Self::homogenize_input_scales(opkind, inputs.clone())?; + } - // only support pytorch type formatting for now - if pool_spec.data_format != DataFormat::NCHW { - return Err(Box::new(GraphError::MissingParams( - "data in wrong format".to_string(), - ))); - } + // rescale the inputs if necessary to get consistent fixed points + let in_scales: Vec = inputs.iter().map(|i| i.out_scale).collect(); + opkind = opkind.rescale(in_scales.clone(), scale); + let out_scale = match in_scales.len() { + 0 => scale, + _ => opkind.out_scale(in_scales, scale), + }; - let stride = pool_spec.strides.clone().unwrap(); - let padding = match &pool_spec.padding { - PaddingSpec::Explicit(p, _, _) => p, - _ => { - return Err(Box::new(GraphError::MissingParams( - "padding".to_string(), - ))); - } + // get the output shape + let in_dims: Vec> = inputs.iter().map(|i| i.out_dims.clone()).collect(); + let out_dims = match in_dims.len() { + // if there are no inputs, we need to get the output shape from the node + 0 => { + // remove batch dim for now + match opkind.const_value() { + Some(ref const_value) => const_value.dims().to_vec(), + _ => { + let output_shapes = match node_output_shapes(&node) { + Ok(s) => Some(s), + _ => None, }; - let kernel_shape = &pool_spec.kernel_shape; - let (padding_h, padding_w, stride_h, stride_w) = - (padding[0], padding[1], stride[0], stride[1]); - let (kernel_height, kernel_width) = (kernel_shape[0], kernel_shape[1]); - - let input_channels = input_node.out_dims[0]; - let input_height = input_node.out_dims[1]; - let input_width = input_node.out_dims[2]; - - let out_height = - (input_height + 2 * padding_h - kernel_height) / stride_h + 1; - let out_width = (input_width + 2 * padding_w - kernel_width) / stride_w + 1; - - Node { - idx, - opkind: OpKind::Lookup(LookupOp::MaxPool2d { - padding: (padding_h, padding_w), - stride: (stride_h, stride_w), - pool_dims: (kernel_height, kernel_width), - }), - inputs: node.inputs.iter().map(|i| i.node).collect(), - in_dims: vec![input_node.out_dims.clone()], - out_dims: vec![input_channels, out_height, out_width], - in_scale: input_node.out_scale, - out_scale: input_node.out_scale, - ..Default::default() - } - } - - LookupOp::Sqrt { .. } => { - let input_node = &inputs[0]; - let scale_diff = input_node.out_scale; - if scale_diff > 0 { - let mult = scale_to_multiplier(scale_diff); - opkind = OpKind::Lookup(LookupOp::Sqrt { - scales: (mult as usize, scale_to_multiplier(scale) as usize), - }); - } else { - opkind = OpKind::Lookup(LookupOp::Sqrt { - scales: (1, scale_to_multiplier(scale) as usize), - }); - } - - Node { - idx, - opkind, - inputs: node.inputs.iter().map(|i| i.node).collect(), - in_dims: vec![input_node.out_dims.clone()], - out_dims: input_node.out_dims.clone(), - in_scale: input_node.out_scale, - out_scale: scale, - ..Default::default() - } - } - - LookupOp::Tanh { .. } => { - let input_node = &inputs[0]; - let scale_diff = input_node.out_scale; - if scale_diff > 0 { - let mult = scale_to_multiplier(scale_diff); - opkind = OpKind::Lookup(LookupOp::Tanh { - scales: (mult as usize, scale_to_multiplier(scale) as usize), - }); + let dims = if let Some([Some(v)]) = output_shapes.as_deref() { + v.to_vec() } else { - opkind = OpKind::Lookup(LookupOp::Tanh { - scales: (1, scale_to_multiplier(scale) as usize), - }); - } - - Node { - idx, - opkind, - inputs: node.inputs.iter().map(|i| i.node).collect(), - in_dims: vec![input_node.out_dims.clone()], - out_dims: input_node.out_dims.clone(), - in_scale: input_node.out_scale, - out_scale: scale, - ..Default::default() - } - } - - LookupOp::Erf { .. } => { - let input_node = &inputs[0]; - let scale_diff = input_node.out_scale; - if scale_diff > 0 { - let mult = scale_to_multiplier(scale_diff); - opkind = OpKind::Lookup(LookupOp::Erf { - scales: (mult as usize, scale_to_multiplier(scale) as usize), - }); - } else { - opkind = OpKind::Lookup(LookupOp::Erf { - scales: (1, scale_to_multiplier(scale) as usize), - }); - } - - Node { - idx, - opkind, - inputs: node.inputs.iter().map(|i| i.node).collect(), - in_dims: vec![input_node.out_dims.clone()], - out_dims: input_node.out_dims.clone(), - in_scale: input_node.out_scale, - out_scale: scale, - ..Default::default() - } - } - - LookupOp::ReLU { .. } => { - let input_node = &inputs[0]; - let scale_diff = input_node.out_scale - scale; - // We can also consider adjusting the scale of all inputs and the output in a more custom way. - if scale_diff > 0 { - let mult = scale_to_multiplier(scale_diff); - opkind = OpKind::Lookup(LookupOp::ReLU { - scale: mult as usize, - }); // now the input will be scaled down to match - } - Node { - idx, - opkind, - inputs: node.inputs.iter().map(|i| i.node).collect(), - in_dims: vec![input_node.out_dims.clone()], - out_dims: input_node.out_dims.clone(), - in_scale: input_node.out_scale, - out_scale: scale, - ..Default::default() - } - } - LookupOp::Mean { .. } => { - let input_node = &inputs[0]; - let scale_diff = input_node.out_scale - scale; - // We can also consider adjusting the scale of all inputs and the output in a more custom way. - if scale_diff > 0 { - let mult = scale_to_multiplier(scale_diff); - opkind = OpKind::Lookup(LookupOp::Mean { - scale: mult as usize, - }); // now the input will be scaled down to match - } - Node { - idx, - opkind, - inputs: node.inputs.iter().map(|i| i.node).collect(), - in_dims: vec![input_node.out_dims.clone()], - out_dims: vec![1], - in_scale: input_node.out_scale, - out_scale: scale, - ..Default::default() - } - } - LookupOp::LeakyReLU { - scale: mut layer_scale, - .. - } => { - let input_node = &inputs[0]; - - // Extract the slope layer hyperparams - let op = Box::new(node.op()); - - let leaky_op: &LeakyRelu = match op.downcast_ref::>() { - Some(b) => match (*b).as_any().downcast_ref() { - Some(b) => b, - None => { - return Err(Box::new(GraphError::OpMismatch(idx, opkind))); - } - }, - None => { - return Err(Box::new(GraphError::OpMismatch(idx, opkind))); - } + // Turn `outputs: [?,3,32,32,F32 >3/0]` into `vec![3,32,32]` in two steps + let the_shape: Result> = node.outputs[0] + .fact + .shape + .dims() + .filter_map(|x| x.concretize()) + .map(|x| x.to_i64()) + .collect(); + + the_shape + .unwrap() + .iter() + .map(|x| (*x as i128) as usize) + .collect() }; - - let scale_diff = input_node.out_scale - scale; - // We can also consider adjusting the scale of all inputs and the output in a more custom way. - if scale_diff > 0 { - layer_scale = scale_to_multiplier(scale_diff) as usize; - } - - opkind = OpKind::Lookup(LookupOp::LeakyReLU { - scale: layer_scale, - slope: crate::circuit::utils::F32(leaky_op.0), - }); // now the input will be scaled down to match - - Node { - idx, - opkind, - inputs: node.inputs.iter().map(|i| i.node).collect(), - in_dims: vec![input_node.out_dims.clone()], - out_dims: input_node.out_dims.clone(), - in_scale: input_node.out_scale, - out_scale: scale, - ..Default::default() - } - } - LookupOp::PReLU { - scale: mut layer_scale, - .. - } => { - let input_node = &inputs[0]; - // Extract the slope layer hyperparams - let slopes = inputs[1] - .clone() - .raw_const_value - .unwrap() - .deref() - .iter() - .map(|value| crate::circuit::utils::F32(*value)) - .collect_vec(); - // node.inputs.pop(); - - let scale_diff = input_node.out_scale - scale; - // We can also consider adjusting the scale of all inputs and the output in a more custom way. - if scale_diff > 0 { - layer_scale = scale_to_multiplier(scale_diff) as usize; - } - - opkind = OpKind::Lookup(LookupOp::PReLU { - scale: layer_scale, - slopes, - }); // now the input will be scaled down to match - - Node { - idx, - opkind, - inputs: node.inputs.iter().map(|i| i.node).collect(), - in_dims: vec![input_node.out_dims.clone()], - out_dims: input_node.out_dims.clone(), - in_scale: input_node.out_scale, - out_scale: scale, - ..Default::default() - } - } - LookupOp::Div { .. } => { - if (inputs[1].out_dims.clone() != [1]) - || !matches!(inputs[1].opkind, OpKind::Const) - { - return Err(Box::new(GraphError::NonConstantDiv)); - } - - let input_node = &inputs[0]; - let mut input_outlets = node.inputs.clone(); - input_outlets.pop(); - - let denom = inputs[1].raw_const_value.as_ref().unwrap()[0]; - - let scale_diff = input_node.out_scale - scale; - // We can also consider adjusting the scale of all inputs and the output in a more custom way. - if scale_diff > 0 { - let mult = scale_to_multiplier(scale_diff); - opkind = OpKind::Lookup(LookupOp::Div { - denom: crate::circuit::utils::F32(denom * mult), - }); // now the input will be scaled down to match + if !dims.is_empty() && dims[0] == 1 && dims.len() > 1 { + dims[1..].to_vec() } else { - opkind = OpKind::Lookup(LookupOp::Div { - denom: crate::circuit::utils::F32(denom), - }); // now the input will be scaled down to match - } - - Node { - idx, - opkind, - inputs: input_outlets.iter().map(|i| i.node).collect(), - in_dims: vec![input_node.out_dims.clone()], - out_dims: input_node.out_dims.clone(), - // in scale is the same as the input - in_scale: input_node.out_scale, - // same for the output scale - out_scale: scale, - ..Default::default() - } - } - } - } - OpKind::Poly(ref s) => { - match s { - PolyOp::Pack(_, _) => { - return Err(Box::new(GraphError::MisformedParams( - "pack op should not be configured here".to_string(), - ))); - } - PolyOp::Pad(..) => { - let input_node = other_nodes.get_mut(&node.inputs[0].node).unwrap(); - // we only support padding for 3D images - inputs[0] = Self::format_3d_inputs(input_node)?.clone(); - - let pad_node: &Pad = match node.op().downcast_ref::() { - Some(b) => b, - None => { - return Err(Box::new(GraphError::OpMismatch(idx, opkind))); - } - }; - // we only support constant 0 padding - if pad_node.mode - != PadMode::Constant(tract_onnx::prelude::Arc::new( - tract_onnx::prelude::Tensor::zero::(&[])?, - )) - { - return Err(Box::new(GraphError::MisformedParams( - "pad mode or pad type".to_string(), - ))); - } - - let padding_len = pad_node.pads.len(); - - // we only support symmetrical padding that affects the last 2 dims (height and width params) - for (i, pad_params) in pad_node.pads.iter().enumerate() { - if (i < padding_len - 2) && ((pad_params.0 != 0) || (pad_params.1 != 0)) - { - return Err(Box::new(GraphError::MisformedParams( - "ezkl currently only supports padding height and width dimensions".to_string(), - ))); - } - if pad_params.0 != pad_params.1 { - return Err(Box::new(GraphError::MisformedParams( - "ezkl currently only supports symmetric padding".to_string(), - ))); - } - } - - let (padding_h, padding_w) = ( - pad_node.pads[padding_len - 2].0, - pad_node.pads[padding_len - 1].0, - ); - - let input_channels = input_node.out_dims[0]; - - let out_height = input_node.out_dims[1] + 2 * padding_h; - let out_width = input_node.out_dims[2] + 2 * padding_w; - - Node { - idx, - opkind: OpKind::Poly(PolyOp::Pad(padding_h, padding_w)), - inputs: node.inputs.iter().map(|i| i.node).collect(), - in_dims: vec![input_node.out_dims.clone()], - out_dims: vec![input_channels, out_height, out_width], - in_scale: input_node.out_scale, - out_scale: input_node.out_scale, - ..Default::default() - } - } - PolyOp::Dot => todo!(), - PolyOp::Conv { .. } => { - let input_node = other_nodes.get_mut(&node.inputs[0].node).unwrap(); - inputs[0] = Self::format_3d_inputs(input_node)?.clone(); - - let (input_node, weight_node) = (&inputs[0], &inputs[1]); - - // Extract the padding and stride layer hyperparams - let op = Box::new(node.op()); - - let conv_node: &Conv = match op.downcast_ref::>() { - Some(b) => match (*b).as_any().downcast_ref() { - Some(b) => b, - None => { - return Err(Box::new(GraphError::OpMismatch(idx, opkind))); - } - }, - None => { - return Err(Box::new(GraphError::OpMismatch(idx, opkind))); - } - }; - - if (conv_node.data_format != DataFormat::NCHW) - || (conv_node.kernel_fmt != KernelFormat::OIHW) - { - return Err(Box::new(GraphError::MisformedParams( - "data or kernel in wrong format".to_string(), - ))); - } - - let stride = match conv_node.strides.clone() { - Some(s) => s, - None => { - return Err(Box::new(GraphError::MissingParams( - "strides".to_string(), - ))); - } - }; - let padding = match &conv_node.padding { - PaddingSpec::Explicit(p, _, _) => p, - _ => { - return Err(Box::new(GraphError::MissingParams( - "padding".to_string(), - ))); - } - }; - - if inputs.len() == 3 { - let bias_node = &inputs[2]; - let scale_diff = - weight_node.out_scale + input_node.out_scale - bias_node.out_scale; - let mut bias_node = other_nodes.get_mut(&node.inputs[2].node).unwrap(); - bias_node = Self::scale_up_const_node(bias_node, scale + scale_diff)?; - if (input_node.out_scale + weight_node.out_scale) != bias_node.out_scale - { - return Err(Box::new(GraphError::RescalingError(opkind))); - } - } - - let oihw = weight_node.out_dims.clone(); - let (out_channels, _, kernel_height, kernel_width) = - (oihw[0], oihw[1], oihw[2], oihw[3]); - - let (padding_h, padding_w, stride_h, stride_w) = - (padding[0], padding[1], stride[0], stride[1]); - - let input_height = input_node.out_dims[1]; - let input_width = input_node.out_dims[2]; - - let out_height = - (input_height + 2 * padding_h - kernel_height) / stride_h + 1; - let out_width = (input_width + 2 * padding_w - kernel_width) / stride_w + 1; - - Node { - idx, - opkind: OpKind::Poly(PolyOp::Conv { - padding: (padding_h, padding_w), - stride: (stride_h, stride_w), - }), - inputs: node.inputs.iter().map(|i| i.node).collect(), - in_dims: vec![input_node.out_dims.clone()], - out_dims: vec![out_channels, out_height, out_width], - in_scale: input_node.out_scale, - out_scale: weight_node.out_scale + input_node.out_scale, - ..Default::default() - } - } - - PolyOp::SumPool { .. } => { - // input_nodes come in all shapes and sizes we gotta homogenize, especially for 2D (single channel images) - let input_node = other_nodes.get_mut(&node.inputs[0].node).unwrap(); - inputs[0] = Self::format_3d_inputs(input_node)?.clone(); - - let input_node = &inputs[0]; - - // Extract the padding and stride layer hyperparams - let op = Box::new(node.op()); - let sumpool_node: &SumPool = match op.downcast_ref() { - Some(b) => b, - None => { - return Err(Box::new(GraphError::OpMismatch(idx, opkind))); - } - }; - - let pool_spec: &PoolSpec = &sumpool_node.pool_spec; - - // only support pytorch type formatting for now - if pool_spec.data_format != DataFormat::NCHW { - return Err(Box::new(GraphError::MissingParams( - "data in wrong format".to_string(), - ))); - } - - let stride = pool_spec.strides.clone().unwrap(); - let padding = match &pool_spec.padding { - PaddingSpec::Explicit(p, _, _) => p, - _ => { - return Err(Box::new(GraphError::MissingParams( - "padding".to_string(), - ))); - } - }; - let kernel_shape = &pool_spec.kernel_shape; - - let (padding_h, padding_w, stride_h, stride_w) = - (padding[0], padding[1], stride[0], stride[1]); - let (kernel_height, kernel_width) = (kernel_shape[0], kernel_shape[1]); - - let input_channels = input_node.out_dims[0]; - let input_height = input_node.out_dims[1]; - let input_width = input_node.out_dims[2]; - - let out_height = - (input_height + 2 * padding_h - kernel_height) / stride_h + 1; - let out_width = (input_width + 2 * padding_w - kernel_width) / stride_w + 1; - - Node { - idx, - opkind: OpKind::Poly(PolyOp::SumPool { - padding: (padding_h, padding_w), - stride: (stride_h, stride_w), - kernel_shape: (kernel_height, kernel_width), - }), - inputs: node.inputs.iter().map(|i| i.node).collect(), - in_dims: vec![input_node.out_dims.clone()], - out_dims: vec![input_channels, out_height, out_width], - in_scale: input_node.out_scale, - out_scale: input_node.out_scale, - ..Default::default() - } - } - - PolyOp::GlobalSumPool => { - // input_nodes come in all shapes and sizes we gotta homogenize, especially for 2D (single channel images) - let input_node = other_nodes.get_mut(&node.inputs[0].node).unwrap(); - inputs[0] = Self::format_3d_inputs(input_node)?.clone(); - - let input_node = &inputs[0]; - let input_channels = input_node.out_dims[0]; - let input_height = input_node.out_dims[1]; - let input_width = input_node.out_dims[2]; - - let (padding_h, padding_w, stride_h, stride_w) = (0, 0, 1, 1); - let (kernel_height, kernel_width) = (input_height, input_width); - - // These are 1 if padding is 0,0 and stride is 1,1 - let out_height = - (input_height + 2 * padding_h - kernel_height) / stride_h + 1; - let out_width = (input_width + 2 * padding_w - kernel_width) / stride_w + 1; - - Node { - idx, - opkind: OpKind::Poly(PolyOp::SumPool { - padding: (padding_h, padding_w), - stride: (stride_h, stride_w), - kernel_shape: (kernel_height, kernel_width), - }), - inputs: node.inputs.iter().map(|i| i.node).collect(), - in_dims: vec![input_node.out_dims.clone()], - out_dims: vec![input_channels, out_height, out_width], - in_scale: input_node.out_scale, - out_scale: input_node.out_scale, - ..Default::default() - } - } - - PolyOp::Matmul => { - let (a_node, b_node) = (&inputs[0], &inputs[1]); - let a_dims = a_node.out_dims.clone(); - let b_dims = b_node.out_dims.clone(); - let in_dim = a_dims[1]; - - let mut dims = Vec::from(&a_dims[0..a_dims.len() - 2]); - dims.push(a_dims[a_dims.len() - 2]); - dims.push(b_dims[a_dims.len() - 1]); - - Node { - idx, - opkind, - inputs: node.inputs.iter().map(|i| i.node).collect(), - in_dims: vec![vec![in_dim]], - out_dims: dims.clone(), - in_scale: a_node.out_scale, - out_scale: a_node.out_scale + b_node.out_scale, - ..Default::default() - } - } - PolyOp::Affine | PolyOp::ScaleAndShift => { - let (input_node, weight_node, bias_node) = - (&inputs[0], &inputs[1], &inputs[2]); - - let scale_diff = - weight_node.out_scale + input_node.out_scale - bias_node.out_scale; - let mut bias_node = other_nodes.get_mut(&node.inputs[2].node).unwrap(); - bias_node = Self::scale_up_const_node(bias_node, scale + scale_diff)?; - if (input_node.out_scale + weight_node.out_scale) != bias_node.out_scale { - return Err(Box::new(GraphError::RescalingError(opkind))); - } - - let out_dim = weight_node.out_dims.clone()[0]; - - Node { - idx, - opkind, - inputs: node.inputs.iter().map(|i| i.node).collect(), - in_dims: inputs.iter().map(|inp| inp.out_dims.clone()).collect(), - out_dims: vec![out_dim], - in_scale: input_node.out_scale, - out_scale: weight_node.out_scale + input_node.out_scale, - ..Default::default() + dims } } - // BatchNorm take four parameters, does some f32 arithmetic and then quantizes - // while ScaleAndShift takes the final two parameters immediately. - // We will also reach back and quantize - PolyOp::BatchNorm => { - //Compute scale and shift from the four inputs, - // then replace the first two, and change this node to a ScaleAndShift - let gamma = inputs[1].raw_const_value.as_ref().unwrap(); - let beta = inputs[2].raw_const_value.as_ref().unwrap(); - let mu = inputs[3].raw_const_value.as_ref().unwrap(); - let sigma = inputs[4].raw_const_value.as_ref().unwrap(); - // let num_entries = gamma.len(); - - let a = (gamma.clone() / sigma.clone())?; - let amu: Tensor = (a.clone() * mu.clone())?; - let amupb: Tensor = (amu + beta.clone())?; - let b = (amupb * Tensor::new(Some(&[-1f32]), &[1])?)?; - - let in_scale = inputs[0].out_scale; - let out_scale = 2 * inputs[0].out_scale; - // gamma node becomes the scale (weigh) in scale and shift - inputs[1].raw_const_value = Some(a); - inputs[1].quantize_const_to_scale(in_scale)?; - - // beta node becomes the shift (bias) - inputs[2].raw_const_value = Some(b); - inputs[2].quantize_const_to_scale(out_scale)?; - - Node { - idx, - opkind: OpKind::Poly(PolyOp::ScaleAndShift), - inputs: node.inputs.iter().map(|i| i.node).collect(), - in_dims: inputs.iter().map(|inp| inp.out_dims.clone()).collect(), - out_dims: inputs[0].out_dims.clone(), - in_scale, - out_scale, - ..Default::default() - } - } - - PolyOp::Add => { - opkind = Self::homogenize_input_scales(opkind, inputs.clone())?; - - Node { - idx, - opkind, - inputs: node.inputs.iter().map(|i| i.node).collect(), - in_dims: inputs.iter().map(|inp| inp.out_dims.clone()).collect(), - out_dims: inputs.iter().map(|e| e.out_dims.clone()).max().unwrap(), - in_scale: inputs.iter().map(|input| input.out_scale).max().unwrap(), - out_scale: inputs.iter().map(|input| input.out_scale).max().unwrap(), - ..Default::default() - } - } - PolyOp::Sum => { - if inputs.len() != 1 { - return Err(Box::new(GraphError::InvalidDims(idx, opkind))); - }; - - Node { - idx, - opkind, - inputs: node.inputs.iter().map(|i| i.node).collect(), - in_dims: inputs.iter().map(|inp| inp.out_dims.clone()).collect(), - out_dims: vec![1], - in_scale: inputs.iter().map(|input| input.out_scale).max().unwrap(), - out_scale: inputs.iter().map(|input| input.out_scale).max().unwrap(), - ..Default::default() - } - } - PolyOp::Sub => { - opkind = Self::homogenize_input_scales(opkind, inputs.clone())?; - - Node { - idx, - opkind, - inputs: node.inputs.iter().map(|i| i.node).collect(), - in_dims: inputs.iter().map(|inp| inp.out_dims.clone()).collect(), - out_dims: inputs.iter().map(|e| e.out_dims.clone()).max().unwrap(), - in_scale: inputs.iter().map(|input| input.out_scale).max().unwrap(), - out_scale: inputs.iter().map(|input| input.out_scale).max().unwrap(), - ..Default::default() - } - } - PolyOp::Mult => { - let input_node = &inputs[0]; - - Node { - idx, - opkind, - inputs: node.inputs.iter().map(|i| i.node).collect(), - in_dims: inputs.iter().map(|inp| inp.out_dims.clone()).collect(), - out_dims: inputs.iter().map(|e| e.out_dims.clone()).max().unwrap(), - in_scale: input_node.out_scale, - out_scale: inputs.iter().map(|input| input.out_scale).sum::(), - ..Default::default() - } - } - PolyOp::Pow(_) => { - let input_node = &inputs[0]; - let pow = inputs[1].clone().raw_const_value.unwrap()[0]; - node.inputs.pop(); - if inputs[1].out_dims != [1] { - { - return Err(Box::new(GraphError::NonConstantPower)); - } - } - - Node { - idx, - opkind: OpKind::Poly(PolyOp::Pow(pow as u32)), - inputs: node.inputs.iter().map(|i| i.node).collect(), - in_dims: inputs.iter().map(|inp| inp.out_dims.clone()).collect(), - out_dims: inputs.iter().map(|e| e.out_dims.clone()).max().unwrap(), - in_scale: input_node.out_scale, - out_scale: input_node.out_scale * (pow as u32), - ..Default::default() - } - } - PolyOp::Rescaled { .. } => { - return Err(Box::new(GraphError::RescalingError(opkind))); - } - PolyOp::Identity => { - let input_node = &inputs[0]; - Node { - idx, - opkind, - inputs: node.inputs.iter().map(|i| i.node).collect(), - in_dims: inputs.iter().map(|inp| inp.out_dims.clone()).collect(), - out_dims: input_node.out_dims.clone(), - in_scale: input_node.out_scale, - out_scale: input_node.out_scale, - ..Default::default() - } - } - PolyOp::Flatten(_) => { - let input_node = &inputs[0]; - let new_dims: Vec = - vec![inputs[0].out_dims.iter().product::()]; - Node { - idx, - opkind: OpKind::Poly(PolyOp::Flatten(new_dims.clone())), - inputs: node.inputs.iter().map(|i| i.node).collect(), - in_dims: inputs.iter().map(|inp| inp.out_dims.clone()).collect(), - out_dims: new_dims, - in_scale: input_node.out_scale, - out_scale: input_node.out_scale, - ..Default::default() - } - } - PolyOp::Reshape(_) => { - let input_node = &inputs[0]; - let shape_const_node = &inputs[1]; - let shape_const = match shape_const_node.const_value.as_ref() { - Some(sc) => sc, - None => { - return Err(Box::new(GraphError::MissingParams( - "shape constant".to_string(), - ))); - } - }; - - let mut shapes = &shape_const[0..]; - - // we remove batch dims as we assume single elem batches - if shapes[0] == -1 && shapes.len() > 1 { - shapes = &shapes[1..]; - } - - let new_dims: Result, Box> = - if shapes.iter().all(|x| x > &0) { - let mut res = vec![]; - for x in shapes.iter() { - if x <= &0 { - return Err(Box::new(GraphError::InvalidDims(idx, opkind))); - } - res.push(*x as usize); - } - Ok(res) - } else { - let num_entries: usize = input_node.out_dims.iter().product(); - let explicit_prod: i128 = - shapes.iter().filter(|x| *x > &0).product(); - if explicit_prod <= 0 { - return Err(Box::new(GraphError::InvalidDims(idx, opkind))); - } - let inferred = num_entries / (explicit_prod as usize); - let mut new_dims: Vec = Vec::new(); - for i in shapes { - match i { - -1 => new_dims.push(inferred), - 0 => continue, - x => new_dims.push(*x as usize), - } - } - Ok(new_dims) - }; - - let new_dims = new_dims?; - - Node { - idx, - opkind: OpKind::Poly(PolyOp::Reshape(new_dims.clone())), - inputs: node.inputs[0..1].iter().map(|i| i.node).collect(), - in_dims: inputs.iter().map(|inp| inp.out_dims.clone()).collect(), - out_dims: new_dims, - in_scale: input_node.out_scale, - out_scale: input_node.out_scale, - ..Default::default() - } - } - _ => unreachable!(""), } } - OpKind::Const => { - let op = Box::new(node.op()); - let const_node: &Const = match op.as_any().downcast_ref() { - Some(b) => b, - None => { - return Err(Box::new(GraphError::OpMismatch(idx, opkind))); - } - }; - let dt = const_node.0.datum_type(); - let mut dims = const_node.0.shape().to_vec(); - if dims.is_empty() { - dims.push(1) - } - - match dt { - DatumType::F32 => { - let vec = const_node.0.as_slice::().unwrap().to_vec(); - let raw: Tensor = Tensor::new(Some(&vec), &dims).unwrap(); - let t = vector_to_quantized(&vec, &dims, 0f32, scale).unwrap(); - - Node { - idx, - opkind, - inputs: node.inputs.iter().map(|i| i.node).collect(), - in_dims: vec![dims.clone()], - out_dims: dims, - in_scale: scale, - out_scale: scale, - const_value: Some(t), - raw_const_value: Some(raw), - ..Default::default() - } - } - - DatumType::I64 => { - // Generally a shape or hyperparam - let vec = const_node.0.as_slice::().unwrap().to_vec(); - let cast: Vec = vec.iter().map(|x| *x as i128).collect(); - let t = Tensor::::new(Some(&cast), &dims).unwrap(); - - Node { - idx, - opkind, - inputs: node.inputs.iter().map(|i| i.node).collect(), - in_dims: vec![dims.clone()], - out_dims: dims, - in_scale: scale, - out_scale: 0, - const_value: Some(t), - raw_const_value: None, - ..Default::default() - } - } - _ => todo!(), - } - } - OpKind::Input => { - let dims = if let Some([Some(v)]) = output_shapes.as_deref() { - v.to_vec() - } else { - // Turn `outputs: [?,3,32,32,F32 >3/0]` into `vec![3,32,32]` in two steps - let the_shape: Result> = node.outputs[0] - .fact - .shape - .dims() - .filter_map(|x| x.concretize()) - .map(|x| x.to_i64()) - .collect(); - - the_shape - .unwrap() - .iter() - .map(|x| (*x as i128) as usize) - .collect() - }; - // remove batch dim for now - let out_dims = if dims[0] == 1 && dims.len() > 1 { - dims[1..].to_vec() - } else { - dims - }; + // else calculate the output shape from the inputs + _ => opkind.out_dims(in_dims), + }; - Node { - idx, - opkind, - inputs: node.inputs.iter().map(|i| i.node).collect(), - in_dims: vec![out_dims.clone()], - out_dims, - in_scale: scale, - out_scale: scale, - ..Default::default() + // we now run a forward pass to re-quantize the inputs to the node + // this is necessary because the inputs to the node may have been quantized differently + if let Some(idx) = opkind.bias_variable() { + if idx >= inputs.len() { + } else { + let bias_node = &inputs[idx]; + let scale_diff = out_scale - bias_node.out_scale; + let mut bias_node = other_nodes.get_mut(&inputs[idx].idx).unwrap(); + bias_node = Self::scale_up_const_node(bias_node, scale + scale_diff)?; + if (out_scale) != bias_node.out_scale { + return Err(Box::new(GraphError::RescalingError( + opkind.as_str().to_string(), + ))); } } + } - OpKind::Unknown(_) => { - warn!("{:?} is unknown", opkind); - Node::default() - } - _ => { - return Err(Box::new(GraphError::UnsupportedOp)); - } - }; - Ok(mn) + Ok(Node { + idx, + opkind, + inputs: inputs.iter().map(|i| i.idx).collect(), + out_dims, + out_scale, + }) } /// Ensures all inputs to a node have the same fixed point denominator. fn homogenize_input_scales( - opkind: OpKind, - inputs: Vec, - ) -> Result> { + opkind: Box>, + inputs: Vec, + ) -> Result>, Box> { let mut multipliers = vec![1; inputs.len()]; let out_scales = inputs.windows(1).map(|w| w[0].out_scale).collect_vec(); if !out_scales.windows(2).all(|w| w[0] == w[1]) { @@ -1152,10 +205,10 @@ impl Node { .collect_vec(); } - if let OpKind::Poly(c) = &opkind { + if let Some(c) = &opkind.required_poly() { // only rescale if need to if multipliers.iter().sum::() > multipliers.len() { - Ok(OpKind::Poly(PolyOp::Rescaled { + Ok(Box::new(PolyOp::Rescaled { inner: Box::new(c.clone()), scale: (0..inputs.len()).zip(multipliers).collect_vec(), })) @@ -1163,52 +216,78 @@ impl Node { Ok(opkind) } } else { - Err(Box::new(GraphError::RescalingError(opkind))) + Err(Box::new(GraphError::RescalingError( + opkind.as_str().to_string(), + ))) } } - fn quantize_const_to_scale(&mut self, scale: u32) -> Result<(), Box> { - if !self.opkind.is_const() { - return Err(Box::new(GraphError::WrongMethod( - self.idx, - self.opkind.clone(), - ))); - }; - let raw = self.raw_const_value.as_ref().unwrap(); - self.out_scale = scale; - let t = vector_to_quantized(raw, raw.dims(), 0f32, self.out_scale).unwrap(); - self.const_value = Some(t); - Ok(()) + /// Scales up a constant node by a given scale. + pub fn quantize_const_to_scale(&mut self, scale: u32) -> Result<(), Box> { + match &self.opkind.raw_const_value() { + Some(raw) => { + self.out_scale = scale; + let t = vector_to_quantized(&raw.map(|e| e.0), raw.dims(), 0f32, self.out_scale) + .unwrap(); + self.opkind = Box::new(crate::circuit::ops::Const { + const_value: t, + raw_const_value: Some(raw.clone()), + }); + Ok(()) + } + _ => { + return Err(Box::new(GraphError::WrongMethod( + self.idx, + self.opkind.as_str().to_string(), + ))) + } + } } /// Re-quantizes a constant value node to a new scale. - fn scale_up_const_node(node: &mut Node, scale: u32) -> Result<&mut Node, Box> { + fn scale_up_const_node(node: &mut Self, scale: u32) -> Result<&mut Self, Box> { if !node.opkind.is_const() { return Err(Box::new(GraphError::WrongMethod( node.idx, - node.opkind.clone(), + node.opkind.as_str().to_string(), ))); }; if scale > 0 { - if let Some(val) = &node.raw_const_value { - let t = vector_to_quantized(val, val.dims(), 0f32, scale)?; - node.const_value = Some(t); - info!( - "------ scaled const node {:?}: {:?} -> {:?}", - node.idx, node.in_scale, scale - ); - node.out_scale = scale; + match &node.opkind.raw_const_value() { + Some(raw_const_value) => { + let t = vector_to_quantized( + &raw_const_value.map(|f| f.0), + raw_const_value.dims(), + 0f32, + scale, + )?; + info!( + "------ scaled const node {:?}: {:?} -> {:?}", + node.idx, node.out_scale, scale + ); + node.out_scale = scale; + node.opkind = Box::new(crate::circuit::ops::Const { + const_value: t, + raw_const_value: Some(raw_const_value.clone()), + }); + } + _ => { + return Err(Box::new(GraphError::WrongMethod( + node.idx, + node.opkind.as_str().to_string(), + ))) + } } } Ok(node) } /// Formats 3d inputs if they have under or overspecified dims (casting 2D -> 3D and nD -> 3D) - fn format_3d_inputs(mut node: &mut Node) -> Result<&mut Node, Box> { + fn format_3d_inputs(mut node: &mut Self) -> Result<(), Box> { if node.opkind.is_const() { return Err(Box::new(GraphError::WrongMethod( node.idx, - node.opkind.clone(), + node.opkind.as_str().to_string(), ))); }; // input_nodes come in all shapes and sizes we gotta homogenize, especially for 2D (single channel images) @@ -1221,18 +300,18 @@ impl Node { if node.out_dims.len() != 3 { return Err(Box::new(GraphError::InvalidDims( node.idx, - node.clone().opkind, + node.clone().opkind.as_str().to_string(), ))); } - Ok(node) + Ok(()) } /// Adds an extra channel dim to nodes that need it. - fn pad_channel_input_node(node: &mut Node) -> Result<&mut Node, Box> { + fn pad_channel_input_node(node: &mut Self) -> Result<&mut Self, Box> { if node.opkind.is_const() { return Err(Box::new(GraphError::WrongMethod( node.idx, - node.opkind.clone(), + node.opkind.as_str().to_string(), ))); }; let mut dims = vec![1]; @@ -1242,11 +321,11 @@ impl Node { } /// Removes excess channels for an image - fn rm_redundant_3d_channels(node: &mut Node) -> Result<&mut Node, Box> { + fn rm_redundant_3d_channels(node: &mut Self) -> Result<&mut Self, Box> { if node.opkind.is_const() { return Err(Box::new(GraphError::WrongMethod( node.idx, - node.opkind.clone(), + node.opkind.as_str().to_string(), ))); }; let dims = &node.out_dims; @@ -1256,7 +335,7 @@ impl Node { if *dim != 1 { return Err(Box::new(GraphError::InvalidDims( node.idx, - node.opkind.clone(), + node.opkind.as_str().to_string(), ))); } } diff --git a/src/graph/utilities.rs b/src/graph/utilities.rs index 84558138d..7d2b7fdda 100644 --- a/src/graph/utilities.rs +++ b/src/graph/utilities.rs @@ -1,7 +1,25 @@ -use crate::tensor::{Tensor, TensorError}; +use super::{node::*, GraphError}; +use crate::circuit::hybrid::HybridOp; +use crate::circuit::lookup::LookupOp; +use crate::circuit::poly::PolyOp; +use crate::circuit::utils; +use crate::tensor::{Tensor, TensorError, TensorType}; use anyhow::Result; -use tract_onnx::prelude::{InferenceFact, Node}; -use tract_onnx::tract_hir::internal::InferenceOp; +use halo2curves::FieldExt; +use log::warn; +use tract_onnx::prelude::{DatumType, InferenceFact, Node as OnnxNode}; +use tract_onnx::tract_hir::{ + internal::InferenceOp, + ops::activations::LeakyRelu, + ops::array::{Pad, PadMode}, + ops::cnn::{Conv, MaxPool, PoolSpec, SumPool}, + ops::expandable::Expansion, + ops::nn::DataFormat, + tract_core::ops::{ + cnn::{conv::KernelFormat, PaddingSpec}, + konst::Const, + }, +}; // Warning: currently ignores stride information /// Quantizes an iterable of f32s to a [Tensor] of i32s using a fixed point representation. @@ -32,7 +50,7 @@ pub fn scale_to_multiplier(scale: u32) -> f32 { /// Gets the shape of a onnx node's outlets. pub fn node_output_shapes( - node: &Node>, + node: &OnnxNode>, ) -> Result>>> { let mut shapes = Vec::new(); let outputs = node.outputs.to_vec(); @@ -47,3 +65,445 @@ pub fn node_output_shapes( } Ok(shapes) } + +/// Matches an onnx node to a [OpKind] and returns a [Node] with the corresponding [OpKind]. +pub fn new_op_from_onnx( + idx: usize, + scale: u32, + node: OnnxNode>, + inputs: &mut Vec>, +) -> Result>, Box> { + Ok(match node.op().name().as_ref() { + "Reduce" => Box::new(HybridOp::Min), + "Reduce" => Box::new(HybridOp::Max), + "Clip" => Box::new(LookupOp::ReLU { scale: 1 }), + "Prelu" => { + let slopes = match inputs[1].opkind.raw_const_value() { + Some(raw_const_value) => raw_const_value, + _ => { + return Err(Box::new(GraphError::MissingParams("slopes".to_string()))); + } + }; + + Box::new(HybridOp::PReLU { + scale: 1, + slopes: slopes.to_vec(), + }) + } + "LeakyRelu" => { + // Extract the slope layer hyperparams + let op = Box::new(node.op()); + + let leaky_op: &LeakyRelu = match op.downcast_ref::>() { + Some(b) => match (*b).as_any().downcast_ref() { + Some(b) => b, + None => { + return Err(Box::new(GraphError::OpMismatch( + idx, + "leaky relu".to_string(), + ))); + } + }, + None => { + return Err(Box::new(GraphError::OpMismatch( + idx, + "leaky relu".to_string(), + ))); + } + }; + Box::new(LookupOp::LeakyReLU { + scale: 1, + slope: crate::circuit::utils::F32(leaky_op.0), + }) + } + "Sigmoid" => Box::new(LookupOp::Sigmoid { scales: (1, 1) }), + "Sqrt" => Box::new(LookupOp::Sqrt { scales: (1, 1) }), + "Tanh" => Box::new(LookupOp::Tanh { scales: (1, 1) }), + "onnx.Erf" => Box::new(LookupOp::Erf { scales: (1, 1) }), + "Div" => { + if (inputs[1].out_dims.clone() != [1]) || !inputs[1].opkind.is_const() { + return Err(Box::new(GraphError::NonConstantDiv)); + } + + let denom = match &inputs[1].opkind.raw_const_value() { + Some(raw_const_value) => raw_const_value.map(|x| x.0)[0], + _ => { + return Err(Box::new(GraphError::MissingParams("slopes".to_string()))); + } + }; + + inputs.pop(); + + Box::new(LookupOp::Div { + denom: crate::circuit::utils::F32(denom), + }) + } + + "Const" => { + let op = Box::new(node.op()); + let const_node: &Const = match op.as_any().downcast_ref() { + Some(b) => b, + None => { + return Err(Box::new(GraphError::OpMismatch(idx, "const".to_string()))); + } + }; + let dt = const_node.0.datum_type(); + let mut dims = const_node.0.shape().to_vec(); + if dims.is_empty() { + dims.push(1) + } + + let const_value: Tensor; + let mut raw_const_value = None; + match dt { + DatumType::F32 => { + let vec = const_node.0.as_slice::().unwrap().to_vec(); + let raw: Tensor = Tensor::new(Some(&vec), &dims).unwrap(); + let t = vector_to_quantized(&vec, &dims, 0f32, scale).unwrap(); + const_value = t; + raw_const_value = Some(raw.map(utils::F32)); + } + + DatumType::I64 => { + // Generally a shape or hyperparam + let vec = const_node.0.as_slice::().unwrap().to_vec(); + let cast: Vec = vec.iter().map(|x| *x as i128).collect(); + let t = Tensor::::new(Some(&cast), &dims).unwrap(); + const_value = t; + } + _ => todo!(), + } + Box::new(crate::circuit::ops::Const { + const_value, + raw_const_value, + }) + } + "Source" => Box::new(crate::circuit::ops::Input), + "Add" => Box::new(PolyOp::Add), + "Sub" => Box::new(PolyOp::Sub), + "Mul" => Box::new(PolyOp::Mult), + "Gemm" => Box::new(PolyOp::Affine), + "MatMulInference" => Box::new(PolyOp::Matmul), + "MaxPool" => { + // Extract the padding and stride layer hyperparams + let op = Box::new(node.op()); + let sumpool_node: &MaxPool = match op.downcast_ref() { + Some(b) => b, + None => { + return Err(Box::new(GraphError::OpMismatch(idx, "Maxpool".to_string()))); + } + }; + + let pool_spec: &PoolSpec = &sumpool_node.pool_spec; + + // only support pytorch type formatting for now + if pool_spec.data_format != DataFormat::NCHW { + return Err(Box::new(GraphError::MissingParams( + "data in wrong format".to_string(), + ))); + } + + let stride = pool_spec.strides.clone().unwrap(); + let padding = match &pool_spec.padding { + PaddingSpec::Explicit(p, _, _) => p, + _ => { + return Err(Box::new(GraphError::MissingParams("padding".to_string()))); + } + }; + let kernel_shape = &pool_spec.kernel_shape; + + let (padding_h, padding_w, stride_h, stride_w) = + (padding[0], padding[1], stride[0], stride[1]); + let (kernel_height, kernel_width) = (kernel_shape[0], kernel_shape[1]); + + Box::new(HybridOp::MaxPool2d { + padding: (padding_h, padding_w), + stride: (stride_h, stride_w), + pool_dims: (kernel_height, kernel_width), + }) + } + "Dot" => Box::new(PolyOp::Dot), + "Reduce" => { + if inputs.len() != 1 { + return Err(Box::new(GraphError::InvalidDims(idx, "sum".to_string()))); + }; + + Box::new(PolyOp::Sum) + } + "Reduce" => Box::new(HybridOp::Mean { + scale: 1, + num_inputs: inputs[0].out_dims.iter().product::(), + }), + "Pow" => match &inputs[1].opkind.raw_const_value() { + Some(raw_const_value) => { + let pow = &raw_const_value[0].0; + if inputs[1].out_dims != [1] { + { + return Err(Box::new(GraphError::NonConstantPower)); + } + } + inputs.pop(); + Box::new(PolyOp::Pow(*pow as u32)) + } + _ => return Err(Box::new(GraphError::MissingParams("pow".to_string()))), + }, + "Conv" | "ConvHir" => { + // Extract the padding and stride layer hyperparams + let op = Box::new(node.op()); + + let conv_node: &Conv = match op.downcast_ref::>() { + Some(b) => match (*b).as_any().downcast_ref() { + Some(b) => b, + None => { + return Err(Box::new(GraphError::OpMismatch(idx, "conv".to_string()))); + } + }, + None => { + return Err(Box::new(GraphError::OpMismatch(idx, "conv".to_string()))); + } + }; + + if (conv_node.data_format != DataFormat::NCHW) + || (conv_node.kernel_fmt != KernelFormat::OIHW) + { + return Err(Box::new(GraphError::MisformedParams( + "data or kernel in wrong format".to_string(), + ))); + } + + let stride = match conv_node.strides.clone() { + Some(s) => s, + None => { + return Err(Box::new(GraphError::MissingParams("strides".to_string()))); + } + }; + let padding = match &conv_node.padding { + PaddingSpec::Explicit(p, _, _) => p, + _ => { + return Err(Box::new(GraphError::MissingParams("padding".to_string()))); + } + }; + + let (padding_h, padding_w, stride_h, stride_w) = + (padding[0], padding[1], stride[0], stride[1]); + Box::new(PolyOp::Conv { + padding: (padding_h, padding_w), + stride: (stride_h, stride_w), + }) + } + + "SumPool" => { + // Extract the padding and stride layer hyperparams + let op = Box::new(node.op()); + let sumpool_node: &SumPool = match op.downcast_ref() { + Some(b) => b, + None => { + return Err(Box::new(GraphError::OpMismatch(idx, "sumpool".to_string()))); + } + }; + + let pool_spec: &PoolSpec = &sumpool_node.pool_spec; + + // only support pytorch type formatting for now + if pool_spec.data_format != DataFormat::NCHW { + return Err(Box::new(GraphError::MissingParams( + "data in wrong format".to_string(), + ))); + } + + let stride = pool_spec.strides.clone().unwrap(); + let padding = match &pool_spec.padding { + PaddingSpec::Explicit(p, _, _) => p, + _ => { + return Err(Box::new(GraphError::MissingParams("padding".to_string()))); + } + }; + let kernel_shape = &pool_spec.kernel_shape; + + let (padding_h, padding_w, stride_h, stride_w) = + (padding[0], padding[1], stride[0], stride[1]); + let (kernel_height, kernel_width) = (kernel_shape[0], kernel_shape[1]); + + Box::new(PolyOp::SumPool { + padding: (padding_h, padding_w), + stride: (stride_h, stride_w), + kernel_shape: (kernel_height, kernel_width), + }) + } + "InstanceNorm" => Box::new(HybridOp::InstanceNorm2d { + epsilon: utils::F32(1e-5), + }), + "GlobalAvgPool" => Box::new(PolyOp::SumPool { + padding: (0, 0), + stride: (1, 1), + kernel_shape: (inputs[0].out_dims[1], inputs[0].out_dims[2]), + }), + "Pad" => { + let pad_node: &Pad = match node.op().downcast_ref::() { + Some(b) => b, + None => { + return Err(Box::new(GraphError::OpMismatch(idx, "pad".to_string()))); + } + }; + // we only support constant 0 padding + if pad_node.mode + != PadMode::Constant(tract_onnx::prelude::Arc::new( + tract_onnx::prelude::Tensor::zero::(&[])?, + )) + { + return Err(Box::new(GraphError::MisformedParams( + "pad mode or pad type".to_string(), + ))); + } + + let padding_len = pad_node.pads.len(); + + // we only support symmetrical padding that affects the last 2 dims (height and width params) + for (i, pad_params) in pad_node.pads.iter().enumerate() { + if (i < padding_len - 2) && ((pad_params.0 != 0) || (pad_params.1 != 0)) { + return Err(Box::new(GraphError::MisformedParams( + "ezkl currently only supports padding height and width dimensions" + .to_string(), + ))); + } + if pad_params.0 != pad_params.1 { + return Err(Box::new(GraphError::MisformedParams( + "ezkl currently only supports symmetric padding".to_string(), + ))); + } + } + + let (padding_h, padding_w) = ( + pad_node.pads[padding_len - 2].0, + pad_node.pads[padding_len - 1].0, + ); + Box::new(PolyOp::Pad(padding_h, padding_w)) + } + "Reshape" => { + let input_node = &inputs[0]; + let shape_const_node = &inputs[1]; + let shape_const = match shape_const_node.opkind.const_value() { + Some(const_value) => const_value, + _ => { + return Err(Box::new(GraphError::MissingParams( + "shape constant".to_string(), + ))); + } + }; + + let mut shapes = &shape_const[0..]; + + // we remove batch dims as we assume single elem batches + if shapes[0] == -1 && shapes.len() > 1 { + shapes = &shapes[1..]; + } + + let new_dims: Result, Box> = + if shapes.iter().all(|x| x > &0) { + let mut res = vec![]; + for x in shapes.iter() { + if x <= &0 { + return Err(Box::new(GraphError::InvalidDims( + idx, + "reshape".to_string(), + ))); + } + res.push(*x as usize); + } + Ok(res) + } else { + let num_entries: usize = input_node.out_dims.iter().product(); + let explicit_prod: i128 = shapes.iter().filter(|x| *x > &0).product(); + if explicit_prod <= 0 { + return Err(Box::new(GraphError::InvalidDims( + idx, + "reshape".to_string(), + ))); + } + let inferred = num_entries / (explicit_prod as usize); + let mut new_dims: Vec = Vec::new(); + for i in shapes { + match i { + -1 => new_dims.push(inferred), + 0 => continue, + x => new_dims.push(*x as usize), + } + } + Ok(new_dims) + }; + + let new_dims = new_dims?; + inputs.pop(); + + Box::new(PolyOp::Reshape(new_dims)) + } + "Flatten" => { + let new_dims: Vec = vec![inputs[0].out_dims.iter().product::()]; + Box::new(PolyOp::Flatten(new_dims)) + } + // BatchNorm take four parameters, does some f32 arithmetic and then quantizes + // while ScaleAndShift takes the final two parameters immediately. + // We will also reach back and quantize + "BatchNorm" => { + //Compute scale and shift from the four inputs, + // then replace the first two, and change this node to a ScaleAndShift + let gamma = match &inputs[1].opkind.raw_const_value() { + Some(raw_const_value, ..) => raw_const_value.map(|x| x.0), + _ => { + return Err(Box::new(GraphError::MissingParams("bn_gamma".to_string()))); + } + }; + + let beta = match &inputs[2].opkind.raw_const_value() { + Some(raw_const_value, ..) => raw_const_value.map(|x| x.0), + _ => { + return Err(Box::new(GraphError::MissingParams("bn_beta".to_string()))); + } + }; + + let mu = match &inputs[3].opkind.raw_const_value() { + Some(raw_const_value, ..) => raw_const_value.map(|x| x.0), + _ => { + return Err(Box::new(GraphError::MissingParams("bn_mu".to_string()))); + } + }; + + let sigma = match &inputs[4].opkind.raw_const_value() { + Some(raw_const_value, ..) => raw_const_value.map(|x| x.0), + _ => { + return Err(Box::new(GraphError::MissingParams("bn_sigma".to_string()))); + } + }; + + let a = (gamma / sigma)?; + let amu: Tensor = (a.clone() * mu)?; + let amupb: Tensor = (amu + beta)?; + let b = (amupb * Tensor::new(Some(&[-1f32]), &[1])?)?; + + let in_scale = inputs[0].out_scale; + let out_scale = 2 * inputs[0].out_scale; + // gamma node becomes the scale (weigh) in scale and shift + inputs[1].opkind = Box::new(crate::circuit::ops::Const { + const_value: Tensor::new(None, &[1])?, + raw_const_value: Some(a.map(utils::F32)), + }); + inputs[1].quantize_const_to_scale(in_scale)?; + + // beta node becomes the shift (bias) + inputs[2].opkind = Box::new(crate::circuit::ops::Const { + const_value: Tensor::new(None, &[1])?, + raw_const_value: Some(b.map(utils::F32)), + }); + inputs[2].quantize_const_to_scale(out_scale)?; + + inputs.pop(); + inputs.pop(); + + Box::new(PolyOp::ScaleAndShift) + } + c => { + warn!("{:?} is not currently supported", c); + Box::new(crate::circuit::ops::Unknown) + } + }) +} diff --git a/src/lib.rs b/src/lib.rs index 4ae0a8e7e..22e37dfe2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -44,8 +44,8 @@ pub mod fieldutils; pub mod graph; /// Tools for proofs and verification used by cli pub mod pfsys; -/// An implementation of multi-dimensional tensors. -pub mod tensor; /// Python bindings #[cfg(feature = "python-bindings")] pub mod python; +/// An implementation of multi-dimensional tensors. +pub mod tensor; diff --git a/src/python.rs b/src/python.rs index 9b3bc3565..a745b4fdf 100644 --- a/src/python.rs +++ b/src/python.rs @@ -1,18 +1,17 @@ +use crate::circuit::CheckMode; +use crate::commands::RunArgs; +use crate::graph::{vector_to_quantized, Mode, Model, VarVisibility, Visibility}; +use crate::pfsys::{gen_srs as ezkl_gen_srs, prepare_data, save_params}; +use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme; +use halo2curves::bn256::{Bn256, Fr}; +use log::trace; +use pyo3::exceptions::{PyIOError, PyRuntimeError, PyValueError}; use pyo3::prelude::*; use pyo3::wrap_pyfunction; -use pyo3::exceptions::{PyIOError, PyRuntimeError, PyValueError}; use pyo3_log; -use tabled::Table; -use crate::graph::{Model, Visibility, VarVisibility, Mode, vector_to_quantized}; -use crate::commands::RunArgs; -use crate::circuit::CheckMode; -use crate::pfsys::{gen_srs as ezkl_gen_srs, save_params, prepare_data}; -use std::path::PathBuf; use std::fs::File; -use log::trace; -use halo2curves::bn256::Bn256; -use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme; - +use std::path::PathBuf; +use tabled::Table; // See commands.rs and execute.rs // RenderCircuit @@ -34,9 +33,7 @@ use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme; // Table #[pyfunction] -fn table( - model: String, -) -> Result { +fn table(model: String) -> Result { // use default values to initialize model let run_args = RunArgs { tolerance: 0, @@ -57,28 +54,16 @@ fn table( output: Visibility::Public, }; - let result = Model::new( - model, - run_args, - Mode::Mock, - visibility, - ); + let result = Model::::new(model, run_args, Mode::Mock, visibility); match result { - Ok(m) => { - Ok(Table::new(m.nodes.iter()).to_string()) - }, - Err(_) => { - Err(PyIOError::new_err("Failed to import model")) - }, + Ok(m) => Ok(Table::new(m.nodes.iter()).to_string()), + Err(_) => Err(PyIOError::new_err("Failed to import model")), } } #[pyfunction] -fn gen_srs( - params_path: PathBuf, - logrows: u32, -) -> PyResult<()> { +fn gen_srs(params_path: PathBuf, logrows: u32) -> PyResult<()> { let run_args = RunArgs { tolerance: 0, scale: 7, @@ -121,7 +106,7 @@ fn forward( public_outputs: bool, public_params: bool, pack_base: u32, - check_mode: &str + check_mode: &str, ) -> PyResult<()> { let data = prepare_data(data); @@ -142,23 +127,12 @@ fn forward( let mut model_inputs = vec![]; // quantize the supplied data using the provided scale. for v in new_data.input_data.iter() { - match vector_to_quantized( - v, - &Vec::from([v.len()]), - 0.0, - run_args.scale - ) { + match vector_to_quantized(v, &Vec::from([v.len()]), 0.0, run_args.scale) { Ok(t) => model_inputs.push(t), - Err(_) => { - return Err(PyValueError::new_err("Failed to quantize vector")) - } + Err(_) => return Err(PyValueError::new_err("Failed to quantize vector")), } } - let res = Model::forward( - model, - &model_inputs, - run_args - ); + let res = Model::::forward(model, &model_inputs, run_args); match res { Ok(r) => { @@ -176,20 +150,14 @@ fn forward( // let py = gil.python(); // return Ok(new_data.to_object(py)) Ok(()) - }, - Err(_) => { - return Err(PyIOError::new_err("Failed to create output file")) } + Err(_) => return Err(PyIOError::new_err("Failed to create output file")), } } - Err(_) => { - Err(PyRuntimeError::new_err("Failed to compute forward pass")) - } + Err(_) => Err(PyRuntimeError::new_err("Failed to compute forward pass")), } - }, - Err(_) => { - Err(PyIOError::new_err("Failed to import files")) - }, + } + Err(_) => Err(PyIOError::new_err("Failed to import files")), } } @@ -213,4 +181,4 @@ fn ezkl_lib(_py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(gen_srs, m)?)?; m.add_function(wrap_pyfunction!(forward, m)?)?; Ok(()) -} \ No newline at end of file +} diff --git a/src/tensor/mod.rs b/src/tensor/mod.rs index aa9310864..6c4c4c018 100644 --- a/src/tensor/mod.rs +++ b/src/tensor/mod.rs @@ -9,7 +9,10 @@ use serde::{Deserialize, Serialize}; pub use val::*; pub use var::*; -use crate::fieldutils::{felt_to_i32, i128_to_felt, i32_to_felt}; +use crate::{ + circuit::utils, + fieldutils::{felt_to_i32, i128_to_felt, i32_to_felt}, +}; use halo2_proofs::{ arithmetic::FieldExt, @@ -100,6 +103,7 @@ tensor_type!(i128, Int128, 0, 1); tensor_type!(i32, Int32, 0, 1); tensor_type!(usize, USize, 0, 1); tensor_type!((), Empty, (), ()); +tensor_type!(utils::F32, F32, utils::F32(0.0), utils::F32(1.0)); impl TensorType for Tensor { fn zero() -> Option { @@ -223,7 +227,7 @@ impl TensorType for halo2curves::bn256::Fr { /// A generic multi-dimensional array representation of a Tensor. /// The `inner` attribute contains a vector of values whereas `dims` corresponds to the dimensionality of the array /// and as such determines how we index, query for values, or slice a Tensor. -#[derive(Clone, Debug, Eq, Serialize, Deserialize)] +#[derive(Clone, Debug, Eq, Serialize, Deserialize, PartialOrd, Ord)] pub struct Tensor { inner: Vec, dims: Vec, diff --git a/src/tensor/ops.rs b/src/tensor/ops.rs index 10118209a..f9f75119d 100644 --- a/src/tensor/ops.rs +++ b/src/tensor/ops.rs @@ -547,7 +547,7 @@ pub fn max_pool2d( let input_channels = image_dims[0]; let (image_height, image_width) = (image_dims[1], image_dims[2]); - let padded_image = pad::(image, padding.clone())?; + let padded_image = pad::(image, *padding)?; let horz_slides = (image_height + 2 * padding.0 - pool_dims.0) / stride.0 + 1; let vert_slides = (image_width + 2 * padding.1 - pool_dims.1) / stride.1 + 1; @@ -828,7 +828,7 @@ pub mod nonlinearities { let mut output = a.clone(); for i in 0..a.len() { - let mut z = a[i].clone() as f32 / (scale_input as f32); + let mut z = a[i] as f32 / (scale_input as f32); z = (scale_output as f32) * (erf(z as f64) as f32); output[i] = z as i128; } @@ -869,6 +869,71 @@ pub mod nonlinearities { output } + /// Elementwise applies instance norm to a tensor of integers. + /// # Arguments + /// + /// * `a` - Tensor + /// * `gamma` - vector of scale values + /// * `beta` - vector of offset values + /// # Examples + /// ``` + /// use ezkl_lib::tensor::Tensor; + /// use ezkl_lib::tensor::ops::nonlinearities::instance_norm; + /// let x = Tensor::::new( + /// Some(&[4, 2, 8, 1, 1, 2, 2, 2, 3]), + /// &[1, 3, 3], + /// ).unwrap(); + /// + /// let gamma = Tensor::::new( + /// Some(&[1]), + /// &[1], + /// ).unwrap(); + /// + /// let beta = Tensor::::new( + /// Some(&[23]), + /// &[1], + /// ).unwrap(); + /// + /// let result = instance_norm([x, gamma, beta], 1.0); + /// let expected = Tensor::::new(Some(&[25, 23, 29, 22, 22, 23, 23, 23, 24]), &[1, 3, 3]).unwrap(); + /// assert_eq!(result, expected); + /// ``` + pub fn instance_norm(inputs: [Tensor; 3], epsilon: f32) -> Tensor { + let a = &inputs[0]; + let gamma = &inputs[1]; + let beta = &inputs[2]; + // calculate value of output + assert!(a.len() > 1); + assert!(a.dims().len() == 3); + assert_eq!(gamma.len(), beta.len()); + // assert num channels is same as num of parameters + assert_eq!(gamma.len(), a.dims()[0]); + let mut output = vec![]; + for i in 0..gamma.len() { + let row = a.get_slice(&[i..i + 1]).unwrap(); + let sum = sum(&row).unwrap(); + let mean = sum.map(|x| (x) / (row.len() as i128)); + + // unbiased = false in pytorch definition. if it was unbiased we would divide by row.len() - 1 + let var = (row.clone() - mean.clone()) + .unwrap() + .map(|e| (e as f32) / (row.len() as f32)); + + let denom = var.map(|e| (e + epsilon).sqrt().round() as i128); + let numerator = (row - mean).unwrap() * vec![gamma[i]].into_iter().into(); + + let result = + ((numerator.unwrap() / denom).unwrap() + vec![beta[i]].into_iter().into()).unwrap(); + + output.push(result); + } + + let mut output = Tensor::from(output.into_iter()).combine().unwrap(); + output.reshape(a.dims()); + + output + } + /// Elementwise applies prelu to a tensor of integers. /// # Arguments /// diff --git a/src/tensor/val.rs b/src/tensor/val.rs index ad9c1cad8..a6e060ff6 100644 --- a/src/tensor/val.rs +++ b/src/tensor/val.rs @@ -352,6 +352,11 @@ impl ValTensor { } } + /// + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + /// Calls `pad_row_ones` on the inner [Tensor]. pub fn pad_row_ones(&mut self) -> Result<(), TensorError> { match self { diff --git a/src/tensor/var.rs b/src/tensor/var.rs index 18fa4fff0..4680c9414 100644 --- a/src/tensor/var.rs +++ b/src/tensor/var.rs @@ -11,7 +11,7 @@ use super::*; /// about the column layout. This enum is generally used to configure and layout circuit variables / advices. /// For instance can be used to represent neural network parameters within a circuit that we later assign to /// using the `assign` method called on a [ValTensor]. -#[derive(Clone, Debug, Default, PartialEq, Eq)] +#[derive(Clone, Default, Debug, PartialEq, Eq)] pub enum VarTensor { /// A VarTensor for holding Advice values, which are assigned at proving time. Advice { @@ -31,9 +31,14 @@ pub enum VarTensor { /// Total capacity (number of advice cells), usually inner.len()*col_size capacity: usize, }, + /// Dummy var + Dummy { + /// Number of rows available to be used in each column of the storage + col_size: usize, + }, + /// Empty var #[default] - /// Dummy / empty var - None + Empty } impl VarTensor { @@ -71,6 +76,15 @@ impl VarTensor { } } + /// Create a new VarTensor::Dummy + pub fn dummy(logrows: usize) -> Self { + let base = 2u32; + let max_rows = base.pow(logrows as u32) as usize - 6; + VarTensor::Dummy { + col_size: max_rows, + } + } + /// Create a new VarTensor::Fixed /// `cs` is the `ConstraintSystem` from which the columns will be allocated. /// `k` is the log2 number of rows in the matrix, including any system and blinding rows. @@ -115,7 +129,7 @@ impl VarTensor { /// Gets the size of each column pub fn col_size(&self) -> usize { match self { - VarTensor::Advice { col_size, .. } | VarTensor::Fixed { col_size, .. } => *col_size, + VarTensor::Advice { col_size, .. } | VarTensor::Fixed { col_size, .. } | VarTensor::Dummy { col_size } => *col_size, _ => 0 } } @@ -182,7 +196,7 @@ impl VarTensor { /// pub fn assign_constant( &self, - region: &mut Region<'_, F>, + region: &mut Region, offset: usize, constant: F ) -> Result, halo2_proofs::plonk::Error>{ @@ -203,10 +217,12 @@ impl VarTensor { /// Assigns specific values [ValTensor] to the columns of the inner tensor. pub fn assign( &self, - region: &mut Region<'_, F>, + region: Option<&mut Region>, offset: usize, values: &ValTensor, - ) -> Result>, halo2_proofs::plonk::Error> { + ) -> Result, halo2_proofs::plonk::Error> { + match region { + Some(region) => { match values { ValTensor::Instance { inner: instance, @@ -215,7 +231,7 @@ impl VarTensor { VarTensor::Advice { inner: v, .. } => { // this should never ever fail let t: Tensor = Tensor::new(None, dims).unwrap(); - t.enum_map(|coord, _| { + Ok(t.enum_map(|coord, _| { let (x, y) = self.cartesian_coord(offset + coord); region.assign_advice_from_instance( || "pub input anchor", @@ -224,14 +240,14 @@ impl VarTensor { v[x], y, ) - }) + })?.into()) } _ => { error!("Instance is only supported for advice columns"); Err(halo2_proofs::plonk::Error::Synthesis) }, }, - ValTensor::Value { inner: v, .. } => v.enum_map(|coord, k| { + ValTensor::Value { inner: v, .. } => Ok(v.enum_map(|coord, k| { let (x, y) = self.cartesian_coord(offset + coord); match k { @@ -265,31 +281,37 @@ impl VarTensor { self.assign_constant(region, offset + coord, v) } } - }), + })?.into()), } } + None => Ok(values.clone()) +} + } /// Assigns specific values (`ValTensor`) to the columns of the inner tensor. pub fn assign_with_duplication( &self, - region: &mut Region<'_, F>, + region: Option<&mut Region>, offset: usize, values: &ValTensor, check_mode: &CheckMode - ) -> Result<(Tensor>, usize), halo2_proofs::plonk::Error> { + ) -> Result<(ValTensor, usize), halo2_proofs::plonk::Error> { + match values { ValTensor::Instance { .. } => unimplemented!("duplication is not supported on instance columns. increase K if you require more rows."), ValTensor::Value { inner: v, dims } => { // duplicates every nth element to adjust for column overflow let v = v.duplicate_every_n(self.col_size(), offset).unwrap(); - let res = v.enum_map(|coord, k| { + let mut res: ValTensor = if let Some(region) = region { + v.enum_map(|coord, k| { let (x, y) = self.cartesian_coord(offset + coord); if matches!(check_mode, CheckMode::SAFE) && x > 0 && y == 0 { // assert that duplication occurred correctly assert_eq!(Into::::into(k.clone()), Into::::into(v[coord - 1].clone())); }; + match k { ValType::Value(v) => match &self { VarTensor::Fixed { inner: fixed, .. } => { @@ -322,26 +344,29 @@ impl VarTensor { self.assign_constant(region, offset + coord, v) } } - })?; - let mut non_duplicated_res = res.remove_every_n(self.col_size(), offset).unwrap(); + })?.into()} else { + v.into() + }; + let total_used_len = res.len(); + res.remove_every_n(self.col_size(), offset).unwrap(); - non_duplicated_res.reshape(dims); + res.reshape(dims).unwrap(); if matches!(check_mode, CheckMode::SAFE) { // during key generation this will be 0 so we use this as a flag to check // TODO: this isn't very safe and would be better to get the phase directly - let is_assigned = !Into::>::into(ValTensor::from(non_duplicated_res.clone()).get_inner().unwrap()) + let is_assigned = !Into::>::into(res.clone().get_inner().unwrap()) .iter() .all(|&x| x == 0); if is_assigned { assert_eq!( Into::>::into(values.get_inner().unwrap()), - Into::>::into(non_duplicated_res.clone()) + Into::>::into(res.get_inner().unwrap()) )}; } - Ok((non_duplicated_res, res.len())) + Ok((res, total_used_len)) } } } diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index 3d68bed3c..2afba4780 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -441,9 +441,9 @@ fn tutorial() { "-K=17", "mock", "-D", - format!("./examples/onnx/tutorial/input.json").as_str(), + "./examples/onnx/tutorial/input.json".to_string().as_str(), "-M", - format!("./examples/onnx/tutorial/network.onnx").as_str(), + "./examples/onnx/tutorial/network.onnx".to_string().as_str(), ]) .status() .expect("failed to execute process"); diff --git a/tests/python/binding_tests.py b/tests/python/binding_tests.py index 9aeb947e1..b238d11e0 100644 --- a/tests/python/binding_tests.py +++ b/tests/python/binding_tests.py @@ -11,16 +11,17 @@ ) examples_path = os.path.abspath( - os.path.join( - folder_path, - '..', - '..', - 'examples', - ) + os.path.join( + folder_path, + '..', + '..', + 'examples', ) +) params_path = os.path.join(folder_path, 'kzg_test.params') + def test_table_1l_average(): """ Test for table() with 1l_average.onnx @@ -33,16 +34,15 @@ def test_table_1l_average(): ) expected_table = \ - """+-------+-----------------------------------------------------------------+----------+-----------+-------------+-----------------+--------+-------------+-----------+-----+ -| usize | opkind | in_scale | out_scale | const_value | raw_const_value | inputs | in_dims | out_dims | idx | -+-------+-----------------------------------------------------------------+----------+-----------+-------------+-----------------+--------+-------------+-----------+-----+ -| 0 | input | 7 | 7 | | | | [[1, 5, 5]] | [1, 5, 5] | 0 | -+-------+-----------------------------------------------------------------+----------+-----------+-------------+-----------------+--------+-------------+-----------+-----+ -| 1 | padding: (0, 0) | 7 | 7 | | | [0] | [[1, 5, 5]] | [1, 5, 5] | 1 | -+-------+-----------------------------------------------------------------+----------+-----------+-------------+-----------------+--------+-------------+-----------+-----+ -| 2 | avg pl w/ padding: (0, 0), stride: (1, 1), kernel shape: (3, 3) | 7 | 7 | | | [1] | [[1, 5, 5]] | [1, 3, 3] | 2 | -+-------+-----------------------------------------------------------------+----------+-----------+-------------+-----------------+--------+-------------+-----------+-----+""" - + """+-------+---------+-----------+--------+-----------+-----+ +| usize | opkind | out_scale | inputs | out_dims | idx | ++-------+---------+-----------+--------+-----------+-----+ +| 0 | Input | 7 | | [1, 5, 5] | 0 | ++-------+---------+-----------+--------+-----------+-----+ +| 1 | PAD | 7 | [0] | [1, 5, 5] | 1 | ++-------+---------+-----------+--------+-----------+-----+ +| 2 | SUMPOOL | 7 | [1] | [1, 3, 3] | 2 | ++-------+---------+-----------+--------+-----------+-----+""" assert ezkl_lib.table(path) == expected_table @@ -54,6 +54,7 @@ def test_gen_srs(): ezkl_lib.gen_srs(params_path, 17) assert os.path.isfile(params_path) + def test_forward(): """ Test for vanilla forward pass @@ -81,6 +82,7 @@ def test_forward(): with open(output_path, "r") as f: data = json.load(f) - assert data == {"input_data":[[0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1]],"input_shapes":[[1,5,5]],"output_data":[[0.9140625,0.9140625,0.9140625,0.9140625,0.9140625,0.9140625,0.9140625,0.9140625,0.9140625]]} + assert data == {"input_data": [[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]], "input_shapes": [ + [1, 5, 5]], "output_data": [[0.9140625, 0.9140625, 0.9140625, 0.9140625, 0.9140625, 0.9140625, 0.9140625, 0.9140625, 0.9140625]]} - os.remove(output_path) \ No newline at end of file + os.remove(output_path)