Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Benjamin committed Oct 3, 2023
1 parent a9dd3e4 commit 5813331
Show file tree
Hide file tree
Showing 12 changed files with 47 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def _create_test_model():
conv_node = onnx.helper.make_node(
"Conv",
inputs=["dequant_linear_0_output", "dequant_linear_1_output", "bias"],
outputs=["conv_node_output"],
outputs=["output"],
kernel_shape=[3, 3],
name="conv_node",
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def onnx_model() -> onnx.ModelProto:
output_quant = helper.make_node(
"QuantizeLinear",
["conv_output", "y_scale", "zero_point"],
["output_quant_output"],
["output"],
name="output_quant",
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def onnx_model():
dequant2 = onnx.helper.make_node(
"DequantizeLinear",
["quant2_output", "scale"],
["dequant2_output"],
["output"],
name="dequant2",
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def onnx_model():
name="scale", data_type=onnx.TensorProto.FLOAT, dims=(1,), vals=[1.0]
)
quantize = onnx.helper.make_node(
"QuantizeLinear", ["input", "scale", "zero_point"], ["id1_output"], name="id1"
"QuantizeLinear", ["input", "scale", "zero_point"], ["output"], name="id1"
)

graph = onnx.helper.make_graph(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def onnx_model():
scale = onnx.helper.make_tensor("scale", onnx.TensorProto.FLOAT, (1,), [1])
relu = onnx.helper.make_node("Relu", ["input"], ["relu_output"], name="relu")
quant = onnx.helper.make_node(
"QuantizeLinear", ["relu_output", "scale"], ["quant_output"], name="quant"
"QuantizeLinear", ["relu_output", "scale"], ["output"], name="quant"
)

graph = onnx.helper.make_graph(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def onnx_model() -> onnx.ModelProto:
gemm = helper.make_node(
"Gemm",
["input_dequant_output", "weight_dequant_output", "bias"],
["gemm_output"],
["output"],
name="gemm",
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ def onnx_model() -> onnx.ModelProto:
model_input_1 = helper.make_tensor_value_info(
"input_1", onnx.TensorProto.FLOAT, (1,)
)
model_output = helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, (1,))
model_output = helper.make_tensor_value_info(
"output_quant_output", onnx.TensorProto.FLOAT, (1,)
)

input_dequant = helper.make_node(
"DequantizeLinear",
Expand Down Expand Up @@ -159,6 +161,7 @@ def test_gemm_with_bias_dequant_after(onnx_model: onnx.ModelProto):
name="output_dequant",
)
)
onnx_model.graph.output[0].name = "output_dequant_output"
validate_onnx(onnx_model)

onnx_model = GemmToQLinearMatMul().apply(onnx_model)
Expand Down Expand Up @@ -201,6 +204,7 @@ def test_gemm_after_changes_nothing(onnx_model: onnx.ModelProto):
name="gemm2",
)
)
onnx_model.graph.output[0].name = "gemm2_output" # update graph output
validate_onnx(onnx_model)
onnx_model = GemmToQLinearMatMul().apply(onnx_model)
validate_onnx(onnx_model)
Expand All @@ -224,6 +228,7 @@ def test_gemm_after_changes_nothing(onnx_model: onnx.ModelProto):

# remove the gemm2 node and now things should change
onnx_model.graph.node.pop()
onnx_model.graph.output[0].name = "output_dequant_output" # update graph output
validate_onnx(onnx_model)
onnx_model = GemmToQLinearMatMul().apply(onnx_model)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def onnx_model() -> onnx.ModelProto:
["matmul_output"],
name="matmul",
)
add = helper.make_node("Add", ["matmul_output", "bias"], ["add_output"], name="add")
add = helper.make_node("Add", ["matmul_output", "bias"], ["output"], name="add")

graph = helper.make_graph(
nodes=[input_dequant, weight_quant, weight_dequant, transpose, matmul, add],
Expand Down Expand Up @@ -150,6 +150,7 @@ def test_matmul_no_bias_converts(onnx_model: onnx.ModelProto):
# remove "bias" initializer and "add" node
assert onnx_model.graph.initializer.pop().name == "bias"
assert onnx_model.graph.node.pop().name == "add"
onnx_model.graph.output[0].name = "matmul_output" # update graph output name
validate_onnx(onnx_model)

onnx_model = MatMulAddToMatMulIntegerAddCastMul().apply(onnx_model)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def _create_test_model(with_transpose=False, with_reshape=False):
dequantize_linear_node_2 = onnx.helper.make_node(
"DequantizeLinear",
["quant_linear_2_output", "x_scale", "zero_point"],
["dequant_linear_2_output"],
["output"],
name="dequantize_linear_node_2",
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def onnx_model():
concat = onnx.helper.make_node(
"Concat",
["pad1_output", "pad2_output", "dequant_output"],
["concat_output"],
["output"],
name="concat",
axis=0,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,23 @@ def onnx_model():
model_input = onnx.helper.make_tensor_value_info(
"input", onnx.TensorProto.FLOAT, (1,)
)
model_output = onnx.helper.make_tensor_value_info(
"output", onnx.TensorProto.FLOAT, (1,)
model_output_1 = onnx.helper.make_tensor_value_info(
"add1_output", onnx.TensorProto.FLOAT, (1,)
)
model_output_2 = onnx.helper.make_tensor_value_info(
"add2_output", onnx.TensorProto.FLOAT, (1,)
)
model_output_3 = onnx.helper.make_tensor_value_info(
"add3_output", onnx.TensorProto.FLOAT, (1,)
)
model_output_4 = onnx.helper.make_tensor_value_info(
"conv4_output", onnx.TensorProto.FLOAT, (1,)
)
model_output_5 = onnx.helper.make_tensor_value_info(
"conv5_output", onnx.TensorProto.FLOAT, (1,)
)
model_output_6 = onnx.helper.make_tensor_value_info(
"conv6_output", onnx.TensorProto.FLOAT, (1,)
)
zp = onnx.helper.make_tensor("zp", onnx.TensorProto.UINT8, (1,), [0])
scale = onnx.helper.make_tensor("scale", onnx.TensorProto.FLOAT, (1,), [1.0])
Expand Down Expand Up @@ -90,7 +105,14 @@ def onnx_model():
nodes=[conv1, conv2, conv3, conv4, conv5, conv6, add1, add2, add3],
name="g",
inputs=[model_input],
outputs=[model_output],
outputs=[
model_output_1,
model_output_2,
model_output_3,
model_output_4,
model_output_5,
model_output_6,
],
initializer=[
weight1_a,
weight1_b,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@ def onnx_model():
model_input = onnx.helper.make_tensor_value_info(
"input", onnx.TensorProto.FLOAT, (1,)
)
model_output = onnx.helper.make_tensor_value_info(
"output", onnx.TensorProto.FLOAT, (1,)
model_output_1 = onnx.helper.make_tensor_value_info(
"quant1_output", onnx.TensorProto.FLOAT, (1,)
)
model_output_2 = onnx.helper.make_tensor_value_info(
"quant2_output", onnx.TensorProto.FLOAT, (1,)
)
zp = onnx.helper.make_tensor("zp", onnx.TensorProto.UINT8, (1,), [0])
scale1 = onnx.helper.make_tensor("scale1", onnx.TensorProto.FLOAT, (1,), [1.0])
Expand All @@ -50,7 +53,7 @@ def onnx_model():
nodes=[quant1, quant2],
name="g",
inputs=[model_input],
outputs=[model_output],
outputs=[model_output_1, model_output_2],
initializer=[scale1, zp],
)

Expand Down

0 comments on commit 5813331

Please sign in to comment.