From 0e6822b82dda71036ddbcd43a273fbb297dbb4fd Mon Sep 17 00:00:00 2001 From: Kyunggeun Lee Date: Thu, 12 Sep 2024 19:14:19 -0700 Subject: [PATCH] Exclude abstract nn.Module from quantization (#3339) Signed-off-by: Kyunggeun Lee --- .../torch/src/python/aimet_torch/quantsim.py | 46 ++++++++----------- .../aimet_torch/v2/quantsim/quantsim.py | 10 ++++ .../test/python/v2/models_/test_models.py | 12 +++-- .../test/python/v2/quantsim/test_quantsim.py | 16 ++++++- 4 files changed, 54 insertions(+), 30 deletions(-) diff --git a/TrainingExtensions/torch/src/python/aimet_torch/quantsim.py b/TrainingExtensions/torch/src/python/aimet_torch/quantsim.py index 199dbfa1e0..487d903398 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/quantsim.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/quantsim.py @@ -202,12 +202,15 @@ def get_original_module(self) -> torch.nn.Module: # Types of modules which cannot be quantized unquantizable_modules = ( + torch.nn.Identity, +) + +quantized_modules = ( QcQuantizeWrapper, QcQuantizeStandAloneBase, QcQuantizeRecurrent, ExportableQuantModule, - torch.nn.Identity, - LazyQuantizeWrapper + LazyQuantizeWrapper, ) @@ -1389,20 +1392,6 @@ def _get_qc_quantized_layers(model) -> List[Tuple[str, QcQuantizeWrapper]]: quantized_layers.append((name, module)) return quantized_layers - @staticmethod - def _is_quantizable_module(module_ref): - """ Function to check if a module is eligible for quantization. - If the module is NOT an PyTorch module type or if the module was already - Quantized or if the module is in the layers_to_ignore list, don't quantize. - """ - - if isinstance(module_ref, unquantizable_modules): - logger.debug("Module %s not quantizable", module_ref) - return False - - logger.debug("Module %s is quantizable", module_ref) - return True - def _create_quantizer_module(self, module_to_quantize: torch.nn.Module, num_inout_tensors: Dict, data_type: QuantizationDataType) -> torch.nn.Module: """Instantiates wrapper based on quant scheme @@ -1432,25 +1421,30 @@ def _create_quantizer_module(self, module_to_quantize: torch.nn.Module, num_inou return quantized_module + @classmethod + def _is_quantizable_module(cls, module: torch.nn.Module): + # pylint: disable=unidiomatic-typecheck + return type(module) != torch.nn.Module and\ + not isinstance(module, unquantizable_modules) and\ + not cls._is_quantized_module(module) + + @classmethod + def _is_quantized_module(cls, module: torch.nn.Module): + return isinstance(module, quantized_modules) + def _add_quantization_wrappers(self, module, num_inout_tensors, default_data_type: QuantizationDataType): """Recursively add quantization wrappers to all appropriate modules starting with module """ + if self._is_quantized_module(module): + return + for module_name, module_ref in module.named_children(): logger.debug("nn.Module found : %s", module_ref) - # check if the module already quantized then ignore - if not self._is_quantizable_module(module_ref): - continue - - # check if the module is leaf or not - if utils.is_leaf_module(module_ref): - + if self._is_quantizable_module(module_ref) and utils.is_leaf_module(module_ref): # Create a new QcQuantize wrapper module quantized_module = self._create_quantizer_module(module_ref, num_inout_tensors, default_data_type) - setattr(module, module_name, quantized_module) - - # recursively call children modules else: self._add_quantization_wrappers(module_ref, num_inout_tensors, default_data_type) diff --git a/TrainingExtensions/torch/src/python/aimet_torch/v2/quantsim/quantsim.py b/TrainingExtensions/torch/src/python/aimet_torch/v2/quantsim/quantsim.py index 2607a9f232..0518150a08 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/v2/quantsim/quantsim.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/v2/quantsim/quantsim.py @@ -263,3 +263,13 @@ def _apply_qdq_to_model_parameters(cls, model: torch.nn.Module): @deprecated(f'Use {V1QuantizationSimModel.named_qmodules.__qualname__} instead.') def quant_wrappers(self): # pylint: disable=missing-docstring return super().quant_wrappers() + + @classmethod + def _is_quantizable_module(cls, module: torch.nn.Module): + return super()._is_quantizable_module(module) and\ + not isinstance(module, QuantizerBase) + + @classmethod + def _is_quantized_module(cls, module: torch.nn.Module): + return super()._is_quantized_module(module) or\ + isinstance(module, BaseQuantizationMixin) diff --git a/TrainingExtensions/torch/test/python/v2/models_/test_models.py b/TrainingExtensions/torch/test/python/v2/models_/test_models.py index 2a54a4d014..308a6dce57 100644 --- a/TrainingExtensions/torch/test/python/v2/models_/test_models.py +++ b/TrainingExtensions/torch/test/python/v2/models_/test_models.py @@ -1302,7 +1302,6 @@ def forward(self, x): return self.log_softmax(x) class ModelWithUnusedAdd(torch.nn.Module): - def __init__(self): super().__init__() self.identity = torch.nn.Identity() @@ -1312,7 +1311,6 @@ def forward(self, x): return self.identity(x) class ModelWithUnusedRNN(torch.nn.Module): - def __init__(self): super().__init__() self.identity = torch.nn.Identity() @@ -1321,8 +1319,16 @@ def __init__(self): def forward(self, x): return self.identity(x) -class ExpandModel(torch.nn.Module): +class ModelWithAbstractModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.module = torch.nn.Module() + self.module.conv = torch.nn.Conv2d(3, 3, 3) + + def forward(self, x): + return self.module.conv(x) +class ExpandModel(torch.nn.Module): def __init__(self): super().__init__() self.expand = aimet_modules.Expand() diff --git a/TrainingExtensions/torch/test/python/v2/quantsim/test_quantsim.py b/TrainingExtensions/torch/test/python/v2/quantsim/test_quantsim.py index 380dce296e..f6a5d2e0e2 100644 --- a/TrainingExtensions/torch/test/python/v2/quantsim/test_quantsim.py +++ b/TrainingExtensions/torch/test/python/v2/quantsim/test_quantsim.py @@ -51,7 +51,7 @@ from aimet_torch.v2.quantization.base import QuantizerBase from aimet_torch.v2.quantization.affine import AffineQuantizerBase, GroupedBlockQuantizeDequantize from aimet_torch.v2.experimental import propagate_output_encodings -from aimet_torch.v2.nn import BaseQuantizationMixin +from aimet_torch.v2.nn import BaseQuantizationMixin, QuantizedConv2d import aimet_torch.v2.nn.modules.custom as custom from ..models_ import test_models @@ -822,6 +822,20 @@ def test_quantsim_with_unused_modules(self): assert len(sim.model.rnn.output_quantizers) == 2 assert type(sim.model.rnn.output_quantizers[0]) is type(sim.model.rnn.output_quantizers[1]) + def test_quantsim_with_abstract_modules(self): + """ + Given: A model with an abstract nn.Module + When: Instantiate quantsim + Then: 1) No error is not raised + 2) Abstract modules stay unchanged + 3) If the abstract module contains non-abstract child modules, + the child modules should be converted to quantized modules. + """ + model = test_models.ModelWithAbstractModule() + sim = QuantizationSimModel(model, dummy_input=torch.randn(1, 3, 10, 10)) + assert type(sim.model.module) == torch.nn.Module + assert isinstance(sim.model.module.conv, QuantizedConv2d) + def test_export_concat_encodings(self): num_inputs = 3 model = ConcatModel()