Skip to content

Commit

Permalink
🐞 Fix: Add map_location when loading the weights (#562)
Browse files Browse the repository at this point in the history
  • Loading branch information
samet-akcay committed Sep 16, 2022
1 parent e9809c4 commit 7305246
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions anomalib/utils/callbacks/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ def on_test_start(self, _trainer, pl_module: AnomalyModule) -> None: # pylint:
Loads the model weights from ``weights_path`` into the PyTorch module.
"""
logger.info("Loading the model from %s", self.weights_path)
pl_module.load_state_dict(torch.load(self.weights_path)["state_dict"])
pl_module.load_state_dict(torch.load(self.weights_path, map_location=pl_module.device)["state_dict"])

def on_predict_start(self, _trainer, pl_module: AnomalyModule) -> None:
"""Call when inference begins.
Loads the model weights from ``weights_path`` into the PyTorch module.
"""
logger.info("Loading the model from %s", self.weights_path)
pl_module.load_state_dict(torch.load(self.weights_path)["state_dict"])
pl_module.load_state_dict(torch.load(self.weights_path, map_location=pl_module.device)["state_dict"])

0 comments on commit 7305246

Please sign in to comment.