diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 3e94d8ae61a20..d2a98780aa766 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -50,6 +50,9 @@ # pass in your own to override the default callback trainer = Trainer(early_stop_callback=early_stop_callback) + # pass in min_epochs to enable the callback after min_epochs have run + trainer = Trainer(early_stop_callback=early_stop_callback, min_epochs=5) + # pass in None to disable it trainer = Trainer(early_stop_callback=None) @@ -339,7 +342,7 @@ def train(self): self.reduce_lr_on_plateau_scheduler.step(val_loss, epoch=self.current_epoch) # early stopping - met_min_epochs = epoch > self.min_epochs + met_min_epochs = epoch >= self.min_epochs - 1 if self.enable_early_stop and (met_min_epochs or self.fast_dev_run): should_stop = self.early_stop_callback.on_epoch_end(epoch=epoch, logs=self.callback_metrics)