From 6cdc5855473a14924d7570082028328564c21b03 Mon Sep 17 00:00:00 2001 From: Benjamin Fineran Date: Fri, 26 Aug 2022 17:12:31 -0400 Subject: [PATCH] Fix quant model re-load bug (#978) (#1027) --- src/sparseml/pytorch/utils/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sparseml/pytorch/utils/model.py b/src/sparseml/pytorch/utils/model.py index 97dcd7adc35..9d676bcc8a4 100644 --- a/src/sparseml/pytorch/utils/model.py +++ b/src/sparseml/pytorch/utils/model.py @@ -80,7 +80,6 @@ def load_model( if path.startswith("zoo:"): path = download_framework_model_by_recipe_type(Model(path)) model_dict = torch.load(path, map_location="cpu") - current_dict = model.state_dict() recipe = model_dict.get("recipe") if recipe: @@ -90,6 +89,7 @@ def load_model( checkpoint_manager = ScheduledModifierManager.from_yaml(recipe) checkpoint_manager.apply_structure(module=model, epoch=epoch) + current_dict = model.state_dict() if "state_dict" in model_dict: model_dict = model_dict["state_dict"]