Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QuanitztionModifier] initialize per channel scales and zps to correct shape #1738

Merged
merged 12 commits into from
Oct 13, 2023
53 changes: 53 additions & 0 deletions src/sparseml/pytorch/sparsification/quantization/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
"freeze_bn_stats",
"fuse_module_conv_bn_relus",
"prepare_embeddings_qat",
"initialize_channel_wise_scale_zp",
"QConfigProperties",
"LINEAR_ACTIVATION_NAMES",
"CONV_ACTIVATION_NAMES",
Expand Down Expand Up @@ -710,6 +711,58 @@ def prepare_embeddings_qat(
_prepare_qat_embedding(submodule, submodule_qconfig)


def initialize_channel_wise_scale_zp(module: Module):
"""
On torch channel-wise quantization, zero points and scales are
initialized to a default size of (1,) instead of their true size
of (num_output_channels,). This can cause issues on reloading
of saved checkpoints due to shape mismatch. This function expands
these initial scales and zero points to match the true expected
shape

:param module: qat ready, uncalibrated model
"""
for name, submodule in module.named_modules():
weight_fake_quant = getattr(submodule, "weight_fake_quant", None)
if not weight_fake_quant or (
getattr(weight_fake_quant, "qscheme", None)
not in [torch.per_channel_affine, torch.per_channel_symmetric]
):
# only consider modules with channel-wise quantized weights
continue
num_channels = None
if hasattr(submodule, "out_features"):
# matmul layers
num_channels = submodule.out_features
elif hasattr(submodule, "out_channels"):
num_channels = submodule.out_channels

if not num_channels:
# unable to infer num_channels or num_channels is 0
continue

# update scale and zero point if they are initialized to a size of 1
scale = weight_fake_quant.scale
if scale.numel() == 1:
weight_fake_quant.scale = torch.ones(num_channels, dtype=scale.dtype)

zero_point = weight_fake_quant.zero_point
if zero_point.numel() == 1:
weight_fake_quant.zero_point = torch.ones(
num_channels, dtype=zero_point.dtype
)

# update the observer min and max vals
if weight_fake_quant.activation_post_process.min_val.numel() == 0:
weight_fake_quant.activation_post_process.min_val = torch.empty_like(
weight_fake_quant.scale
)
if weight_fake_quant.activation_post_process.max_val.numel() == 0:
weight_fake_quant.activation_post_process.max_val = torch.empty_like(
weight_fake_quant.scale
)


def _delete_get_block_hooks(
module: Module,
fuse_blocks: List[List[str]],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
configure_module_bn_wrappers,
freeze_bn_stats,
fuse_module_conv_bn_relus,
initialize_channel_wise_scale_zp,
)
from sparseml.pytorch.sparsification.quantization.legacy_modifier_quantization import (
QuantizationModifier as LegacyQuantizationModifier,
Expand Down Expand Up @@ -516,6 +517,10 @@ def _enable_module_qat(self, module: Module):

self._calibrate_if_possible(module)

# if channel-wise quantization is targeted, properly initialize
# the scale and zp shapes
initialize_channel_wise_scale_zp(module)

def _fuse(self, module: Module):
if self.model_fuse_fn_name in [None, "conv_bn_relus"]:
self._model_fuse_fn_kwargs["inplace"] = True
Expand Down
Loading