Skip to content

Commit

Permalink
Exclude abstract nn.Module from quantization (#3339)
Browse files Browse the repository at this point in the history
Signed-off-by: Kyunggeun Lee <quic_kyunggeu@quicinc.com>
  • Loading branch information
quic-kyunggeu authored Sep 13, 2024
1 parent 83d00df commit 0e6822b
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 30 deletions.
46 changes: 20 additions & 26 deletions TrainingExtensions/torch/src/python/aimet_torch/quantsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,12 +202,15 @@ def get_original_module(self) -> torch.nn.Module:

# Types of modules which cannot be quantized
unquantizable_modules = (
torch.nn.Identity,
)

quantized_modules = (
QcQuantizeWrapper,
QcQuantizeStandAloneBase,
QcQuantizeRecurrent,
ExportableQuantModule,
torch.nn.Identity,
LazyQuantizeWrapper
LazyQuantizeWrapper,
)


Expand Down Expand Up @@ -1389,20 +1392,6 @@ def _get_qc_quantized_layers(model) -> List[Tuple[str, QcQuantizeWrapper]]:
quantized_layers.append((name, module))
return quantized_layers

@staticmethod
def _is_quantizable_module(module_ref):
""" Function to check if a module is eligible for quantization.
If the module is NOT an PyTorch module type or if the module was already
Quantized or if the module is in the layers_to_ignore list, don't quantize.
"""

if isinstance(module_ref, unquantizable_modules):
logger.debug("Module %s not quantizable", module_ref)
return False

logger.debug("Module %s is quantizable", module_ref)
return True

def _create_quantizer_module(self, module_to_quantize: torch.nn.Module, num_inout_tensors: Dict,
data_type: QuantizationDataType) -> torch.nn.Module:
"""Instantiates wrapper based on quant scheme
Expand Down Expand Up @@ -1432,25 +1421,30 @@ def _create_quantizer_module(self, module_to_quantize: torch.nn.Module, num_inou

return quantized_module

@classmethod
def _is_quantizable_module(cls, module: torch.nn.Module):
# pylint: disable=unidiomatic-typecheck
return type(module) != torch.nn.Module and\
not isinstance(module, unquantizable_modules) and\
not cls._is_quantized_module(module)

@classmethod
def _is_quantized_module(cls, module: torch.nn.Module):
return isinstance(module, quantized_modules)

def _add_quantization_wrappers(self, module, num_inout_tensors, default_data_type: QuantizationDataType):
"""Recursively add quantization wrappers to all appropriate modules starting with module
"""
if self._is_quantized_module(module):
return

for module_name, module_ref in module.named_children():
logger.debug("nn.Module found : %s", module_ref)

# check if the module already quantized then ignore
if not self._is_quantizable_module(module_ref):
continue

# check if the module is leaf or not
if utils.is_leaf_module(module_ref):

if self._is_quantizable_module(module_ref) and utils.is_leaf_module(module_ref):
# Create a new QcQuantize wrapper module
quantized_module = self._create_quantizer_module(module_ref, num_inout_tensors, default_data_type)

setattr(module, module_name, quantized_module)

# recursively call children modules
else:
self._add_quantization_wrappers(module_ref, num_inout_tensors, default_data_type)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,3 +263,13 @@ def _apply_qdq_to_model_parameters(cls, model: torch.nn.Module):
@deprecated(f'Use {V1QuantizationSimModel.named_qmodules.__qualname__} instead.')
def quant_wrappers(self): # pylint: disable=missing-docstring
return super().quant_wrappers()

@classmethod
def _is_quantizable_module(cls, module: torch.nn.Module):
return super()._is_quantizable_module(module) and\
not isinstance(module, QuantizerBase)

@classmethod
def _is_quantized_module(cls, module: torch.nn.Module):
return super()._is_quantized_module(module) or\
isinstance(module, BaseQuantizationMixin)
12 changes: 9 additions & 3 deletions TrainingExtensions/torch/test/python/v2/models_/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1302,7 +1302,6 @@ def forward(self, x):
return self.log_softmax(x)

class ModelWithUnusedAdd(torch.nn.Module):

def __init__(self):
super().__init__()
self.identity = torch.nn.Identity()
Expand All @@ -1312,7 +1311,6 @@ def forward(self, x):
return self.identity(x)

class ModelWithUnusedRNN(torch.nn.Module):

def __init__(self):
super().__init__()
self.identity = torch.nn.Identity()
Expand All @@ -1321,8 +1319,16 @@ def __init__(self):
def forward(self, x):
return self.identity(x)

class ExpandModel(torch.nn.Module):
class ModelWithAbstractModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.module = torch.nn.Module()
self.module.conv = torch.nn.Conv2d(3, 3, 3)

def forward(self, x):
return self.module.conv(x)

class ExpandModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.expand = aimet_modules.Expand()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from aimet_torch.v2.quantization.base import QuantizerBase
from aimet_torch.v2.quantization.affine import AffineQuantizerBase, GroupedBlockQuantizeDequantize
from aimet_torch.v2.experimental import propagate_output_encodings
from aimet_torch.v2.nn import BaseQuantizationMixin
from aimet_torch.v2.nn import BaseQuantizationMixin, QuantizedConv2d
import aimet_torch.v2.nn.modules.custom as custom
from ..models_ import test_models

Expand Down Expand Up @@ -822,6 +822,20 @@ def test_quantsim_with_unused_modules(self):
assert len(sim.model.rnn.output_quantizers) == 2
assert type(sim.model.rnn.output_quantizers[0]) is type(sim.model.rnn.output_quantizers[1])

def test_quantsim_with_abstract_modules(self):
"""
Given: A model with an abstract nn.Module
When: Instantiate quantsim
Then: 1) No error is not raised
2) Abstract modules stay unchanged
3) If the abstract module contains non-abstract child modules,
the child modules should be converted to quantized modules.
"""
model = test_models.ModelWithAbstractModule()
sim = QuantizationSimModel(model, dummy_input=torch.randn(1, 3, 10, 10))
assert type(sim.model.module) == torch.nn.Module
assert isinstance(sim.model.module.conv, QuantizedConv2d)

def test_export_concat_encodings(self):
num_inputs = 3
model = ConcatModel()
Expand Down

0 comments on commit 0e6822b

Please sign in to comment.