Skip to content

Commit

Permalink
Also update progress_bar in training_epoch_end (#1724)
Browse files Browse the repository at this point in the history
* update prog. bar metrics on train epoch end

* changelog

* wip test

* more thorough testing

* comments

* update docs

* move test

Co-authored-by: Jirka <jirka.borovec@seznam.cz>
  • Loading branch information
awaelchli and Borda authored May 9, 2020
1 parent 3a64260 commit 25bbd05
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 1 deletion.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added type hints in `Trainer.fit()` and `Trainer.test()` to reflect that also a list of dataloaders can be passed in ([#1723](https://github.com/PyTorchLightning/pytorch-lightning/pull/1723)).

- The progress bar metrics now also get updated in `training_epoch_end` ([#1724](https://github.com/PyTorchLightning/pytorch-lightning/pull/1724)).

### Changed

- Reduction when `batch_size < num_gpus` ([#1609](https://github.com/PyTorchLightning/pytorch-lightning/pull/1609))
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ def training_epoch_end(
May contain the following optional keys:
- log (metrics to be added to the logger; only tensors)
- progress_bar (dict for progress bar display)
- any metric used in a callback (e.g. early stopping).
Note:
Expand All @@ -280,7 +281,8 @@ def training_epoch_end(self, outputs):
# log training accuracy at the end of an epoch
results = {
'log': {'train_acc': train_acc_mean.item()}
'log': {'train_acc': train_acc_mean.item()},
'progress_bar': {'train_acc': train_acc_mean},
}
return results
Expand All @@ -303,6 +305,7 @@ def training_epoch_end(self, outputs):
# log training accuracy at the end of an epoch
results = {
'log': {'train_acc': train_acc_mean.item(), 'step': self.current_epoch}
'progress_bar': {'train_acc': train_acc_mean},
}
return results
"""
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,7 @@ def run_training_epoch(self):
callback_epoch_metrics = _processed_outputs[3]
self.log_metrics(log_epoch_metrics, {})
self.callback_metrics.update(callback_epoch_metrics)
self.add_progress_bar_metrics(_processed_outputs[1])

# when no val loop is present or fast-dev-run still need to call checkpoints
if not self.is_overridden('validation_step') and not (self.fast_dev_run or should_check_val):
Expand Down
46 changes: 46 additions & 0 deletions tests/models/test_module_hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import torch

from pytorch_lightning import Trainer
from tests.base import EvalModelTemplate

import tests.base.utils as tutils


def test_training_epoch_end_metrics_collection(tmpdir):
""" Test that progress bar metrics also get collected at the end of an epoch. """
num_epochs = 3
class CurrentModel(EvalModelTemplate):

def training_step(self, *args, **kwargs):
output = super().training_step(*args, **kwargs)
output['progress_bar'].update({'step_metric': torch.tensor(-1)})
output['progress_bar'].update({'shared_metric': 100})
return output

def training_epoch_end(self, outputs):
epoch = self.current_epoch
# both scalar tensors and Python numbers are accepted
return {
'progress_bar': {
f'epoch_metric_{epoch}': torch.tensor(epoch), # add a new metric key every epoch
'shared_metric': 111,
}
}

model = CurrentModel(tutils.get_default_hparams())
trainer = Trainer(
max_epochs=num_epochs,
default_root_dir=tmpdir,
overfit_pct=0.1,
)
result = trainer.fit(model)
assert result == 1
metrics = trainer.progress_bar_dict

# metrics added in training step should be unchanged by epoch end method
assert metrics['step_metric'] == -1
# a metric shared in both methods gets overwritten by epoch_end
assert metrics['shared_metric'] == 111
# metrics are kept after each epoch
for i in range(num_epochs):
assert metrics[f'epoch_metric_{i}'] == i

0 comments on commit 25bbd05

Please sign in to comment.