Skip to content

Commit

Permalink
ONNX export update for YOLOv8 (#1497)
Browse files Browse the repository at this point in the history
* Changes for YOLOv8 quantization

* Changes for YOLOv8 quantization

* Quality fixes

* Update src/sparseml/onnx/utils/helpers.py

Co-authored-by: Benjamin Fineran <bfineran@users.noreply.github.com>

* Revert changes. Set didn't work w/ onnx nodes

---------

Co-authored-by: Benjamin Fineran <bfineran@users.noreply.github.com>
  • Loading branch information
anmarques and bfineran committed Mar 30, 2023
1 parent df20cf7 commit 93ba39c
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/sparseml/onnx/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,11 @@ def get_node_output_nodes(model: ModelProto, node: NodeProto) -> List[NodeProto]
for output_id in get_node_outputs(model, node):
nodes.extend(get_nodes_by_input_id(model, output_id))

return nodes
unique_nodes = []
for node in nodes:
if node not in unique_nodes:
unique_nodes.append(node)
return unique_nodes


def is_prunable_node(model: ModelProto, node: NodeProto) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1573,6 +1573,7 @@ def quantize_torch_qat_export(
_delete_repeated_qat_blocks(model)
_quantize_qat_embedding(model)
_propagate_mobilebert_embedding_quantization(model)
_propagate_through_split(model)
_convert_quantizable_matmul(model)
_convert_quantizable_matmul_and_add(model)
_fold_relu_quants(model)
Expand Down Expand Up @@ -1836,3 +1837,78 @@ def _propagate_mobilebert_embedding_quantization(model: ModelProto):
_LOGGER.info(
f"Propagated {converted_nodes} DequantizeLinear node(s) through Concat"
)


def _propagate_through_split(model: ModelProto):
"""
A pass for propagating dequantization down through a split node
so if there are quantized operations after the split they can
be properly converted
Starting with:
| INPUT
| |
| DequantizeLinear
| |
| Split
| | | |
Converts to:
| INPUT
| |
| Split
| | | |
| DequantizeLinear DequantizeLinear DequantizeLinear
| | | |
"""
new_nodes = []
to_remove = []
split_nodes = [n for n in model.graph.node if n.op_type in ["Split"]]
graph = ONNXGraph(model)
for split_node in split_nodes:
dequant_node = graph.get_node_single_parent(split_node, 0)
if not dequant_node or dequant_node.op_type != "DequantizeLinear":
continue

# Make input to dequantize linear node input to split node
split_node.input[0] = dequant_node.input[0]

# For every output of split create a dequantize linear node
dequant_id = 0
for other_node in get_node_output_nodes(model, split_node):
split_node_output = []
for out in split_node.output:
if out in other_node.input:
split_node_output.append(out)
for out in split_node_output:
dequant_node_name = split_node.name + f"_dequant.{dequant_id}"
dequant_id += 1
dequant_node_output = dequant_node_name + "_output"
new_nodes.append(
onnx.helper.make_node(
"DequantizeLinear",
[
out, # input
dequant_node.input[1], # scale
dequant_node.input[2], # zero point
],
[dequant_node_output],
dequant_node_name,
)
)
for other_node_input_index, other_node_input in enumerate(
other_node.input
):
if other_node_input == out:
break
other_node.input[other_node_input_index] = dequant_node_output
to_remove.append(dequant_node)

model.graph.node.extend(new_nodes)
for node in to_remove:
model.graph.node.remove(node)

if len(to_remove) > 0:
_LOGGER.info(
f"Propagated {len(to_remove)} DequantizeLinear node(s) through Split"
)

0 comments on commit 93ba39c

Please sign in to comment.