Skip to content

Commit

Permalink
Support for native Pytorch embedded quantization encodings (quic#1156)
Browse files Browse the repository at this point in the history
Signed-off-by: Huan Zhao <quic_huzh@quicinc.com>

Support FP16 native torch quantizer using cast

Signed-off-by: Huan Zhao <quic_huzh@quicinc.com>

Fixed some comment

Signed-off-by: Huan Zhao <quic_huzh@quicinc.com>

Move functions related to quantizer to quantsim_utils

Signed-off-by: Huan Zhao <quic_huzh@quicinc.com>

Removed uncertain check caused by different version of onnx

Signed-off-by: Huan Zhao <quic_huzh@quicinc.com>

Change API to call embedded encodings from single function to sim.export

Signed-off-by: Huan Zhao <quic_huzh@quicinc.com>

Renamed use_embedded_encodings and raise error when use strict symmetric.

Signed-off-by: Huan Zhao <quic_huzh@quicinc.com>

Resolved conflicts and added testing cases to ensure the output of native torch quantization nodes is correct

Signed-off-by: Huan Zhao <quic_huzh@quicinc.com>
  • Loading branch information
quic-huzh committed Jun 30, 2023
1 parent 2bc2856 commit 4d91499
Show file tree
Hide file tree
Showing 5 changed files with 483 additions and 8 deletions.
94 changes: 92 additions & 2 deletions TrainingExtensions/torch/src/python/aimet_torch/qc_quantize_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@

# pylint: disable=too-many-lines
""" Custom PyTorch Op for quantizing weights and activations """

# pylint: disable=too-many-lines
import abc
from enum import Enum
from typing import Dict, Tuple, Union, List, Callable, Type
Expand All @@ -53,7 +53,7 @@
from aimet_torch.custom import custom_tensor_utils
from aimet_torch import utils
from aimet_torch.tensor_quantizer import StaticGridPerTensorQuantizer, StaticGridPerChannelQuantizer, TensorQuantizer, \
LearnedGridTensorQuantizer, set_encoding_min_max_gating_threshold
LearnedGridTensorQuantizer, set_encoding_min_max_gating_threshold, TorchQuantizer
import aimet_torch.quantsim_straight_through_grad as ste

_logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Quant)
Expand Down Expand Up @@ -967,6 +967,96 @@ def _linear_forward_with_recompute(quant_wrapper: LearnedGridQuantWrapper, input
quant_wrapper.param_quantizers['weight'])


class NativeTorchQuantWrapper(nn.Module):
"""
A custom PyTorch module for inserting native PyToch quantization nodes
"""
def __init__(self, post_training_module: StaticGridQuantWrapper, module_name: str, device: torch.device):
"""
Constructor
:param post_training_module: StaticGridQuantWrapper wrapped module
:param module_name: name of module
:param device: device on which model is
"""
super(NativeTorchQuantWrapper, self).__init__()

self._module_to_wrap = getattr(post_training_module, module_name)
# pylint: disable=protected-access
self._mode = post_training_module._mode

self.output_quantizers = [TorchQuantizer(quantizer, device) for quantizer in post_training_module.output_quantizers]

self.input_quantizers = [TorchQuantizer(quantizer, device) for quantizer in post_training_module.input_quantizers]

self.param_quantizers = {}
for name, quantizer in post_training_module.param_quantizers.items():
self.param_quantizers[name] = TorchQuantizer(quantizer, device)

def _quantize_dequantize(self, tensor_quantizers, tensors_to_quantize):
"""
Forward-pass routine. This quantizes the weights before delegating to the wrapped module and
then quantizes the output before returning the same
:param tensor_quantizers: Tensor quantizers to use for updating stats or quantizing
:param tensors_to_quantize: Inputs passed to the module in the forward pass
:return: Quantized output from the wrapped module
"""
outputs = []
for index, input_tensor in enumerate(tensors_to_quantize):
if not isinstance(input_tensor, torch.Tensor):
_logger.error('Expecting quantize activation input of type torch.Tensor but got %s', type(input_tensor))
raise AssertionError
if input_tensor.dtype in utils.torch_dtypes_to_ignore_for_quantization:
# Do not quantize integer tensors
outputs.append(input_tensor)
continue

assert len(tensor_quantizers) > index, \
f"Not enough tensor quantizers ({len(tensor_quantizers)}) allocated"

if self._mode is QcQuantizeOpMode.ACTIVE:
output = tensor_quantizers[index].quantize_dequantize(input_tensor)
else:
output = input_tensor

outputs.append(output)

# Flatten if there is only one output - which is by far the most common case
if len(outputs) == 1:
outputs = outputs[0]

return outputs

def forward(self, *inputs):
"""
Forward-pass routine. This quantizes the weights before delegating to the wrapped module and
then quantizes the output before returning the same
:param inputs: Inputs passed to the module in the forward pass
:return: Quantized output from the wrapped module
"""
# Quantize inputs
quantized_inputs = self._quantize_dequantize(self.input_quantizers, inputs)
if isinstance(quantized_inputs, torch.Tensor):
quantized_inputs = [quantized_inputs]

# Quantize params
for name, param in self._module_to_wrap.named_parameters():
param_quantizer = self.param_quantizers[name]
if param_quantizer.enabled:
setattr(self._module_to_wrap, name,
torch.nn.parameter.Parameter(param_quantizer.quantize_dequantize(param)))
wrapped_output = self._module_to_wrap(*quantized_inputs)

# Quantize the outputs
if not self.output_quantizers[0].enabled:
output = wrapped_output
else:
if isinstance(wrapped_output, torch.Tensor):
wrapped_output = [wrapped_output]
output = self._quantize_dequantize(self.output_quantizers, wrapped_output)

return output


class QcQuantizeStandalone(QcQuantizeStandAloneBase):
""" A custom PyTorch module that derives from QcQuantizeStandAloneBase and quantizes inputs """

Expand Down
74 changes: 68 additions & 6 deletions TrainingExtensions/torch/src/python/aimet_torch/quantsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@

from aimet_torch.quantsim_config.quantsim_config import QuantSimConfigurator
from aimet_torch.qc_quantize_op import QcQuantizeStandAloneBase, QcQuantizeWrapper, QcQuantizeOpMode, \
StaticGridQuantWrapper, LearnedGridQuantWrapper, QUANTIZER_TYPE_INPUT, QUANTIZER_TYPE_OUTPUT
StaticGridQuantWrapper, LearnedGridQuantWrapper, NativeTorchQuantWrapper, QUANTIZER_TYPE_INPUT, QUANTIZER_TYPE_OUTPUT
from aimet_torch.tensor_quantizer import StaticGridTensorQuantizer, LearnedGridTensorQuantizer, \
initialize_learned_grid_quantizer_attributes
from aimet_torch import torchscript_utils, utils, transformer_utils
Expand Down Expand Up @@ -363,7 +363,7 @@ def set_percentile_value(self, percentile_value: float):

def export(self, path: str, filename_prefix: str, dummy_input: Union[torch.Tensor, Tuple],
onnx_export_args: Optional[Union[OnnxExportApiArgs, Dict]] = None, propagate_encodings: bool = False,
export_to_torchscript: bool = False):
export_to_torchscript: bool = False, use_embedded_encodings: bool = False):
"""
This method exports out the quant-sim model so it is ready to be run on-target.
Expand All @@ -387,6 +387,7 @@ def export(self, path: str, filename_prefix: str, dummy_input: Union[torch.Tenso
multiple ONNX nodes) are filled with the same BW and data_type as the output tensor for that series of
ops. Defaults to False.
:param export_to_torchscript: If True, export to torchscript. Export to onnx otherwise. Defaults to False.
:param use_embedded_encodings: If True, another onnx model embedded with fakequant nodes will be exported
"""
# save the quantized model and encodings
model_filename = filename_prefix + '.pth'
Expand All @@ -401,7 +402,7 @@ def export(self, path: str, filename_prefix: str, dummy_input: Union[torch.Tenso

if export_to_torchscript:
self.export_torch_script_model_and_encodings(path, filename_prefix, model_to_export, self.model,
dummy_input, self._excluded_layer_names)
dummy_input, self._excluded_layer_names, use_embedded_encodings)
else:
if onnx_export_args is None:
onnx_export_args = {'opset_version': None,
Expand All @@ -414,15 +415,16 @@ def export(self, path: str, filename_prefix: str, dummy_input: Union[torch.Tenso
f'unsupported opt_args type={type(onnx_export_args)}')
self.export_onnx_model_and_encodings(path, filename_prefix, model_to_export, self.model,
dummy_input, onnx_export_args, propagate_encodings,
self._module_marker_map, self._is_conditional,
use_embedded_encodings, self._module_marker_map, self._is_conditional,
self._excluded_layer_names, quantizer_args=self.quant_args)

@staticmethod
def export_torch_script_model_and_encodings(path: str, filename_prefix: str,
original_model: torch.nn.Module,
sim_model: torch.nn.Module,
dummy_input: Union[torch.Tensor, Tuple],
excluded_layer_names: List = None):
excluded_layer_names: List = None,
use_embedded_encodings: bool = False):
"""
This method exports a onnx mode and the corresponding encodings
Expand All @@ -432,8 +434,11 @@ def export_torch_script_model_and_encodings(path: str, filename_prefix: str,
:param sim_model: model with the quantsim wrappers
:param dummy_input: Dummy input to the model. Used to parse model graph.
:param excluded_layer_names: List of names of layers that have been excluded from quantization.
:param use_embedded_encodings: If True, another onnx model embedded with fakequant nodes will be exported
:return: None
"""
if use_embedded_encodings:
QuantizationSimModel.save_model_with_embedded_quantization_nodes(sim_model, path, filename_prefix, dummy_input, None)
with utils.in_eval_mode(original_model), torch.no_grad():
trace = torch.jit.trace(original_model, dummy_input)
ts_path = os.path.join(path, filename_prefix + '.torchscript.pth')
Expand All @@ -453,7 +458,7 @@ def export_torch_script_model_and_encodings(path: str, filename_prefix: str,
def export_onnx_model_and_encodings(path: str, filename_prefix: str, original_model: torch.nn.Module,
sim_model: torch.nn.Module, dummy_input: Union[torch.Tensor, Tuple],
onnx_export_args: Union[OnnxExportApiArgs, dict], propagate_encodings: bool,
module_marker_map: Dict[torch.nn.Module, torch.Tensor] = None,
use_embedded_encodings: bool = False, module_marker_map: Dict[torch.nn.Module, torch.Tensor] = None,
is_conditional: bool = False, excluded_layer_names: List = None,
quantizer_args: Dict = None):
"""
Expand All @@ -468,6 +473,7 @@ def export_onnx_model_and_encodings(path: str, filename_prefix: str, original_mo
:param propagate_encodings: If True, encoding entries for intermediate ops (when one PyTorch ops results in
multiple ONNX nodes) are filled with the same BW and data_type as the output tensor for that series of
ops.
:param use_embedded_encodings: If True, another onnx model embedded with fakequant nodes will be exported
:param module_marker_map: Maps module names to traced custom markers (only used for conditional models)
:param is_conditional: True if model is conditional, False otherwise
:param excluded_layer_names: List of names of layers that have been excluded from quantization.
Expand All @@ -482,6 +488,8 @@ def export_onnx_model_and_encodings(path: str, filename_prefix: str, original_mo

for dropout_type in DROPOUT_TYPES:
utils.replace_modules_of_type1_with_type2(original_model, dropout_type, torch.nn.Identity)
if use_embedded_encodings:
QuantizationSimModel.save_model_with_embedded_quantization_nodes(sim_model, path, filename_prefix, dummy_input, onnx_export_args)

OnnxSaver.set_node_names(onnx_path, original_model, dummy_input, is_conditional, module_marker_map,
onnx_export_args)
Expand Down Expand Up @@ -1487,6 +1495,60 @@ def apply_act_rules(act: Tuple[int, QuantizationDataType], allowed_supported_ker
for candidate in set(act_candidates):
apply_act_rules(candidate, supported_kernels, name)

@staticmethod
def _replace_quantization_wrapper_with_native_torch_quantization_nodes(quant_sim_model, device: torch.device):
"""
Recursively remove quantization wrappers from all appropriate modules starting with a given module
:param quant_sim_model: model for which QcQuantizeWrapper gets replaced with wrapped module using
native torch quantization nodes
:param device: device on which model is present
:return:
"""
# Recursively replace quantization wrappers to native torch quantization nodes
for module_name, module_ref in quant_sim_model.named_children():
# Create a native torch quantization node
if isinstance(module_ref, QcQuantizeWrapper):
embedded_module = NativeTorchQuantWrapper(module_ref, '_module_to_wrap', device)
setattr(quant_sim_model, module_name, embedded_module)

elif isinstance(module_ref, QcQuantizeRecurrent):
logger.error('Do not support save model embedded native torch quantization nodes using QcQuantizeRecurrent.')
raise AssertionError

# Recursively call children modules if present
if not utils.is_leaf_module(module_ref):
QuantizationSimModel._replace_quantization_wrapper_with_native_torch_quantization_nodes(module_ref, device)

@staticmethod
def save_model_with_embedded_quantization_nodes(sim_model, path: str, filename_prefix: str, dummy_input: Union[torch.Tensor, Tuple],
onnx_export_args: Union[OnnxExportApiArgs, None] = OnnxExportApiArgs()):
"""
Export model embedded with native torch quantization nodes. These nodes will be exported
as default onnx or torch script quantized nodes.
:param sim_model: model with the quantsim wrappers
:param path: path where to store model pth and encodings
:param filename_prefix: Prefix to use for filenames of the model pth and encodings files
:param dummy_input: Dummy input to the model. Used to parse model graph
:param onnx_export_args: optional export argument with onnx specific overrides if not provide export via
torchscript graph. Int16 can only be exported by torchscript
:return:
"""

model_filename = filename_prefix + '_embedded' + '.onnx'
model_path = os.path.join(path, model_filename)
quant_sim_model = copy.deepcopy(sim_model)

device = utils.get_device(quant_sim_model)
QuantizationSimModel._replace_quantization_wrapper_with_native_torch_quantization_nodes(quant_sim_model, device)

if onnx_export_args is None:
with utils.in_eval_mode(quant_sim_model), torch.no_grad():
trace = torch.jit.trace(quant_sim_model, dummy_input)
ts_path = os.path.join(path, filename_prefix + '_embedded' + '.torchscript.pth')
trace.save(ts_path)
else:
torch.onnx.export(quant_sim_model, dummy_input, model_path, enable_onnx_checker=False, **onnx_export_args.kwargs)


def save_checkpoint(quant_sim_model: QuantizationSimModel, file_path: str):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1054,6 +1054,46 @@ def backward(ctx, *output_grad):
return (*output_grad, None, *param_encoding_grads)


class TorchQuantizer:
"""
A Quantizer using native torch quantization nodes
"""
def __init__(self, quantizer: Union[StaticGridPerChannelQuantizer, StaticGridPerTensorQuantizer],
device: torch.device):
"""
Constructor
:param post_training_module: StaticGridQuantWrapper wrapped module
:param device: device on which model is
"""
super(TorchQuantizer, self).__init__()
self.device, self.enabled, self.data_type, self.bitwidth = device, quantizer.enabled, quantizer.data_type, quantizer.bitwidth
self.per_channel_enabled = False
if hasattr(quantizer, '_ch_axis'):
# pylint: disable=protected-access
self._ch_axis = quantizer._ch_axis
self.per_channel_enabled = True
if quantizer.enabled and quantizer.encoding:
self.scale, self.zero_point, self.q_max, self.q_min = calc_params_for_native_torch_quantizer(quantizer, self.per_channel_enabled, device)

def quantize_dequantize(self, tensor: torch.Tensor):
"""
Quantize-dequantize the tensor, using the saved encoding for this tensor
:param tensor: Tensor passed to the module in the forward pass
:return: Quantized output from the wrapped module
"""
if self.enabled:
if self.data_type == QuantizationDataType.float:
quantized_tensor = tensor.half()
quantized_tensor = quantized_tensor.float()
return quantized_tensor
if self.per_channel_enabled:
return torch.fake_quantize_per_channel_affine(tensor, self.scale, self.zero_point,
self._ch_axis, self.q_min, self.q_max)
return torch.fake_quantize_per_tensor_affine(tensor, self.scale, self.zero_point,
self.q_min, self.q_max)
return tensor


# pylint: disable=abstract-method
class QuantizeDequantize(torch.autograd.Function):
"""
Expand Down
Loading

0 comments on commit 4d91499

Please sign in to comment.