Skip to content

Commit

Permalink
early stopping checks on_validation_end (#1458)
Browse files Browse the repository at this point in the history
* Fixes #490

`EarlyStopping` should check the metric of interest `on_validation_end` rather than `on_epoch_end`. 
In a normal scenario, this does not cause a problem, but in combination with `check_val_every_n_epoch>1` in the `Trainer` it results in a warning or in a `RuntimeError` depending on `strict`.

* Highlighted that ES callback runs on val epochs in docstring

* Updated EarlyStopping in rst doc

* Update early_stopping.py

* Update early_stopping.rst

* Update early_stopping.rst

* Update early_stopping.rst

* Update early_stopping.rst

* Apply suggestions from code review

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* Update docs/source/early_stopping.rst

* fix doctest indentation warning

* Train loop calls early_stop.on_validation_end

* chlog

Co-authored-by: William Falcon <waf2107@columbia.edu>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Jirka <jirka@pytorchlightning.ai>
  • Loading branch information
5 people committed May 25, 2020
1 parent 8ca8336 commit 65b4352
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 22 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added option `save_last` to save the model at the end of every epoch in `ModelCheckpoint` [(#1908)](https://github.com/PyTorchLightning/pytorch-lightning/pull/1908)

- Early stopping checks `on_validation_end` ([#1458](https://github.com/PyTorchLightning/pytorch-lightning/pull/1458))

### Changed

- Allow user to select individual TPU core to train on ([#1729](https://github.com/PyTorchLightning/pytorch-lightning/pull/1729))
Expand Down
61 changes: 44 additions & 17 deletions docs/source/early_stopping.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,36 +19,63 @@ By default early stopping will be enabled if `'val_loss'`
is found in :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end`'s
return dict. Otherwise training will proceed with early stopping disabled.

Enable Early Stopping using Callbacks on epoch end
--------------------------------------------------
There are two ways to enable early stopping using callbacks on epoch end.
Enable Early Stopping using the EarlyStopping Callback
------------------------------------------------------
The
:class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping`
callback can be used to monitor a validation metric and stop the training when no improvement is observed.

There are two ways to enable the EarlyStopping callback:

- Set early_stop_callback to True. Will look for 'val_loss' in validation_epoch_end() return dict.
If it is not found an error is raised.
- Set `early_stop_callback=True`.
The callback will look for 'val_loss' in the dict returned by
:meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end`
and raise an error if `val_loss` is not present.

.. testcode::

trainer = Trainer(early_stop_callback=True)

- Or configure your own callback
- Create the callback object and pass it to the trainer.
This allows for further customization.

.. testcode::

early_stop_callback = EarlyStopping(
monitor='val_loss',
min_delta=0.00,
patience=3,
verbose=False,
mode='min'
monitor='val_accuracy',
min_delta=0.00,
patience=3,
verbose=False,
mode='max'
)
trainer = Trainer(early_stop_callback=early_stop_callback)

In any case, the callback will fall back to the training metrics (returned in
:meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`,
:meth:`~pytorch_lightning.core.lightning.LightningModule.training_step_end`)
looking for a key to monitor if validation is disabled or
:meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end`
is not defined.
In case you need early stopping in a different part of training, subclass EarlyStopping
and change where it is called:

.. testcode::

class MyEarlyStopping(EarlyStopping):

def on_validation_end(self, trainer, pl_module):
# override this to disable early stopping at the end of val loop
pass

def on_train_end(self, trainer, pl_module):
# instead, do it at the end of training loop
self._run_early_stopping_check(trainer, pl_module)

.. note::
The EarlyStopping callback runs at the end of every validation epoch,
which, under the default configuration, happen after every training epoch.
However, the frequency of validation can be modified by setting various parameters
on the :class:`~pytorch_lightning.trainer.trainer.Trainer`,
for example :paramref:`~pytorch_lightning.trainer.trainer.Trainer.check_val_every_n_epoch`
and :paramref:`~pytorch_lightning.trainer.trainer.Trainer.val_check_interval`.
It must be noted that the `patience` parameter counts the number of
validation epochs with no improvement, and not the number of training epochs.
Therefore, with parameters `check_val_every_n_epoch=10` and `patience=3`, the trainer
will perform at least 40 training epochs before being stopped.

.. seealso::
- :class:`~pytorch_lightning.trainer.trainer.Trainer`
Expand Down
11 changes: 7 additions & 4 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Early Stopping
==============
Stop training when a monitored quantity has stopped improving.
Monitor a validation metric and stop training when it stops improving.
"""

Expand All @@ -25,7 +25,7 @@ class EarlyStopping(Callback):
to qualify as an improvement, i.e. an absolute
change of less than `min_delta`, will count as no
improvement. Default: ``0``.
patience: number of epochs with no improvement
patience: number of validation epochs with no improvement
after which training will be stopped. Default: ``0``.
verbose: verbosity mode. Default: ``False``.
mode: one of {auto, min, max}. In `min` mode,
Expand All @@ -36,7 +36,7 @@ class EarlyStopping(Callback):
mode, the direction is automatically inferred
from the name of the monitored quantity. Default: ``'auto'``.
strict: whether to crash the training if `monitor` is
not found in the metrics. Default: ``True``.
not found in the validation metrics. Default: ``True``.
Example::
Expand Down Expand Up @@ -109,7 +109,10 @@ def on_train_start(self, trainer, pl_module):
self.stopped_epoch = 0
self.best = torch_inf if self.monitor_op == torch.lt else -torch_inf

def on_epoch_end(self, trainer, pl_module):
def on_validation_end(self, trainer, pl_module):
self._run_early_stopping_check(trainer, pl_module)

def _run_early_stopping_check(self, trainer, pl_module):
logs = trainer.callback_metrics
stop_training = False
if not self._validate_condition_metric(logs):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def train(self):
# TODO wrap this logic into the callback
if self.enable_early_stop:
if (met_min_epochs and met_min_steps) or self.fast_dev_run:
should_stop = self.early_stop_callback.on_epoch_end(self, self.get_model())
should_stop = self.early_stop_callback.on_validation_end(self, self.get_model())
# stop training
stop = should_stop and met_min_epochs
if stop:
Expand Down

0 comments on commit 65b4352

Please sign in to comment.