Skip to content

Commit

Permalink
[QuanitztionModifier] initialize per channel scales and zps to correc…
Browse files Browse the repository at this point in the history
…t shape (#1738)

* [QuanitztionModifier] initialize per channel scales and zps to correct shape

* add adjustment for observer min/max vals

* typo bug fix

* Fixes to load pre-trained model w/ channel-wise quantization

* Quality fixes

* Switch fake initialization to just prior to loading model weights

* Style and quality fixes

---------

Co-authored-by: Alexandre Marques <alexandre@neuralmagic.com>
  • Loading branch information
bfineran and anmarques committed Oct 13, 2023
1 parent 1ab17ad commit 1a2bddf
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 0 deletions.
53 changes: 53 additions & 0 deletions src/sparseml/pytorch/sparsification/quantization/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
"freeze_bn_stats",
"fuse_module_conv_bn_relus",
"prepare_embeddings_qat",
"initialize_channel_wise_scale_zp",
"QConfigProperties",
"LINEAR_ACTIVATION_NAMES",
"CONV_ACTIVATION_NAMES",
Expand Down Expand Up @@ -710,6 +711,58 @@ def prepare_embeddings_qat(
_prepare_qat_embedding(submodule, submodule_qconfig)


def initialize_channel_wise_scale_zp(module: Module):
"""
On torch channel-wise quantization, zero points and scales are
initialized to a default size of (1,) instead of their true size
of (num_output_channels,). This can cause issues on reloading
of saved checkpoints due to shape mismatch. This function expands
these initial scales and zero points to match the true expected
shape
:param module: qat ready, uncalibrated model
"""
for name, submodule in module.named_modules():
weight_fake_quant = getattr(submodule, "weight_fake_quant", None)
if not weight_fake_quant or (
getattr(weight_fake_quant, "qscheme", None)
not in [torch.per_channel_affine, torch.per_channel_symmetric]
):
# only consider modules with channel-wise quantized weights
continue
num_channels = None
if hasattr(submodule, "out_features"):
# matmul layers
num_channels = submodule.out_features
elif hasattr(submodule, "out_channels"):
num_channels = submodule.out_channels

if not num_channels:
# unable to infer num_channels or num_channels is 0
continue

# update scale and zero point if they are initialized to a size of 1
scale = weight_fake_quant.scale
if scale.numel() == 1:
weight_fake_quant.scale = torch.ones(num_channels, dtype=scale.dtype)

zero_point = weight_fake_quant.zero_point
if zero_point.numel() == 1:
weight_fake_quant.zero_point = torch.ones(
num_channels, dtype=zero_point.dtype
)

# update the observer min and max vals
if weight_fake_quant.activation_post_process.min_val.numel() == 0:
weight_fake_quant.activation_post_process.min_val = torch.empty_like(
weight_fake_quant.scale
)
if weight_fake_quant.activation_post_process.max_val.numel() == 0:
weight_fake_quant.activation_post_process.max_val = torch.empty_like(
weight_fake_quant.scale
)


def _delete_get_block_hooks(
module: Module,
fuse_blocks: List[List[str]],
Expand Down
10 changes: 10 additions & 0 deletions src/sparseml/transformers/sparsification/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
from transformers.trainer_utils import ShardedDDPOption, get_last_checkpoint

from sparseml.pytorch.optim import ScheduledModifierManager, ScheduledOptimizer
from sparseml.pytorch.sparsification.quantization.helpers import (
initialize_channel_wise_scale_zp,
)
from sparseml.pytorch.utils import (
LoggerManager,
ModuleSparsificationInfo,
Expand Down Expand Up @@ -671,6 +674,13 @@ def _reload_model_state(self, load_path: str, orig_state_dict: Dict[str, Any]):
)
return False

# PerChannel quantization observers initialize variables
# to dummy shapes that do not match the ones saved in
# state_dict.
# Need to reshape these variables in order to load state_dict
# properly.
initialize_channel_wise_scale_zp(self.model)

current_state_dict = self.model.state_dict()

if set(orig_state_dict.keys()) == set(current_state_dict):
Expand Down

0 comments on commit 1a2bddf

Please sign in to comment.