diff --git a/CHANGELOG.md b/CHANGELOG.md index 003aec2ea75f9..cc9af571c04da 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -42,6 +42,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed loading model with kwargs ([#2387](https://github.com/PyTorchLightning/pytorch-lightning/pull/2387)) +- Fixed loading past checkpoints from v0.7.x ([#2405](https://github.com/PyTorchLightning/pytorch-lightning/pull/2405)) + ## [0.8.1] - 2020-06-19 ### Fixed diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index 4ad11e6ec1023..31a8ca9481165 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -180,8 +180,7 @@ def _load_model_state(cls, checkpoint: Dict[str, Any], *cls_args, **cls_kwargs): if hparam_key in checkpoint: model_args.update(checkpoint[hparam_key]) - if cls.CHECKPOINT_HYPER_PARAMS_TYPE in checkpoint: - model_args = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_TYPE](model_args) + model_args = _convert_loaded_hparams(model_args, checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_TYPE)) args_name = checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_NAME) cls_spec = inspect.getfullargspec(cls.__init__) @@ -248,6 +247,18 @@ def on_hpc_load(self, checkpoint: Dict[str, Any]) -> None: """ +def _convert_loaded_hparams(model_args: dict, hparams_type: Union[Callable, str] = None) -> object: + """Convert hparams according given type in callable or string (past) format""" + # if not hparams type define + if not hparams_type: + return model_args + # if past checkpoint loaded, convert str to callable + if isinstance(hparams_type, str): + hparams_type = AttributeDict + # convert hparams + return hparams_type(model_args) + + def update_hparams(hparams: dict, updates: dict) -> None: """ Overrides hparams with new values diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py index e2b200adc563b..98df8d205e0df 100644 --- a/tests/models/test_hparams.py +++ b/tests/models/test_hparams.py @@ -268,6 +268,7 @@ def test_collect_init_arguments(tmpdir, cls): # verify that the checkpoint saved the correct values trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_batches=0.5) trainer.fit(model) + raw_checkpoint_path = _raw_checkpoint_path(trainer) raw_checkpoint = torch.load(raw_checkpoint_path) @@ -391,6 +392,7 @@ def test_load_past_checkpoint(tmpdir, past_key): raw_checkpoint_path = _raw_checkpoint_path(trainer) raw_checkpoint = torch.load(raw_checkpoint_path) raw_checkpoint[past_key] = raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] + raw_checkpoint['hparams_type'] = 'Namespace' raw_checkpoint[past_key]['batch_size'] = -17 del raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] # save back the checkpoint