Skip to content

Commit

Permalink
QAT folding update (#1639)
Browse files Browse the repository at this point in the history
* Add transformation to propagate dequantize op through split

* Remove requirement that QuantizeLinear must be next to DequantizeLinear for input branch of Conv node

* Fixed embedding quantization propagation

* Quality fixes

* Add zero point to dequant node

* Add zero point to initializers

* Style fixes

* Fix data type

* Allow MatMul weight to be on either input 0 or 1

* Style fixes

* Add padding value

* Make initializers distinct

* Style and quality fixes

* Make bias optional for Conv QAT conversion

* Quality fix
  • Loading branch information
anmarques committed Jul 10, 2023
1 parent 0617a9e commit d0ba055
Show file tree
Hide file tree
Showing 7 changed files with 230 additions and 18 deletions.
1 change: 1 addition & 0 deletions src/sparseml/exporters/onnx_to_deepsparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(
sparseml_transforms.DeleteRepeatedQdq(),
sparseml_transforms.QuantizeQATEmbedding(),
sparseml_transforms.PropagateEmbeddingQuantization(),
sparseml_transforms.PropagateDequantThroughSplit(),
sparseml_transforms.MatMulToQLinearMatMul(),
sparseml_transforms.MatMulAddToMatMulIntegerAddCastMul(),
sparseml_transforms.MatMulToMatMulIntegerCastMul(),
Expand Down
1 change: 1 addition & 0 deletions src/sparseml/exporters/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from .matmul_add_to_matmulinteger_add_cast_mul import MatMulAddToMatMulIntegerAddCastMul
from .matmul_to_matmulinteger_cast_mul import MatMulToMatMulIntegerCastMul
from .propagate_embedding_quantization import PropagateEmbeddingQuantization
from .propagate_dequant_through_split import PropagateDequantThroughSplit
from .quantize_qat_embedding import QuantizeQATEmbedding
from .quantize_residuals import QuantizeResiduals
from .remove_duplicate_qconv_weights import RemoveDuplicateQConvWeights
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,12 @@ class ConvToConvIntegerAddCastMul(OnnxTransform):

def transform(self, model: ModelProto) -> ModelProto:
graph = ONNXGraph(model)
matches = get_structural_matches(

# Nodes with bias
matches_bias = get_structural_matches(
graph,
parent_ops=[
["QuantizeLinear", "DequantizeLinear"],
["DequantizeLinear"],
[
# weight should be initializer
INITIALIZER_MATCH,
Expand All @@ -78,20 +80,50 @@ def transform(self, model: ModelProto) -> ModelProto:
],
[
# bias should be initializer
INITIALIZER_MATCH
INITIALIZER_MATCH,
],
],
op_type="Conv",
)

# Nodes without bias
matches_no_bias = get_structural_matches(
graph,
parent_ops=[
["DequantizeLinear"],
[
# weight should be initializer
INITIALIZER_MATCH,
"QuantizeLinear",
"DequantizeLinear",
],
],
op_type="Conv",
)

matches = matches_bias
matches_names = [m.node.name for m in matches]
for match in matches_no_bias:
if match.node.name not in matches_names:
matches.append(match)

for match in matches:
self.log_match(match)
self._transform_match(graph, model, match)
return model

def _transform_match(self, graph: ONNXGraph, model: ModelProto, match: MatchResult):
input_quant, input_dequant = match.parents[0]
def _transform_match(
self,
graph: ONNXGraph,
model: ModelProto,
match: MatchResult,
):
(input_dequant,) = match.parents[0]
weight_init, weight_quantize_node, weight_dequantize_node = match.parents[1]
(bias_init,) = match.parents[2]
if len(match.parents) == 3:
(bias_init,) = match.parents[2]
else:
bias_init = None

model = add_quantized_conv_matmul_add_ops(
model=model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,37 @@ class MatMulAddToMatMulIntegerAddCastMul(OnnxTransform):

def transform(self, model: ModelProto) -> ModelProto:
graph = ONNXGraph(model)

# Weight on input 0
matches = get_structural_matches(
graph,
op_type="MatMul",
parent_ops=[
[
# weight should be initializer
INITIALIZER_MATCH,
"QuantizeLinear",
"DequantizeLinear",
optional_node("Transpose"),
],
[any_of("QuantizeLinear", "DequantizeLinear")],
],
children_ops=[[optional_node("Add")]],
)
for match in matches:
add_node = match.children[0][0]
bias_init = None
if add_node:
# NOTE: bias could be either input 0 or 1 of add node
# if add does not have a bias initializer,
# still do conversion, but do not fold the bias add to rescale
bias_init = graph.get_init_by_name(match.children[0][0].input[1])
if bias_init is None:
bias_init = graph.get_init_by_name(match.children[0][0].input[0])
self.log_match(match)
self._transform_match(graph, model, match, bias_init, 0)

# Weight on input 1
matches = get_structural_matches(
graph,
op_type="MatMul",
Expand All @@ -93,7 +124,8 @@ def transform(self, model: ModelProto) -> ModelProto:
if bias_init is None:
bias_init = graph.get_init_by_name(match.children[0][0].input[0])
self.log_match(match)
self._transform_match(graph, model, match, bias_init)
self._transform_match(graph, model, match, bias_init, 1)

return model

def _transform_match(
Expand All @@ -102,10 +134,15 @@ def _transform_match(
model: ModelProto,
match: MatchResult,
bias_init: TensorProto,
weight_parent: int,
):
matmul = match.node
(input_quant,) = match.parents[0]
weight_init, weight_quant, weight_dequant, opt_transpose = match.parents[1]
if weight_parent == 0:
(input_quant,) = match.parents[1]
weight_init, weight_quant, weight_dequant, opt_transpose = match.parents[0]
else:
(input_quant,) = match.parents[0]
weight_init, weight_quant, weight_dequant, opt_transpose = match.parents[1]
(add,) = match.children[0]

input_quantize_params = get_quantization_params(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import onnx
from onnx import ModelProto

from sparseml.exporters.transforms import OnnxTransform
from sparseml.exporters.transforms.utils import MatchResult, get_structural_matches
from sparseml.onnx.utils import ONNXGraph


__all__ = ["PropagateDequantThroughSplit"]


class PropagateDequantThroughSplit(OnnxTransform):
"""
A pass for propagating DequantizeLinear nodes down through a split node
so if there are quantized operations after the split they can
be properly converted.
Starting with:
| INPUT
| |
| DequantizeLinear
| |
| Split
| | | |
Converts to:
| INPUT
| |
| Split
| | | |
| DequantizeLinear DequantizeLinear DequantizeLinear
| | | |
"""

def transform(self, model: ModelProto) -> ModelProto:
graph = ONNXGraph(model)
matches = get_structural_matches(
graph,
parent_ops=[["DequantizeLinear"]],
op_type="Split",
)
for match in matches:
self.log_match(match)
self._transform_match(model, match)
return model

def _transform_match(self, model: ModelProto, match: MatchResult):

# Loop through the nodes that are children of the Split node
# For every child, create a DequantizeLinear node and insert
# between Split and child
for split_output_id in range(len(match.node.output)):
dequant_node_name = match.node.name + f"_dequant.{split_output_id}"
dequant_node_output = match.node.output[split_output_id]
dequant_node_input = dequant_node_name + "_input"

# Input to DequantizeLinear node is the output of the Split node
model.graph.node.append(
onnx.helper.make_node(
"DequantizeLinear",
[
dequant_node_input, # input
match.parents[0][0].input[1], # scale
match.parents[0][0].input[2], # zero point
],
[dequant_node_output],
dequant_node_name,
)
)

# Replace the output of the Split node with the input of
# the new DequantizeLinear node
match.node.output[split_output_id] = dequant_node_input

# Set the input to the Split node to what was the input of the
# original DequantizeLinear node
match.node.input[0] = match.parents[0][0].input[0]

# Remove original DequantizeLinear node
self.delete_node_deferred(match.parents[0][0])
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import logging

import numpy
import onnx.numpy_helper
from onnx import ModelProto, numpy_helper

from sparseml.exporters.transforms.onnx_transform import OnnxTransform
Expand Down Expand Up @@ -79,16 +80,22 @@ def transform(self, model: ModelProto) -> ModelProto:
["Concat"],
],
)

initializer_dict = {i.name: i for i in model.graph.initializer}

for match in matches:
(gather,) = match.parents[0]
dequant = match.node
slice1, _, concat1 = match.children[0]
slice2, _, concat2 = match.children[1]
slice1, pad1, concat1 = match.children[0]
slice2, pad2, concat2 = match.children[1]
(concat,) = match.children[2]

# check for uint8 initializer
indices = graph.get_init_by_name(gather.input[0])
if indices is None or numpy_helper.to_array(indices).dtype != numpy.uint8:
if indices is None or numpy_helper.to_array(indices).dtype not in [
numpy.uint8,
numpy.int8,
]:
continue

# check that all concats are the same
Expand All @@ -97,11 +104,35 @@ def transform(self, model: ModelProto) -> ModelProto:

self.log_match(match)

assert concat.input[2] == dequant.output[0]
concat.input[2] = gather.output[0]
for id, input_name in enumerate(concat.input):
if input_name == dequant.output[0]:
break

concat.input[id] = gather.output[0]
slice1.input[0] = gather.output[0]
slice2.input[0] = gather.output[0]

zero_point_initializer = initializer_dict[match.node.input[2]]
zero_point = onnx.numpy_helper.to_array(zero_point_initializer)

pad1_value_initializer = initializer_dict[pad1.input[2]]
pad1_value = onnx.numpy_helper.to_array(pad1_value_initializer)
pad1_value = pad1_value.astype(zero_point.dtype) + zero_point
new_pad1_value_initializer = numpy_helper.from_array(
pad1_value, name=pad1_value_initializer.name
)
model.graph.initializer.remove(pad1_value_initializer)
model.graph.initializer.append(new_pad1_value_initializer)

pad2_value_initializer = initializer_dict[pad2.input[2]]
pad2_value = onnx.numpy_helper.to_array(pad2_value_initializer)
pad2_value = pad2_value.astype(zero_point.dtype) + zero_point
new_pad2_value_initializer = numpy_helper.from_array(
pad2_value, name=pad2_value_initializer.name
)
model.graph.initializer.remove(pad2_value_initializer)
model.graph.initializer.append(new_pad2_value_initializer)

tmp = concat.output[0]
concat.output[0] = dequant.output[0]
dequant.output[0] = tmp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,18 @@ def onnx_model():
"output", onnx.TensorProto.FLOAT, (1,)
)
scale = onnx.helper.make_tensor("scale", onnx.TensorProto.FLOAT, (1,), [1.0])
zero_point = onnx.helper.make_tensor(
"zero_point", onnx.TensorProto.UINT8, (1,), [128]
)
starts = onnx.helper.make_tensor("starts", onnx.TensorProto.INT64, (1,), [0])
ends = onnx.helper.make_tensor("ends", onnx.TensorProto.INT64, (1,), [1])
pads = onnx.helper.make_tensor("pads", onnx.TensorProto.INT64, (1,), [1])
padding1_value = onnx.helper.make_tensor(
"padding1_value", onnx.TensorProto.FLOAT, (1,), [0.0]
)
padding2_value = onnx.helper.make_tensor(
"padding2_value", onnx.TensorProto.FLOAT, (1,), [0.0]
)
embeddings = onnx.helper.make_tensor(
"embeddings", onnx.TensorProto.UINT8, (1,), [0]
)
Expand All @@ -43,7 +52,7 @@ def onnx_model():
)
dequant = onnx.helper.make_node(
"DequantizeLinear",
["gather_output", "scale"],
["gather_output", "scale", "zero_point"],
["dequant_output"],
name="dequant",
)
Expand All @@ -52,13 +61,13 @@ def onnx_model():
"Slice", ["dequant_output", "starts", "ends"], ["slice1_output"], name="slice1"
)
pad1 = onnx.helper.make_node(
"Pad", ["slice1_output", "pads"], ["pad1_output"], name="pad1"
"Pad", ["slice1_output", "pads", "padding1_value"], ["pad1_output"], name="pad1"
)
slice2 = onnx.helper.make_node(
"Slice", ["dequant_output", "starts", "ends"], ["slice2_output"], name="slice2"
)
pad2 = onnx.helper.make_node(
"Pad", ["slice2_output", "pads"], ["pad2_output"], name="pad2"
"Pad", ["slice2_output", "pads", "padding2_value"], ["pad2_output"], name="pad2"
)
concat = onnx.helper.make_node(
"Concat",
Expand All @@ -73,7 +82,16 @@ def onnx_model():
name="g",
inputs=[model_input],
outputs=[model_output],
initializer=[scale, starts, ends, embeddings, pads],
initializer=[
scale,
zero_point,
starts,
ends,
embeddings,
pads,
padding1_value,
padding2_value,
],
)

model = onnx.helper.make_model(graph)
Expand Down

0 comments on commit d0ba055

Please sign in to comment.