Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: num groups for conv operations should be specified at load time #828

Merged
merged 4 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions benches/accum_conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ impl Circuit<Fr> for MyCircuit {
Box::new(PolyOp::Conv {
padding: vec![(0, 0)],
stride: vec![1; 2],
group: 1,
}),
)
.unwrap();
Expand Down
1 change: 1 addition & 0 deletions examples/conv2d_mnist/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ where
let op = PolyOp::Conv {
padding: vec![(PADDING, PADDING); 2],
stride: vec![STRIDE; 2],
group: 1,
};
let x = config
.layer_config
Expand Down
51 changes: 29 additions & 22 deletions src/circuit/ops/layouts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3023,7 +3023,7 @@ pub fn sumpool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI
.map(|coord| {
let (b, i) = (coord[0], coord[1]);
let input = values[0].get_slice(&[b..b + 1, i..i + 1])?;
let output = conv(config, region, &[input, kernel.clone()], padding, stride)?;
let output = conv(config, region, &[input, kernel.clone()], padding, stride, 1)?;
res.push(output);
Ok(())
})
Expand Down Expand Up @@ -3159,7 +3159,7 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + Into
/// &[1, 1, 2, 2],
/// ).unwrap());
///
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, c], &vec![(1, 1); 2], &vec![1;2], &vec![2;2]).unwrap();
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, c], &vec![(1, 1); 2], &vec![1;2], &vec![2;2], 1).unwrap();
/// let expected = Tensor::<i64>::new(Some(&[0, 32, 0, 32, 0, 6, 0, 12, 0, 4, 0, 8, 0, 4, 0, 8, 0, 0, 0, 3, 0, 0, 0, 2]), &[1, 2, 3, 4]).unwrap();
/// assert_eq!(result.get_int_evals().unwrap(), expected);
///
Expand All @@ -3171,7 +3171,7 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + Into
/// Some(&[3, 1, 1, 5]),
/// &[1, 1, 2, 2],
/// ).unwrap());
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![1;2]).unwrap();
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![1;2], 1).unwrap();
/// let expected = Tensor::<i64>::new(Some(&[6, 14, 4, 2, 17, 21, 0, 1, 5]), &[1, 1, 3, 3]).unwrap();
/// assert_eq!(result.get_int_evals().unwrap(), expected);
///
Expand All @@ -3184,7 +3184,7 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + Into
/// Some(&[3, 1, 1, 5]),
/// &[1, 1, 2, 2],
/// ).unwrap());
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(1, 1); 2], &vec![0;2], &vec![1;2]).unwrap();
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(1, 1); 2], &vec![0;2], &vec![1;2], 1).unwrap();
/// let expected = Tensor::<i64>::new(Some(&[17]), &[1, 1, 1, 1]).unwrap();
/// assert_eq!(result.get_int_evals().unwrap(), expected);
///
Expand All @@ -3197,7 +3197,7 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + Into
/// Some(&[3, 1, 1, 5]),
/// &[1, 1, 2, 2],
/// ).unwrap());
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(1, 1); 2], &vec![0;2], &vec![2; 2]).unwrap();
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(1, 1); 2], &vec![0;2], &vec![2; 2], 1).unwrap();
/// let expected = Tensor::<i64>::new(Some(&[10, 4, 0, 3]), &[1, 1, 2, 2]).unwrap();
/// assert_eq!(result.get_int_evals().unwrap(), expected);
///
Expand All @@ -3209,7 +3209,7 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + Into
/// Some(&[3, 1, 1, 5]),
/// &[1, 1, 2, 2],
/// ).unwrap());
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![2; 2]).unwrap();
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![2; 2], 1).unwrap();
/// let expected = Tensor::<i64>::new(Some(&[6, 2, 12, 4, 2, 10, 4, 20, 0, 0, 3, 1, 0, 0, 1, 5]), &[1, 1, 4, 4]).unwrap();
/// assert_eq!(result.get_int_evals().unwrap(), expected);
///
Expand All @@ -3221,7 +3221,7 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + Into
/// Some(&[3, 2]),
/// &[1, 1, 2, 1],
/// ).unwrap());
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(1, 1); 2], &vec![0;2], &vec![2; 2]).unwrap();
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(1, 1); 2], &vec![0;2], &vec![2; 2], 1).unwrap();
/// let expected = Tensor::<i64>::new(Some(&[0, 0]), &[1, 1, 2, 1]).unwrap();
/// assert_eq!(result.get_int_evals().unwrap(), expected);
///
Expand All @@ -3233,7 +3233,7 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + Into
/// Some(&[3, 2]),
/// &[1, 1, 2, 1],
/// ).unwrap());
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![2; 2]).unwrap();
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![2; 2], 1).unwrap();
/// let expected = Tensor::<i64>::new(Some(&[6, 0, 12, 4, 0, 8, 0, 0, 3, 0, 0, 2]), &[1, 1, 4, 3]).unwrap();
/// assert_eq!(result.get_int_evals().unwrap(), expected);
///
Expand All @@ -3244,7 +3244,7 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + Into
/// &[1, 1, 2, 2],
/// ).unwrap());
///
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, c], &vec![(1, 1); 2], &vec![0;2], &vec![2;2]).unwrap();
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, c], &vec![(1, 1); 2], &vec![0;2], &vec![2;2], 1).unwrap();
/// let expected = Tensor::<i64>::new(Some(&[0, 32, 0, 0, 6, 0, 0, 4, 0, 0, 0, 0]), &[1, 2, 2, 3]).unwrap();
/// assert_eq!(result.get_int_evals().unwrap(), expected);
/// let x = ValTensor::from_i64_tensor(Tensor::<i64>::new(
Expand All @@ -3259,7 +3259,7 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + Into
/// Some(&[1]),
/// &[1],
/// ).unwrap());
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(1, 1); 2], &vec![0;2], &vec![1;2]).unwrap();
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(1, 1); 2], &vec![0;2], &vec![1;2], 1).unwrap();
/// let expected = Tensor::<i64>::new(Some(&[55, 58, 66, 69]), &[1, 1, 2, 2]).unwrap();
/// assert_eq!(result.get_int_evals().unwrap(), expected);
///
Expand All @@ -3279,6 +3279,7 @@ pub fn deconv<
padding: &[(usize, usize)],
output_padding: &[usize],
stride: &[usize],
num_groups: usize,
) -> Result<ValTensor<F>, CircuitError> {
let has_bias = inputs.len() == 3;
let (image, kernel) = (&inputs[0], &inputs[1]);
Expand Down Expand Up @@ -3364,6 +3365,7 @@ pub fn deconv<
&conv_input,
&vec![(0, 0); conv_dim],
&vec![1; conv_dim],
num_groups,
)?;

Ok(output)
Expand Down Expand Up @@ -3395,7 +3397,7 @@ pub fn deconv<
/// Some(&[0]),
/// &[1],
/// ).unwrap());
/// let result = conv::<Fp>(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(0, 0); 2], &vec![1;2]).unwrap();
/// let result = conv::<Fp>(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(0, 0); 2], &vec![1;2], 1).unwrap();
/// let expected = Tensor::<i64>::new(Some(&[31, 16, 8, 26]), &[1, 1, 2, 2]).unwrap();
/// assert_eq!(result.get_int_evals().unwrap(), expected);
///
Expand All @@ -3413,7 +3415,7 @@ pub fn deconv<
/// &[2],
/// ).unwrap());
///
/// let result = conv::<Fp>(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(0, 0); 2], &vec![1;2]).unwrap();
/// let result = conv::<Fp>(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(0, 0); 2], &vec![1;2], 2).unwrap();
/// let expected = Tensor::<i64>::new(Some(&[32, 17, 9, 27, 34, 20, 13, 26]), &[1, 2, 2, 2]).unwrap();
/// assert_eq!(result.get_int_evals().unwrap(), expected);
///
Expand All @@ -3431,7 +3433,7 @@ pub fn deconv<
/// &[4],
/// ).unwrap());
///
/// let result =conv(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(0, 0); 2], &vec![1;2]).unwrap();
/// let result =conv(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(0, 0); 2], &vec![1;2], 1).unwrap();
/// let expected = Tensor::<i64>::new(Some(&[65, 36, 21, 52, 73, 48, 37, 48, 65, 36, 21, 52, 73, 48, 37, 48]), &[1, 4, 2, 2]).unwrap();
/// assert_eq!(result.get_int_evals().unwrap(), expected);
/// ```
Expand All @@ -3450,6 +3452,7 @@ pub fn conv<
values: &[ValTensor<F>],
padding: &[(usize, usize)],
stride: &[usize],
num_groups: usize,
) -> Result<ValTensor<F>, CircuitError> {
let has_bias = values.len() == 3;
let (mut image, mut kernel) = (values[0].clone(), values[1].clone());
Expand Down Expand Up @@ -3480,6 +3483,11 @@ pub fn conv<
region.increment(*assigned_len.iter().max().unwrap());
}

// if image is 3d add a dummy batch dimension
if image.dims().len() == kernel.dims().len() - 1 {
image.reshape(&[1, image.dims()[0], image.dims()[1], image.dims()[2]])?;
}

let image_dims = image.dims();
let kernel_dims = kernel.dims();

Expand Down Expand Up @@ -3513,25 +3521,24 @@ pub fn conv<

log::debug!("slides: {:?}", slides);

let num_groups = input_channels / kernel_dims[1];
let input_channels_per_group = input_channels / num_groups;
let output_channels_per_group = output_channels / num_groups;

if output_channels_per_group == 0 || input_channels_per_group == 0 {
return Err(TensorError::DimMismatch(format!(
"Given groups={}, expected input channels and output channels to be divisible by groups, but got input_channels={}, output_channels={}",
num_groups, input_channels, output_channels
))
.into());
}

log::debug!(
"num_groups: {}, input_channels_per_group: {}, output_channels_per_group: {}",
num_groups,
input_channels_per_group,
output_channels_per_group
);

if output_channels_per_group == 0 {
return Err(TensorError::DimMismatch(format!(
"Given groups={}, expected kernel to be at least {} at dimension 0 but got {} instead",
num_groups, num_groups, output_channels_per_group
))
.into());
}

let num_outputs =
batch_size * num_groups * output_channels_per_group * slides.iter().product::<usize>();

Expand Down
35 changes: 28 additions & 7 deletions src/circuit/ops/poly.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ pub enum PolyOp {
Conv {
padding: Vec<(usize, usize)>,
stride: Vec<usize>,
group: usize,
},
Downsample {
axis: usize,
Expand All @@ -43,6 +44,7 @@ pub enum PolyOp {
padding: Vec<(usize, usize)>,
output_padding: Vec<usize>,
stride: Vec<usize>,
group: usize,
},
Add,
Sub,
Expand Down Expand Up @@ -148,17 +150,25 @@ impl<
PolyOp::Sum { axes } => format!("SUM (axes={:?})", axes),
PolyOp::Prod { .. } => "PROD".into(),
PolyOp::Pow(_) => "POW".into(),
PolyOp::Conv { stride, padding } => {
format!("CONV (stride={:?}, padding={:?})", stride, padding)
PolyOp::Conv {
stride,
padding,
group,
} => {
format!(
"CONV (stride={:?}, padding={:?}, group={})",
stride, padding, group
)
}
PolyOp::DeConv {
stride,
padding,
output_padding,
group,
} => {
format!(
"DECONV (stride={:?}, padding={:?}, output_padding={:?})",
stride, padding, output_padding
"DECONV (stride={:?}, padding={:?}, output_padding={:?}, group={})",
stride, padding, output_padding, group
)
}
PolyOp::Concat { axis } => format!("CONCAT (axis={})", axis),
Expand Down Expand Up @@ -212,9 +222,18 @@ impl<
PolyOp::Prod { axes, .. } => {
layouts::prod_axes(config, region, values[..].try_into()?, axes)?
}
PolyOp::Conv { padding, stride } => {
layouts::conv(config, region, values[..].try_into()?, padding, stride)?
}
PolyOp::Conv {
padding,
stride,
group,
} => layouts::conv(
config,
region,
values[..].try_into()?,
padding,
stride,
*group,
)?,
PolyOp::GatherElements { dim, constant_idx } => {
if let Some(idx) = constant_idx {
tensor::ops::gather_elements(values[0].get_inner_tensor()?, idx, *dim)?.into()
Expand Down Expand Up @@ -261,13 +280,15 @@ impl<
padding,
output_padding,
stride,
group,
} => layouts::deconv(
config,
region,
values[..].try_into()?,
padding,
output_padding,
stride,
*group,
)?,
PolyOp::Add => layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Add)?,
PolyOp::Sub => layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Sub)?,
Expand Down
3 changes: 3 additions & 0 deletions src/circuit/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1050,6 +1050,7 @@ mod conv {
Box::new(PolyOp::Conv {
padding: vec![(1, 1); 2],
stride: vec![2; 2],
group: 1,
}),
)
.map_err(|_| Error::Synthesis)
Expand Down Expand Up @@ -1200,6 +1201,7 @@ mod conv_col_ultra_overflow {
Box::new(PolyOp::Conv {
padding: vec![(1, 1); 2],
stride: vec![2; 2],
group: 1,
}),
)
.map_err(|_| Error::Synthesis)
Expand Down Expand Up @@ -1345,6 +1347,7 @@ mod conv_relu_col_ultra_overflow {
Box::new(PolyOp::Conv {
padding: vec![(1, 1); 2],
stride: vec![2; 2],
group: 1,
}),
)
.map_err(|_| Error::Synthesis);
Expand Down
14 changes: 9 additions & 5 deletions src/graph/utilities.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,10 +283,7 @@ pub fn new_op_from_onnx(
.flat_map(|x| x.out_scales())
.collect::<Vec<_>>();

let input_dims = inputs
.iter()
.flat_map(|x| x.out_dims())
.collect::<Vec<_>>();
let input_dims = inputs.iter().flat_map(|x| x.out_dims()).collect::<Vec<_>>();

let mut replace_const = |scale: crate::Scale,
index: usize,
Expand Down Expand Up @@ -1192,7 +1189,13 @@ pub fn new_op_from_onnx(
}
}

SupportedOp::Linear(PolyOp::Conv { padding, stride })
let group = conv_node.group;

SupportedOp::Linear(PolyOp::Conv {
padding,
stride,
group,
})
}
"Not" => SupportedOp::Linear(PolyOp::Not),
"And" => SupportedOp::Linear(PolyOp::And),
Expand Down Expand Up @@ -1247,6 +1250,7 @@ pub fn new_op_from_onnx(
padding,
output_padding: deconv_node.adjustments.to_vec(),
stride,
group: deconv_node.group,
})
}
"Downsample" => {
Expand Down
Loading