From 26ef31b2b0908ad39d1fd4795bed3f6516ab0fa8 Mon Sep 17 00:00:00 2001 From: dante <45801863+alexander-camuto@users.noreply.github.com> Date: Tue, 25 Apr 2023 07:50:19 +0100 Subject: [PATCH] feat: conv for groups > 1 (#204) --- src/circuit/ops/layouts.rs | 203 ++++++++++++++++++--------- src/graph/utilities.rs | 8 ++ src/tensor/mod.rs | 165 +++++++++++++--------- src/tensor/ops.rs | 274 ++++++++++++++++--------------------- 4 files changed, 361 insertions(+), 289 deletions(-) diff --git a/src/circuit/ops/layouts.rs b/src/circuit/ops/layouts.rs index 6ad50dbd4..fa803d2eb 100644 --- a/src/circuit/ops/layouts.rs +++ b/src/circuit/ops/layouts.rs @@ -883,80 +883,155 @@ pub fn conv( let mut padded_image = image.clone(); padded_image.pad(padding)?; - // calculate value of output - let mut output: Tensor> = - Tensor::new(None, &[output_channels, vert_slides, horz_slides]).unwrap(); + let num_groups = image_dims[0] / kernel_dims[1]; + let input_channels_per_group = input_channels / num_groups; + let output_channels_per_group = output_channels / num_groups; + + if output_channels_per_group == 0 { + return Err(Box::new(TensorError::DimMismatch(format!( + "Given groups={}, expected kernel to be at least {} at dimension 0 but got {} instead", + num_groups, num_groups, output_channels_per_group + )))); + } - let cartesian_coord = vec![(0..output_channels), (0..vert_slides), (0..horz_slides)] - .iter() - .cloned() - .multi_cartesian_product() - .collect::>(); + let mut outputs_per_group = vec![Tensor::new(None, &[0])?; num_groups]; match region { // non-parallel version Some(_) => { - for coord in cartesian_coord.iter() { - let (i, j, k) = (coord[0], coord[1], coord[2]); - let rs = j * stride.0; - let cs = k * stride.1; - - let local_kernel = kernel.get_slice(&[i..i + 1])?; - let local_image = padded_image.get_slice(&[ - 0..input_channels, - rs..(rs + kernel_height), - cs..(cs + kernel_width), - ])?; - - let res = dot( - config, - region.as_deref_mut(), - &[local_kernel, local_image], - offset, - )?; + outputs_per_group + .iter_mut() + .enumerate() + .for_each(|(group, o)| { + let start_channel = group * input_channels_per_group; + let end_channel = start_channel + input_channels_per_group; + let padded_image_per_group = &padded_image + .get_slice(&[start_channel..end_channel]) + .unwrap(); + + let kernel_per_group = &kernel + .get_slice(&[group * output_channels_per_group + ..(group + 1) * output_channels_per_group]) + .unwrap(); + let mut output_per_group = + Tensor::new(None, &[output_channels_per_group, vert_slides, horz_slides]) + .unwrap(); + + let cartesian_coord_per_group = vec![ + (0..output_channels_per_group), + (0..vert_slides), + (0..horz_slides), + ] + .iter() + .cloned() + .multi_cartesian_product() + .collect::>(); - output.set(&[i, j, k], res.get_inner_tensor().unwrap()[0].clone()); - } + output_per_group + .iter_mut() + .enumerate() + .for_each(|(flat_index, o)| { + let coord = &cartesian_coord_per_group[flat_index]; + let (i, j, k) = (coord[0], coord[1], coord[2]); + let rs = j * stride.0; + let cs = k * stride.1; + + let res = dot( + config, + region.as_deref_mut(), + &[ + kernel_per_group.get_slice(&[i..i + 1]).unwrap(), + padded_image_per_group + .get_slice(&[ + 0..input_channels_per_group, + rs..(rs + kernel_height), + cs..(cs + kernel_width), + ]) + .unwrap(), + ], + offset, + ) + .unwrap(); + + *o = res.get_inner_tensor().unwrap()[0].clone(); + }); + + *o = output_per_group; + }); } None => { let offset_thread = Arc::new(Mutex::new(offset.clone())); let config_thread = Arc::new(Mutex::new(config.clone())); - output.par_iter_mut().enumerate().for_each(|(i, x)| { - let (i, j, k) = ( - cartesian_coord[i][0], - cartesian_coord[i][1], - cartesian_coord[i][2], - ); - let rs = j * stride.0; - let cs = k * stride.1; - - let local_kernel = kernel.get_slice(&[i..i + 1]).unwrap(); - let local_image = padded_image - .get_slice(&[ - 0..input_channels, - rs..(rs + kernel_height), - cs..(cs + kernel_width), - ]) - .unwrap(); - - let mut offset_lock = offset_thread.lock().unwrap(); - let mut config_lock = config_thread.lock().unwrap(); - let res = dot( - &mut config_lock, - None, - &[local_kernel, local_image], - &mut offset_lock, - ) - .unwrap(); - - *x = res.get_inner_tensor().unwrap()[0].clone(); - }); + outputs_per_group + .par_iter_mut() + .enumerate() + .for_each(|(group, o)| { + let start_channel = group * input_channels_per_group; + let end_channel = start_channel + input_channels_per_group; + let padded_image_per_group = &padded_image + .get_slice(&[start_channel..end_channel]) + .unwrap(); + + let kernel_per_group = &kernel + .get_slice(&[group * output_channels_per_group + ..(group + 1) * output_channels_per_group]) + .unwrap(); + let mut output_per_group = + Tensor::new(None, &[output_channels_per_group, vert_slides, horz_slides]) + .unwrap(); + + let cartesian_coord_per_group = vec![ + (0..output_channels_per_group), + (0..vert_slides), + (0..horz_slides), + ] + .iter() + .cloned() + .multi_cartesian_product() + .collect::>(); + output_per_group + .par_iter_mut() + .enumerate() + .for_each(|(flat_index, o)| { + let coord = &cartesian_coord_per_group[flat_index]; + let (i, j, k) = (coord[0], coord[1], coord[2]); + let rs = j * stride.0; + let cs = k * stride.1; + + let mut offset_lock = offset_thread.lock().unwrap(); + let mut config_lock = config_thread.lock().unwrap(); + let res = dot( + &mut config_lock, + None, + &[ + kernel_per_group.get_slice(&[i..i + 1]).unwrap(), + padded_image_per_group + .get_slice(&[ + 0..input_channels_per_group, + rs..(rs + kernel_height), + cs..(cs + kernel_width), + ]) + .unwrap(), + ], + &mut offset_lock, + ) + .unwrap(); + + *o = res.get_inner_tensor().unwrap()[0].clone(); + }); + + *o = output_per_group; + }); *offset += offset_thread.lock().unwrap().clone(); } } - let mut res: ValTensor = output.into(); + let mut output: ValTensor = Tensor::new(Some(&outputs_per_group), &[num_groups])? + .combine()? + .into(); + + output.reshape(&[output_channels, vert_slides, horz_slides])?; if has_bias { let tiled_bias = values[2].clone(); @@ -964,15 +1039,13 @@ pub fn conv( return Err(Box::new(TensorError::DimMismatch("conv bias".to_string()))); }; - res = pairwise(config, region, &[res, tiled_bias], offset, BaseOp::Add)? + output = pairwise(config, region, &[output, tiled_bias], offset, BaseOp::Add)? }; - res.reshape(&[output_channels, vert_slides, horz_slides])?; - if matches!(&config.check_mode, CheckMode::SAFE) { // during key generation this will be 0 so we use this as a flag to check // TODO: this isn't very safe and would be better to get the phase directly - let is_assigned = !Into::>::into(res.clone().get_inner()?) + let is_assigned = !Into::>::into(output.clone().get_inner()?) .iter() .all(|&x| x == 0); if is_assigned { @@ -990,13 +1063,13 @@ pub fn conv( })?; assert_eq!( - Into::>::into(res.get_inner()?), + Into::>::into(output.get_inner()?), Into::>::into(safe_conv), ) } } - Ok(res) + Ok(output) } /// Power accumulated layout diff --git a/src/graph/utilities.rs b/src/graph/utilities.rs index beb672402..32cb5b1bc 100644 --- a/src/graph/utilities.rs +++ b/src/graph/utilities.rs @@ -391,6 +391,14 @@ pub fn new_op_from_onnx( } }; + if let Some(dilations) = &conv_node.pool_spec.dilations { + if dilations.iter().any(|x| *x != 1) { + return Err(Box::new(GraphError::MisformedParams( + "non unit dilations not supported".to_string(), + ))); + } + } + if (conv_node.pool_spec.data_format != DataFormat::NCHW) || (conv_node.kernel_fmt != KernelFormat::OIHW) { diff --git a/src/tensor/mod.rs b/src/tensor/mod.rs index 2902eddde..e16f5b9ce 100644 --- a/src/tensor/mod.rs +++ b/src/tensor/mod.rs @@ -1017,7 +1017,7 @@ impl Tensor> { } } -impl> Add for Tensor { +impl + std::marker::Send + std::marker::Sync> Add for Tensor { type Output = Result, TensorError>; /// Adds tensors. /// # Arguments @@ -1072,53 +1072,64 @@ impl> Add for Tensor { if self.len() != rhs.len() { if self.dims().iter().map(|x| (x > &1) as usize).sum::() == 1 && rhs.dims().iter().product::() > 1 + && self.dims().iter().product::() > 1 && self.dims() != rhs.dims() { assert_eq!(rhs.dims()[0], self.dims().iter().product::()); output = rhs.clone(); let lhs = self.clone(); - let full_indices = rhs.dims().iter().map(|d| 0..*d); - for coord in full_indices.multi_cartesian_product() { - let i = self.get_index(&coord); - output[i] = output[i].clone() + lhs[coord[0]].clone(); - } + let full_indices = rhs + .dims() + .iter() + .map(|d| 0..*d) + .multi_cartesian_product() + .collect::>>(); + output.par_iter_mut().enumerate().for_each(|(i, x)| { + let coord = &full_indices[i]; + *x = x.clone() + lhs[coord[0]].clone(); + }); } else if rhs.dims().iter().map(|x| (x > &1) as usize).sum::() == 1 + && rhs.dims().iter().product::() > 1 && self.dims().iter().product::() > 1 && self.dims() != rhs.dims() { assert_eq!(self.dims()[0], rhs.dims().iter().product::()); - let rhs = rhs.clone(); - let full_indices = self.dims().iter().map(|d| 0..*d); - for coord in full_indices.multi_cartesian_product() { - let i = self.get_index(&coord); - output[i] = output[i].clone() + rhs[coord[0]].clone(); - } + let full_indices = self + .dims() + .iter() + .map(|d| 0..*d) + .multi_cartesian_product() + .collect::>>(); + output.par_iter_mut().enumerate().for_each(|(i, x)| { + let coord = &full_indices[i]; + *x = x.clone() + rhs[coord[0]].clone(); + }); } // casts a 1D addition else if rhs.dims().iter().product::() == 1 { - for i in 0..output.len() { - output[i] = output[i].clone() + rhs[0].clone(); - } + output.par_iter_mut().for_each(|o| { + *o = o.clone() + rhs[0].clone(); + }); } // make 1D casting commutative else if self.dims().iter().product::() == 1 { output = rhs.clone(); - for i in 0..rhs.len() { - output[i] = output[i].clone() + self[0].clone(); - } + output.par_iter_mut().for_each(|o| { + *o = o.clone() + self[0].clone(); + }); } else { return Err(TensorError::DimMismatch("add".to_string())); } } else { - for (i, e_i) in rhs.iter().enumerate() { - output[i] = output[i].clone() + e_i.clone() - } + output.par_iter_mut().zip(rhs).for_each(|(o, r)| { + *o = o.clone() + r.clone(); + }); } Ok(output) } } -impl> Sub for Tensor { +impl + std::marker::Send + std::marker::Sync> Sub for Tensor { type Output = Result, TensorError>; /// Subtracts tensors. /// # Arguments @@ -1174,53 +1185,64 @@ impl> Sub for Tensor { if self.len() != rhs.len() { if self.dims().iter().map(|x| (x > &1) as usize).sum::() == 1 && rhs.dims().iter().product::() > 1 + && self.dims().iter().product::() > 1 && self.dims() != rhs.dims() { assert_eq!(rhs.dims()[0], self.dims().iter().product::()); output = rhs.clone(); let lhs = self.clone(); - let full_indices = rhs.dims().iter().map(|d| 0..*d); - for coord in full_indices.multi_cartesian_product() { - let i = self.get_index(&coord); - output[i] = lhs[coord[0]].clone() - output[i].clone(); - } + let full_indices = rhs + .dims() + .iter() + .map(|d| 0..*d) + .multi_cartesian_product() + .collect::>>(); + output.par_iter_mut().enumerate().for_each(|(i, x)| { + let coord = &full_indices[i]; + *x = x.clone() - lhs[coord[0]].clone(); + }); } else if rhs.dims().iter().map(|x| (x > &1) as usize).sum::() == 1 + && rhs.dims().iter().product::() > 1 && self.dims().iter().product::() > 1 && self.dims() != rhs.dims() { assert_eq!(self.dims()[0], rhs.dims().iter().product::()); - let rhs = rhs.clone(); - let full_indices = self.dims().iter().map(|d| 0..*d); - for coord in full_indices.multi_cartesian_product() { - let i = self.get_index(&coord); - output[i] = output[i].clone() - rhs[coord[0]].clone(); - } + let full_indices = self + .dims() + .iter() + .map(|d| 0..*d) + .multi_cartesian_product() + .collect::>>(); + output.par_iter_mut().enumerate().for_each(|(i, x)| { + let coord = &full_indices[i]; + *x = x.clone() - rhs[coord[0]].clone(); + }); } // casts a 1D addition else if rhs.dims().iter().product::() == 1 { - for i in 0..output.len() { - output[i] = output[i].clone() - rhs[0].clone(); - } + output.par_iter_mut().for_each(|o| { + *o = o.clone() - rhs[0].clone(); + }); } // make 1D casting commutative else if self.dims().iter().product::() == 1 { output = rhs.clone(); - for i in 0..rhs.len() { - output[i] = self[0].clone() - output[i].clone(); - } + output.par_iter_mut().for_each(|o| { + *o = self[0].clone() - o.clone(); + }); } else { return Err(TensorError::DimMismatch("sub".to_string())); } } else { - for (i, e_i) in rhs.iter().enumerate() { - output[i] = output[i].clone() - e_i.clone() - } + output.par_iter_mut().zip(rhs).for_each(|(o, r)| { + *o = o.clone() - r.clone(); + }); } Ok(output) } } -impl> Mul for Tensor { +impl + std::marker::Send + std::marker::Sync> Mul for Tensor { type Output = Result, TensorError>; /// Elementwise multiplies tensors. /// # Arguments @@ -1274,53 +1296,64 @@ impl> Mul for Tensor { if self.len() != rhs.len() { if self.dims().iter().map(|x| (x > &1) as usize).sum::() == 1 && rhs.dims().iter().product::() > 1 + && self.dims().iter().product::() > 1 && self.dims() != rhs.dims() { assert_eq!(rhs.dims()[0], self.dims().iter().product::()); output = rhs.clone(); let lhs = self.clone(); - let full_indices = rhs.dims().iter().map(|d| 0..*d); - for coord in full_indices.multi_cartesian_product() { - let i = self.get_index(&coord); - output[i] = lhs[coord[0]].clone() * output[i].clone(); - } + let full_indices = rhs + .dims() + .iter() + .map(|d| 0..*d) + .multi_cartesian_product() + .collect::>>(); + output.par_iter_mut().enumerate().for_each(|(i, x)| { + let coord = &full_indices[i]; + *x = x.clone() * lhs[coord[0]].clone(); + }); } else if rhs.dims().iter().map(|x| (x > &1) as usize).sum::() == 1 + && rhs.dims().iter().product::() > 1 && self.dims().iter().product::() > 1 && self.dims() != rhs.dims() { assert_eq!(self.dims()[0], rhs.dims().iter().product::()); - let rhs = rhs.clone(); - let full_indices = self.dims().iter().map(|d| 0..*d); - for coord in full_indices.multi_cartesian_product() { - let i = self.get_index(&coord); - output[i] = output[i].clone() * rhs[coord[0]].clone(); - } + let full_indices = self + .dims() + .iter() + .map(|d| 0..*d) + .multi_cartesian_product() + .collect::>>(); + output.par_iter_mut().enumerate().for_each(|(i, x)| { + let coord = &full_indices[i]; + *x = rhs[coord[0]].clone() * x.clone(); + }); } - // cast 1D mul + // casts a 1D addition else if rhs.dims().iter().product::() == 1 { - for i in 0..output.len() { - output[i] = output[i].clone() * rhs[0].clone(); - } + output.par_iter_mut().for_each(|o| { + *o = o.clone() * rhs[0].clone(); + }); } // make 1D casting commutative else if self.dims().iter().product::() == 1 { output = rhs.clone(); - for i in 0..rhs.len() { - output[i] = output[i].clone() * self[0].clone(); - } + output.par_iter_mut().for_each(|o| { + *o = self[0].clone() * o.clone(); + }); } else { - return Err(TensorError::DimMismatch("mul".to_string())); + return Err(TensorError::DimMismatch("sub".to_string())); } } else { - for (i, e_i) in rhs.iter().enumerate() { - output[i] = output[i].clone() * e_i.clone() - } + output.par_iter_mut().zip(rhs).for_each(|(o, r)| { + *o = o.clone() * r.clone(); + }); } Ok(output) } } -impl> Tensor { +impl + std::marker::Send + std::marker::Sync> Tensor { /// Elementwise raise a tensor to the nth power. /// # Arguments /// diff --git a/src/tensor/ops.rs b/src/tensor/ops.rs index 505d1c6b6..0dfd11a82 100644 --- a/src/tensor/ops.rs +++ b/src/tensor/ops.rs @@ -7,107 +7,6 @@ use rayon::{ }; pub use std::ops::{Add, Div, Mul, Sub}; -/// Matrix multiplies two 2D tensors (and adds an offset). -/// # Arguments -/// -/// * `inputs` - A vector of tensors holding in order: input data, affine kernel, convolution bias. -/// # Examples -/// ``` -/// use ezkl_lib::tensor::Tensor; -/// use ezkl_lib::tensor::ops::affine; -/// -/// let x = Tensor::::new( -/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6, 2, 1, 1]), -/// &[3, 4], -/// ).unwrap(); -/// let k = Tensor::::new( -/// Some(&[2, 1, 2, 1, 1, 1]), -/// &[2, 3], -/// ).unwrap(); -/// let b = Tensor::::new( -/// Some(&[0, 0]), -/// &[2], -/// ).unwrap(); -/// let result = affine(&[x, k, b]).unwrap(); -/// let expected = Tensor::::new(Some(&[26, 7, 11, 3, 15, 3, 7, 2]), &[2, 4]).unwrap(); -/// assert_eq!(result, expected); -/// ``` -pub fn affine< - T: TensorType + Mul + Add + std::marker::Send + std::marker::Sync, ->( - inputs: &[Tensor], -) -> Result, TensorError> { - let (mut input, kernel, bias) = (inputs[0].clone(), inputs[1].clone(), inputs[2].clone()); - if (inputs.len() != 3) - || (bias.dims()[0] != kernel.dims()[0]) - || (input.dims()[0] != kernel.dims()[1]) - { - return Err(TensorError::DimMismatch("affine".to_string())); - } - - // does matrix to vector multiplication - if input.dims().len() == 1 { - input.reshape(&[input.dims()[0], 1]) - } - - // calculate value of output - let mut output: Tensor = matmul(&[kernel.clone(), input.clone()])?; - - for i in 0..kernel.dims()[0] { - for j in 0..input.dims()[1] { - output.set(&[i, j], output.get(&[i, j]) + bias[i].clone()); - } - } - // does matrix to vector multiplication - if output.dims()[1] == 1 { - output.flatten(); - } - Ok(output) -} - -/// Scales and shifts a tensor. -/// Given inputs (x,k,b) computes k*x + b elementwise -/// # Arguments -/// -/// * `inputs` - Vector of tensors of length 2 -/// # Examples -/// ``` -/// use ezkl_lib::tensor::Tensor; -/// use ezkl_lib::tensor::ops::scale_and_shift; -/// -/// let x = Tensor::::new( -/// Some(&[2, 1, 2, 1, 1, 1]), -/// &[2, 3], -/// ).unwrap(); -/// let k = Tensor::::new( -/// Some(&[2, 1, 2, 1, 1, 1]), -/// &[2, 3], -/// ).unwrap(); -/// let b = Tensor::::new( -/// Some(&[2, 1, 2, 1, 1, 1]), -/// &[2, 3], -/// ).unwrap(); -/// let result = scale_and_shift(&[x, k, b]).unwrap(); -/// let expected = Tensor::::new(Some(&[6, 2, 6, 2, 2, 2]), &[2, 3]).unwrap(); -/// assert_eq!(result, expected); -/// ``` -pub fn scale_and_shift + Add>( - inputs: &[Tensor], -) -> Result, TensorError> { - if (inputs.len() != 3) - || (inputs[1].dims() != inputs[2].dims()) - || (inputs[0].dims() != inputs[1].dims()) - { - return Err(TensorError::DimMismatch("scale and shift".to_string())); - } - let (input, kernel, bias) = (inputs[0].clone(), inputs[1].clone(), inputs[2].clone()); - let mut output: Tensor = input; - for (i, bias_i) in bias.iter().enumerate() { - output[i] = kernel[i].clone() * output[i].clone() + bias_i.clone() - } - Ok(output) -} - /// Matrix multiplies two 2D tensors. /// # Arguments /// @@ -220,12 +119,14 @@ pub fn matmul< /// let expected = Tensor::::new(Some(&[4, 3, 4, 3, 3, 3]), &[2, 3]).unwrap(); /// assert_eq!(result, expected); /// ``` -pub fn add>(t: &[Tensor]) -> Result, TensorError> { +pub fn add + std::marker::Send + std::marker::Sync>( + t: &[Tensor], +) -> Result, TensorError> { // calculate value of output let mut output: Tensor = t[0].clone(); for e in t[1..].iter() { - output = (output + e.clone())?; + output = output.add(e.clone())?; } Ok(output) @@ -265,7 +166,9 @@ pub fn add>(t: &[Tensor]) -> Result /// let expected = Tensor::::new(Some(&[0, -1, 0, -1, -1, -1]), &[2, 3]).unwrap(); /// assert_eq!(result, expected); /// ``` -pub fn sub>(t: &[Tensor]) -> Result, TensorError> { +pub fn sub + std::marker::Send + std::marker::Sync>( + t: &[Tensor], +) -> Result, TensorError> { // calculate value of output let mut output: Tensor = t[0].clone(); @@ -309,7 +212,9 @@ pub fn sub>(t: &[Tensor]) -> Result /// let expected = Tensor::::new(Some(&[4, 2, 4, 2, 2, 2]), &[2, 3]).unwrap(); /// assert_eq!(result, expected); /// ``` -pub fn mult>(t: &[Tensor]) -> Result, TensorError> { +pub fn mult + std::marker::Send + std::marker::Sync>( + t: &[Tensor], +) -> Result, TensorError> { // calculate value of output let mut output: Tensor = t[0].clone(); @@ -338,17 +243,17 @@ pub fn mult>(t: &[Tensor]) -> Result::new(Some(&[4, 2, 4, 2, 2, 2]), &[2, 3]).unwrap(); /// assert_eq!(result, expected); /// ``` -pub fn rescale>( +pub fn rescale + std::marker::Send + std::marker::Sync>( a: &Tensor, mult: usize, ) -> Result, TensorError> { // calculate value of output let mut output: Tensor = a.clone(); - for (i, a_i) in a.iter().enumerate() { + output.par_iter_mut().enumerate().for_each(|(i, a_i)| { for _ in 1..mult { - output[i] = output[i].clone() + a_i.clone(); + *a_i = a_i.clone() + a[i].clone(); } - } + }); Ok(output) } @@ -574,6 +479,8 @@ pub fn max_axes + std::cmp::Ord>( /// * `stride` - Tuple of stride values in x and y directions. /// # Examples /// ``` +/// // expected ouputs are taken from pytorch torch.nn.functional.conv2d +/// /// use ezkl_lib::tensor::Tensor; /// use ezkl_lib::tensor::ops::conv; /// @@ -599,16 +506,34 @@ pub fn max_axes + std::cmp::Ord>( /// &[2, 3, 3], /// ).unwrap(); /// let k = Tensor::::new( -/// Some(&[5, 1, 1, 1]), -/// &[1, 1, 2, 2], +/// Some(&[5, 1, 1, 1, 5, 2, 1, 1]), +/// &[2, 1, 2, 2], /// ).unwrap(); /// let b = Tensor::::new( -/// Some(&[0]), -/// &[1], +/// Some(&[1, 1]), +/// &[2], /// ).unwrap(); /// /// let result = conv::(&[x, k, b], (0, 0), (1, 1)).unwrap(); -/// let expected = Tensor::::new(Some(&[62, 32, 16, 52]), &[1, 2, 2]).unwrap(); +/// let expected = Tensor::::new(Some(&[32, 17, 9, 27, 34, 20, 13, 26]), &[2, 2, 2]).unwrap(); +/// assert_eq!(result, expected); +/// +/// // Now test multi channel +/// let x = Tensor::::new( +/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6, 5, 2, 3, 0, 4, -1, 3, 1, 6]), +/// &[2, 3, 3], +/// ).unwrap(); +/// let k = Tensor::::new( +/// Some(&[5, 1, 1, 1, 5, 2, 1, 1, 5, 3, 1, 1, 5, 4, 1, 1, 5, 1, 1, 1, 5, 2, 1, 1, 5, 3, 1, 1, 5, 4, 1, 1]), +/// &[4, 2, 2, 2], +/// ).unwrap(); +/// let b = Tensor::::new( +/// Some(&[1, 1, 1, 1]), +/// &[4], +/// ).unwrap(); +/// +/// let result = conv::(&[x, k, b], (0, 0), (1, 1)).unwrap(); +/// let expected = Tensor::::new(Some(&[65, 36, 21, 52, 73, 48, 37, 48, 65, 36, 21, 52, 73, 48, 37, 48]), &[4, 2, 2]).unwrap(); /// assert_eq!(result, expected); /// ``` pub fn conv< @@ -619,30 +544,21 @@ pub fn conv< stride: (usize, usize), ) -> Result, TensorError> { let has_bias = inputs.len() == 3; - let (image, mut kernel) = (inputs[0].clone(), inputs[1].clone()); + let (image, kernel) = (&inputs[0], &inputs[1]); if (image.dims().len() != 3) || (kernel.dims().len() != 4) - || ((image.dims()[0] != kernel.dims()[1]) && (kernel.dims()[1] != 1)) + // ensure number of groups makes sense + || (image.dims()[0] % kernel.dims()[1] != 0) { return Err(TensorError::DimMismatch("conv".to_string())); } - if kernel.dims()[1] == 1 && kernel.dims()[1] != image.dims()[0] { - kernel = kernel.repeat_rows(image.dims()[0])?; - kernel.reshape(&[ - kernel.dims()[0], - image.dims()[0], - kernel.dims()[2], - kernel.dims()[3], - ]); - } - let image_dims = image.dims(); let kernel_dims = kernel.dims(); if has_bias { - let bias = inputs[2].clone(); + let bias = &inputs[2]; if (bias.dims().len() != 1) || (bias.dims()[0] != kernel.dims()[0]) { return Err(TensorError::DimMismatch("conv bias".to_string())); } @@ -650,7 +566,7 @@ pub fn conv< let (output_channels, input_channels, kernel_height, kernel_width) = ( kernel_dims[0], - kernel_dims[1], + image_dims[0], kernel_dims[2], kernel_dims[3], ); @@ -662,45 +578,83 @@ pub fn conv< let vert_slides = (image_height + 2 * padding.0 - kernel_height) / stride.0 + 1; let horz_slides = (image_width + 2 * padding.1 - kernel_width) / stride.1 + 1; - // calculate value of output - let mut output: Tensor = - Tensor::new(None, &[output_channels, vert_slides, horz_slides]).unwrap(); + let num_groups = image_dims[0] / kernel_dims[1]; + let input_channels_per_group = input_channels / num_groups; + let output_channels_per_group = output_channels / num_groups; - let cartesian_coord = vec![(0..output_channels), (0..vert_slides), (0..horz_slides)] - .iter() - .cloned() - .multi_cartesian_product() - .collect::>(); + if output_channels_per_group == 0 { + return Err(TensorError::DimMismatch(format!( + "Given groups={}, expected kernel to be at least {} at dimension 0 but got {} instead", + num_groups, num_groups, output_channels_per_group + ))); + } - output + let mut outputs_per_group = vec![Tensor::new(None, &[0])?; num_groups]; + + outputs_per_group .par_iter_mut() .enumerate() - .for_each(|(flat_index, o)| { - let coord = &cartesian_coord[flat_index]; - let (i, j, k) = (coord[0], coord[1], coord[2]); - let rs = j * stride.0; - let cs = k * stride.1; + .for_each(|(group, o)| { + let start_channel = group * input_channels_per_group; + let end_channel = start_channel + input_channels_per_group; + let padded_image_per_group = &padded_image + .get_slice(&[start_channel..end_channel]) + .unwrap(); - let mut res = dot(&vec![ - &kernel.get_slice(&[i..i + 1]).unwrap().clone(), - &padded_image - .get_slice(&[ - 0..input_channels, - rs..(rs + kernel_height), - cs..(cs + kernel_width), + let kernel_per_group = &kernel + .get_slice(&[ + group * output_channels_per_group..(group + 1) * output_channels_per_group + ]) + .unwrap(); + let mut output_per_group = + Tensor::new(None, &[output_channels_per_group, vert_slides, horz_slides]).unwrap(); + + let cartesian_coord_per_group = vec![ + (0..output_channels_per_group), + (0..vert_slides), + (0..horz_slides), + ] + .iter() + .cloned() + .multi_cartesian_product() + .collect::>(); + + output_per_group + .par_iter_mut() + .enumerate() + .for_each(|(flat_index, o)| { + let coord = &cartesian_coord_per_group[flat_index]; + let (i, j, k) = (coord[0], coord[1], coord[2]); + let rs = j * stride.0; + let cs = k * stride.1; + + let res = dot(&vec![ + &kernel_per_group.get_slice(&[i..i + 1]).unwrap().clone(), + &padded_image_per_group + .get_slice(&[ + 0..input_channels_per_group, + rs..(rs + kernel_height), + cs..(cs + kernel_width), + ]) + .unwrap(), ]) - .unwrap(), - ]) - .unwrap(); + .unwrap(); - if has_bias { - // increment result by the bias - res[0] = res[0].clone() + inputs[2][i].clone(); - } + *o = res[0].clone(); + }); - *o = res[0].clone(); + *o = output_per_group; }); + let mut output = Tensor::new(Some(&outputs_per_group), &[num_groups])?.combine()?; + + output.reshape(&[output_channels, vert_slides, horz_slides]); + + if has_bias { + // increment result by the bias + output = (output + inputs[2].clone())?; + } + Ok(output) } @@ -961,7 +915,11 @@ pub fn pad( /// ).unwrap(); /// assert_eq!(result, expected); /// ``` -pub fn pack(a: &Tensor, base: T, scale: u32) -> Result, TensorError> +pub fn pack( + a: &Tensor, + base: T, + scale: u32, +) -> Result, TensorError> where T: Add, T: Mul,