Skip to content

Commit

Permalink
Fixes to load pre-trained model w/ channel-wise quantization
Browse files Browse the repository at this point in the history
  • Loading branch information
anmarques committed Oct 2, 2023
1 parent 779af69 commit 7cfad32
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/sparseml/pytorch/sparsification/quantization/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,7 @@ def initialize_channel_wise_scale_zp(module: Module):
for name, submodule in module.named_modules():
weight_fake_quant = getattr(submodule, "weight_fake_quant", None)
if not weight_fake_quant or (
getattr(weight_fake_quant, "qscheme", None) is not torch.per_channel_affine
getattr(weight_fake_quant, "qscheme", None) not in [torch.per_channel_affine, torch.per_channel_symmetric]
):
# only consider modules with channel-wise quantized weights
continue
Expand All @@ -743,11 +743,11 @@ def initialize_channel_wise_scale_zp(module: Module):
# update scale and zero point if they are initialized to a size of 1
scale = weight_fake_quant.scale
if scale.numel() == 1:
weight_fake_quant.scale = scale.reshape(-1).expand(num_channels)
weight_fake_quant.scale = torch.ones(num_channels, dtype=scale.dtype)

zero_point = weight_fake_quant.zero_point
if zero_point.numel() == 1:
weight_fake_quant.zero_point = zero_point.reshape(-1).expand(num_channels)
weight_fake_quant.scale = torch.ones(num_channels, dtype=zero_point.dtype)

# update the observer min and max vals
if weight_fake_quant.activation_post_process.min_val.numel() == 0:
Expand Down

0 comments on commit 7cfad32

Please sign in to comment.