Skip to content

Commit

Permalink
feat: conv for groups > 1 (#204)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-camuto authored Apr 25, 2023
1 parent 7d8fae2 commit 26ef31b
Show file tree
Hide file tree
Showing 4 changed files with 361 additions and 289 deletions.
203 changes: 138 additions & 65 deletions src/circuit/ops/layouts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -883,96 +883,169 @@ pub fn conv<F: FieldExt + TensorType + std::marker::Send + std::marker::Sync>(
let mut padded_image = image.clone();
padded_image.pad(padding)?;

// calculate value of output
let mut output: Tensor<ValType<F>> =
Tensor::new(None, &[output_channels, vert_slides, horz_slides]).unwrap();
let num_groups = image_dims[0] / kernel_dims[1];
let input_channels_per_group = input_channels / num_groups;
let output_channels_per_group = output_channels / num_groups;

if output_channels_per_group == 0 {
return Err(Box::new(TensorError::DimMismatch(format!(
"Given groups={}, expected kernel to be at least {} at dimension 0 but got {} instead",
num_groups, num_groups, output_channels_per_group
))));
}

let cartesian_coord = vec![(0..output_channels), (0..vert_slides), (0..horz_slides)]
.iter()
.cloned()
.multi_cartesian_product()
.collect::<Vec<_>>();
let mut outputs_per_group = vec![Tensor::new(None, &[0])?; num_groups];

match region {
// non-parallel version
Some(_) => {
for coord in cartesian_coord.iter() {
let (i, j, k) = (coord[0], coord[1], coord[2]);
let rs = j * stride.0;
let cs = k * stride.1;

let local_kernel = kernel.get_slice(&[i..i + 1])?;
let local_image = padded_image.get_slice(&[
0..input_channels,
rs..(rs + kernel_height),
cs..(cs + kernel_width),
])?;

let res = dot(
config,
region.as_deref_mut(),
&[local_kernel, local_image],
offset,
)?;
outputs_per_group
.iter_mut()
.enumerate()
.for_each(|(group, o)| {
let start_channel = group * input_channels_per_group;
let end_channel = start_channel + input_channels_per_group;
let padded_image_per_group = &padded_image
.get_slice(&[start_channel..end_channel])
.unwrap();

let kernel_per_group = &kernel
.get_slice(&[group * output_channels_per_group
..(group + 1) * output_channels_per_group])
.unwrap();
let mut output_per_group =
Tensor::new(None, &[output_channels_per_group, vert_slides, horz_slides])
.unwrap();

let cartesian_coord_per_group = vec![
(0..output_channels_per_group),
(0..vert_slides),
(0..horz_slides),
]
.iter()
.cloned()
.multi_cartesian_product()
.collect::<Vec<_>>();

output.set(&[i, j, k], res.get_inner_tensor().unwrap()[0].clone());
}
output_per_group
.iter_mut()
.enumerate()
.for_each(|(flat_index, o)| {
let coord = &cartesian_coord_per_group[flat_index];
let (i, j, k) = (coord[0], coord[1], coord[2]);
let rs = j * stride.0;
let cs = k * stride.1;

let res = dot(
config,
region.as_deref_mut(),
&[
kernel_per_group.get_slice(&[i..i + 1]).unwrap(),
padded_image_per_group
.get_slice(&[
0..input_channels_per_group,
rs..(rs + kernel_height),
cs..(cs + kernel_width),
])
.unwrap(),
],
offset,
)
.unwrap();

*o = res.get_inner_tensor().unwrap()[0].clone();
});

*o = output_per_group;
});
}
None => {
let offset_thread = Arc::new(Mutex::new(offset.clone()));
let config_thread = Arc::new(Mutex::new(config.clone()));
output.par_iter_mut().enumerate().for_each(|(i, x)| {
let (i, j, k) = (
cartesian_coord[i][0],
cartesian_coord[i][1],
cartesian_coord[i][2],
);
let rs = j * stride.0;
let cs = k * stride.1;

let local_kernel = kernel.get_slice(&[i..i + 1]).unwrap();
let local_image = padded_image
.get_slice(&[
0..input_channels,
rs..(rs + kernel_height),
cs..(cs + kernel_width),
])
.unwrap();

let mut offset_lock = offset_thread.lock().unwrap();
let mut config_lock = config_thread.lock().unwrap();
let res = dot(
&mut config_lock,
None,
&[local_kernel, local_image],
&mut offset_lock,
)
.unwrap();

*x = res.get_inner_tensor().unwrap()[0].clone();
});
outputs_per_group
.par_iter_mut()
.enumerate()
.for_each(|(group, o)| {
let start_channel = group * input_channels_per_group;
let end_channel = start_channel + input_channels_per_group;
let padded_image_per_group = &padded_image
.get_slice(&[start_channel..end_channel])
.unwrap();

let kernel_per_group = &kernel
.get_slice(&[group * output_channels_per_group
..(group + 1) * output_channels_per_group])
.unwrap();
let mut output_per_group =
Tensor::new(None, &[output_channels_per_group, vert_slides, horz_slides])
.unwrap();

let cartesian_coord_per_group = vec![
(0..output_channels_per_group),
(0..vert_slides),
(0..horz_slides),
]
.iter()
.cloned()
.multi_cartesian_product()
.collect::<Vec<_>>();

output_per_group
.par_iter_mut()
.enumerate()
.for_each(|(flat_index, o)| {
let coord = &cartesian_coord_per_group[flat_index];
let (i, j, k) = (coord[0], coord[1], coord[2]);
let rs = j * stride.0;
let cs = k * stride.1;

let mut offset_lock = offset_thread.lock().unwrap();
let mut config_lock = config_thread.lock().unwrap();
let res = dot(
&mut config_lock,
None,
&[
kernel_per_group.get_slice(&[i..i + 1]).unwrap(),
padded_image_per_group
.get_slice(&[
0..input_channels_per_group,
rs..(rs + kernel_height),
cs..(cs + kernel_width),
])
.unwrap(),
],
&mut offset_lock,
)
.unwrap();

*o = res.get_inner_tensor().unwrap()[0].clone();
});

*o = output_per_group;
});
*offset += offset_thread.lock().unwrap().clone();
}
}

let mut res: ValTensor<F> = output.into();
let mut output: ValTensor<F> = Tensor::new(Some(&outputs_per_group), &[num_groups])?
.combine()?
.into();

output.reshape(&[output_channels, vert_slides, horz_slides])?;

if has_bias {
let tiled_bias = values[2].clone();
if (tiled_bias.dims().len() != 1) || (tiled_bias.dims()[0] != kernel.dims()[0]) {
return Err(Box::new(TensorError::DimMismatch("conv bias".to_string())));
};

res = pairwise(config, region, &[res, tiled_bias], offset, BaseOp::Add)?
output = pairwise(config, region, &[output, tiled_bias], offset, BaseOp::Add)?
};

res.reshape(&[output_channels, vert_slides, horz_slides])?;

if matches!(&config.check_mode, CheckMode::SAFE) {
// during key generation this will be 0 so we use this as a flag to check
// TODO: this isn't very safe and would be better to get the phase directly
let is_assigned = !Into::<Tensor<i32>>::into(res.clone().get_inner()?)
let is_assigned = !Into::<Tensor<i32>>::into(output.clone().get_inner()?)
.iter()
.all(|&x| x == 0);
if is_assigned {
Expand All @@ -990,13 +1063,13 @@ pub fn conv<F: FieldExt + TensorType + std::marker::Send + std::marker::Sync>(
})?;

assert_eq!(
Into::<Tensor<i32>>::into(res.get_inner()?),
Into::<Tensor<i32>>::into(output.get_inner()?),
Into::<Tensor<i32>>::into(safe_conv),
)
}
}

Ok(res)
Ok(output)
}

/// Power accumulated layout
Expand Down
8 changes: 8 additions & 0 deletions src/graph/utilities.rs
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,14 @@ pub fn new_op_from_onnx<F: FieldExt + TensorType>(
}
};

if let Some(dilations) = &conv_node.pool_spec.dilations {
if dilations.iter().any(|x| *x != 1) {
return Err(Box::new(GraphError::MisformedParams(
"non unit dilations not supported".to_string(),
)));
}
}

if (conv_node.pool_spec.data_format != DataFormat::NCHW)
|| (conv_node.kernel_fmt != KernelFormat::OIHW)
{
Expand Down
Loading

0 comments on commit 26ef31b

Please sign in to comment.