From 81f5f338086df6fe938cb007e6b4473cb8a755f1 Mon Sep 17 00:00:00 2001 From: Benjamin Fineran Date: Thu, 21 Apr 2022 15:39:16 -0400 Subject: [PATCH] Cherry pick 012 distilbert qat fixes (#726) * fix QATWrapper not properly overwritting qconfig properties for symmetric activations (#724) * re-add fix symmetric zero points for unit8 quantization (#604) (#725) --- .../sparsification/quantization/helpers.py | 34 ++++++++++--------- .../quantization/modifier_quantization.py | 3 +- 2 files changed, 19 insertions(+), 18 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index b5c54ab309c..5621b786888 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -119,10 +119,12 @@ class QConfigProperties: Default is torch.qint8. :param activation_bits: number of bits for activations. Default is 8. :param weight_bits: number of bits for weights. Default is 8. + :param tensorrt: if True sets quantization configuration for compatibility with + explict quantization as supported by TensorRT 8.2. """ - _symmetric_activations: Optional[bool] = None - _symmetric_weights: Optional[bool] = None + _symmetric_activations: bool = False + _symmetric_weights: bool = True reduce_range: bool = False activation_dtype: torch.dtype = torch.quint8 weight_dtype: torch.dtype = torch.qint8 @@ -130,30 +132,24 @@ class QConfigProperties: weight_bits: int = 8 activation_qconfig_kwargs: Dict[str, Any] = field(default_factory=dict) weight_qconfig_kwargs: Dict[str, Any] = field(default_factory=dict) + tensorrt: bool = False @property def symmetric_activations(self) -> bool: - if self._symmetric_activations: - return self._symmetric_activations - else: - return False + # always use symmetric activations in tensorrt mode + return self.tensorrt or self._symmetric_activations @symmetric_activations.setter def symmetric_activations(self, value: bool): - if self._symmetric_activations is None: - self._symmetric_activations = value + self._symmetric_activations = value @property def symmetric_weights(self) -> bool: - if self._symmetric_weights: - return self._symmetric_weights - else: - return True + return self.tensorrt or self._symmetric_weights @symmetric_weights.setter def symmetric_weights(self, value: bool): - if self._symmetric_weights is None: - self._symmetric_weights = value + self._symmetric_weights = value class QATWrapper(Module): @@ -365,9 +361,10 @@ def _load_qconfigs( f"Found string with value {qconfig} in {name}" ) - qproperties.symmetric_activations = qconfig == "symmetric" + qproperties_idx = deepcopy(qproperties) + qproperties_idx.symmetric_activations = qconfig == "symmetric" - qconfigs[idx] = get_qat_qconfig(qproperties) + qconfigs[idx] = get_qat_qconfig(qproperties_idx) return qconfigs @@ -578,6 +575,11 @@ def fix_observer_quant_range(module: Module): fake_quantize.quant_min is None or fake_quantize.quant_max is None or (observer.quant_min is not None or observer.quant_max is not None) + or ( # do not propagate default uint8 symmetric range + observer.qscheme == torch.per_tensor_symmetric + and fake_quantize.quant_min == 0 + and fake_quantize.quant_max == 255 + ) ): continue observer.quant_min = fake_quantize.quant_min diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 6c7de08f2c7..2937b305259 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -615,9 +615,8 @@ def _enable_module_qat(self, module: Module): "Overriding quantization scheme to symmetric int8 " "for both weights and activations because tensorrt flag is True." ) - qproperties.symmetric_activations = True + qproperties.tensorrt = True qproperties.activation_dtype = torch.qint8 - qproperties.symmetric_weights = True qproperties.weight_dtype = torch.qint8 qconfig = get_qat_qconfig(qproperties)