Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change Checkpoint callback's save_best_only to save_top_k #128

Merged
merged 50 commits into from
Nov 19, 2019
Merged
Changes from 2 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
6b54c1e
docs: enable syntax highlight
Ir1d Aug 13, 2019
49e35b1
feat: change Checkpoint callback's `save_best_only` to `save_top_k`
Ir1d Aug 17, 2019
08f57b6
docs: update docs for save_top_k
Ir1d Aug 17, 2019
4855bf8
revert other files
Ir1d Aug 17, 2019
2f6f784
style: lint for travis-ci
Ir1d Aug 17, 2019
daae566
fix typo
Ir1d Aug 17, 2019
a7da269
make flake8 happy
Ir1d Aug 17, 2019
373cd1d
update according to review
Ir1d Aug 19, 2019
49f78d6
add tests
Ir1d Aug 19, 2019
a34811a
Merge remote-tracking branch 'wf/master' into save_top_k
Ir1d Aug 19, 2019
fbc8a4e
rename func to private
Ir1d Aug 19, 2019
52bfcb7
add doc on `save_top_k == 0`
Ir1d Aug 19, 2019
38afd66
make flake8 happy
Ir1d Aug 20, 2019
ab92212
update according to PR comments
Ir1d Aug 22, 2019
7e8767f
change some f-strings
Ir1d Aug 22, 2019
3079b71
Update pt_callbacks.py
williamFalcon Aug 23, 2019
3bb6e3e
Update test_models.py
williamFalcon Aug 23, 2019
4626e1a
update options
Ir1d Aug 23, 2019
6a47d55
create folders
Ir1d Aug 23, 2019
4b3d5bf
Update test_models.py
williamFalcon Aug 23, 2019
d29eff7
change epoch num
Ir1d Aug 23, 2019
22cfeb5
Merge remote-tracking branch 'origin/save_top_k' into save_top_k
Ir1d Aug 23, 2019
d15b03b
support calling multiple times, add docs and tests
Ir1d Aug 23, 2019
58a8410
update docs
Ir1d Aug 23, 2019
b9a855d
Merge remote-tracking branch 'wf/master' into save_top_k
Ir1d Aug 23, 2019
e1f7a48
Merge remote-tracking branch 'wf/master' into save_top_k
Ir1d Aug 24, 2019
2dd65e9
roll back changes in earlystopping
Ir1d Aug 24, 2019
634ad80
clean test files
Ir1d Aug 24, 2019
1861f7d
rebase upstream
Ir1d Oct 22, 2019
b66b33c
make flake8 happy
Ir1d Oct 22, 2019
aed9ad9
Merge remote-tracking branch 'will/master' into save_top_k
Ir1d Nov 4, 2019
abb9629
fix epoch number
Ir1d Nov 4, 2019
a0c2269
update tests about epoch numbers
Ir1d Nov 4, 2019
4490466
clean debugging code
Ir1d Nov 4, 2019
0e98423
fix testing utils codes
Ir1d Nov 4, 2019
2eab575
fix testing utils codes
Ir1d Nov 4, 2019
27ebd1c
fix testing utils codes
Ir1d Nov 4, 2019
37cfedf
fix testing utils codes
Ir1d Nov 4, 2019
3f49122
change save_dir to tests/tests according to previous lines
Ir1d Nov 4, 2019
585fff0
remove unused overwrite option
Ir1d Nov 4, 2019
94a6c49
make flake8 happy
Ir1d Nov 4, 2019
b983860
change var name as per review
Ir1d Nov 4, 2019
c3f56dd
Merge remote-tracking branch 'will/master' into save_top_k
Ir1d Nov 5, 2019
13803ff
make flake8 happy
Ir1d Nov 5, 2019
7c58669
Merge remote-tracking branch 'will/master' into save_top_k
Ir1d Nov 5, 2019
235e5f1
update property name to work on master
Ir1d Nov 5, 2019
8635612
elaborate in the docs
Ir1d Nov 5, 2019
e59d5bd
update docs as per review
Ir1d Nov 6, 2019
3f05419
Merge branch 'master' into save_top_k
Ir1d Nov 16, 2019
b000db9
revert previous commit
Ir1d Nov 16, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 12 additions & 66 deletions pytorch_lightning/callbacks/pt_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,10 @@ def _save_model(self, filepath, overwrite):
self.save_function(filepath)

def check_monitor_top_k(self, current):
return ((len(self.best_k_models.keys()) < self.save_top_k) or
(self.monitor_op(current, self.best_k_models[self.kth_value])))
less_than_k_models = len(self.best_k_models.keys()) < self.save_top_k
if less_than_k_models:
williamFalcon marked this conversation as resolved.
Show resolved Hide resolved
return True
return self.monitor_op(current, self.best_k_models[self.kth_value])

def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
Expand All @@ -249,7 +251,7 @@ def on_epoch_end(self, epoch, logs=None):
print('Can save best model only with %s available,'
' skipping.' % (self.monitor), RuntimeWarning)
else:
if (self.check_monitor_top_k(current)):
if self.check_monitor_top_k(current):
if len(self.best_k_models.keys()) == self.save_top_k:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you have a method _del_model which is quite empty, so move this logic about removing k-th model there

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think so :)
_save_model handles a filename to save, and _del_model should do the same, and handles a filename to delete.

# need to pop the kth
delpath = '{}/{}_ckpt_epoch_{}.ckpt'.format(
Expand All @@ -268,19 +270,19 @@ def on_epoch_end(self, epoch, logs=None):
else:
self.best = max(self.best_k_models.values())
if self.verbose > 0:
print('\nEpoch %05d: %s reached %s (best %s),'
' saving model to %s as top %d'
% (epoch + 1, self.monitor, current, self.best,
filepath, self.save_top_k))
print(f"\nEpoch {epoch + 1:05d}: {self.monitor} reached",
f" {current} (best {self.best}), saving model to",
f" {filepath} as top {self.save_top_k}")
self._save_model(filepath, overwrite=False)

else:
if self.verbose > 0:
print('\nEpoch %05d: %s was not in top %d' %
(epoch + 1, self.monitor, self.save_top_k))
print(f"\nEpoch {epoch + 1:05d}: {self.monitor}",
f" was not in top {self.save_top_k}")

else:
if self.verbose > 0:
print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath))
print(f"\nEpoch {epoch + 1:05d}: saving model to {filepath}")
self._save_model(filepath, overwrite=False)


Expand All @@ -292,59 +294,3 @@ def on_epoch_end(self, epoch, logs=None):
print(loss)
if should_stop:
break

def my_own_save_function(filepath):
open(filepath, 'a').close()

def init_save_dir():
root_dir = os.path.dirname(os.path.realpath(__file__))
save_dir = os.path.join(root_dir, 'save_dir')

if os.path.exists(save_dir):
shutil.rmtree(save_dir)

os.makedirs(save_dir, exist_ok=True)

return save_dir

def clear_save_dir():
root_dir = os.path.dirname(os.path.realpath(__file__))
save_dir = os.path.join(root_dir, 'save_dir')
if os.path.exists(save_dir):
shutil.rmtree(save_dir)

save_dir = init_save_dir()
print(save_dir)

w = ModelCheckpoint(save_dir, save_top_k=0, verbose=1)
w.save_function = my_own_save_function
for i, loss in enumerate(losses):
w.on_epoch_end(i, logs={'val_loss': loss})

file_lists = os.listdir(save_dir)

assert len(file_lists) == 10, "Should save 10 models when save_top_k=0"

clear_save_dir()

w = ModelCheckpoint(save_dir, save_top_k=1, verbose=1)
w.save_function = my_own_save_function
for i, loss in enumerate(losses):
w.on_epoch_end(i, logs={'val_loss': loss})

file_lists = os.listdir(save_dir)

assert len(file_lists) == 1, "Should save 1 model when save_top_k=1"

clear_save_dir()

w = ModelCheckpoint(save_dir, save_top_k=2, verbose=1)
w.save_function = my_own_save_function
for i, loss in enumerate(losses):
w.on_epoch_end(i, logs={'val_loss': loss})

file_lists = os.listdir(save_dir)

assert len(file_lists) == 2, "Should save 2 model when save_top_k=2"

clear_save_dir()