Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change Checkpoint callback's save_best_only to save_top_k #128

Merged
merged 50 commits into from
Nov 19, 2019
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
6b54c1e
docs: enable syntax highlight
Ir1d Aug 13, 2019
49e35b1
feat: change Checkpoint callback's `save_best_only` to `save_top_k`
Ir1d Aug 17, 2019
08f57b6
docs: update docs for save_top_k
Ir1d Aug 17, 2019
4855bf8
revert other files
Ir1d Aug 17, 2019
2f6f784
style: lint for travis-ci
Ir1d Aug 17, 2019
daae566
fix typo
Ir1d Aug 17, 2019
a7da269
make flake8 happy
Ir1d Aug 17, 2019
373cd1d
update according to review
Ir1d Aug 19, 2019
49f78d6
add tests
Ir1d Aug 19, 2019
a34811a
Merge remote-tracking branch 'wf/master' into save_top_k
Ir1d Aug 19, 2019
fbc8a4e
rename func to private
Ir1d Aug 19, 2019
52bfcb7
add doc on `save_top_k == 0`
Ir1d Aug 19, 2019
38afd66
make flake8 happy
Ir1d Aug 20, 2019
ab92212
update according to PR comments
Ir1d Aug 22, 2019
7e8767f
change some f-strings
Ir1d Aug 22, 2019
3079b71
Update pt_callbacks.py
williamFalcon Aug 23, 2019
3bb6e3e
Update test_models.py
williamFalcon Aug 23, 2019
4626e1a
update options
Ir1d Aug 23, 2019
6a47d55
create folders
Ir1d Aug 23, 2019
4b3d5bf
Update test_models.py
williamFalcon Aug 23, 2019
d29eff7
change epoch num
Ir1d Aug 23, 2019
22cfeb5
Merge remote-tracking branch 'origin/save_top_k' into save_top_k
Ir1d Aug 23, 2019
d15b03b
support calling multiple times, add docs and tests
Ir1d Aug 23, 2019
58a8410
update docs
Ir1d Aug 23, 2019
b9a855d
Merge remote-tracking branch 'wf/master' into save_top_k
Ir1d Aug 23, 2019
e1f7a48
Merge remote-tracking branch 'wf/master' into save_top_k
Ir1d Aug 24, 2019
2dd65e9
roll back changes in earlystopping
Ir1d Aug 24, 2019
634ad80
clean test files
Ir1d Aug 24, 2019
1861f7d
rebase upstream
Ir1d Oct 22, 2019
b66b33c
make flake8 happy
Ir1d Oct 22, 2019
aed9ad9
Merge remote-tracking branch 'will/master' into save_top_k
Ir1d Nov 4, 2019
abb9629
fix epoch number
Ir1d Nov 4, 2019
a0c2269
update tests about epoch numbers
Ir1d Nov 4, 2019
4490466
clean debugging code
Ir1d Nov 4, 2019
0e98423
fix testing utils codes
Ir1d Nov 4, 2019
2eab575
fix testing utils codes
Ir1d Nov 4, 2019
27ebd1c
fix testing utils codes
Ir1d Nov 4, 2019
37cfedf
fix testing utils codes
Ir1d Nov 4, 2019
3f49122
change save_dir to tests/tests according to previous lines
Ir1d Nov 4, 2019
585fff0
remove unused overwrite option
Ir1d Nov 4, 2019
94a6c49
make flake8 happy
Ir1d Nov 4, 2019
b983860
change var name as per review
Ir1d Nov 4, 2019
c3f56dd
Merge remote-tracking branch 'will/master' into save_top_k
Ir1d Nov 5, 2019
13803ff
make flake8 happy
Ir1d Nov 5, 2019
7c58669
Merge remote-tracking branch 'will/master' into save_top_k
Ir1d Nov 5, 2019
235e5f1
update property name to work on master
Ir1d Nov 5, 2019
8635612
elaborate in the docs
Ir1d Nov 5, 2019
e59d5bd
update docs as per review
Ir1d Nov 6, 2019
3f05419
Merge branch 'master' into save_top_k
Ir1d Nov 16, 2019
b000db9
revert previous commit
Ir1d Nov 16, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/Trainer/Checkpointing.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ from pytorch_lightning.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(
filepath='/path/to/store/weights.ckpt',
save_best_only=True,
save_top_k=1,
verbose=True,
monitor='val_loss',
mode='min'
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/Examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def main(hparams, cluster, results_dict):
checkpoint = ModelCheckpoint(
filepath=model_save_path,
save_function=None,
save_best_only=True,
save_top_k=1,
verbose=True,
monitor=hparams.model_save_monitor_value,
mode=hparams.model_save_monitor_mode
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def main(hparams, cluster, results_dict):

checkpoint = ModelCheckpoint(
filepath=model_save_path,
save_best_only=True,
save_top_k=1,
verbose=True,
monitor='val_loss',
mode='min'
Expand Down
2 changes: 1 addition & 1 deletion examples/new_project_templates/single_cpu_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def main(hparams):

checkpoint = ModelCheckpoint(
filepath=model_save_path,
save_best_only=True,
save_top_k=1,
verbose=True,
monitor='val_loss',
mode='min'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def main(hparams):

checkpoint = ModelCheckpoint(
filepath=model_save_path,
save_best_only=True,
save_top_k=1,
verbose=True,
monitor='val_loss',
mode='min'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def main(hparams):

checkpoint = ModelCheckpoint(
filepath=model_save_path,
save_best_only=True,
save_top_k=1,
verbose=True,
monitor='val_loss',
mode='min'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def main(hparams):

checkpoint = ModelCheckpoint(
filepath=model_save_path,
save_best_only=True,
save_top_k=1,
verbose=True,
monitor='val_loss',
mode='min'
Expand Down
2 changes: 1 addition & 1 deletion examples/new_project_templates/trainer_cpu_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def main(hparams):
model_save_path = '{}/{}/{}'.format(hparams.model_save_path, exp.name, exp.version)
checkpoint = ModelCheckpoint(
filepath=model_save_path,
save_best_only=True,
save_top_k=1,
verbose=True,
monitor='val_acc',
mode='min'
Expand Down
87 changes: 64 additions & 23 deletions pytorch_lightning/callbacks/pt_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,11 @@ class ModelCheckpoint(Callback):
filepath: string, path to save the model file.
monitor: quantity to monitor.
verbose: verbosity mode, 0 or 1.
save_best_only: if `save_best_only=True`,
the latest best model according to
the quantity monitored will not be overwritten.
save_top_k: if `save_top_k == k`,
Ir1d marked this conversation as resolved.
Show resolved Hide resolved
the best k models according to
the quantity monitored will be saved.
mode: one of {auto, min, max}.
If `save_best_only=True`, the decision
If `save_top_k > 0`, the decision
to overwrite the current save file is made
based on either the maximization or the
minimization of the monitored quantity. For `val_acc`,
Expand All @@ -174,17 +174,20 @@ class ModelCheckpoint(Callback):
"""

def __init__(self, filepath, monitor='val_loss', verbose=0,
save_best_only=False, save_weights_only=False,
save_top_k=0, save_weights_only=False,
Ir1d marked this conversation as resolved.
Show resolved Hide resolved
mode='auto', period=1, prefix=''):
super(ModelCheckpoint, self).__init__()
self.monitor = monitor
self.verbose = verbose
self.filepath = filepath
self.save_best_only = save_best_only
self.save_top_k = save_top_k
self.save_weights_only = save_weights_only
self.period = period
self.epochs_since_last_save = 0
self.epochs_since_last_check = 0
self.prefix = prefix
self.bestk = {}
# {epoch: monitor}
self.best = 0

if mode not in ['auto', 'min', 'max']:
print('ModelCheckpoint mode %s is unknown, '
Expand All @@ -193,17 +196,35 @@ def __init__(self, filepath, monitor='val_loss', verbose=0,

if mode == 'min':
self.monitor_op = np.less
self.best = np.Inf
self.kth = np.Inf
Ir1d marked this conversation as resolved.
Show resolved Hide resolved
self.mode = 'min'
elif mode == 'max':
self.monitor_op = np.greater
self.best = -np.Inf
self.kth = -np.Inf
self.mode = 'max'
else:
if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
self.monitor_op = np.greater
self.best = -np.Inf
self.kth = -np.Inf
self.mode = 'max'
else:
self.monitor_op = np.less
self.best = np.Inf
self.kth = np.Inf
self.mode = 'min'

def del_model(self, filepath):
Ir1d marked this conversation as resolved.
Show resolved Hide resolved
dirpath = '/'.join(filepath.split('/')[:-1])
Ir1d marked this conversation as resolved.
Show resolved Hide resolved

# make paths
os.makedirs(os.path.dirname(filepath), exist_ok=True)

for filename in os.listdir(dirpath):
if self.prefix in filename:
path_to_delete = os.path.join(dirpath, filename)
try:
Ir1d marked this conversation as resolved.
Show resolved Hide resolved
shutil.rmtree(path_to_delete)
except OSError:
os.remove(path_to_delete)

def save_model(self, filepath, overwrite):
dirpath = '/'.join(filepath.split('/')[:-1])
Expand All @@ -225,29 +246,46 @@ def save_model(self, filepath, overwrite):

def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
self.epochs_since_last_save += 1
if self.epochs_since_last_save >= self.period:
self.epochs_since_last_save = 0
self.epochs_since_last_check += 1
if self.epochs_since_last_check >= self.period:
self.epochs_since_last_check = 0
filepath = '{}/{}_ckpt_epoch_{}.ckpt'.format(self.filepath, self.prefix, epoch + 1)
if self.save_best_only:
if self.save_top_k:
current = logs.get(self.monitor)
if current is None:
print('Can save best model only with %s available,'
' skipping.' % (self.monitor), RuntimeWarning)
else:
if self.monitor_op(current, self.best):
if ((len(self.bestk.keys()) < self.save_top_k) or
Ir1d marked this conversation as resolved.
Show resolved Hide resolved
(self.monitor_op(current, self.bestk[self.kth]))):
if len(self.bestk.keys()) == self.save_top_k:
# need to pop the kth
delpath = '{}/{}_ckpt_epoch_{}.ckpt'.format(
self.filepath, self.prefix, self.kth + 1)
self.bestk.pop(self.kth)
Ir1d marked this conversation as resolved.
Show resolved Hide resolved
self.del_model(delpath)
self.bestk[epoch] = current
if len(self.bestk.keys()) == self.save_top_k:
# monitor dict has reached k elements
Ir1d marked this conversation as resolved.
Show resolved Hide resolved
if self.mode == 'min':
self.kth = max(self.bestk, key=self.bestk.get)
else:
self.kth = min(self.bestk, key=self.bestk.get)
if self.mode == 'min':
self.best = min(self.bestk.values())
else:
self.best = max(self.bestk.values())
if self.verbose > 0:
print('\nEpoch %05d: %s improved from %0.5f to %0.5f,'
' saving model to %s'
% (epoch + 1, self.monitor, self.best,
current, filepath))
self.best = current
print('\nEpoch %05d: %s reached %s (best %s),'
Ir1d marked this conversation as resolved.
Show resolved Hide resolved
' saving model to %s as top %d'
Ir1d marked this conversation as resolved.
Show resolved Hide resolved
% (epoch + 1, self.monitor, current, self.best,
filepath, self.save_top_k))
self.save_model(filepath, overwrite=True)

else:
if self.verbose > 0:
print('\nEpoch %05d: %s did not improve' %
(epoch + 1, self.monitor))
print('\nEpoch %05d: %s was not in top %d' %
(epoch + 1, self.monitor, self.save_top_k))
else:
if self.verbose > 0:
print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath))
Expand All @@ -262,3 +300,6 @@ def on_epoch_end(self, epoch, logs=None):
print(loss)
if should_stop:
break
w = ModelCheckpoint('res', save_top_k=2, verbose=1)
for i, loss in enumerate(losses):
w.on_epoch_end(i, logs={'val_loss': loss})