Skip to content

Commit

Permalink
save last model after saving top_k when save_last=True (Lightning-AI#…
Browse files Browse the repository at this point in the history
…2881)

* save_last should be last

* changelog

* seed, docs

* retrigger ci

* compare filenames

* move constants

* fix test

* epoch, global step

* improve test
  • Loading branch information
awaelchli authored Aug 8, 2020
1 parent d9d7e91 commit f798cff
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 10 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed LR finder and `hparams` compatibility ([#2821](https://github.com/PyTorchLightning/pytorch-lightning/pull/2821))

- Fixed `ModelCheckpoint` not saving the latest information when `save_last=True` ([#2881](https://github.com/PyTorchLightning/pytorch-lightning/pull/2881))

## [0.8.5] - 2020-07-09

### Added
Expand Down
12 changes: 8 additions & 4 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ class ModelCheckpoint(Callback):
"""

CHECKPOINT_NAME_LAST = "last.ckpt"
CHECKPOINT_STATE_BEST_SCORE = "checkpoint_callback_best_model_score"
CHECKPOINT_STATE_BEST_PATH = "checkpoint_callback_best_model_path"

def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', verbose: bool = False,
save_last: bool = False, save_top_k: int = 1, save_weights_only: bool = False,
mode: str = 'auto', period: int = 1, prefix: str = ''):
Expand Down Expand Up @@ -302,10 +306,6 @@ def on_validation_end(self, trainer, pl_module):

self.epoch_last_check = epoch

if self.save_last:
filepath = os.path.join(self.dirpath, self.prefix + 'last.ckpt')
self._save_model(filepath, trainer, pl_module)

filepath = self.format_checkpoint_name(epoch, metrics)
version_cnt = 0
while os.path.isfile(filepath):
Expand Down Expand Up @@ -340,6 +340,10 @@ def on_validation_end(self, trainer, pl_module):
assert trainer.global_rank == 0, 'tried to make a checkpoint from non global_rank=0'
self._save_model(filepath, trainer, pl_module)

if self.save_last:
filepath = os.path.join(self.dirpath, self.prefix + ModelCheckpoint.CHECKPOINT_NAME_LAST)
self._save_model(filepath, trainer, pl_module)

def _do_check_save(self, filepath, current, epoch, trainer, pl_module):
# remove kth

Expand Down
10 changes: 5 additions & 5 deletions pytorch_lightning/trainer/training_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,8 +354,8 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
if checkpoint_callbacks:
# we add the official checkpoint callback to the end of the list
# extra user provided callbacks will not be persisted yet
checkpoint['checkpoint_callback_best_model_score'] = self.checkpoint_callback.best_model_score
checkpoint['checkpoint_callback_best_model_path'] = self.checkpoint_callback.best_model_path
checkpoint[ModelCheckpoint.CHECKPOINT_STATE_BEST_SCORE] = self.checkpoint_callback.best_model_score
checkpoint[ModelCheckpoint.CHECKPOINT_STATE_BEST_PATH] = self.checkpoint_callback.best_model_path

if early_stopping_callbacks and checkpoint_callbacks:
# we add the official early stopping callback to the end of the list
Expand Down Expand Up @@ -436,16 +436,16 @@ def restore_training_state(self, checkpoint):
early_stopping_callbacks = [c for c in self.callbacks if isinstance(c, EarlyStopping)]

if checkpoint_callbacks:
if 'checkpoint_callback_best_model_score' in checkpoint:
checkpoint_callbacks[-1].best_model_score = checkpoint['checkpoint_callback_best_model_score']
if ModelCheckpoint.CHECKPOINT_STATE_BEST_SCORE in checkpoint:
checkpoint_callbacks[-1].best_model_score = checkpoint[ModelCheckpoint.CHECKPOINT_STATE_BEST_SCORE]
else:
# Old naming until version 0.7.6
rank_zero_warn(
'Loading a checkpoint created with an old version of Lightning; '
'this will not be supported in the future.'
)
checkpoint_callbacks[-1].best_model_score = checkpoint['checkpoint_callback_best']
checkpoint_callbacks[-1].best_model_path = checkpoint['checkpoint_callback_best_model_path']
checkpoint_callbacks[-1].best_model_path = checkpoint[ModelCheckpoint.CHECKPOINT_STATE_BEST_PATH]

if early_stopping_callbacks:
state = checkpoint['early_stop_callback_state_dict']
Expand Down
37 changes: 36 additions & 1 deletion tests/callbacks/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

import cloudpickle
import pytest
import torch

import tests.base.develop_utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from tests.base import EvalModelTemplate
Expand Down Expand Up @@ -93,3 +94,37 @@ def test_model_checkpoint_no_extraneous_invocations(tmpdir):
)
result = trainer.fit(model)
assert 1 == result


def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
""" Tests that the checkpoint saved as 'last.ckpt' contains the latest information. """
seed_everything(100)
model = EvalModelTemplate()
num_epochs = 3
model_checkpoint = ModelCheckpoint(filepath=tmpdir, save_top_k=num_epochs, save_last=True)
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=False,
checkpoint_callback=model_checkpoint,
max_epochs=num_epochs,
)
trainer.fit(model)
path_last_epoch = model_checkpoint.format_checkpoint_name(num_epochs - 1, {}) # epoch=3.ckpt
path_last = str(tmpdir / ModelCheckpoint.CHECKPOINT_NAME_LAST) # last.ckpt
assert path_last_epoch != path_last
ckpt_last_epoch = torch.load(path_last_epoch)
ckpt_last = torch.load(path_last)
matching_keys = (
"epoch",
"global_step",
ModelCheckpoint.CHECKPOINT_STATE_BEST_SCORE,
ModelCheckpoint.CHECKPOINT_STATE_BEST_PATH,
)
for key in matching_keys:
assert ckpt_last_epoch[key] == ckpt_last[key]

# it is easier to load the model objects than to iterate over the raw dict of tensors
model_last_epoch = EvalModelTemplate.load_from_checkpoint(path_last_epoch)
model_last = EvalModelTemplate.load_from_checkpoint(path_last)
for w0, w1 in zip(model_last_epoch.parameters(), model_last.parameters()):
assert w0.eq(w1).all()

0 comments on commit f798cff

Please sign in to comment.