Skip to content

Commit

Permalink
[ONNXTransform] always delete orphaned node branches after running an…
Browse files Browse the repository at this point in the history
… onnx transform (#1746)

* [ONNXTransform] always delete orphaned node branches after running an onnx transform

* update tests
  • Loading branch information
bfineran committed Oct 3, 2023
1 parent 38fe044 commit 2d0f8a0
Show file tree
Hide file tree
Showing 13 changed files with 48 additions and 16 deletions.
1 change: 1 addition & 0 deletions src/sparseml/exporters/transforms/onnx_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def post_validate(self, model: ModelProto) -> ModelProto:
model.graph.node.remove(node)
graph = ONNXGraph(model)
graph.delete_unused_initializers()
graph.delete_orphaned_node_branches()
graph.sort_nodes_topologically()
validate_onnx(model)
return model
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def _create_test_model():
conv_node = onnx.helper.make_node(
"Conv",
inputs=["dequant_linear_0_output", "dequant_linear_1_output", "bias"],
outputs=["conv_node_output"],
outputs=["output"],
kernel_shape=[3, 3],
name="conv_node",
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def onnx_model() -> onnx.ModelProto:
output_quant = helper.make_node(
"QuantizeLinear",
["conv_output", "y_scale", "zero_point"],
["output_quant_output"],
["output"],
name="output_quant",
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def onnx_model():
dequant2 = onnx.helper.make_node(
"DequantizeLinear",
["quant2_output", "scale"],
["dequant2_output"],
["output"],
name="dequant2",
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def onnx_model():
name="scale", data_type=onnx.TensorProto.FLOAT, dims=(1,), vals=[1.0]
)
quantize = onnx.helper.make_node(
"QuantizeLinear", ["input", "scale", "zero_point"], ["id1_output"], name="id1"
"QuantizeLinear", ["input", "scale", "zero_point"], ["output"], name="id1"
)

graph = onnx.helper.make_graph(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def onnx_model():
scale = onnx.helper.make_tensor("scale", onnx.TensorProto.FLOAT, (1,), [1])
relu = onnx.helper.make_node("Relu", ["input"], ["relu_output"], name="relu")
quant = onnx.helper.make_node(
"QuantizeLinear", ["relu_output", "scale"], ["quant_output"], name="quant"
"QuantizeLinear", ["relu_output", "scale"], ["output"], name="quant"
)

graph = onnx.helper.make_graph(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def onnx_model() -> onnx.ModelProto:
gemm = helper.make_node(
"Gemm",
["input_dequant_output", "weight_dequant_output", "bias"],
["gemm_output"],
["output"],
name="gemm",
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ def onnx_model() -> onnx.ModelProto:
model_input_1 = helper.make_tensor_value_info(
"input_1", onnx.TensorProto.FLOAT, (1,)
)
model_output = helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, (1,))
model_output = helper.make_tensor_value_info(
"output_quant_output", onnx.TensorProto.FLOAT, (1,)
)

input_dequant = helper.make_node(
"DequantizeLinear",
Expand Down Expand Up @@ -159,6 +161,7 @@ def test_gemm_with_bias_dequant_after(onnx_model: onnx.ModelProto):
name="output_dequant",
)
)
onnx_model.graph.output[0].name = "output_dequant_output"
validate_onnx(onnx_model)

onnx_model = GemmToQLinearMatMul().apply(onnx_model)
Expand Down Expand Up @@ -201,6 +204,7 @@ def test_gemm_after_changes_nothing(onnx_model: onnx.ModelProto):
name="gemm2",
)
)
onnx_model.graph.output[0].name = "gemm2_output" # update graph output
validate_onnx(onnx_model)
onnx_model = GemmToQLinearMatMul().apply(onnx_model)
validate_onnx(onnx_model)
Expand All @@ -224,6 +228,7 @@ def test_gemm_after_changes_nothing(onnx_model: onnx.ModelProto):

# remove the gemm2 node and now things should change
onnx_model.graph.node.pop()
onnx_model.graph.output[0].name = "output_dequant_output" # update graph output
validate_onnx(onnx_model)
onnx_model = GemmToQLinearMatMul().apply(onnx_model)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def onnx_model() -> onnx.ModelProto:
["matmul_output"],
name="matmul",
)
add = helper.make_node("Add", ["matmul_output", "bias"], ["add_output"], name="add")
add = helper.make_node("Add", ["matmul_output", "bias"], ["output"], name="add")

graph = helper.make_graph(
nodes=[input_dequant, weight_quant, weight_dequant, transpose, matmul, add],
Expand Down Expand Up @@ -150,6 +150,7 @@ def test_matmul_no_bias_converts(onnx_model: onnx.ModelProto):
# remove "bias" initializer and "add" node
assert onnx_model.graph.initializer.pop().name == "bias"
assert onnx_model.graph.node.pop().name == "add"
onnx_model.graph.output[0].name = "matmul_output" # update graph output name
validate_onnx(onnx_model)

onnx_model = MatMulAddToMatMulIntegerAddCastMul().apply(onnx_model)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def _create_test_model(with_transpose=False, with_reshape=False):
dequantize_linear_node_2 = onnx.helper.make_node(
"DequantizeLinear",
["quant_linear_2_output", "x_scale", "zero_point"],
["dequant_linear_2_output"],
["output"],
name="dequantize_linear_node_2",
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def onnx_model():
concat = onnx.helper.make_node(
"Concat",
["pad1_output", "pad2_output", "dequant_output"],
["concat_output"],
["output"],
name="concat",
axis=0,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,23 @@ def onnx_model():
model_input = onnx.helper.make_tensor_value_info(
"input", onnx.TensorProto.FLOAT, (1,)
)
model_output = onnx.helper.make_tensor_value_info(
"output", onnx.TensorProto.FLOAT, (1,)
model_output_1 = onnx.helper.make_tensor_value_info(
"add1_output", onnx.TensorProto.FLOAT, (1,)
)
model_output_2 = onnx.helper.make_tensor_value_info(
"add2_output", onnx.TensorProto.FLOAT, (1,)
)
model_output_3 = onnx.helper.make_tensor_value_info(
"add3_output", onnx.TensorProto.FLOAT, (1,)
)
model_output_4 = onnx.helper.make_tensor_value_info(
"conv4_output", onnx.TensorProto.FLOAT, (1,)
)
model_output_5 = onnx.helper.make_tensor_value_info(
"conv5_output", onnx.TensorProto.FLOAT, (1,)
)
model_output_6 = onnx.helper.make_tensor_value_info(
"conv6_output", onnx.TensorProto.FLOAT, (1,)
)
zp = onnx.helper.make_tensor("zp", onnx.TensorProto.UINT8, (1,), [0])
scale = onnx.helper.make_tensor("scale", onnx.TensorProto.FLOAT, (1,), [1.0])
Expand Down Expand Up @@ -90,7 +105,14 @@ def onnx_model():
nodes=[conv1, conv2, conv3, conv4, conv5, conv6, add1, add2, add3],
name="g",
inputs=[model_input],
outputs=[model_output],
outputs=[
model_output_1,
model_output_2,
model_output_3,
model_output_4,
model_output_5,
model_output_6,
],
initializer=[
weight1_a,
weight1_b,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@ def onnx_model():
model_input = onnx.helper.make_tensor_value_info(
"input", onnx.TensorProto.FLOAT, (1,)
)
model_output = onnx.helper.make_tensor_value_info(
"output", onnx.TensorProto.FLOAT, (1,)
model_output_1 = onnx.helper.make_tensor_value_info(
"quant1_output", onnx.TensorProto.FLOAT, (1,)
)
model_output_2 = onnx.helper.make_tensor_value_info(
"quant2_output", onnx.TensorProto.FLOAT, (1,)
)
zp = onnx.helper.make_tensor("zp", onnx.TensorProto.UINT8, (1,), [0])
scale1 = onnx.helper.make_tensor("scale1", onnx.TensorProto.FLOAT, (1,), [1.0])
Expand All @@ -50,7 +53,7 @@ def onnx_model():
nodes=[quant1, quant2],
name="g",
inputs=[model_input],
outputs=[model_output],
outputs=[model_output_1, model_output_2],
initializer=[scale1, zp],
)

Expand Down

0 comments on commit 2d0f8a0

Please sign in to comment.