Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

early stopping checks on_validation_end #1458

Merged
merged 13 commits into from
May 25, 2020
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added option `save_last` to save the model at the end of every epoch in `ModelCheckpoint` [(#1908)](https://github.com/PyTorchLightning/pytorch-lightning/pull/1908)

- Early stopping checks `on_validation_end` ([#1458](https://github.com/PyTorchLightning/pytorch-lightning/pull/1458))

### Changed

- Allow user to select individual TPU core to train on ([#1729](https://github.com/PyTorchLightning/pytorch-lightning/pull/1729))
Expand Down
61 changes: 44 additions & 17 deletions docs/source/early_stopping.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,36 +19,63 @@ By default early stopping will be enabled if `'val_loss'`
is found in :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end`'s
return dict. Otherwise training will proceed with early stopping disabled.

Enable Early Stopping using Callbacks on epoch end
--------------------------------------------------
There are two ways to enable early stopping using callbacks on epoch end.
Enable Early Stopping using the EarlyStopping Callback
------------------------------------------------------
The
:class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping`
callback can be used to monitor a validation metric and stop the training when no improvement is observed.

There are two ways to enable the EarlyStopping callback:

- Set early_stop_callback to True. Will look for 'val_loss' in validation_epoch_end() return dict.
If it is not found an error is raised.
- Set `early_stop_callback=True`.
The callback will look for 'val_loss' in the dict returned by
:meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end`
and raise an error if `val_loss` is not present.

.. testcode::

trainer = Trainer(early_stop_callback=True)

- Or configure your own callback
- Create the callback object and pass it to the trainer.
This allows for further customization.

.. testcode::

early_stop_callback = EarlyStopping(
monitor='val_loss',
min_delta=0.00,
patience=3,
verbose=False,
mode='min'
monitor='val_accuracy',
min_delta=0.00,
patience=3,
verbose=False,
mode='max'
)
trainer = Trainer(early_stop_callback=early_stop_callback)

In any case, the callback will fall back to the training metrics (returned in
:meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`,
:meth:`~pytorch_lightning.core.lightning.LightningModule.training_step_end`)
looking for a key to monitor if validation is disabled or
:meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end`
is not defined.
In case you need early stopping in a different part of training, subclass EarlyStopping
and change where it is called:

.. testcode::

class MyEarlyStopping(EarlyStopping):

def on_validation_end(self, trainer, pl_module):
# override this to disable early stopping at the end of val loop
pass

def on_train_end(self, trainer, pl_module):
# instead, do it at the end of training loop
self._run_early_stopping_check(trainer, pl_module)

.. note::
The EarlyStopping callback runs at the end of every validation epoch,
which, under the default configuration, happen after every training epoch.
However, the frequency of validation can be modified by setting various parameters
on the :class:`~pytorch_lightning.trainer.trainer.Trainer`,
for example :paramref:`~pytorch_lightning.trainer.trainer.Trainer.check_val_every_n_epoch`
and :paramref:`~pytorch_lightning.trainer.trainer.Trainer.val_check_interval`.
It must be noted that the `patience` parameter counts the number of
validation epochs with no improvement, and not the number of training epochs.
Therefore, with parameters `check_val_every_n_epoch=10` and `patience=3`, the trainer
will perform at least 40 training epochs before being stopped.

.. seealso::
- :class:`~pytorch_lightning.trainer.trainer.Trainer`
Expand Down
11 changes: 7 additions & 4 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Early Stopping
==============

Stop training when a monitored quantity has stopped improving.
Monitor a validation metric and stop training when it stops improving.

"""

Expand All @@ -25,7 +25,7 @@ class EarlyStopping(Callback):
to qualify as an improvement, i.e. an absolute
change of less than `min_delta`, will count as no
improvement. Default: ``0``.
patience: number of epochs with no improvement
patience: number of validation epochs with no improvement
after which training will be stopped. Default: ``0``.
verbose: verbosity mode. Default: ``False``.
mode: one of {auto, min, max}. In `min` mode,
Expand All @@ -36,7 +36,7 @@ class EarlyStopping(Callback):
mode, the direction is automatically inferred
from the name of the monitored quantity. Default: ``'auto'``.
strict: whether to crash the training if `monitor` is
not found in the metrics. Default: ``True``.
not found in the validation metrics. Default: ``True``.

Example::

Expand Down Expand Up @@ -109,7 +109,10 @@ def on_train_start(self, trainer, pl_module):
self.stopped_epoch = 0
self.best = torch_inf if self.monitor_op == torch.lt else -torch_inf

def on_epoch_end(self, trainer, pl_module):
def on_validation_end(self, trainer, pl_module):
self._run_early_stopping_check(trainer, pl_module)

def _run_early_stopping_check(self, trainer, pl_module):
logs = trainer.callback_metrics
stop_training = False
if not self._validate_condition_metric(logs):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def train(self):
# TODO wrap this logic into the callback
if self.enable_early_stop:
if (met_min_epochs and met_min_steps) or self.fast_dev_run:
should_stop = self.early_stop_callback.on_epoch_end(self, self.get_model())
should_stop = self.early_stop_callback.on_validation_end(self, self.get_model())
# stop training
stop = should_stop and met_min_epochs
if stop:
Expand Down