From 5651a25ebd72b54ed8df259c6308e91087c88473 Mon Sep 17 00:00:00 2001 From: Abhinav Moudgil Date: Tue, 17 Mar 2020 18:45:36 -0400 Subject: [PATCH] Backward compatibility for checkpoint loading (#1132) * check if hparams_type exists in checkpoint dictionary for backward compatibility * concisely maintain backward compatibility for hparams type * Bug fix in checkpoint loading (#1132) --- CHANGELOG.md | 1 + pytorch_lightning/core/lightning.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6e12875f459898..176b1826b37199 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed - Fixed bug related to type cheking of `ReduceLROnPlateau` lr schedulers([#1114](https://github.com/PyTorchLightning/pytorch-lightning/issues/1114)) +- Fixed a bug to ensure lightning checkpoints to be backward compatible ([#1132](https://github.com/PyTorchLightning/pytorch-lightning/pull/1132)) ## [0.7.1] - 2020-03-07 diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 9542783d386796..d5b408423dce12 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1396,7 +1396,7 @@ def _load_model_state(cls, checkpoint: Dict[str, Any]) -> 'LightningModule': if cls_takes_hparams: if ckpt_hparams is not None: - is_namespace = checkpoint.get('hparams_type') == 'namespace' + is_namespace = checkpoint.get('hparams_type', 'namespace') == 'namespace' hparams = Namespace(**ckpt_hparams) if is_namespace else ckpt_hparams else: warnings.warn(