Skip to content

Commit

Permalink
re-enabled naming metrics in ckpt name (#3060)
Browse files Browse the repository at this point in the history
* re-enabled naming metrics in ckpt name

* re-enabled naming metrics in ckpt name

* re-enabled naming metrics in ckpt name

* re-enabled naming metrics in ckpt name

* re-enabled naming metrics in ckpt name

* re-enabled naming metrics in ckpt name
  • Loading branch information
williamFalcon committed Aug 20, 2020
1 parent cefc7f7 commit 3453bba
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 8 deletions.
5 changes: 3 additions & 2 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,10 +339,11 @@ def on_validation_end(self, trainer, pl_module):

self.epoch_last_check = epoch

filepath = self.format_checkpoint_name(epoch, metrics)
ckpt_name_metrics = trainer.logged_metrics
filepath = self.format_checkpoint_name(epoch, ckpt_name_metrics)
version_cnt = 0
while gfile.exists(filepath):
filepath = self.format_checkpoint_name(epoch, metrics, ver=version_cnt)
filepath = self.format_checkpoint_name(epoch, ckpt_name_metrics, ver=version_cnt)
# this epoch called before
version_cnt += 1

Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/trainer/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class TrainerLoggingMixin(ABC):
default_root_dir: str
slurm_job_id: int
num_gpus: int
logged_metrics: ...

def configure_logger(self, logger):
if logger is True:
Expand Down Expand Up @@ -75,6 +76,8 @@ def log_metrics(self, metrics, grad_norm_dic, step=None):
self.logger.agg_and_log_metrics(scalar_metrics, step=step)
self.logger.save()

# track the logged metrics
self.logged_metrics = scalar_metrics
self.dev_debugger.track_logged_metrics_history(scalar_metrics)

def add_progress_bar_metrics(self, metrics):
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ def __init__(
self.batch_idx = 0
self.progress_bar_metrics = {}
self.callback_metrics = {}
self.logged_metrics = {}
self.num_training_batches = 0
self.num_val_batches = []
self.num_test_batches = []
Expand Down
56 changes: 56 additions & 0 deletions tests/callbacks/test_model_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import re
import pickle
import platform
from pathlib import Path
Expand Down Expand Up @@ -128,3 +129,58 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
model_last = EvalModelTemplate.load_from_checkpoint(path_last)
for w0, w1 in zip(model_last_epoch.parameters(), model_last.parameters()):
assert w0.eq(w1).all()


def test_ckpt_metric_names(tmpdir):
model = EvalModelTemplate()

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
gradient_clip_val=1.0,
overfit_batches=0.20,
progress_bar_refresh_rate=0,
limit_train_batches=0.01,
limit_val_batches=0.01,
checkpoint_callback=ModelCheckpoint(filepath=tmpdir + '/{val_loss:.2f}')
)

trainer.fit(model)

# make sure the checkpoint we saved has the metric in the name
ckpts = os.listdir(tmpdir)
ckpts = [x for x in ckpts if 'val_loss' in x]
assert len(ckpts) == 1
val = re.sub('[^0-9.]', '', ckpts[0])
assert len(val) > 3


def test_ckpt_metric_names_results(tmpdir):
model = EvalModelTemplate()
model.training_step = model.training_step_result_obj
model.training_step_end = None
model.training_epoch_end = None

model.validation_step = model.validation_step_result_obj
model.validation_step_end = None
model.validation_epoch_end = None

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
gradient_clip_val=1.0,
overfit_batches=0.20,
progress_bar_refresh_rate=0,
limit_train_batches=0.01,
limit_val_batches=0.01,
checkpoint_callback=ModelCheckpoint(filepath=tmpdir + '/{val_loss:.2f}')
)

trainer.fit(model)

# make sure the checkpoint we saved has the metric in the name
ckpts = os.listdir(tmpdir)
ckpts = [x for x in ckpts if 'val_loss' in x]
assert len(ckpts) == 1
val = re.sub('[^0-9.]', '', ckpts[0])
assert len(val) > 3
8 changes: 4 additions & 4 deletions tests/trainer/test_eval_loop_dict_return.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def test_validation_step_dict_return(tmpdir):
assert k in eval_results[1]

# ensure all the keys ended up as candidates for callbacks
assert len(trainer.callback_metrics) == 8
assert len(trainer.callback_metrics) == 7

# make sure correct steps were called
assert model.validation_step_called
Expand Down Expand Up @@ -211,7 +211,7 @@ def test_val_step_step_end(tmpdir):
assert k in eval_results[1]

# ensure all the keys ended up as candidates for callbacks
assert len(trainer.callback_metrics) == 9
assert len(trainer.callback_metrics) == 8

# make sure correct steps were called
assert model.validation_step_called
Expand Down Expand Up @@ -254,7 +254,7 @@ def test_no_val_step_end(tmpdir):
assert k in eval_results

# ensure all the keys ended up as candidates for callbacks
assert len(trainer.callback_metrics) == 9
assert len(trainer.callback_metrics) == 8

# make sure correct steps were called
assert model.validation_step_called
Expand Down Expand Up @@ -297,7 +297,7 @@ def test_full_val_loop(tmpdir):
assert k in eval_results

# ensure all the keys ended up as candidates for callbacks
assert len(trainer.callback_metrics) == 10
assert len(trainer.callback_metrics) == 9

# make sure correct steps were called
assert model.validation_step_called
Expand Down
4 changes: 2 additions & 2 deletions tests/trainer/test_trainer_steps_scalar_return.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def test_full_training_loop_scalar(tmpdir):
assert model.training_epoch_end_called

# assert epoch end metrics were added
assert 'epoch' in trainer.callback_metrics and len(trainer.callback_metrics) == 1
assert len(trainer.callback_metrics) == 0
assert len(trainer.progress_bar_metrics) == 0

# make sure training outputs what is expected
Expand Down Expand Up @@ -151,7 +151,7 @@ def test_train_step_epoch_end_scalar(tmpdir):
assert model.training_epoch_end_called

# assert epoch end metrics were added
assert 'epoch' in trainer.callback_metrics and len(trainer.callback_metrics) == 1
assert len(trainer.callback_metrics) == 0
assert len(trainer.progress_bar_metrics) == 0

# make sure training outputs what is expected
Expand Down

0 comments on commit 3453bba

Please sign in to comment.