Skip to content

Commit

Permalink
[hotfix 0.10.1] transformers export and QAT flow fixes (#549)
Browse files Browse the repository at this point in the history
* Update README.md for transformers to note the quantization conversion issue (#539)

* Update README.md

* Update integrations/huggingface-transformers/README.md

Co-authored-by: Jeannie Finks <74554921+jeanniefinks@users.noreply.github.com>

Co-authored-by: Jeannie Finks <74554921+jeanniefinks@users.noreply.github.com>

* Enforce order on input keys to export (#545)

* Enforce order on input keys to export

* Warn if input dropped from onnx export

* Restrict mistune version to fix docs build (#547)

* quantization fixes for transformers flows (#548)

* quantization fixes for transformers flows

* match on class name instead

* quality

* set release branch version to 0.10.1

* Revert "Update README.md for transformers to note the quantization conversion issue (#539)"

This reverts commit 9304997.

Co-authored-by: Mark Kurtz <mark@neuralmagic.com>
Co-authored-by: Jeannie Finks <74554921+jeanniefinks@users.noreply.github.com>
Co-authored-by: Tuan Nguyen <tuan@neuralmagic.com>
Co-authored-by: Michael Goin <michael@neuralmagic.com>
  • Loading branch information
5 people committed Feb 9, 2022
1 parent b09c6d0 commit ce8a677
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 20 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
"flake8==3.9.2",
"isort==5.8.0",
"m2r2~=0.2.7",
"mistune==0.8.4",
"myst-parser~=0.14.0",
"rinohtype~=0.4.2",
"sphinx~=3.5.0",
Expand Down
27 changes: 27 additions & 0 deletions src/sparseml/pytorch/optim/modifier_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ class QuantizationModifier(ScheduledModifier):
transformer based models such as BERT where the quantized MatMul outputs
are kept at 32 bits of precision and fake quantizing the outputs harm training
recovery. Default is True
:param exclude_module_types: optional list of module class names
to not propagate quantization configs to. Default is None
"""

def __init__(
Expand All @@ -114,6 +116,7 @@ def __init__(
quantize_embeddings: bool = True,
reduce_range: bool = False,
quantize_linear_activations: bool = True,
exclude_module_types: Union[List[str], None] = None,
):
if torch_quantization is None or torch_intrinsic is None:
raise RuntimeError(
Expand All @@ -138,6 +141,7 @@ def __init__(
self._quantize_embeddings = quantize_embeddings
self._reduce_range = reduce_range
self._quantize_linear_activations = quantize_linear_activations
self._exclude_module_types = exclude_module_types

self._modules_to_quantize = None
self._qat_enabled = False
Expand Down Expand Up @@ -278,6 +282,14 @@ def quantize_linear_activations(self) -> bool:
"""
return self._quantize_linear_activations

@ModifierProp()
def exclude_module_types(self) -> Union[List[str], None]:
"""
:return: optional list of module class names to not propagate
quantization configs to. Default is None
"""
return self._exclude_module_types

def initialize(
self,
module: Module,
Expand Down Expand Up @@ -423,10 +435,15 @@ def _enable_module_qat(self, module: Module):
if not self._quantize_linear_activations:
remove_activation_qat_by_layer_name(quant_module, ["Linear"])

# remove qconfigs for module types in exclude_module_types
if self._exclude_module_types:
self._strip_excluded_module_qconfigs(module)

# set modules with proper qconfigs to QAT mode
torch_quantization.prepare_qat(module, inplace=True)
if self._quantize_embeddings:
prepare_embeddings_qat(module, reduce_range=self._reduce_range)

self._qat_enabled = True

def _disable_quantization_observer_update_ready(self, epoch: float) -> bool:
Expand All @@ -443,6 +460,16 @@ def _freeze_bn_stats_update_ready(self, epoch: float) -> bool:
and not self._bn_stats_frozen
)

def _strip_excluded_module_qconfigs(self, module: Module):
if not self._exclude_module_types:
return
excluded_classes = set(self._exclude_module_types)
for submodule in module.modules():
if submodule.__class__.__name__ in excluded_classes and hasattr(
submodule, "qconfig"
):
submodule.qconfig = None

def _validate_params(self):
if (
self._disable_quantization_observer_epoch is not None
Expand Down
28 changes: 20 additions & 8 deletions src/sparseml/pytorch/utils/quantization/quantize_qat_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,11 @@ def _delete_repeated_qat_blocks(model: ModelProto):
nodes_to_delete.append(dequant_node_1)

for n in nodes_to_delete:
delete_quant_node(model, n)
delete_quant_node(model, n, keep_params=True)

# cleanup graph
graph.update()
graph.delete_unused_initializers()


def _attribute_to_kwarg(attribute: onnx.AttributeProto):
Expand Down Expand Up @@ -1214,12 +1218,14 @@ def _quantize_qat_embedding(model: ModelProto):
qdq_output = False

if qdq_output:
# forward gather output to dequant input
output_dequant_node.input[0] = gather_node.output[0]
output_dequant_node.input[1] = input_quant_node.input[1]
output_dequant_node.input[2] = input_quant_node.input[2]
# delete unnecessary quantize and dequantize ops
delete_quant_node(model, input_quant_node, keep_params=False)
delete_quant_node(model, input_quant_node, keep_params=True)
delete_quant_node(model, input_dequant_node, keep_params=False)
delete_quant_node(model, output_quant_node, keep_params=False)
# forward gather output to dequant input
output_dequant_node.input[0] = gather_node.output[0]

else:
# use input dequant to dequantize output
Expand Down Expand Up @@ -1265,7 +1271,10 @@ def _remove_duplicate_quantize_ops(model: ModelProto):
_replace_input_id_model(
model, remove_node.output[0], keep_node.output[0]
)
remove_node_and_params_from_graph(model, remove_node)
delete_quant_node(model, remove_node, keep_params=True)
# cleanup graph
graph.update()
graph.delete_unused_initializers()


def _cleanup_unused_quants(model: ModelProto):
Expand Down Expand Up @@ -1296,15 +1305,18 @@ def _cleanup_unused_quants(model: ModelProto):
continue

# Forward QuantizeLinear input to DequantizeLinear output
for child in dequant_children:
_replace_input_id_model(model, dequant_node.output[0], quant_node.input[0])
_replace_input_id_model(model, dequant_node.output[0], quant_node.input[0])

# Remove QuantizeLinear->DequantizeLinear block
nodes_to_delete.append(quant_node)
nodes_to_delete.append(dequant_node)

for n in nodes_to_delete:
delete_quant_node(model, n)
delete_quant_node(model, n, keep_params=True)

# update graph
graph.update()
graph.delete_unused_initializers()


def quantize_torch_qat_export(
Expand Down
21 changes: 21 additions & 0 deletions src/sparseml/transformers/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@
"""

import argparse
import collections
import inspect
import logging
import math
import os
Expand Down Expand Up @@ -180,13 +182,32 @@ def export_transformer_to_onnx(
inputs = tokenizer(
"", return_tensors="pt", padding=PaddingStrategy.MAX_LENGTH.value
).data # Dict[Tensor]

# Rearrange inputs' keys to match those defined by model foward func, which
# seem to define how the order of inputs is determined in the exported model
forward_args_spec = inspect.getfullargspec(model.__class__.forward)
dropped = [f for f in inputs.keys() if f not in forward_args_spec.args]
inputs = collections.OrderedDict(
[
(f, inputs[f][0].reshape(1, -1))
for f in forward_args_spec.args
if f in inputs
]
)
if dropped:
_LOGGER.warning(
"The following inputs were not present in the model forward function "
f"and therefore dropped from ONNX export: {dropped}"
)

inputs_shapes = {
key: (
f"{val.dtype if hasattr(val, 'dtype') else 'unknown'}: "
f"{list(val.shape) if hasattr(val, 'shape') else 'unknown'}"
)
for key, val in inputs.items()
}

_LOGGER.info(f"Created sample inputs for the ONNX export process: {inputs_shapes}")

# run export
Expand Down
2 changes: 1 addition & 1 deletion src/sparseml/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from datetime import date


version_base = "0.10.0"
version_base = "0.10.1"
is_release = False # change to True to set the generated version as a release version


Expand Down
32 changes: 21 additions & 11 deletions tests/sparseml/pytorch/optim/test_modifier_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@
start_epoch=0.0,
quantize_linear_activations=False,
),
lambda: QuantizationModifier(
start_epoch=0.0,
exclude_module_types=["Linear"],
),
]


Expand All @@ -67,9 +71,13 @@ def _is_quantiable_module(module):
return isinstance(module, (Conv2d, Linear))


def _test_quantizable_module(
module, qat_expected, reduce_range, quantize_linear_activations
):
def _test_quantizable_module(module, qat_expected, modifier):
reduce_range = modifier.reduce_range
quantize_linear_activations = modifier.quantize_linear_activations

excluded_types = modifier.exclude_module_types or []
qat_expected = qat_expected and module.__class__.__name__ not in excluded_types

if qat_expected:
assert hasattr(module, "qconfig") and module.qconfig is not None
assert hasattr(module, "weight_fake_quant") and (
Expand Down Expand Up @@ -97,12 +105,7 @@ def _test_qat_applied(modifier, model):
submodules = [""]
for module in model.modules():
if _is_quantiable_module(module):
_test_quantizable_module(
module,
True,
modifier.reduce_range,
modifier.quantize_linear_activations,
)
_test_quantizable_module(module, True, modifier)
else:
assert not hasattr(model, "qconfig") or model.qconfig is None
submodules = modifier.submodules
Expand All @@ -112,8 +115,7 @@ def _test_qat_applied(modifier, model):
_test_quantizable_module(
module,
_is_valid_submodule(name, submodules),
modifier.reduce_range,
modifier.quantize_linear_activations,
modifier,
)


Expand Down Expand Up @@ -207,6 +209,7 @@ def test_quantization_modifier_yaml():
quantize_embeddings = False
reduce_range = True
quantize_linear_activations = False
exclude_module_types = ["LayerNorm", "Tanh"]
yaml_str = f"""
!QuantizationModifier
start_epoch: {start_epoch}
Expand All @@ -217,6 +220,7 @@ def test_quantization_modifier_yaml():
quantize_embeddings: {quantize_embeddings}
reduce_range: {reduce_range}
quantize_linear_activations: {quantize_linear_activations}
exclude_module_types: {exclude_module_types}
"""
yaml_modifier = QuantizationModifier.load_obj(
yaml_str
Expand All @@ -233,6 +237,7 @@ def test_quantization_modifier_yaml():
quantize_embeddings=quantize_embeddings,
reduce_range=reduce_range,
quantize_linear_activations=quantize_linear_activations,
exclude_module_types=exclude_module_types,
)

assert isinstance(yaml_modifier, QuantizationModifier)
Expand Down Expand Up @@ -276,3 +281,8 @@ def test_quantization_modifier_yaml():
== serialized_modifier.quantize_linear_activations
== obj_modifier.quantize_linear_activations
)
assert (
yaml_modifier.exclude_module_types
== serialized_modifier.exclude_module_types
== obj_modifier.exclude_module_types
)

0 comments on commit ce8a677

Please sign in to comment.