Skip to content

Commit

Permalink
Fix quant model re-load bug (#978) (#1027)
Browse files Browse the repository at this point in the history
  • Loading branch information
bfineran authored Aug 26, 2022
1 parent 3c80cde commit 6cdc585
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/sparseml/pytorch/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def load_model(
if path.startswith("zoo:"):
path = download_framework_model_by_recipe_type(Model(path))
model_dict = torch.load(path, map_location="cpu")
current_dict = model.state_dict()
recipe = model_dict.get("recipe")

if recipe:
Expand All @@ -90,6 +89,7 @@ def load_model(
checkpoint_manager = ScheduledModifierManager.from_yaml(recipe)
checkpoint_manager.apply_structure(module=model, epoch=epoch)

current_dict = model.state_dict()
if "state_dict" in model_dict:
model_dict = model_dict["state_dict"]

Expand Down

0 comments on commit 6cdc585

Please sign in to comment.