Skip to content

Commit

Permalink
Cherry-pick ONNX export update into release 1.0 (#936)
Browse files Browse the repository at this point in the history
* Bump up version id

* Fix for ONNX export for quantized BERT models (#935)

* Remove quantization of identity branch on BERT models

* Style and quality fixes.

* Update src/sparseml/pytorch/sparsification/quantization/quantize_qat_export.py

Co-authored-by: Benjamin Fineran <bfineran@users.noreply.github.com>

* Removed unused function

Co-authored-by: Benjamin Fineran <bfineran@users.noreply.github.com>

Co-authored-by: Benjamin Fineran <bfineran@users.noreply.github.com>
  • Loading branch information
anmarques and bfineran committed Jul 12, 2022
1 parent fc4c771 commit 0fa9f72
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 93 deletions.
86 changes: 0 additions & 86 deletions src/sparseml/onnx/utils/graph_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
__all__ = [
"fold_conv_bns",
"quantize_resnet_identity_add_inputs",
"quantized_residual_add_optim",
]


Expand Down Expand Up @@ -202,91 +201,6 @@ def quantize_resnet_identity_add_inputs(quantized_model: onnx.ModelProto) -> boo
return optimization_made


def quantized_residual_add_optim(quantized_model: onnx.ModelProto) -> bool:
"""
This optimization adds a quant/dequant block to the identity branch of a
residual whose non-identity branch is quantized. This enables the add at the
end of the residual to be fused at runtime.
Function will match to any node who has two children nodes - one add node
and one quantize node whose branch eventually leads to the other add node.
:param quantized_model: A loaded quantized model to perform this optimization on
:return: True if an in-place optimization was made
"""
graph = ONNXGraph(quantized_model)
optimization_made = False
for node in quantized_model.graph.node:
children_nodes = graph.get_node_children(node)
if len(children_nodes) != 2:
continue

add_node = [node for node in children_nodes if node.op_type == "Add"]
quant_node = [
node for node in children_nodes if node.op_type == "QuantizeLinear"
]
if not add_node or not quant_node:
continue
add_node = add_node[0]
quant_node = quant_node[0]

# verify that quant_node eventually leads to add_node
curr_node = [quant_node]
iter = 0
max_iter = 20 # avoid cycles
while curr_node and curr_node[0] != add_node and iter < max_iter:
curr_node = graph.get_node_children(curr_node[0])
iter += 1
if curr_node[0] != add_node:
continue

# create de-quantize node for identity
dequant_node = _make_dequant_node_for_quant(quant_node)

# update graph
identity_edge_idx = 0 if add_node.input[0] == node.output[0] else 1
graph.add_node(dequant_node)
graph.update_node_input(add_node, dequant_node.output[0], identity_edge_idx)
optimization_made = True

# if any of the add children have are a quantize op while others aren't
# add a quant/dequant block to the non quantized paths to allow for fusion
# of the add
add_node_children = graph.get_node_children(add_node)
add_node_quant_child_idx = [
idx
for idx, node in enumerate(add_node_children)
if node.op_type == "QuantizeLinear"
]
if not add_node_quant_child_idx or all(
n.op_type == "Add" or n.op_type == "QuantizeLinear"
for n in add_node_children
):
# no quant child node, or all child nodes are quant/add nodes
continue

# make dequant pair node for quant child and add to graph
add_node_dequant_child = _make_dequant_node_for_quant(
add_node_children[add_node_quant_child_idx[0]]
)
graph.add_node(add_node_dequant_child)

# update all non quant node children to take the quant/dequant block as input
for add_child_node in add_node_children:
if add_child_node.op_type == "QuantizeLinear":
continue
add_node_id_idx = [
idx
for idx, output_id in enumerate(add_child_node.input)
if output_id == add_node.output[0]
][0]
graph.update_node_input(
add_child_node, add_node_dequant_child.output[0], add_node_id_idx
)

return optimization_made


def _make_dequant_node_for_quant(quant_node: onnx.NodeProto) -> onnx.NodeProto:
return onnx.helper.make_node(
"DequantizeLinear",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import numpy
import onnx
import torch
from onnx import ModelProto, NodeProto, numpy_helper

from sparseml.onnx.utils import (
Expand All @@ -34,7 +35,6 @@
get_node_attributes,
get_node_output_nodes,
quantize_resnet_identity_add_inputs,
quantized_residual_add_optim,
remove_node_and_params_from_graph,
swap_node_output,
update_model_param,
Expand Down Expand Up @@ -323,9 +323,21 @@ def _attribute_to_kwarg(attribute: onnx.AttributeProto):
def _quantize_array(
array: numpy.ndarray, scale: float, zero_point: int, dtype: Any = numpy.uint8
) -> numpy.ndarray:
dmin = numpy.iinfo(dtype).min
dmax = numpy.iinfo(dtype).max
return ((array / scale).round() + zero_point).clip(dmin, dmax).astype(dtype)
if dtype == numpy.uint8:
tensor_dtype = torch.quint8
elif dtype == numpy.int8:
tensor_dtype = torch.qint8
elif dtype == numpy.int32:
tensor_dtype = torch.qint32

tensor = torch.Tensor(array).to(torch.float32)
if isinstance(scale, numpy.ndarray):
scale = scale.item()
if isinstance(zero_point, numpy.ndarray):
zero_point = zero_point.item()

quant_tensor = torch.quantize_per_tensor(tensor, scale, zero_point, tensor_dtype)
return quant_tensor.int_repr().numpy()


def _convert_quantizable_conv(
Expand Down Expand Up @@ -450,6 +462,7 @@ def _convert_quantizable_gemm(
weight_quantize_params.target,
weight_quantize_params.scale,
weight_quantize_params.zero_point,
weight_quantize_params.zero_point.dtype,
)
quantized_weight = quantized_weight.transpose() # Gemm has implicit transpose
quantized_weight_name = "{}.weight_quantized".format(gemm_node.name)
Expand Down Expand Up @@ -732,6 +745,7 @@ def _add_quantized_conv_matmul_add_ops(
weight_quantize_params.target,
weight_quantize_params.scale,
weight_quantize_params.zero_point,
weight_quantize_params.zero_point.dtype,
)
if transpose_weight:
quantized_weight = quantized_weight.transpose()
Expand Down Expand Up @@ -1404,7 +1418,9 @@ def _quantize_qat_embedding(model: ModelProto):
embedding = numpy_helper.to_array(embedding_initializer)
scale = numpy_helper.to_array(scale_initializer)
zero_point = numpy_helper.to_array(zp_initializer)
embedding_quant = _quantize_array(embedding, scale, zero_point)
embedding_quant = _quantize_array(
embedding, scale, zero_point, zero_point.dtype
)
embedding_quant_initializer = numpy_helper.from_array(
embedding_quant, name=f"{embedding_initializer.name}_quant"
)
Expand Down Expand Up @@ -1569,7 +1585,6 @@ def quantize_torch_qat_export(
_convert_quantizable_gemm_no_activations(model)
_quantize_qat_embedding(model)
quantize_resnet_identity_add_inputs(model)
quantized_residual_add_optim(model)
_remove_duplicate_quantize_ops(model)
_cleanup_unused_quants(model)

Expand Down
2 changes: 1 addition & 1 deletion src/sparseml/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from datetime import date


version_base = "1.0.0"
version_base = "1.0.1"
is_release = False # change to True to set the generated version as a release version


Expand Down

0 comments on commit 0fa9f72

Please sign in to comment.