Skip to content

Commit

Permalink
feat: save checkpoint before deleting old ones (#1453)
Browse files Browse the repository at this point in the history
* feat: save checkpoint before deleting old ones

* fix: make sure that the new model is not deleted

* changelog

Co-authored-by: J. Borovec <jirka.borovec@seznam.cz>
Co-authored-by: William Falcon <waf2107@columbia.edu>
  • Loading branch information
3 people committed Apr 16, 2020
1 parent 2ab2f7d commit 9b31272
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed


- Fixed saving checkpoint before deleting old ones ([#1453](https://github.com/PyTorchLightning/pytorch-lightning/pull/1453))

- Fixed loggers - flushing last logged metrics even before continue, e.g. `trainer.test()` results ([#1459](https://github.com/PyTorchLightning/pytorch-lightning/pull/1459))

- Added a missing call to the `on_before_zero_grad` model hook ([#1493](https://github.com/PyTorchLightning/pytorch-lightning/pull/1493)).
Expand Down
8 changes: 7 additions & 1 deletion pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,12 @@ def on_validation_end(self, trainer, pl_module):

def _do_check_save(self, filepath, current, epoch):
# remove kth

del_list = []
if len(self.best_k_models) == self.save_top_k and self.save_top_k > 0:
delpath = self.kth_best_model
self.best_k_models.pop(self.kth_best_model)
self._del_model(delpath)
del_list.append(delpath)

self.best_k_models[filepath] = current
if len(self.best_k_models) == self.save_top_k:
Expand All @@ -238,3 +240,7 @@ def _do_check_save(self, filepath, current, epoch):
f' {current:0.5f} (best {self.best:0.5f}), saving model to'
f' {filepath} as top {self.save_top_k}')
self._save_model(filepath)

for cur_path in del_list:
if cur_path != filepath:
self._del_model(cur_path)

0 comments on commit 9b31272

Please sign in to comment.