Skip to content

Commit

Permalink
Fix early stopping with training step's return dict (#3347)
Browse files Browse the repository at this point in the history
* Fixes the test for early stopping without val step.

The expression which checked, if early stopping was triggered, had an off-by-one error and hence was true even if early stopping was not triggered.

Furthermore set patience to 0 and max epochs to 10, to ensure loss has enough time to flatten.

* Fixes early stopping without val step.

The issue has been, that only `early_stop_on` key was checked and not an arbitrary monitor key.

* Fixes branch, which checks whether early stopping is done during validation.

Before only `val_early_stop_on` was checked. Since arbitrary keys can be used, the set of possible validation keys cannot be exhaustive. Hence this disables "early stopping on_train_epoch_end" via an instance attribute if early stopping was executed in on_validation_epoch_end.
Furthermore adds a test, which ensures arbitrary keys work.

* Improve check whether eval results are used.

Only disable early checking with train results if eval results are actually used. Before they were always disabled in ``on_validation_epoch_end``.
Rename and document instance variable, to make it more clear.

* Remove wrong documentation on behaviour of early stopping with train result' dict.

* Apply suggestions from code review

* Apply suggestions from code review

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
Lucas-Steinmann and Borda authored Sep 18, 2020
1 parent c46de8a commit 197acd5
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 16 deletions.
10 changes: 0 additions & 10 deletions docs/source/results.rst
Original file line number Diff line number Diff line change
Expand Up @@ -101,16 +101,6 @@ checkpointing or early stopping:
return TrainResult(some_metric, checkpoint_on=metric_a, early_stop_on=metric_b)
In the manual loop, checkpoint and early stop is based only on the loss returned. With the `TrainResult` you
can change it every batch if you want, or even monitor different metrics for each purpose.

.. code-block:: python
# early stop + checkpoint can only use the `loss` when done manually via dictionaries
def training_step(...):
return loss
def training_step(...):
return {'loss': loss}
logging
^^^^^^^
Expand Down
19 changes: 16 additions & 3 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience:
self.stopped_epoch = 0
self.mode = mode
self.warned_result_obj = False
# Indicates, if eval results are used as basis for early stopping
# It is set to False initially and overwritten, if eval results have been validated
self.based_on_eval_results = False

if mode not in self.mode_dict:
if self.verbose > 0:
Expand Down Expand Up @@ -157,19 +160,25 @@ def on_validation_epoch_end(self, trainer, pl_module):
if val_es_key in trainer.logger_connector.callback_metrics:
self.strict = False

self._validate_condition_metric(trainer.logger_connector.callback_metrics)
if self._validate_condition_metric(trainer.logger_connector.callback_metrics):
# turn off early stopping in on_train_epoch_end
self.based_on_eval_results = True

def on_train_epoch_end(self, trainer, pl_module):
# disable early stopping in train loop when there's a val loop
if self.monitor == 'val_early_stop_on':
if self.based_on_eval_results:
return

# early stopping can also work in the train loop when there is no val loop and when using structured results
# early stopping can also work in the train loop when there is no val loop
should_check_early_stop = False
# early_stop_on takes precedence over monitor key
train_es_key = 'early_stop_on'
if trainer.logger_connector.callback_metrics.get(train_es_key, None) is not None:
self.monitor = train_es_key
should_check_early_stop = True
# fallback to monitor key in result dict
if trainer.logger_connector.callback_metrics.get(self.monitor, None) is not None:
should_check_early_stop = True

if should_check_early_stop:
self._run_early_stopping_check(trainer, pl_module)
Expand All @@ -187,6 +196,10 @@ def __warn_deprecated_monitor_key(self):
rank_zero_warn(m)

def _run_early_stopping_check(self, trainer, pl_module):
"""
Checks whether the early stopping condition is met
and if so tells the trainer to stop the training.
"""
logs = trainer.logger_connector.callback_metrics

if not self._validate_condition_metric(logs):
Expand Down
27 changes: 24 additions & 3 deletions tests/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,17 +137,17 @@ def training_step(self, *args, **kwargs):
model.validation_step = None
model.val_dataloader = None

stopping = EarlyStopping(monitor='my_train_metric', min_delta=0.1)
stopping = EarlyStopping(monitor='my_train_metric', min_delta=0.1, patience=0)
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=stopping,
overfit_batches=0.20,
max_epochs=2,
max_epochs=10,
)
result = trainer.fit(model)

assert result == 1, 'training failed to complete'
assert trainer.current_epoch < trainer.max_epochs
assert trainer.current_epoch < trainer.max_epochs - 1


def test_early_stopping_functionality(tmpdir):
Expand All @@ -168,3 +168,24 @@ def validation_epoch_end(self, outputs):
)
trainer.fit(model)
assert trainer.current_epoch == 5, 'early_stopping failed'


def test_early_stopping_functionality_arbitrary_key(tmpdir):
"""Tests whether early stopping works with a custom key and dictionary results on val step."""

class CurrentModel(EvalModelTemplate):
def validation_epoch_end(self, outputs):
losses = [8, 4, 2, 3, 4, 5, 8, 10]
val_loss = losses[self.current_epoch]
return {'jiraffe': torch.tensor(val_loss)}

model = CurrentModel()

trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=EarlyStopping(monitor='jiraffe'),
overfit_batches=0.20,
max_epochs=20,
)
trainer.fit(model)
assert trainer.current_epoch >= 5, 'early_stopping failed'

0 comments on commit 197acd5

Please sign in to comment.