Skip to content

Commit

Permalink
fix identity as div
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-camuto committed Oct 3, 2024
1 parent 8d54f76 commit 1be9864
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 5 deletions.
14 changes: 14 additions & 0 deletions src/graph/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ impl RebaseScale {
if (op_out_scale > (global_scale * scale_rebase_multiplier as i32))
&& !inner.is_constant()
&& !inner.is_input()
&& !inner.is_identity()
{
let multiplier =
scale_to_multiplier(op_out_scale - global_scale * scale_rebase_multiplier as i32);
Expand Down Expand Up @@ -326,6 +327,19 @@ impl SupportedOp {
SupportedOp::RebaseScale(op) => op,
}
}

/// check if is the identity operation
/// # Returns
/// * `true` if the operation is the identity operation
/// * `false` otherwise
pub fn is_identity(&self) -> bool {
match self {
SupportedOp::Linear(op) => matches!(op, PolyOp::Identity { .. }),
SupportedOp::Rescaled(op) => op.inner.is_identity(),
SupportedOp::RebaseScale(op) => op.inner.is_identity(),
_ => false,
}
}
}

impl From<Box<dyn Op<Fp>>> for SupportedOp {
Expand Down
15 changes: 10 additions & 5 deletions src/graph/utilities.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1017,8 +1017,13 @@ pub fn new_op_from_onnx(
if raw_values.log2().fract() == 0.0 {
inputs[const_idx].decrement_use();
deleted_indices.push(const_idx);
// get the non constant index
let non_const_idx = if const_idx == 0 { 1 } else { 0 };

op = SupportedOp::Linear(PolyOp::Identity {
out_scale: Some(input_scales[0] + raw_values.log2() as i32),
out_scale: Some(
input_scales[non_const_idx] + raw_values.log2() as i32,
),
});
}
}
Expand All @@ -1027,29 +1032,29 @@ pub fn new_op_from_onnx(
op
}
"Iff" => SupportedOp::Linear(PolyOp::Iff),
"Less" => {
"<" => {
if inputs.len() == 2 {
SupportedOp::Hybrid(HybridOp::Less)
} else {
return Err(GraphError::InvalidDims(idx, "less".to_string()));
}
}
"LessEqual" => {
"<=" => {
if inputs.len() == 2 {
SupportedOp::Hybrid(HybridOp::LessEqual)
} else {
return Err(GraphError::InvalidDims(idx, "less equal".to_string()));
}
}
"Greater" => {
">" => {
// Extract the slope layer hyperparams
if inputs.len() == 2 {
SupportedOp::Hybrid(HybridOp::Greater)
} else {
return Err(GraphError::InvalidDims(idx, "greater".to_string()));
}
}
"GreaterEqual" => {
">=" => {
// Extract the slope layer hyperparams
if inputs.len() == 2 {
SupportedOp::Hybrid(HybridOp::GreaterEqual)
Expand Down

0 comments on commit 1be9864

Please sign in to comment.