Skip to content

Commit

Permalink
fix loading past checpoints (#2405)
Browse files Browse the repository at this point in the history
* fix #2334

* chlog
  • Loading branch information
Borda committed Jun 28, 2020
1 parent 66ffbad commit 861a73b
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 2 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed loading model with kwargs ([#2387](https://github.com/PyTorchLightning/pytorch-lightning/pull/2387))

- Fixed loading past checkpoints from v0.7.x ([#2405](https://github.com/PyTorchLightning/pytorch-lightning/pull/2405))

## [0.8.1] - 2020-06-19

### Fixed
Expand Down
15 changes: 13 additions & 2 deletions pytorch_lightning/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,7 @@ def _load_model_state(cls, checkpoint: Dict[str, Any], *cls_args, **cls_kwargs):
if hparam_key in checkpoint:
model_args.update(checkpoint[hparam_key])

if cls.CHECKPOINT_HYPER_PARAMS_TYPE in checkpoint:
model_args = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_TYPE](model_args)
model_args = _convert_loaded_hparams(model_args, checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_TYPE))

args_name = checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_NAME)
cls_spec = inspect.getfullargspec(cls.__init__)
Expand Down Expand Up @@ -248,6 +247,18 @@ def on_hpc_load(self, checkpoint: Dict[str, Any]) -> None:
"""


def _convert_loaded_hparams(model_args: dict, hparams_type: Union[Callable, str] = None) -> object:
"""Convert hparams according given type in callable or string (past) format"""
# if not hparams type define
if not hparams_type:
return model_args
# if past checkpoint loaded, convert str to callable
if isinstance(hparams_type, str):
hparams_type = AttributeDict
# convert hparams
return hparams_type(model_args)


def update_hparams(hparams: dict, updates: dict) -> None:
"""
Overrides hparams with new values
Expand Down
2 changes: 2 additions & 0 deletions tests/models/test_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ def test_collect_init_arguments(tmpdir, cls):
# verify that the checkpoint saved the correct values
trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_batches=0.5)
trainer.fit(model)

raw_checkpoint_path = _raw_checkpoint_path(trainer)

raw_checkpoint = torch.load(raw_checkpoint_path)
Expand Down Expand Up @@ -391,6 +392,7 @@ def test_load_past_checkpoint(tmpdir, past_key):
raw_checkpoint_path = _raw_checkpoint_path(trainer)
raw_checkpoint = torch.load(raw_checkpoint_path)
raw_checkpoint[past_key] = raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
raw_checkpoint['hparams_type'] = 'Namespace'
raw_checkpoint[past_key]['batch_size'] = -17
del raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
# save back the checkpoint
Expand Down

0 comments on commit 861a73b

Please sign in to comment.