Skip to content

Commit

Permalink
temporarily fixes early stopping bug (#2119)
Browse files Browse the repository at this point in the history
* fixes early stopping bug

* fixes early stopping bug

* fixes early stopping bug

* fixes early stopping bug

* fixe docs

* fixe docs

* fixe docs

* fixe docs

* fixe docs

* fixe docs

* fixe docs

* fixe docs

* fixe docs

* fixe docs

* fixe docs

* fixe docs

* fixe docs

* fixe docs

* fixe docs

* fixe docs

* fixe docs

* fixe docs

* fixe docs

* fixe docs

* added test
  • Loading branch information
williamFalcon authored Jun 8, 2020
1 parent 73a6a95 commit 479ab49
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def on_train_start(self, trainer, pl_module):
self.best = torch_inf if self.monitor_op == torch.lt else -torch_inf

def on_validation_end(self, trainer, pl_module):
self._run_early_stopping_check(trainer, pl_module)
return self._run_early_stopping_check(trainer, pl_module)

def _run_early_stopping_check(self, trainer, pl_module):
logs = trainer.callback_metrics
Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,10 @@ def train(self):
met_min_steps = self.global_step >= self.min_steps if self.min_steps else True

# TODO wrap this logic into the callback
# DO NOT DELETE
# early stopping as a (new Callback) class doesn't yet work because we have to know these
# trainer flags including the current epoch stuff
# all of this needs to go into the early stopping to clean up better
if self.enable_early_stop:
if (met_min_epochs and met_min_steps) or self.fast_dev_run:
should_stop = self.early_stop_callback.on_validation_end(self, self.get_model())
Expand Down
24 changes: 24 additions & 0 deletions tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,30 @@
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from tests.base import EvalModelTemplate
import torch


def test_early_stopping_functionality(tmpdir):

class CurrentModel(EvalModelTemplate):
def validation_epoch_end(self, outputs):
losses = [8, 4, 2, 3, 4, 5, 8, 10]
val_loss = losses[self.current_epoch]
val_loss = torch.tensor(val_loss)
return {'val_loss': val_loss}

model = CurrentModel()

trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=True,
overfit_pct=0.20,
max_epochs=20,
)
result = trainer.fit(model)
print(trainer.current_epoch)

assert trainer.current_epoch == 5, 'early_stopping failed'


def test_trainer_callback_system(tmpdir):
Expand Down

0 comments on commit 479ab49

Please sign in to comment.