Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support native torch embedded quantization encodings #2305

Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@

# pylint: disable=too-many-lines
""" Custom PyTorch Op for quantizing weights and activations """

# pylint: disable=too-many-lines
import abc
from enum import Enum
from typing import Dict, Tuple, Union, List, Callable, Type
Expand All @@ -54,6 +54,7 @@
from aimet_torch import utils
from aimet_torch.tensor_quantizer import StaticGridPerTensorQuantizer, StaticGridPerChannelQuantizer, TensorQuantizer, \
LearnedGridTensorQuantizer, set_encoding_min_max_gating_threshold
from aimet_torch.torch_quantizer import TorchQuantizer
import aimet_torch.quantsim_straight_through_grad as ste

_logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Quant)
Expand Down Expand Up @@ -967,6 +968,96 @@ def _linear_forward_with_recompute(quant_wrapper: LearnedGridQuantWrapper, input
quant_wrapper.param_quantizers['weight'])


class NativeTorchQuantWrapper(nn.Module):
"""
A custom PyTorch module for inserting native PyToch quantization nodes
"""
def __init__(self, post_training_module: Union[StaticGridQuantWrapper, LearnedGridQuantWrapper], module_name: str, device: torch.device):
"""
Constructor
:param post_training_module: StaticGridQuantWrapper wrapped module
:param module_name: name of module
:param device: device on which model is
"""
super(NativeTorchQuantWrapper, self).__init__()

self._module_to_wrap = getattr(post_training_module, module_name)
if isinstance(post_training_module, StaticGridQuantWrapper):
if post_training_module._mode != QcQuantizeOpMode.ACTIVE: # pylint: disable=protected-access
raise ValueError('Only ACTIVE QcQuantizeOpMode is supported while using StaticGridQuantWrapper')

self.output_quantizers = [TorchQuantizer(quantizer, device) for quantizer in post_training_module.output_quantizers]

self.input_quantizers = [TorchQuantizer(quantizer, device) for quantizer in post_training_module.input_quantizers]

self.param_quantizers = {}
for name, quantizer in post_training_module.param_quantizers.items():
self.param_quantizers[name] = TorchQuantizer(quantizer, device)

@staticmethod
def _quantize_dequantize(tensor_quantizers, tensors_to_quantize):
"""
Forward-pass routine. This quantizes the weights before delegating to the wrapped module and
then quantizes the output before returning the same
:param tensor_quantizers: Tensor quantizers to use for updating stats or quantizing
:param tensors_to_quantize: Inputs passed to the module in the forward pass
:return: Quantized output from the wrapped module
"""
outputs = []
for index, input_tensor in enumerate(tensors_to_quantize):
if not isinstance(input_tensor, torch.Tensor):
_logger.error('Expecting quantize activation input of type torch.Tensor but got %s', type(input_tensor))
raise AssertionError
if input_tensor.dtype in utils.torch_dtypes_to_ignore_for_quantization:
# Do not quantize integer tensors
outputs.append(input_tensor)
continue

assert len(tensor_quantizers) > index, \
f"Not enough tensor quantizers ({len(tensor_quantizers)}) allocated"

output = tensor_quantizers[index].quantize_dequantize(input_tensor)

outputs.append(output)

# Flatten if there is only one output - which is by far the most common case
if len(outputs) == 1:
outputs = outputs[0]

return outputs

def forward(self, *inputs):
"""
Forward-pass routine. This quantizes the weights before delegating to the wrapped module and
then quantizes the output before returning the same
:param inputs: Inputs passed to the module in the forward pass
:return: Quantized output from the wrapped module
"""
# Quantize inputs
quantized_inputs = self._quantize_dequantize(self.input_quantizers, inputs)
if isinstance(quantized_inputs, torch.Tensor):
quantized_inputs = [quantized_inputs]

# Quantize params
for name, param in self._module_to_wrap.named_parameters():
param_quantizer = self.param_quantizers[name]
if param_quantizer.enabled:
setattr(self._module_to_wrap, name,
torch.nn.Parameter(param_quantizer.quantize_dequantize(param), requires_grad=True))

wrapped_output = self._module_to_wrap(*quantized_inputs)

# Quantize the outputs
if not self.output_quantizers[0].enabled:
output = wrapped_output
else:
if isinstance(wrapped_output, torch.Tensor):
wrapped_output = [wrapped_output]
output = self._quantize_dequantize(self.output_quantizers, wrapped_output)

return output


class QcQuantizeStandalone(QcQuantizeStandAloneBase):
""" A custom PyTorch module that derives from QcQuantizeStandAloneBase and quantizes inputs """

Expand Down
119 changes: 101 additions & 18 deletions TrainingExtensions/torch/src/python/aimet_torch/quantsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@

from aimet_torch.quantsim_config.quantsim_config import QuantSimConfigurator
from aimet_torch.qc_quantize_op import QcQuantizeStandAloneBase, QcQuantizeWrapper, QcQuantizeOpMode, \
StaticGridQuantWrapper, LearnedGridQuantWrapper, QUANTIZER_TYPE_INPUT, QUANTIZER_TYPE_OUTPUT
StaticGridQuantWrapper, LearnedGridQuantWrapper, NativeTorchQuantWrapper, QUANTIZER_TYPE_INPUT, QUANTIZER_TYPE_OUTPUT
from aimet_torch.tensor_quantizer import StaticGridTensorQuantizer, LearnedGridTensorQuantizer, \
initialize_learned_grid_quantizer_attributes
from aimet_torch import torchscript_utils, utils, transformer_utils
Expand Down Expand Up @@ -363,7 +363,7 @@ def set_percentile_value(self, percentile_value: float):

def export(self, path: str, filename_prefix: str, dummy_input: Union[torch.Tensor, Tuple],
onnx_export_args: Optional[Union[OnnxExportApiArgs, Dict]] = None, propagate_encodings: bool = False,
export_to_torchscript: bool = False):
export_to_torchscript: bool = False, use_embedded_encodings: bool = False):
"""
This method exports out the quant-sim model so it is ready to be run on-target.

Expand All @@ -387,6 +387,7 @@ def export(self, path: str, filename_prefix: str, dummy_input: Union[torch.Tenso
multiple ONNX nodes) are filled with the same BW and data_type as the output tensor for that series of
ops. Defaults to False.
:param export_to_torchscript: If True, export to torchscript. Export to onnx otherwise. Defaults to False.
:param use_embedded_encodings: If True, another onnx model embedded with fakequant nodes will be exported
"""
# save the quantized model and encodings
model_filename = filename_prefix + '.pth'
Expand All @@ -399,23 +400,28 @@ def export(self, path: str, filename_prefix: str, dummy_input: Union[torch.Tenso

torch.save(model_to_export, model_path)

if export_to_torchscript:
self.export_torch_script_model_and_encodings(path, filename_prefix, model_to_export, self.model,
dummy_input, self._excluded_layer_names)
if onnx_export_args is None:
onnx_export_args = {'opset_version': None,
'input_names': None,
'output_names': None}
if version.parse(torch.__version__) < version.parse("1.10.0") and isinstance(onnx_export_args, dict):
onnx_export_args['enable_onnx_checker'] = False
log_with_error_and_assert_if_false(isinstance(onnx_export_args, (OnnxExportApiArgs, dict)),
logger,
f'unsupported opt_args type={type(onnx_export_args)}')

if use_embedded_encodings:
QuantizationSimModel.save_model_with_embedded_quantization_nodes(self.model, path, filename_prefix, dummy_input,
onnx_export_args, export_to_torchscript, self._is_conditional)
else:
if onnx_export_args is None:
onnx_export_args = {'opset_version': None,
'input_names': None,
'output_names': None}
if version.parse(torch.__version__) < version.parse("1.10.0") and isinstance(onnx_export_args, dict):
onnx_export_args['enable_onnx_checker'] = False
log_with_error_and_assert_if_false(isinstance(onnx_export_args, (OnnxExportApiArgs, dict)),
logger,
f'unsupported opt_args type={type(onnx_export_args)}')
self.export_onnx_model_and_encodings(path, filename_prefix, model_to_export, self.model,
dummy_input, onnx_export_args, propagate_encodings,
self._module_marker_map, self._is_conditional,
self._excluded_layer_names, quantizer_args=self.quant_args)
if export_to_torchscript:
self.export_torch_script_model_and_encodings(path, filename_prefix, model_to_export, self.model,
dummy_input, self._excluded_layer_names)
else:
self.export_onnx_model_and_encodings(path, filename_prefix, model_to_export, self.model,
dummy_input, onnx_export_args, propagate_encodings,
self._module_marker_map, self._is_conditional,
self._excluded_layer_names, quantizer_args=self.quant_args)

@staticmethod
def export_torch_script_model_and_encodings(path: str, filename_prefix: str,
Expand Down Expand Up @@ -1487,6 +1493,83 @@ def apply_act_rules(act: Tuple[int, QuantizationDataType], allowed_supported_ker
for candidate in set(act_candidates):
apply_act_rules(candidate, supported_kernels, name)

@staticmethod
def _replace_quantization_wrapper_with_native_torch_quantization_nodes(quant_sim_model, device: torch.device):
"""
Recursively remove quantization wrappers from all appropriate modules starting with a given module
:param quant_sim_model: model for which QcQuantizeWrapper gets replaced with wrapped module using
native torch quantization nodes
:param device: device on which model is present
:return:
"""
# Recursively replace quantization wrappers to native torch quantization nodes
for module_name, module_ref in quant_sim_model.named_children():
# Create a native torch quantization node
if isinstance(module_ref, QcQuantizeWrapper):
embedded_module = NativeTorchQuantWrapper(module_ref, '_module_to_wrap', device)
setattr(quant_sim_model, module_name, embedded_module)

elif isinstance(module_ref, QcQuantizeRecurrent):
logger.error('Do not support save model embedded native torch quantization nodes using QcQuantizeRecurrent.')
raise AssertionError

# Recursively call children modules if present
if not utils.is_leaf_module(module_ref):
QuantizationSimModel._replace_quantization_wrapper_with_native_torch_quantization_nodes(module_ref, device)

@staticmethod
def save_model_with_embedded_quantization_nodes(sim_model, path: str, filename_prefix: str, dummy_input: Union[torch.Tensor, Tuple],
onnx_export_args: Optional[Union[OnnxExportApiArgs, Dict]] = None,
export_to_torchscript: bool = False, is_conditional: bool = False):
"""
Export model embedded with native torch quantization nodes. These nodes will be exported
as default onnx or torch script quantized nodes.
:param sim_model: model with the quantsim wrappers
:param path: path where to store model pth and encodings
:param filename_prefix: Prefix to use for filenames of the model pth and encodings files
:param dummy_input: Dummy input to the model. Used to parse model graph
:param onnx_export_args: optional export argument with onnx specific overrides if not provide export via
torchscript graph. Int16 can only be exported by torchscript
:param export_to_torchscript: If True, export to torchscript. Export to onnx otherwise. Defaults to False.
:param is_conditional: True if model is conditional, False otherwise
:return:
"""
def _validate_torchquantizer(quant_sim_model):
# To avoid non 8 bit TorchQuantizer are exported to ONNX
for _, module in quant_sim_model.named_modules():
if isinstance(module, NativeTorchQuantWrapper):
quantizers = module.input_quantizers + module.output_quantizers
if 'weight' in module.param_quantizers:
quantizers += [module.param_quantizers['weight']]
if 'bias' in module.param_quantizers:
quantizers += [module.param_quantizers['bias']]

for quantizer in quantizers:
if quantizer.enabled and quantizer.data_type == QuantizationDataType.int and quantizer.bitwidth != 8:
raise ValueError('Only 8 bit quantizers are supported by exporting to ONNX model.'
'Please enable export_to_torchscript if you want to export non 8 bit quantizers.')

model_filename = filename_prefix + '_embedded' + '.onnx'
model_path = os.path.join(path, model_filename)
quant_sim_model = copy.deepcopy(sim_model)

device = utils.get_device(quant_sim_model)
if isinstance(dummy_input, torch.Tensor):
dummy_input = dummy_input.to(device)
else:
dummy_input = tuple([input.to(device) for input in dummy_input])
QuantizationSimModel._replace_quantization_wrapper_with_native_torch_quantization_nodes(quant_sim_model, device)

if export_to_torchscript:
with utils.in_eval_mode(quant_sim_model), torch.no_grad():
trace = torch.jit.trace(quant_sim_model, dummy_input)
ts_path = os.path.join(path, filename_prefix + '_embedded' + '.torchscript.pth')
trace.save(ts_path)
else:
_validate_torchquantizer(quant_sim_model)
OnnxSaver._export_model_to_onnx(quant_sim_model, dummy_input, model_path, is_conditional, onnx_export_args) # pylint: disable=protected-access



def save_checkpoint(quant_sim_model: QuantizationSimModel, file_path: str):
"""
Expand Down
Loading
Loading