diff --git a/CHANGELOG.md b/CHANGELOG.md index b3660119302b5..2fb884ec1c0b9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * from `pytorch_lightning.metrics.functional.reduction` removed `reduce`, `class_reduce` +- Removed deprecated `ModelCheckpoint` arguments `prefix`, `mode="auto"` ([#6162](https://github.com/PyTorchLightning/pytorch-lightning/pull/6162)) + + +- Removed `mode='auto'` from `EarlyStopping` ([#6167](https://github.com/PyTorchLightning/pytorch-lightning/pull/6167)) + + ### Fixed - Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011)) @@ -48,6 +54,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed epoch level schedulers not being called when `val_check_interval < 1.0` ([#6075](https://github.com/PyTorchLightning/pytorch-lightning/pull/6075)) +- Fixed multiple early stopping callbacks ([#6197](https://github.com/PyTorchLightning/pytorch-lightning/pull/6197)) + + ## [1.2.1] - 2021-02-23 ### Fixed diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 384ce9699f60e..eddf80c30b1f5 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -181,15 +181,12 @@ def _run_early_stopping_check(self, trainer, pl_module): if self.monitor_op(current - self.min_delta, self.best_score): self.best_score = current self.wait_count = 0 - should_stop = False else: self.wait_count += 1 - should_stop = self.wait_count >= self.patience - if bool(should_stop): + if self.wait_count >= self.patience: self.stopped_epoch = trainer.current_epoch trainer.should_stop = True # stop every ddp process if any world process decides to stop - should_stop = trainer.training_type_plugin.reduce_early_stopping_decision(should_stop) - trainer.should_stop = should_stop + trainer.should_stop = trainer.training_type_plugin.reduce_early_stopping_decision(trainer.should_stop) diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index 9d326f045544e..1f724eafa3207 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -13,6 +13,7 @@ # limitations under the License. import os import pickle +import sys from unittest import mock import cloudpickle @@ -344,3 +345,57 @@ def validation_epoch_end(self, outputs): def test_early_stopping_mode_options(): with pytest.raises(MisconfigurationException, match="`mode` can be auto, .* got unknown_option"): EarlyStopping(mode="unknown_option") + + +class EarlyStoppingModel(BoringModel): + + def __init__(self, expected_end_epoch): + super().__init__() + self.expected_end_epoch = expected_end_epoch + + def validation_epoch_end(self, outputs): + losses = [8, 4, 2, 3, 4, 5, 8, 10] + val_loss = losses[self.current_epoch] + self.log('abc', torch.tensor(val_loss)) + self.log('cba', torch.tensor(0)) + + def on_train_end(self) -> None: + assert self.trainer.current_epoch == self.expected_end_epoch, 'Early Stopping Failed' + + +@pytest.mark.parametrize( + "callbacks, expected_stop_epoch, accelerator, num_processes", + [ + ([EarlyStopping(monitor='abc'), EarlyStopping(monitor='cba', patience=3)], 3, None, 1), + ([EarlyStopping(monitor='cba', patience=3), + EarlyStopping(monitor='abc')], 3, None, 1), + pytest.param([EarlyStopping(monitor='abc'), + EarlyStopping(monitor='cba', patience=3)], + 3, + 'ddp_cpu', + 2, + marks=pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")), + pytest.param([EarlyStopping(monitor='cba', patience=3), + EarlyStopping(monitor='abc')], + 3, + 'ddp_cpu', + 2, + marks=pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")), + ], +) +def test_multiple_early_stopping_callbacks(callbacks, expected_stop_epoch, accelerator, num_processes, tmpdir): + """ + Ensure when using multiple early stopping callbacks we stop if any signals we should stop. + """ + + model = EarlyStoppingModel(expected_stop_epoch) + + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=callbacks, + overfit_batches=0.20, + max_epochs=20, + accelerator=accelerator, + num_processes=num_processes + ) + trainer.fit(model)