Skip to content

Commit

Permalink
handle new HF interface
Browse files Browse the repository at this point in the history
  • Loading branch information
eldarkurtic committed Jul 6, 2023
1 parent 537afd2 commit e066230
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions src/sparseml/transformers/sparsification/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,11 +683,11 @@ def _reload_model_state(self, load_path: str, orig_state_dict: Dict[str, Any]):
dd = torch.load(os.path.join(load_path, f), map_location="cpu")
loaded_state_dict.update(dd)

_, missing, unexpected, _, _ = self.model._load_pretrained_model(
_, missing, unexpected, mismatched, _, _ = self.model._load_pretrained_model(
model=self.model,
state_dict=loaded_state_dict,
loaded_keys=list(loaded_state_dict.keys()),
resolved_archive_file=[],
resolved_archive_file=None,
pretrained_model_name_or_path=load_path,
_fast_init=False,
)
Expand All @@ -704,6 +704,12 @@ def _reload_model_state(self, load_path: str, orig_state_dict: Dict[str, Any]):
f"{unexpected}"
)

if mismatched:
_LOGGER.warning(
f"Mismatched keys found when reloading model state for SparseML recipe:"
f"{mismatched}"
)

total_loaded = len(current_state_dict) - (len(missing) if len(missing) else 0)
_LOGGER.info(
f"Reloaded {total_loaded} model params for SparseML Recipe from {load_path}"
Expand Down

0 comments on commit e066230

Please sign in to comment.