Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check early stopping metric in the beginning of the training #542

Merged
5 changes: 3 additions & 2 deletions pytorch_lightning/callbacks/pt_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,10 @@ def on_epoch_end(self, epoch, logs=None):
if current is None:
warnings.warn(
f'Early stopping conditioned on metric `{self.monitor}`'
f' which is not available. Available metrics are: {",".join(list(logs.keys()))}',
f' which is not available, so early stopping will not work.'
f' Available metrics are: {",".join(list(logs.keys()))}',
RuntimeWarning)
stop_training = True

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then you should return True

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not exactly. Return True was before and it caused the interruption of the training if the required metric was not found. And now it just gives a warning and training just proceeds as though without early stopping. The point is that the callback should not stop the training if it can't find the metrics.

Actually, in the current implementation this branch is not reachable because we check for the availability of the metric in the trainer initialization. But my idea was that if we decide to set early_stopping to True by default, then it can be used to give a warning but not to stop the training.

You can also look at #524 for better understanding.

return stop_training

if self.monitor_op(current - self.min_delta, self.best):
Expand Down
7 changes: 7 additions & 0 deletions pytorch_lightning/trainer/train_loop_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,13 @@ def run_training_epoch(self):
if self.fast_dev_run or should_check_val:
self.run_evaluation(test=self.testing)

if (self.enable_early_stop and
self.callback_metrics.get(self.early_stop_callback.monitor) is None):
raise RuntimeError(f"Early stopping was configured to monitor "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

f"{self.early_stop_callback.monitor} but it is not available"
f" after validation_end. Available metrics are: "
f"{','.join(list(self.callback_metrics.keys()))}")

# when logs should be saved
should_save_log = (batch_nb + 1) % self.log_save_interval == 0 or early_stop_epoch
if should_save_log or self.fast_dev_run:
Expand Down
11 changes: 10 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,12 +453,21 @@ def run_pretrain_routine(self, model):
# dummy validation progress bar
self.val_progress_bar = tqdm.tqdm(disable=True)

self.evaluate(model, self.get_val_dataloaders(), self.nb_sanity_val_steps, self.testing)
eval_results = self.evaluate(model, self.get_val_dataloaders(),
self.nb_sanity_val_steps, False)
_, _, _, callback_metrics, _ = self.process_output(eval_results)

# close progress bars
self.main_progress_bar.close()
self.val_progress_bar.close()

if (self.enable_early_stop and
callback_metrics.get(self.early_stop_callback.monitor) is None):
raise RuntimeError(f"Early stopping was configured to monitor "
f"{self.early_stop_callback.monitor} but it is not available "
f"after validation_end. Available metrics are: "
f"{','.join(list(callback_metrics.keys()))}")

# init progress bar
pbar = tqdm.tqdm(leave=True, position=2 * self.process_position,
disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch',
Expand Down