Skip to content

Commit

Permalink
Cherry pick 012 distilbert qat fixes (#726)
Browse files Browse the repository at this point in the history
* fix QATWrapper not properly overwritting qconfig properties for symmetric activations (#724)

* re-add fix symmetric zero points for unit8 quantization (#604) (#725)
  • Loading branch information
bfineran committed Apr 21, 2022
1 parent fe598cb commit 81f5f33
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 18 deletions.
34 changes: 18 additions & 16 deletions src/sparseml/pytorch/sparsification/quantization/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,41 +119,37 @@ class QConfigProperties:
Default is torch.qint8.
:param activation_bits: number of bits for activations. Default is 8.
:param weight_bits: number of bits for weights. Default is 8.
:param tensorrt: if True sets quantization configuration for compatibility with
explict quantization as supported by TensorRT 8.2.
"""

_symmetric_activations: Optional[bool] = None
_symmetric_weights: Optional[bool] = None
_symmetric_activations: bool = False
_symmetric_weights: bool = True
reduce_range: bool = False
activation_dtype: torch.dtype = torch.quint8
weight_dtype: torch.dtype = torch.qint8
activation_bits: int = 8
weight_bits: int = 8
activation_qconfig_kwargs: Dict[str, Any] = field(default_factory=dict)
weight_qconfig_kwargs: Dict[str, Any] = field(default_factory=dict)
tensorrt: bool = False

@property
def symmetric_activations(self) -> bool:
if self._symmetric_activations:
return self._symmetric_activations
else:
return False
# always use symmetric activations in tensorrt mode
return self.tensorrt or self._symmetric_activations

@symmetric_activations.setter
def symmetric_activations(self, value: bool):
if self._symmetric_activations is None:
self._symmetric_activations = value
self._symmetric_activations = value

@property
def symmetric_weights(self) -> bool:
if self._symmetric_weights:
return self._symmetric_weights
else:
return True
return self.tensorrt or self._symmetric_weights

@symmetric_weights.setter
def symmetric_weights(self, value: bool):
if self._symmetric_weights is None:
self._symmetric_weights = value
self._symmetric_weights = value


class QATWrapper(Module):
Expand Down Expand Up @@ -365,9 +361,10 @@ def _load_qconfigs(
f"Found string with value {qconfig} in {name}"
)

qproperties.symmetric_activations = qconfig == "symmetric"
qproperties_idx = deepcopy(qproperties)
qproperties_idx.symmetric_activations = qconfig == "symmetric"

qconfigs[idx] = get_qat_qconfig(qproperties)
qconfigs[idx] = get_qat_qconfig(qproperties_idx)

return qconfigs

Expand Down Expand Up @@ -578,6 +575,11 @@ def fix_observer_quant_range(module: Module):
fake_quantize.quant_min is None
or fake_quantize.quant_max is None
or (observer.quant_min is not None or observer.quant_max is not None)
or ( # do not propagate default uint8 symmetric range
observer.qscheme == torch.per_tensor_symmetric
and fake_quantize.quant_min == 0
and fake_quantize.quant_max == 255
)
):
continue
observer.quant_min = fake_quantize.quant_min
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -615,9 +615,8 @@ def _enable_module_qat(self, module: Module):
"Overriding quantization scheme to symmetric int8 "
"for both weights and activations because tensorrt flag is True."
)
qproperties.symmetric_activations = True
qproperties.tensorrt = True
qproperties.activation_dtype = torch.qint8
qproperties.symmetric_weights = True
qproperties.weight_dtype = torch.qint8

qconfig = get_qat_qconfig(qproperties)
Expand Down

0 comments on commit 81f5f33

Please sign in to comment.