Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

re-enabled naming metrics in ckpt name #3060

Merged
merged 6 commits into from
Aug 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,10 +339,11 @@ def on_validation_end(self, trainer, pl_module):

self.epoch_last_check = epoch

filepath = self.format_checkpoint_name(epoch, metrics)
ckpt_name_metrics = trainer.logged_metrics
filepath = self.format_checkpoint_name(epoch, ckpt_name_metrics)
version_cnt = 0
while gfile.exists(filepath):
filepath = self.format_checkpoint_name(epoch, metrics, ver=version_cnt)
filepath = self.format_checkpoint_name(epoch, ckpt_name_metrics, ver=version_cnt)
# this epoch called before
version_cnt += 1

Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/trainer/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class TrainerLoggingMixin(ABC):
default_root_dir: str
slurm_job_id: int
num_gpus: int
logged_metrics: ...

def configure_logger(self, logger):
if logger is True:
Expand Down Expand Up @@ -75,6 +76,8 @@ def log_metrics(self, metrics, grad_norm_dic, step=None):
self.logger.agg_and_log_metrics(scalar_metrics, step=step)
self.logger.save()

# track the logged metrics
self.logged_metrics = scalar_metrics
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@williamFalcon naming metrics in ckpt filename is not possible if not using a logger?

self.dev_debugger.track_logged_metrics_history(scalar_metrics)

def add_progress_bar_metrics(self, metrics):
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ def __init__(
self.batch_idx = 0
self.progress_bar_metrics = {}
self.callback_metrics = {}
self.logged_metrics = {}
self.num_training_batches = 0
self.num_val_batches = []
self.num_test_batches = []
Expand Down
56 changes: 56 additions & 0 deletions tests/callbacks/test_model_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import re
import pickle
import platform
from pathlib import Path
Expand Down Expand Up @@ -128,3 +129,58 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
model_last = EvalModelTemplate.load_from_checkpoint(path_last)
for w0, w1 in zip(model_last_epoch.parameters(), model_last.parameters()):
assert w0.eq(w1).all()


def test_ckpt_metric_names(tmpdir):
model = EvalModelTemplate()

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
gradient_clip_val=1.0,
overfit_batches=0.20,
progress_bar_refresh_rate=0,
limit_train_batches=0.01,
limit_val_batches=0.01,
checkpoint_callback=ModelCheckpoint(filepath=tmpdir + '/{val_loss:.2f}')
)

trainer.fit(model)

# make sure the checkpoint we saved has the metric in the name
ckpts = os.listdir(tmpdir)
ckpts = [x for x in ckpts if 'val_loss' in x]
assert len(ckpts) == 1
val = re.sub('[^0-9.]', '', ckpts[0])
assert len(val) > 3


def test_ckpt_metric_names_results(tmpdir):
model = EvalModelTemplate()
model.training_step = model.training_step_result_obj
model.training_step_end = None
model.training_epoch_end = None

model.validation_step = model.validation_step_result_obj
model.validation_step_end = None
model.validation_epoch_end = None

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
gradient_clip_val=1.0,
overfit_batches=0.20,
progress_bar_refresh_rate=0,
limit_train_batches=0.01,
limit_val_batches=0.01,
checkpoint_callback=ModelCheckpoint(filepath=tmpdir + '/{val_loss:.2f}')
)

trainer.fit(model)

# make sure the checkpoint we saved has the metric in the name
ckpts = os.listdir(tmpdir)
ckpts = [x for x in ckpts if 'val_loss' in x]
assert len(ckpts) == 1
val = re.sub('[^0-9.]', '', ckpts[0])
assert len(val) > 3
8 changes: 4 additions & 4 deletions tests/trainer/test_eval_loop_dict_return.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def test_validation_step_dict_return(tmpdir):
assert k in eval_results[1]

# ensure all the keys ended up as candidates for callbacks
assert len(trainer.callback_metrics) == 8
assert len(trainer.callback_metrics) == 7

# make sure correct steps were called
assert model.validation_step_called
Expand Down Expand Up @@ -211,7 +211,7 @@ def test_val_step_step_end(tmpdir):
assert k in eval_results[1]

# ensure all the keys ended up as candidates for callbacks
assert len(trainer.callback_metrics) == 9
assert len(trainer.callback_metrics) == 8

# make sure correct steps were called
assert model.validation_step_called
Expand Down Expand Up @@ -254,7 +254,7 @@ def test_no_val_step_end(tmpdir):
assert k in eval_results

# ensure all the keys ended up as candidates for callbacks
assert len(trainer.callback_metrics) == 9
assert len(trainer.callback_metrics) == 8

# make sure correct steps were called
assert model.validation_step_called
Expand Down Expand Up @@ -297,7 +297,7 @@ def test_full_val_loop(tmpdir):
assert k in eval_results

# ensure all the keys ended up as candidates for callbacks
assert len(trainer.callback_metrics) == 10
assert len(trainer.callback_metrics) == 9

# make sure correct steps were called
assert model.validation_step_called
Expand Down
4 changes: 2 additions & 2 deletions tests/trainer/test_trainer_steps_scalar_return.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def test_full_training_loop_scalar(tmpdir):
assert model.training_epoch_end_called

# assert epoch end metrics were added
assert 'epoch' in trainer.callback_metrics and len(trainer.callback_metrics) == 1
assert len(trainer.callback_metrics) == 0
assert len(trainer.progress_bar_metrics) == 0

# make sure training outputs what is expected
Expand Down Expand Up @@ -151,7 +151,7 @@ def test_train_step_epoch_end_scalar(tmpdir):
assert model.training_epoch_end_called

# assert epoch end metrics were added
assert 'epoch' in trainer.callback_metrics and len(trainer.callback_metrics) == 1
assert len(trainer.callback_metrics) == 0
assert len(trainer.progress_bar_metrics) == 0

# make sure training outputs what is expected
Expand Down