Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check early stopping metric in the beginning of the training #542

Merged
5 changes: 3 additions & 2 deletions pytorch_lightning/callbacks/pt_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,10 @@ def on_epoch_end(self, epoch, logs=None):
if current is None:
warnings.warn(
f'Early stopping conditioned on metric `{self.monitor}`'
f' which is not available. Available metrics are: {",".join(list(logs.keys()))}',
f' which is not available, so early stopping will not work.'
f' Available metrics are: {",".join(list(logs.keys()))}',
RuntimeWarning)
stop_training = True

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then you should return True

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not exactly. Return True was before and it caused the interruption of the training if the required metric was not found. And now it just gives a warning and training just proceeds as though without early stopping. The point is that the callback should not stop the training if it can't find the metrics.

Actually, in the current implementation this branch is not reachable because we check for the availability of the metric in the trainer initialization. But my idea was that if we decide to set early_stopping to True by default, then it can be used to give a warning but not to stop the training.

You can also look at #524 for better understanding.

return stop_training

if self.monitor_op(current - self.min_delta, self.best):
Expand Down
14 changes: 13 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ def __init__(self,
# creates a default one if none passed in
self.early_stop_callback = None
self.configure_early_stopping(early_stop_callback, logger)
if self.enable_early_stop:
self.nb_sanity_val_steps = max(1, self.nb_sanity_val_steps)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe max(1, nb_sanity_val_steps) since earlier you have

if self.fast_dev_run:
    self.nb_sanity_val_steps = 1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But exactly by that reason it should be max(1, self.nb_sanity_val_steps) :)

We just take the previously defined final self.nb_sanity_val_steps and set it to 1 if it is less than 1.

If we made as you have suggested then self.nb_sanity_val_steps will be equal to the user defined value in fast dev run mode, but it should be 1.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not do this. People need to have the option of turning sanity_val_check off

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But how then we will check that early stopping will work correctly? (Note that we force this check only if early stopping is turned on.)

Copy link
Contributor

@williamFalcon williamFalcon Nov 30, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand what you're saying, but restricting EVERYONE to force sanity check will certainly block some esoteric research or production cases, so we can't do this.

But I think this is on the user at this point. If they turned off sanity check then it's on them at that point and are willingly exposing themselves to these kinds of issues... but for people who keep it on, then we use what you suggest.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


# configure checkpoint callback
self.checkpoint_callback = checkpoint_callback
Expand Down Expand Up @@ -444,6 +446,7 @@ def run_pretrain_routine(self, model):
# run tiny validation (if validation defined)
# to make sure program won't crash during val
ref_model.on_sanity_check_start()
callback_metrics = {}
if self.get_val_dataloaders() is not None and self.nb_sanity_val_steps > 0:
# init progress bars for validation sanity check
pbar = tqdm.tqdm(desc='Validation sanity check', total=self.nb_sanity_val_steps,
Expand All @@ -453,12 +456,21 @@ def run_pretrain_routine(self, model):
# dummy validation progress bar
self.val_progress_bar = tqdm.tqdm(disable=True)

self.evaluate(model, self.get_val_dataloaders(), self.nb_sanity_val_steps, self.testing)
eval_results = self.evaluate(model, self.get_val_dataloaders(),
self.nb_sanity_val_steps, False)
_, _, _, callback_metrics, _ = self.process_output(eval_results)

# close progress bars
self.main_progress_bar.close()
self.val_progress_bar.close()

if (self.enable_early_stop and
callback_metrics.get(self.early_stop_callback.monitor) is None):
raise RuntimeError(f"Early stopping was configured to monitor "
f"{self.early_stop_callback.monitor} but it is not available "
f"after validation_end. Available metrics are: "
f"{','.join(list(callback_metrics.keys()))}")

# init progress bar
pbar = tqdm.tqdm(leave=True, position=2 * self.process_position,
disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch',
Expand Down