Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check early stopping metric in the beginning of the training #542

Merged
7 changes: 7 additions & 0 deletions pytorch_lightning/trainer/train_loop_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,13 @@ def run_training_epoch(self):
if self.fast_dev_run or should_check_val:
self.run_evaluation(test=self.testing)

if (self.enable_early_stop and
self.callback_metrics.get(self.early_stop_callback.monitor) is None):
raise RuntimeError(f"Early stopping was configured to monitor "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

f"{self.early_stop_callback.monitor} but it is not available"
f" after validation_end. Available metrics are: "
f"{','.join(list(self.callback_metrics.keys()))}")

# when logs should be saved
should_save_log = (batch_nb + 1) % self.log_save_interval == 0 or early_stop_epoch
if should_save_log or self.fast_dev_run:
Expand Down
15 changes: 6 additions & 9 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,6 @@ def __init__(self,
# creates a default one if none passed in
self.early_stop_callback = None
self.configure_early_stopping(early_stop_callback, logger)
if self.enable_early_stop:
self.nb_sanity_val_steps = max(1, self.nb_sanity_val_steps)

# configure checkpoint callback
self.checkpoint_callback = checkpoint_callback
Expand Down Expand Up @@ -446,7 +444,6 @@ def run_pretrain_routine(self, model):
# run tiny validation (if validation defined)
# to make sure program won't crash during val
ref_model.on_sanity_check_start()
callback_metrics = {}
if self.get_val_dataloaders() is not None and self.nb_sanity_val_steps > 0:
# init progress bars for validation sanity check
pbar = tqdm.tqdm(desc='Validation sanity check', total=self.nb_sanity_val_steps,
Expand All @@ -464,12 +461,12 @@ def run_pretrain_routine(self, model):
self.main_progress_bar.close()
self.val_progress_bar.close()

if (self.enable_early_stop and
callback_metrics.get(self.early_stop_callback.monitor) is None):
raise RuntimeError(f"Early stopping was configured to monitor "
f"{self.early_stop_callback.monitor} but it is not available "
f"after validation_end. Available metrics are: "
f"{','.join(list(callback_metrics.keys()))}")
if (self.enable_early_stop and
callback_metrics.get(self.early_stop_callback.monitor) is None):
raise RuntimeError(f"Early stopping was configured to monitor "
f"{self.early_stop_callback.monitor} but it is not available "
f"after validation_end. Available metrics are: "
f"{','.join(list(callback_metrics.keys()))}")

# init progress bar
pbar = tqdm.tqdm(leave=True, position=2 * self.process_position,
Expand Down