Skip to content

Commit

Permalink
feat: parallel witness gen for matmul and conv (#205)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-camuto authored Apr 25, 2023
1 parent 26ef31b commit d197bdb
Show file tree
Hide file tree
Showing 21 changed files with 348 additions and 419 deletions.
2 changes: 1 addition & 1 deletion benches/accum_conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ impl Circuit<Fr> for MyCircuit {
|mut region| {
config
.layout(
Some(&mut region),
&mut Some(&mut region),
&[self.image.clone()],
&mut 0,
Box::new(PolyOp::Conv {
Expand Down
2 changes: 1 addition & 1 deletion benches/accum_dot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ impl Circuit<Fr> for MyCircuit {
|mut region| {
config
.layout(
Some(&mut region),
&mut Some(&mut region),
&self.inputs,
&mut 0,
Box::new(PolyOp::Dot),
Expand Down
2 changes: 1 addition & 1 deletion benches/accum_matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ impl Circuit<Fr> for MyCircuit {
|mut region| {
config
.layout(
Some(&mut region),
&mut Some(&mut region),
&self.inputs,
&mut 0,
Box::new(PolyOp::Matmul { a: None }),
Expand Down
9 changes: 7 additions & 2 deletions benches/accum_matmul_relu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,17 @@ impl Circuit<Fr> for MyCircuit {
let mut offset = 0;
let output = config
.base_config
.layout(Some(&mut region), &self.inputs, &mut offset, Box::new(op))
.layout(
&mut Some(&mut region),
&self.inputs,
&mut offset,
Box::new(op),
)
.unwrap();
let _output = config
.base_config
.layout(
Some(&mut region),
&mut Some(&mut region),
&[output.unwrap()],
&mut offset,
Box::new(LookupOp::ReLU { scale: 1 }),
Expand Down
2 changes: 1 addition & 1 deletion benches/accum_pack.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ impl Circuit<Fr> for MyCircuit {
|mut region| {
config
.layout(
Some(&mut region),
&mut Some(&mut region),
&self.inputs,
&mut 0,
Box::new(PolyOp::Pack(2, 1)),
Expand Down
2 changes: 1 addition & 1 deletion benches/accum_sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ impl Circuit<Fr> for MyCircuit {
|mut region| {
config
.layout(
Some(&mut region),
&mut Some(&mut region),
&self.inputs,
&mut 0,
Box::new(PolyOp::Sum { axes: vec![0] }),
Expand Down
2 changes: 1 addition & 1 deletion benches/accum_sumpool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ impl Circuit<Fr> for MyCircuit {
|mut region| {
config
.layout(
Some(&mut region),
&mut Some(&mut region),
&[self.image.clone()],
&mut 0,
Box::new(PolyOp::SumPool {
Expand Down
2 changes: 1 addition & 1 deletion benches/pairwise_add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ impl Circuit<Fr> for MyCircuit {
|mut region| {
config
.layout(
Some(&mut region),
&mut Some(&mut region),
&self.inputs,
&mut 0,
Box::new(PolyOp::Add { a: None }),
Expand Down
2 changes: 1 addition & 1 deletion benches/pairwise_pow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ impl Circuit<Fr> for MyCircuit {
|mut region| {
config
.layout(
Some(&mut region),
&mut Some(&mut region),
&self.inputs,
&mut 0,
Box::new(PolyOp::Pow(4)),
Expand Down
2 changes: 1 addition & 1 deletion benches/relu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ impl Circuit<Fr> for NLCircuit {
|mut region| {
config
.layout(
Some(&mut region),
&mut Some(&mut region),
&[self.input.clone()],
&mut 0,
Box::new(LookupOp::ReLU { scale: 128 }),
Expand Down
14 changes: 5 additions & 9 deletions examples/conv2d_mnist/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ where
|| "mlp_4d",
|mut region| {
let mut offset = 0;
let region = &mut Some(&mut region);
let op = PolyOp::Conv {
kernel: self.l0_params[0].clone(),
bias: Some(self.l0_params[1].clone()),
Expand All @@ -181,18 +182,13 @@ where
};
let x = config
.layer_config
.layout(
Some(&mut region),
&[self.input.clone()],
&mut offset,
Box::new(op),
)
.layout(region, &[self.input.clone()], &mut offset, Box::new(op))
.unwrap();

let mut x = config
.layer_config
.layout(
Some(&mut region),
region,
&[x.unwrap()],
&mut offset,
Box::new(LookupOp::ReLU { scale: 32 }),
Expand All @@ -204,7 +200,7 @@ where
let x = config
.layer_config
.layout(
Some(&mut region),
region,
&[x],
&mut offset,
Box::new(PolyOp::Matmul {
Expand All @@ -217,7 +213,7 @@ where
let x: ValTensor<F> = config
.layer_config
.layout(
Some(&mut region),
region,
&[x],
&mut offset,
Box::new(PolyOp::Add {
Expand Down
15 changes: 8 additions & 7 deletions examples/mlp_4d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,11 @@ impl<F: FieldExt + TensorType, const LEN: usize, const BITS: usize> Circuit<F>
|| "mlp_4d",
|mut region| {
let mut offset = 0;
let region = &mut Some(&mut region);
let x = config
.layer_config
.layout(
Some(&mut region),
region,
&[self.input.clone()],
&mut offset,
Box::new(PolyOp::Matmul {
Expand All @@ -110,7 +111,7 @@ impl<F: FieldExt + TensorType, const LEN: usize, const BITS: usize> Circuit<F>
let x = config
.layer_config
.layout(
Some(&mut region),
region,
&[x],
&mut offset,
Box::new(PolyOp::Add {
Expand All @@ -125,7 +126,7 @@ impl<F: FieldExt + TensorType, const LEN: usize, const BITS: usize> Circuit<F>
let mut x = config
.layer_config
.layout(
Some(&mut region),
region,
&[x],
&mut offset,
Box::new(LookupOp::ReLU { scale: 1 }),
Expand All @@ -138,7 +139,7 @@ impl<F: FieldExt + TensorType, const LEN: usize, const BITS: usize> Circuit<F>
let x = config
.layer_config
.layout(
Some(&mut region),
region,
&[x],
&mut offset,
Box::new(PolyOp::Matmul {
Expand All @@ -151,7 +152,7 @@ impl<F: FieldExt + TensorType, const LEN: usize, const BITS: usize> Circuit<F>
let x = config
.layer_config
.layout(
Some(&mut region),
region,
&[x],
&mut offset,
Box::new(PolyOp::Add {
Expand All @@ -165,7 +166,7 @@ impl<F: FieldExt + TensorType, const LEN: usize, const BITS: usize> Circuit<F>
let x = config
.layer_config
.layout(
Some(&mut region),
region,
&[x],
&mut offset,
Box::new(LookupOp::ReLU { scale: 1 }),
Expand All @@ -175,7 +176,7 @@ impl<F: FieldExt + TensorType, const LEN: usize, const BITS: usize> Circuit<F>
Ok(config
.layer_config
.layout(
Some(&mut region),
region,
&[x.unwrap()],
&mut offset,
Box::new(LookupOp::Div {
Expand Down
9 changes: 2 additions & 7 deletions src/circuit/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -305,20 +305,15 @@ impl<F: FieldExt + TensorType> BaseConfig<F> {
/// * `op` - The operation being represented.
pub fn layout(
&mut self,
mut region: Option<&mut Region<F>>,
region: &mut Option<&mut Region<F>>,
values: &[ValTensor<F>],
offset: &mut usize,
op: Box<dyn Op<F>>,
) -> Result<Option<ValTensor<F>>, Box<dyn Error>> {
let mut cp_values = vec![];
for v in values.iter() {
if let ValTensor::Instance { .. } = v {
cp_values.push(layouts::identity(
self,
region.as_deref_mut(),
&[v.clone()],
offset,
)?);
cp_values.push(layouts::identity(self, region, &[v.clone()], offset)?);
} else {
cp_values.push(v.clone());
}
Expand Down
2 changes: 1 addition & 1 deletion src/circuit/ops/hybrid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ impl<F: FieldExt + TensorType> Op<F> for HybridOp {
fn layout(
&self,
config: &mut crate::circuit::BaseConfig<F>,
region: Option<&mut Region<F>>,
region: &mut Option<&mut Region<F>>,
values: &[ValTensor<F>],
offset: &mut usize,
) -> Result<Option<ValTensor<F>>, Box<dyn std::error::Error>> {
Expand Down
Loading

0 comments on commit d197bdb

Please sign in to comment.