Skip to content

Commit

Permalink
Fix ModelCheckpoints name formatting (#3163)
Browse files Browse the repository at this point in the history
* Fix ModelCheckpoint's name formatting

* Fix failing tests

* Add dot to CHECKPOINT_SUFFIX

* Set variables to their default values at the end of tests

* Fix logic for filepath='' and filename=None. Add test

* Fix Windows tests

* Fix typo. Remove leading line break and zeroes

* Remove CHECKPOINT_SUFFIX

* Fix typos. Use appropriate f-string format

* Apply suggestions from code review

* Fix broken tests after #3320

* Finish changes suggested by Borda

* Use explicit test var names

* Apply suggestions

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* Apply suggestions

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* Update CHANGELOG

* Apply suggestions from code review

* for

* prepend whitespace in warn msg

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
  • Loading branch information
4 people committed Sep 18, 2020
1 parent 197acd5 commit 580b04b
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 66 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed gradient norm tracking for `row_log_interval > 1` ([#3489](https://github.com/PyTorchLightning/pytorch-lightning/pull/3489))

- Fixed `ModelCheckpoint` name formatting ([3164](https://github.com/PyTorchLightning/pytorch-lightning/pull/3163))

## [0.9.0] - YYYY-MM-DD

### Added
Expand Down
17 changes: 8 additions & 9 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,9 @@ def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience:
def _validate_condition_metric(self, logs):
monitor_val = logs.get(self.monitor)
error_msg = (f'Early stopping conditioned on metric `{self.monitor}`'
f' which is not available. Either add `{self.monitor}` to the return of '
f' validation_epoch end or modify your EarlyStopping callback to use any of the '
f'following: `{"`, `".join(list(logs.keys()))}`')
f' which is not available. Either add `{self.monitor}` to the return of'
' `validation_epoch_end` or modify your `EarlyStopping` callback to use any of the'
f' following: `{"`, `".join(list(logs.keys()))}`')

if monitor_val is None:
if self.strict:
Expand Down Expand Up @@ -188,12 +188,11 @@ def __warn_deprecated_monitor_key(self):
invalid_key = self.monitor not in ['val_loss', 'early_stop_on', 'val_early_stop_on', 'loss']
if using_result_obj and not self.warned_result_obj and invalid_key:
self.warned_result_obj = True
m = f"""
When using EvalResult(early_stop_on=X) or TrainResult(early_stop_on=X) the
'monitor' key of EarlyStopping has no effect.
Remove EarlyStopping(monitor='{self.monitor}) to fix')
"""
rank_zero_warn(m)
rank_zero_warn(
f"When using `EvalResult(early_stop_on=X)` or `TrainResult(early_stop_on=X)`"
" the 'monitor' key of `EarlyStopping` has no effect. "
f" Remove `EarlyStopping(monitor='{self.monitor}')` to fix."
)

def _run_early_stopping_check(self, trainer, pl_module):
"""
Expand Down
93 changes: 50 additions & 43 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ class ModelCheckpoint(Callback):
Example::
# custom path
# saves a file like: my/path/epoch_0.ckpt
# saves a file like: my/path/epoch=0.ckpt
>>> checkpoint_callback = ModelCheckpoint('my/path/')
# save any arbitrary metrics like `val_loss`, etc. in name
# saves a file like: my/path/epoch=2-val_loss=0.2_other_metric=0.3.ckpt
# saves a file like: my/path/epoch=2-val_loss=0.02-other_metric=0.03.ckpt
>>> checkpoint_callback = ModelCheckpoint(
... filepath='my/path/{epoch}-{val_loss:.2f}-{other_metric:.2f}'
... )
Expand Down Expand Up @@ -97,9 +97,9 @@ class ModelCheckpoint(Callback):
>>> trainer = Trainer(checkpoint_callback=checkpoint_callback)
# save epoch and val_loss in name
# saves a file like: my/path/sample-mnist_epoch=02_val_loss=0.32.ckpt
# saves a file like: my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt
>>> checkpoint_callback = ModelCheckpoint(
... filepath='my/path/sample-mnist_{epoch:02d}-{val_loss:.2f}'
... filepath='my/path/sample-mnist-{epoch:02d}-{val_loss:.2f}'
... )
# retrieve the best checkpoint after training
Expand All @@ -111,32 +111,30 @@ class ModelCheckpoint(Callback):
"""

CHECKPOINT_NAME_LAST = "last.ckpt"
CHECKPOINT_JOIN_CHAR = "-"
CHECKPOINT_NAME_LAST = "last"
CHECKPOINT_STATE_BEST_SCORE = "checkpoint_callback_best_model_score"
CHECKPOINT_STATE_BEST_PATH = "checkpoint_callback_best_model_path"

def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', verbose: bool = False,
save_last: bool = False, save_top_k: int = 1, save_weights_only: bool = False,
mode: str = 'auto', period: int = 1, prefix: str = ''):
super().__init__()
if filepath:
self._fs = get_filesystem(filepath)
else:
self._fs = get_filesystem("") # will give local fileystem
self._fs = get_filesystem(filepath if filepath is not None else "")
if save_top_k > 0 and filepath is not None and self._fs.isdir(filepath) and len(self._fs.ls(filepath)) > 0:
rank_zero_warn(
f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0."
"All files in this directory will be deleted when a checkpoint is saved!"
" All files in this directory will be deleted when a checkpoint is saved!"
)
self._rank = 0

self.monitor = monitor
self.verbose = verbose
if filepath is None: # will be determined by trainer at runtime
if not filepath: # will be determined by trainer at runtime
self.dirpath, self.filename = None, None
else:
if self._fs.isdir(filepath):
self.dirpath, self.filename = filepath, "{epoch}"
self.dirpath, self.filename = filepath, None
else:
if self._fs.protocol == "file": # dont normalize remote paths
filepath = os.path.realpath(filepath)
Expand All @@ -153,6 +151,7 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve
self.kth_best_model_path = ''
self.best_model_score = 0
self.best_model_path = ''
self.last_model_path = ''
self.save_function = None
self.warned_result_obj = False

Expand Down Expand Up @@ -220,6 +219,23 @@ def check_monitor_top_k(self, current):

return monitor_op(current, self.best_k_models[self.kth_best_model_path])

@classmethod
def _format_checkpoint_name(cls, filename, epoch, metrics, prefix=""):
if not filename:
# filename is not set, use default name
filename = '{epoch}'
# check and parse user passed keys in the string
groups = re.findall(r'(\{.*?)[:\}]', filename)
if groups:
metrics['epoch'] = epoch
for group in groups:
name = group[1:]
filename = filename.replace(group, name + '={' + name)
if name not in metrics:
metrics[name] = 0
filename = filename.format(**metrics)
return cls.CHECKPOINT_JOIN_CHAR.join([txt for txt in (prefix, filename) if txt])

def format_checkpoint_name(self, epoch, metrics, ver=None):
"""Generate a filename according to the defined template.
Expand All @@ -239,24 +255,11 @@ def format_checkpoint_name(self, epoch, metrics, ver=None):
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
'missing=0.ckpt'
"""
# check if user passed in keys to the string
groups = re.findall(r'(\{.*?)[:\}]', self.filename)

if len(groups) == 0:
# default name
filename = f'{self.prefix}_ckpt_epoch_{epoch}'
else:
metrics['epoch'] = epoch
filename = self.filename
for tmp in groups:
name = tmp[1:]
filename = filename.replace(tmp, name + '={' + name)
if name not in metrics:
metrics[name] = 0
filename = filename.format(**metrics)
str_ver = f'_v{ver}' if ver is not None else ''
filepath = os.path.join(self.dirpath, self.prefix + filename + str_ver + '.ckpt')
return filepath
filename = self._format_checkpoint_name(self.filename, epoch, metrics, prefix=self.prefix)
if ver is not None:
filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f'v{ver}'))
ckpt_name = f'{filename}.ckpt'
return os.path.join(self.dirpath, ckpt_name) if self.dirpath else ckpt_name

@rank_zero_only
def on_pretrain_routine_start(self, trainer, pl_module):
Expand All @@ -276,7 +279,7 @@ def on_pretrain_routine_start(self, trainer, pl_module):
if self.dirpath is not None:
return # short circuit

self.filename = '{epoch}'
self.filename = None

if trainer.logger is not None:
if trainer.weights_save_path != trainer.default_root_dir:
Expand Down Expand Up @@ -306,12 +309,11 @@ def __warn_deprecated_monitor_key(self):
invalid_key = self.monitor not in ['val_loss', 'checkpoint_on', 'loss', 'val_checkpoint_on']
if using_result_obj and not self.warned_result_obj and invalid_key:
self.warned_result_obj = True
m = f"""
When using EvalResult(checkpoint_on=X) or TrainResult(checkpoint_on=X) the
'monitor' key of ModelCheckpoint has no effect.
Remove ModelCheckpoint(monitor='{self.monitor}) to fix')
"""
rank_zero_warn(m)
rank_zero_warn(
f"When using `EvalResult(checkpoint_on=X)` or `TrainResult(checkpoint_on=X)`"
" the 'monitor' key of `ModelCheckpoint` has no effect."
f" Remove `ModelCheckpoint(monitor='{self.monitor}')` to fix."
)

@rank_zero_only
def on_validation_end(self, trainer, pl_module):
Expand Down Expand Up @@ -371,18 +373,23 @@ def on_validation_end(self, trainer, pl_module):
elif self.check_monitor_top_k(current):
self._do_check_save(filepath, current, epoch, trainer, pl_module)
elif self.verbose > 0:
log.info(f'\nEpoch {epoch:05d}: {self.monitor} was not in top {self.save_top_k}')
log.info(f'Epoch {epoch:d}: {self.monitor} was not in top {self.save_top_k}')

else:
if self.verbose > 0:
log.info(f'\nEpoch {epoch:05d}: saving model to {filepath}')
log.info(f'Epoch {epoch:d}: saving model to {filepath}')

assert trainer.global_rank == 0, 'tried to make a checkpoint from non global_rank=0'
self._save_model(filepath, trainer, pl_module)

if self.save_last:
filepath = os.path.join(self.dirpath, self.prefix + ModelCheckpoint.CHECKPOINT_NAME_LAST)
filename = self._format_checkpoint_name(
self.CHECKPOINT_NAME_LAST, epoch, ckpt_name_metrics, prefix=self.prefix
)
filepath = os.path.join(self.dirpath, f'{filename}.ckpt')
self._save_model(filepath, trainer, pl_module)
if self.last_model_path and self.last_model_path != filepath:
self._del_model(self.last_model_path)

def _do_check_save(self, filepath, current, epoch, trainer, pl_module):
# remove kth
Expand All @@ -407,9 +414,9 @@ def _do_check_save(self, filepath, current, epoch, trainer, pl_module):

if self.verbose > 0:
log.info(
f'\nEpoch {epoch:05d}: {self.monitor} reached'
f' {current:0.5f} (best {self.best_model_score:0.5f}), saving model to'
f' {filepath} as top {self.save_top_k}')
f'Epoch {epoch:d}: {self.monitor} reached'
f' {current:0.5f} (best {self.best_model_score:0.5f}),'
f' saving model to {filepath} as top {self.save_top_k}')
self._save_model(filepath, trainer, pl_module)

for cur_path in del_list:
Expand Down
56 changes: 47 additions & 9 deletions tests/callbacks/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

@pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2])
def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k):
""" Test that None in checkpoint callback is valid and that chkp_path is set correctly """
"""Test that None in checkpoint callback is valid and that ckpt_path is set correctly"""
tutils.reset_seed()
model = EvalModelTemplate()

Expand Down Expand Up @@ -111,25 +111,62 @@ def test_model_checkpoint_no_extraneous_invocations(tmpdir):
assert 1 == result


def test_model_checkpoint_format_checkpoint_name(tmpdir):
# empty filename:
ckpt_name = ModelCheckpoint._format_checkpoint_name('', 3, {})
assert ckpt_name == 'epoch=3'
ckpt_name = ModelCheckpoint._format_checkpoint_name(None, 3, {}, prefix='test')
assert ckpt_name == 'test-epoch=3'
# no groups case:
ckpt_name = ModelCheckpoint._format_checkpoint_name('ckpt', 3, {}, prefix='test')
assert ckpt_name == 'test-ckpt'
# no prefix
ckpt_name = ModelCheckpoint._format_checkpoint_name('{epoch:03d}-{acc}', 3, {'acc': 0.03})
assert ckpt_name == 'epoch=003-acc=0.03'
# prefix
char_org = ModelCheckpoint.CHECKPOINT_JOIN_CHAR
ModelCheckpoint.CHECKPOINT_JOIN_CHAR = '@'
ckpt_name = ModelCheckpoint._format_checkpoint_name('{epoch},{acc:.5f}', 3, {'acc': 0.03}, prefix='test')
assert ckpt_name == 'test@epoch=3,acc=0.03000'
ModelCheckpoint.CHECKPOINT_JOIN_CHAR = char_org
# no filepath set
ckpt_name = ModelCheckpoint(filepath=None).format_checkpoint_name(3, {})
assert ckpt_name == 'epoch=3.ckpt'
ckpt_name = ModelCheckpoint(filepath='').format_checkpoint_name(5, {})
assert ckpt_name == 'epoch=5.ckpt'
# CWD
ckpt_name = ModelCheckpoint(filepath='.').format_checkpoint_name(3, {})
assert Path(ckpt_name) == Path('.') / 'epoch=3.ckpt'
# dir does not exist so it is used as filename
filepath = tmpdir / 'dir'
ckpt_name = ModelCheckpoint(filepath=filepath, prefix='test').format_checkpoint_name(3, {})
assert ckpt_name == tmpdir / 'test-dir.ckpt'
# now, dir exists
os.mkdir(filepath)
ckpt_name = ModelCheckpoint(filepath=filepath, prefix='test').format_checkpoint_name(3, {})
assert ckpt_name == filepath / 'test-epoch=3.ckpt'
# with ver
ckpt_name = ModelCheckpoint(filepath=tmpdir / 'name', prefix='test').format_checkpoint_name(3, {}, ver=3)
assert ckpt_name == tmpdir / 'test-name-v3.ckpt'


def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
""" Tests that the checkpoint saved as 'last.ckpt' contains the latest information. """
"""Tests that the save_last checkpoint contains the latest information."""
seed_everything(100)
model = EvalModelTemplate()
num_epochs = 3
model_checkpoint = ModelCheckpoint(
filepath=tmpdir, save_top_k=num_epochs, save_last=True
)
ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last-{epoch}'
model_checkpoint = ModelCheckpoint(filepath=tmpdir, save_top_k=num_epochs, save_last=True)
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=False,
checkpoint_callback=model_checkpoint,
max_epochs=num_epochs,
)
trainer.fit(model)
path_last_epoch = model_checkpoint.format_checkpoint_name(
num_epochs - 1, {}
) # epoch=3.ckpt
path_last = str(tmpdir / ModelCheckpoint.CHECKPOINT_NAME_LAST) # last.ckpt
last_filename = model_checkpoint._format_checkpoint_name(ModelCheckpoint.CHECKPOINT_NAME_LAST, num_epochs - 1, {})
path_last_epoch = model_checkpoint.format_checkpoint_name(num_epochs - 1, {}) # epoch=3.ckpt
path_last = str(tmpdir / f'{last_filename}.ckpt') # last-epoch=3.ckpt
assert path_last_epoch != path_last
ckpt_last_epoch = torch.load(path_last_epoch)
ckpt_last = torch.load(path_last)
Expand All @@ -150,6 +187,7 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
model_last = EvalModelTemplate.load_from_checkpoint(path_last)
for w0, w1 in zip(model_last_epoch.parameters(), model_last.parameters()):
assert w0.eq(w1).all()
ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last'


def test_ckpt_metric_names(tmpdir):
Expand Down
7 changes: 4 additions & 3 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def test_dp_output_reduce():
@pytest.mark.parametrize(["save_top_k", "save_last", "file_prefix", "expected_files"], [
pytest.param(-1, False, '', {'epoch=4.ckpt', 'epoch=3.ckpt', 'epoch=2.ckpt', 'epoch=1.ckpt', 'epoch=0.ckpt'},
id="CASE K=-1 (all)"),
pytest.param(1, False, 'test_prefix_', {'test_prefix_epoch=4.ckpt'},
pytest.param(1, False, 'test_prefix', {'test_prefix-epoch=4.ckpt'},
id="CASE K=1 (2.5, epoch 4)"),
pytest.param(2, False, '', {'epoch=4.ckpt', 'epoch=2.ckpt'},
id="CASE K=2 (2.5 epoch 4, 2.8 epoch 2)"),
Expand Down Expand Up @@ -413,8 +413,9 @@ def mock_save_function(filepath, *args):

file_lists = set(os.listdir(tmpdir))

assert len(file_lists) == len(expected_files), \
"Should save %i models when save_top_k=%i" % (len(expected_files), save_top_k)
assert len(file_lists) == len(expected_files), (
f"Should save {len(expected_files)} models when save_top_k={save_top_k} but found={file_lists}"
)

# verify correct naming
for fname in expected_files:
Expand Down
4 changes: 2 additions & 2 deletions tests/trainer/test_trainer_steps_result_return.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ def test_result_monitor_warnings(tmpdir):
checkpoint_callback=ModelCheckpoint(monitor='not_val_loss')
)

with pytest.warns(UserWarning, match='key of ModelCheckpoint has no effect'):
with pytest.warns(UserWarning, match='key of `ModelCheckpoint` has no effect'):
trainer.fit(model)

trainer = Trainer(
Expand All @@ -621,5 +621,5 @@ def test_result_monitor_warnings(tmpdir):
early_stop_callback=EarlyStopping(monitor='not_val_loss')
)

with pytest.warns(UserWarning, match='key of EarlyStopping has no effect'):
with pytest.warns(UserWarning, match='key of `EarlyStopping` has no effect'):
trainer.fit(model)

0 comments on commit 580b04b

Please sign in to comment.