Skip to content

Commit

Permalink
chore: support all padding types (#848)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-camuto authored Oct 5, 2024
1 parent 64fbc8a commit e5aa48f
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 50 deletions.
18 changes: 6 additions & 12 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion examples/onnx/boolean/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ def __init__(self):
super(MyModel, self).__init__()

def forward(self, w, x, y, z):
return [((x & y)) == (x & (y | (z ^ w)))]
a = (x & y)
b = (y & (z ^ w))
return [a & b]


circuit = MyModel()
Expand Down
2 changes: 1 addition & 1 deletion examples/onnx/boolean/input.json
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"input_data": [[false, true, false], [true, false, false], [true, false, false], [false, false, false]]}
{"input_data": [[false, true, true], [false, true, true], [true, false, false], [false, true, true]]}
26 changes: 11 additions & 15 deletions examples/onnx/boolean/network.onnx
Original file line number Diff line number Diff line change
@@ -1,21 +1,17 @@
pytorch1.12.1:�
+
pytorch2.2.2:�
*
input1
input2onnx::Equal_4And_0"And
'
input2/And_output_0/And"And
)
input3
input
onnx::Or_5Xor_1"Xor
+
input/Xor_output_0/Xor"Xor
5
input2

onnx::Or_5 onnx::And_6Or_2"Or
0
input1
onnx::And_6onnx::Equal_7And_3"And
6
onnx::Equal_4
onnx::Equal_7outputEqual_4"Equal torch_jitZ!
/Xor_output_0/And_1_output_0/And_1"And
5
/And_output_0
/And_1_output_0output/And_2"And
main_graphZ!
input
 

Expand Down
14 changes: 14 additions & 0 deletions src/graph/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ impl RebaseScale {
if (op_out_scale > (global_scale * scale_rebase_multiplier as i32))
&& !inner.is_constant()
&& !inner.is_input()
&& !inner.is_identity()
{
let multiplier =
scale_to_multiplier(op_out_scale - global_scale * scale_rebase_multiplier as i32);
Expand Down Expand Up @@ -326,6 +327,19 @@ impl SupportedOp {
SupportedOp::RebaseScale(op) => op,
}
}

/// check if is the identity operation
/// # Returns
/// * `true` if the operation is the identity operation
/// * `false` otherwise
pub fn is_identity(&self) -> bool {
match self {
SupportedOp::Linear(op) => matches!(op, PolyOp::Identity { .. }),
SupportedOp::Rescaled(op) => op.inner.is_identity(),
SupportedOp::RebaseScale(op) => op.inner.is_identity(),
_ => false,
}
}
}

impl From<Box<dyn Op<Fp>>> for SupportedOp {
Expand Down
43 changes: 22 additions & 21 deletions src/graph/utilities.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ use tract_onnx::tract_hir::{
ops::konst::Const,
ops::nn::DataFormat,
tract_core::ops::cast::Cast,
tract_core::ops::cnn::{conv::KernelFormat, MaxPool, PaddingSpec, SumPool},
tract_core::ops::cnn::{conv::KernelFormat, MaxPool, SumPool},
};

/// Quantizes an iterable of f32s to a [Tensor] of i32s using a fixed point representation.
Expand Down Expand Up @@ -94,17 +94,18 @@ pub fn multiplier_to_scale(mult: f64) -> crate::Scale {
/// extract padding from a onnx node.
pub fn extract_padding(
pool_spec: &PoolSpec,
num_dims: usize,
image_size: &[usize],
) -> Result<Vec<(usize, usize)>, GraphError> {
let padding = match &pool_spec.padding {
PaddingSpec::Explicit(b, a) | PaddingSpec::ExplicitOnnxPool(b, a, _) => {
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
}
PaddingSpec::Valid => vec![(0, 0); num_dims],
_ => {
return Err(GraphError::MissingParams("padding".to_string()));
}
};
let num_relevant_dims = pool_spec.kernel_shape.len();

// get the last num_relevant_dims of the image size
let image_size = &image_size[image_size.len() - num_relevant_dims..];

let dims = pool_spec.computed_padding(image_size);
let mut padding = Vec::new();
for dim in dims {
padding.push((dim.pad_before, dim.pad_after));
}
Ok(padding)
}

Expand Down Expand Up @@ -1016,8 +1017,13 @@ pub fn new_op_from_onnx(
if raw_values.log2().fract() == 0.0 {
inputs[const_idx].decrement_use();
deleted_indices.push(const_idx);
// get the non constant index
let non_const_idx = if const_idx == 0 { 1 } else { 0 };

op = SupportedOp::Linear(PolyOp::Identity {
out_scale: Some(input_scales[0] + raw_values.log2() as i32),
out_scale: Some(
input_scales[non_const_idx] + raw_values.log2() as i32,
),
});
}
}
Expand Down Expand Up @@ -1108,7 +1114,7 @@ pub fn new_op_from_onnx(
}

let stride = extract_strides(pool_spec)?;
let padding = extract_padding(pool_spec, input_dims[0].len())?;
let padding = extract_padding(pool_spec, &input_dims[0])?;
let kernel_shape = &pool_spec.kernel_shape;

SupportedOp::Hybrid(HybridOp::MaxPool {
Expand Down Expand Up @@ -1178,7 +1184,7 @@ pub fn new_op_from_onnx(
let pool_spec = &conv_node.pool_spec;

let stride = extract_strides(pool_spec)?;
let padding = extract_padding(pool_spec, input_dims[0].len())?;
let padding = extract_padding(pool_spec, &input_dims[0])?;

// if bias exists then rescale it to the input + kernel scale
if input_scales.len() == 3 {
Expand Down Expand Up @@ -1236,7 +1242,7 @@ pub fn new_op_from_onnx(
let pool_spec = &deconv_node.pool_spec;

let stride = extract_strides(pool_spec)?;
let padding = extract_padding(pool_spec, input_dims[0].len())?;
let padding = extract_padding(pool_spec, &input_dims[0])?;
// if bias exists then rescale it to the input + kernel scale
if input_scales.len() == 3 {
let bias_scale = input_scales[2];
Expand Down Expand Up @@ -1349,7 +1355,7 @@ pub fn new_op_from_onnx(
}

let stride = extract_strides(pool_spec)?;
let padding = extract_padding(pool_spec, input_dims[0].len())?;
let padding = extract_padding(pool_spec, &input_dims[0])?;

SupportedOp::Hybrid(HybridOp::SumPool {
padding,
Expand All @@ -1358,11 +1364,6 @@ pub fn new_op_from_onnx(
normalized: sumpool_node.normalize,
})
}
// "GlobalAvgPool" => SupportedOp::Linear(PolyOp::SumPool {
// padding: [(0, 0); 2],
// stride: (1, 1),
// kernel_shape: (inputs[0].out_dims()[0][1], inputs[0].out_dims()[0][2]),
// }),
"Pad" => {
let pad_node: &Pad = match node.op().downcast_ref::<Pad>() {
Some(b) => b,
Expand Down

0 comments on commit e5aa48f

Please sign in to comment.