From 1be9864790f1914aab45240b077f4ad8c8fb6e81 Mon Sep 17 00:00:00 2001 From: dante <45801863+alexander-camuto@users.noreply.github.com> Date: Thu, 3 Oct 2024 16:10:43 -0400 Subject: [PATCH] fix identity as div --- src/graph/node.rs | 14 ++++++++++++++ src/graph/utilities.rs | 15 ++++++++++----- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/src/graph/node.rs b/src/graph/node.rs index 539f10d79..a46654752 100644 --- a/src/graph/node.rs +++ b/src/graph/node.rs @@ -125,6 +125,7 @@ impl RebaseScale { if (op_out_scale > (global_scale * scale_rebase_multiplier as i32)) && !inner.is_constant() && !inner.is_input() + && !inner.is_identity() { let multiplier = scale_to_multiplier(op_out_scale - global_scale * scale_rebase_multiplier as i32); @@ -326,6 +327,19 @@ impl SupportedOp { SupportedOp::RebaseScale(op) => op, } } + + /// check if is the identity operation + /// # Returns + /// * `true` if the operation is the identity operation + /// * `false` otherwise + pub fn is_identity(&self) -> bool { + match self { + SupportedOp::Linear(op) => matches!(op, PolyOp::Identity { .. }), + SupportedOp::Rescaled(op) => op.inner.is_identity(), + SupportedOp::RebaseScale(op) => op.inner.is_identity(), + _ => false, + } + } } impl From>> for SupportedOp { diff --git a/src/graph/utilities.rs b/src/graph/utilities.rs index a7d11c5dc..dddf142ad 100644 --- a/src/graph/utilities.rs +++ b/src/graph/utilities.rs @@ -1017,8 +1017,13 @@ pub fn new_op_from_onnx( if raw_values.log2().fract() == 0.0 { inputs[const_idx].decrement_use(); deleted_indices.push(const_idx); + // get the non constant index + let non_const_idx = if const_idx == 0 { 1 } else { 0 }; + op = SupportedOp::Linear(PolyOp::Identity { - out_scale: Some(input_scales[0] + raw_values.log2() as i32), + out_scale: Some( + input_scales[non_const_idx] + raw_values.log2() as i32, + ), }); } } @@ -1027,21 +1032,21 @@ pub fn new_op_from_onnx( op } "Iff" => SupportedOp::Linear(PolyOp::Iff), - "Less" => { + "<" => { if inputs.len() == 2 { SupportedOp::Hybrid(HybridOp::Less) } else { return Err(GraphError::InvalidDims(idx, "less".to_string())); } } - "LessEqual" => { + "<=" => { if inputs.len() == 2 { SupportedOp::Hybrid(HybridOp::LessEqual) } else { return Err(GraphError::InvalidDims(idx, "less equal".to_string())); } } - "Greater" => { + ">" => { // Extract the slope layer hyperparams if inputs.len() == 2 { SupportedOp::Hybrid(HybridOp::Greater) @@ -1049,7 +1054,7 @@ pub fn new_op_from_onnx( return Err(GraphError::InvalidDims(idx, "greater".to_string())); } } - "GreaterEqual" => { + ">=" => { // Extract the slope layer hyperparams if inputs.len() == 2 { SupportedOp::Hybrid(HybridOp::GreaterEqual)