Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

QAT folding update #1639

Merged
merged 26 commits into from
Jul 10, 2023
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
031125f
Add transformation to propagate dequantize op through split
anmarques Jun 23, 2023
c54c289
Remove requirement that QuantizeLinear must be next to DequantizeLine…
anmarques Jun 23, 2023
20632ce
Fixed embedding quantization propagation
anmarques Jun 23, 2023
4384a11
Quality fixes
anmarques Jun 23, 2023
c64310e
Merge branch 'main' into feature/qat_folding_update
anmarques Jun 26, 2023
b5ea337
Merge branch 'main' into feature/qat_folding_update
anmarques Jun 27, 2023
a98bd91
Add zero point to dequant node
anmarques Jun 27, 2023
5316cc8
Merge branch 'feature/qat_folding_update' of github.com:neuralmagic/s…
anmarques Jun 27, 2023
32db97b
Add zero point to initializers
anmarques Jun 27, 2023
8e18613
Style fixes
anmarques Jun 27, 2023
fdae854
Fix data type
anmarques Jun 27, 2023
85c6828
Allow MatMul weight to be on either input 0 or 1
anmarques Jun 27, 2023
f2bf1d7
Style fixes
anmarques Jun 27, 2023
f904e49
Add padding value
anmarques Jun 27, 2023
8e15c59
Make initializers distinct
anmarques Jun 27, 2023
1cdbacb
Style and quality fixes
anmarques Jun 27, 2023
d703d4f
Merge branch 'main' into feature/qat_folding_update
anmarques Jun 27, 2023
8f47a93
Merge branch 'main' into feature/qat_folding_update
anmarques Jun 28, 2023
80d678b
Merge branch 'main' into feature/qat_folding_update
anmarques Jun 30, 2023
252ac11
Merge branch 'main' into feature/qat_folding_update
anmarques Jun 30, 2023
8b622a4
Merge branch 'main' into feature/qat_folding_update
anmarques Jul 3, 2023
7629f5e
Make bias optional for Conv QAT conversion
anmarques Jul 5, 2023
d189627
Quality fix
anmarques Jul 5, 2023
ce5bf37
Merge branch 'main' into feature/qat_folding_update
anmarques Jul 7, 2023
03b4279
Merge branch 'main' into feature/qat_folding_update
anmarques Jul 10, 2023
9058a6c
Merge branch 'main' into feature/qat_folding_update
anmarques Jul 10, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/sparseml/exporters/onnx_to_deepsparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(
sparseml_transforms.DeleteRepeatedQdq(),
sparseml_transforms.QuantizeQATEmbedding(),
sparseml_transforms.PropagateEmbeddingQuantization(),
sparseml_transforms.PropagateDequantThroughSplit(),
sparseml_transforms.MatMulToQLinearMatMul(),
sparseml_transforms.MatMulAddToMatMulIntegerAddCastMul(),
sparseml_transforms.MatMulToMatMulIntegerCastMul(),
Expand Down
1 change: 1 addition & 0 deletions src/sparseml/exporters/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from .matmul_add_to_matmulinteger_add_cast_mul import MatMulAddToMatMulIntegerAddCastMul
from .matmul_to_matmulinteger_cast_mul import MatMulToMatMulIntegerCastMul
from .propagate_embedding_quantization import PropagateEmbeddingQuantization
from .propagate_dequant_through_split import PropagateDequantThroughSplit
from .quantize_qat_embedding import QuantizeQATEmbedding
from .quantize_residuals import QuantizeResiduals
from .remove_duplicate_qconv_weights import RemoveDuplicateQConvWeights
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def transform(self, model: ModelProto) -> ModelProto:
matches = get_structural_matches(
graph,
parent_ops=[
["QuantizeLinear", "DequantizeLinear"],
["DequantizeLinear"],
[
# weight should be initializer
INITIALIZER_MATCH,
Expand All @@ -89,7 +89,7 @@ def transform(self, model: ModelProto) -> ModelProto:
return model

def _transform_match(self, graph: ONNXGraph, model: ModelProto, match: MatchResult):
input_quant, input_dequant = match.parents[0]
(input_dequant,) = match.parents[0]
weight_init, weight_quantize_node, weight_dequantize_node = match.parents[1]
(bias_init,) = match.parents[2]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,37 @@ class MatMulAddToMatMulIntegerAddCastMul(OnnxTransform):

def transform(self, model: ModelProto) -> ModelProto:
graph = ONNXGraph(model)

# Weight on input 0
matches = get_structural_matches(
graph,
op_type="MatMul",
parent_ops=[
[
# weight should be initializer
INITIALIZER_MATCH,
"QuantizeLinear",
"DequantizeLinear",
optional_node("Transpose"),
],
[any_of("QuantizeLinear", "DequantizeLinear")],
],
children_ops=[[optional_node("Add")]],
)
for match in matches:
add_node = match.children[0][0]
bias_init = None
if add_node:
# NOTE: bias could be either input 0 or 1 of add node
# if add does not have a bias initializer,
# still do conversion, but do not fold the bias add to rescale
bias_init = graph.get_init_by_name(match.children[0][0].input[1])
if bias_init is None:
bias_init = graph.get_init_by_name(match.children[0][0].input[0])
self.log_match(match)
self._transform_match(graph, model, match, bias_init, 0)

# Weight on input 1
matches = get_structural_matches(
graph,
op_type="MatMul",
Expand All @@ -93,7 +124,8 @@ def transform(self, model: ModelProto) -> ModelProto:
if bias_init is None:
bias_init = graph.get_init_by_name(match.children[0][0].input[0])
self.log_match(match)
self._transform_match(graph, model, match, bias_init)
self._transform_match(graph, model, match, bias_init, 1)

return model

def _transform_match(
Expand All @@ -102,10 +134,15 @@ def _transform_match(
model: ModelProto,
match: MatchResult,
bias_init: TensorProto,
weight_parent: int,
):
matmul = match.node
(input_quant,) = match.parents[0]
weight_init, weight_quant, weight_dequant, opt_transpose = match.parents[1]
if weight_parent == 0:
(input_quant,) = match.parents[1]
weight_init, weight_quant, weight_dequant, opt_transpose = match.parents[0]
else:
(input_quant,) = match.parents[0]
weight_init, weight_quant, weight_dequant, opt_transpose = match.parents[1]
(add,) = match.children[0]

input_quantize_params = get_quantization_params(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import onnx
from onnx import ModelProto

from sparseml.exporters.transforms import OnnxTransform
from sparseml.exporters.transforms.utils import MatchResult, get_structural_matches
from sparseml.onnx.utils import ONNXGraph


__all__ = ["PropagateDequantThroughSplit"]


class PropagateDequantThroughSplit(OnnxTransform):
"""
A pass for propagating DequantizeLinear nodes down through a split node
so if there are quantized operations after the split they can
be properly converted.
Starting with:
| INPUT
| |
| DequantizeLinear
| |
| Split
| | | |
Converts to:
| INPUT
| |
| Split
| | | |
| DequantizeLinear DequantizeLinear DequantizeLinear
| | | |
"""

def transform(self, model: ModelProto) -> ModelProto:
graph = ONNXGraph(model)
matches = get_structural_matches(
graph,
parent_ops=[["DequantizeLinear"]],
op_type="Split",
)
for match in matches:
self.log_match(match)
self._transform_match(model, match)
return model

def _transform_match(self, model: ModelProto, match: MatchResult):

# Loop through the nodes that are children of the Split node
# For every child, create a DequantizeLinear node and insert
# between Split and child
for split_output_id in range(len(match.node.output)):
dequant_node_name = match.node.name + f"_dequant.{split_output_id}"
dequant_node_output = match.node.output[split_output_id]
dequant_node_input = dequant_node_name + "_input"

# Input to DequantizeLinear node is the output of the Split node
model.graph.node.append(
onnx.helper.make_node(
"DequantizeLinear",
[
dequant_node_input, # input
match.parents[0][0].input[1], # scale
match.parents[0][0].input[2], # zero point
],
[dequant_node_output],
dequant_node_name,
)
)

# Replace the output of the Split node with the input of
# the new DequantizeLinear node
match.node.output[split_output_id] = dequant_node_input

# Set the input to the Split node to what was the input of the
# original DequantizeLinear node
match.node.input[0] = match.parents[0][0].input[0]

# Remove original DequantizeLinear node
self.delete_node_deferred(match.parents[0][0])
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import logging

import numpy
import onnx.numpy_helper
from onnx import ModelProto, numpy_helper

from sparseml.exporters.transforms.onnx_transform import OnnxTransform
Expand Down Expand Up @@ -79,16 +80,22 @@ def transform(self, model: ModelProto) -> ModelProto:
["Concat"],
],
)

initializer_dict = {i.name: i for i in model.graph.initializer}

for match in matches:
(gather,) = match.parents[0]
dequant = match.node
slice1, _, concat1 = match.children[0]
slice2, _, concat2 = match.children[1]
slice1, pad1, concat1 = match.children[0]
slice2, pad2, concat2 = match.children[1]
(concat,) = match.children[2]

# check for uint8 initializer
indices = graph.get_init_by_name(gather.input[0])
if indices is None or numpy_helper.to_array(indices).dtype != numpy.uint8:
if indices is None or numpy_helper.to_array(indices).dtype not in [
numpy.uint8,
numpy.int8,
]:
continue

# check that all concats are the same
Expand All @@ -97,11 +104,35 @@ def transform(self, model: ModelProto) -> ModelProto:

self.log_match(match)

assert concat.input[2] == dequant.output[0]
concat.input[2] = gather.output[0]
for id, input_name in enumerate(concat.input):
if input_name == dequant.output[0]:
break

concat.input[id] = gather.output[0]
slice1.input[0] = gather.output[0]
slice2.input[0] = gather.output[0]

zero_point_initializer = initializer_dict[match.node.input[2]]
zero_point = onnx.numpy_helper.to_array(zero_point_initializer)

pad1_value_initializer = initializer_dict[pad1.input[2]]
pad1_value = onnx.numpy_helper.to_array(pad1_value_initializer)
pad1_value = pad1_value.astype(zero_point.dtype) + zero_point
new_pad1_value_initializer = numpy_helper.from_array(
pad1_value, name=pad1_value_initializer.name
)
model.graph.initializer.remove(pad1_value_initializer)
model.graph.initializer.append(new_pad1_value_initializer)

pad2_value_initializer = initializer_dict[pad2.input[2]]
pad2_value = onnx.numpy_helper.to_array(pad2_value_initializer)
pad2_value = pad2_value.astype(zero_point.dtype) + zero_point
new_pad2_value_initializer = numpy_helper.from_array(
pad2_value, name=pad2_value_initializer.name
)
model.graph.initializer.remove(pad2_value_initializer)
model.graph.initializer.append(new_pad2_value_initializer)

tmp = concat.output[0]
concat.output[0] = dequant.output[0]
dequant.output[0] = tmp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,18 @@ def onnx_model():
"output", onnx.TensorProto.FLOAT, (1,)
)
scale = onnx.helper.make_tensor("scale", onnx.TensorProto.FLOAT, (1,), [1.0])
zero_point = onnx.helper.make_tensor(
"zero_point", onnx.TensorProto.UINT8, (1,), [128]
)
starts = onnx.helper.make_tensor("starts", onnx.TensorProto.INT64, (1,), [0])
ends = onnx.helper.make_tensor("ends", onnx.TensorProto.INT64, (1,), [1])
pads = onnx.helper.make_tensor("pads", onnx.TensorProto.INT64, (1,), [1])
padding1_value = onnx.helper.make_tensor(
"padding1_value", onnx.TensorProto.FLOAT, (1,), [0.0]
)
padding2_value = onnx.helper.make_tensor(
"padding2_value", onnx.TensorProto.FLOAT, (1,), [0.0]
)
embeddings = onnx.helper.make_tensor(
"embeddings", onnx.TensorProto.UINT8, (1,), [0]
)
Expand All @@ -43,7 +52,7 @@ def onnx_model():
)
dequant = onnx.helper.make_node(
"DequantizeLinear",
["gather_output", "scale"],
["gather_output", "scale", "zero_point"],
["dequant_output"],
name="dequant",
)
Expand All @@ -52,13 +61,13 @@ def onnx_model():
"Slice", ["dequant_output", "starts", "ends"], ["slice1_output"], name="slice1"
)
pad1 = onnx.helper.make_node(
"Pad", ["slice1_output", "pads"], ["pad1_output"], name="pad1"
"Pad", ["slice1_output", "pads", "padding1_value"], ["pad1_output"], name="pad1"
)
slice2 = onnx.helper.make_node(
"Slice", ["dequant_output", "starts", "ends"], ["slice2_output"], name="slice2"
)
pad2 = onnx.helper.make_node(
"Pad", ["slice2_output", "pads"], ["pad2_output"], name="pad2"
"Pad", ["slice2_output", "pads", "padding2_value"], ["pad2_output"], name="pad2"
)
concat = onnx.helper.make_node(
"Concat",
Expand All @@ -73,7 +82,16 @@ def onnx_model():
name="g",
inputs=[model_input],
outputs=[model_output],
initializer=[scale, starts, ends, embeddings, pads],
initializer=[
scale,
zero_point,
starts,
ends,
embeddings,
pads,
padding1_value,
padding2_value,
],
)

model = onnx.helper.make_model(graph)
Expand Down
Loading