Skip to content

Commit

Permalink
Switch fake initialization to just prior to loading model weights
Browse files Browse the repository at this point in the history
  • Loading branch information
anmarques committed Oct 12, 2023
1 parent 1129c7d commit 6b41c1f
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
configure_module_bn_wrappers,
freeze_bn_stats,
fuse_module_conv_bn_relus,
initialize_channel_wise_scale_zp,
)
from sparseml.pytorch.sparsification.quantization.legacy_modifier_quantization import (
QuantizationModifier as LegacyQuantizationModifier,
Expand Down Expand Up @@ -517,10 +516,6 @@ def _enable_module_qat(self, module: Module):

self._calibrate_if_possible(module)

# if channel-wise quantization is targeted, properly initialize
# the scale and zp shapes
initialize_channel_wise_scale_zp(module)

def _fuse(self, module: Module):
if self.model_fuse_fn_name in [None, "conv_bn_relus"]:
self._model_fuse_fn_kwargs["inplace"] = True
Expand Down
9 changes: 8 additions & 1 deletion src/sparseml/transformers/sparsification/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
)
from sparseml.transformers.utils import SparseAutoModel
from sparseml.transformers.utils.helpers import RECIPE_NAME

from sparseml.pytorch.sparsification.quantization.helpers import initialize_channel_wise_scale_zp

__all__ = [
"RecipeManagerTrainerInterface",
Expand Down Expand Up @@ -671,6 +671,13 @@ def _reload_model_state(self, load_path: str, orig_state_dict: Dict[str, Any]):
)
return False

# PerChannel quantization observers initialize variables
# to dummy shapes that do not match the ones saved in
# state_dict.
# Need to reshape these variables in order to load state_dict
# properly.
initialize_channel_wise_scale_zp(self.model)

current_state_dict = self.model.state_dict()

if set(orig_state_dict.keys()) == set(current_state_dict):
Expand Down

0 comments on commit 6b41c1f

Please sign in to comment.