Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check early stopping metric in the beginning of the training #542

Merged
45 changes: 30 additions & 15 deletions pytorch_lightning/callbacks/pt_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,21 +71,23 @@ class EarlyStopping(Callback):
Stop training when a monitored quantity has stopped improving.

Args:
monitor (str): quantity to be monitored.
monitor (str): quantity to be monitored. Default: ``'val_loss'``.
min_delta (float): minimum change in the monitored quantity
to qualify as an improvement, i.e. an absolute
change of less than min_delta, will count as no
improvement.
change of less than `min_delta`, will count as no
improvement. Default: ``0``.
patience (int): number of epochs with no improvement
after which training will be stopped.
verbose (bool): verbosity mode.
after which training will be stopped. Default: ``0``.
verbose (bool): verbosity mode. Default: ``0``.
mode (str): one of {auto, min, max}. In `min` mode,
training will stop when the quantity
monitored has stopped decreasing; in `max`
mode it will stop when the quantity
monitored has stopped increasing; in `auto`
mode, the direction is automatically inferred
from the name of the monitored quantity.
from the name of the monitored quantity. Default: ``'auto'``.
strict (bool): whether to crash the training if `monitor` is
not found in the metrics. Default: ``True``.

Example::

Expand All @@ -97,18 +99,20 @@ class EarlyStopping(Callback):
"""

def __init__(self, monitor='val_loss',
min_delta=0.0, patience=0, verbose=0, mode='auto'):
min_delta=0.0, patience=0, verbose=0, mode='auto', strict=True):
super(EarlyStopping, self).__init__()

self.monitor = monitor
self.patience = patience
self.verbose = verbose
self.strict = strict
self.min_delta = min_delta
self.wait = 0
self.stopped_epoch = 0

if mode not in ['auto', 'min', 'max']:
logging.info(f'EarlyStopping mode {mode} is unknown, fallback to auto mode.')
if self.verbose > 0:
logging.info(f'EarlyStopping mode {mode} is unknown, fallback to auto mode.')
mode = 'auto'

if mode == 'min':
Expand All @@ -128,23 +132,34 @@ def __init__(self, monitor='val_loss',

self.on_train_begin()

def check_metrics(self, logs):
monitor_val = logs.get(self.monitor)
error_msg = (f'Early stopping conditioned on metric `{self.monitor}`'
f' which is not available. Available metrics are:'
f' `{"`, `".join(list(logs.keys()))}`')

if monitor_val is None:
if self.strict:
raise RuntimeError(error_msg)
elif self.verbose > 0:
warnings.warn(error_msg, RuntimeWarning)

return False

return True

def on_train_begin(self, logs=None):
# Allow instances to be re-used
self.wait = 0
self.stopped_epoch = 0
self.best = np.Inf if self.monitor_op == np.less else -np.Inf

def on_epoch_end(self, epoch, logs=None):
current = logs.get(self.monitor)
stop_training = False
if current is None:
warnings.warn(
f'Early stopping conditioned on metric `{self.monitor}`'
f' which is not available. Available metrics are: {",".join(list(logs.keys()))}',
RuntimeWarning)
stop_training = True
if not self.check_metrics(logs):
return stop_training

current = logs.get(self.monitor)
if self.monitor_op(current - self.min_delta, self.best):
self.best = current
self.wait = 0
Expand Down
10 changes: 10 additions & 0 deletions pytorch_lightning/trainer/callback_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,20 @@ def configure_early_stopping(self, early_stop_callback, logger):
self.early_stop_callback = EarlyStopping(
monitor='val_loss',
patience=3,
strict=True,
verbose=True,
mode='min'
)
self.enable_early_stop = True
elif early_stop_callback is None:
self.early_stop_callback = EarlyStopping(
monitor='val_loss',
patience=3,
strict=False,
verbose=False,
mode='min'
)
self.enable_early_stop = True
elif not early_stop_callback:
self.early_stop_callback = None
self.enable_early_stop = False
Expand Down
20 changes: 16 additions & 4 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(
self,
logger=True,
checkpoint_callback=True,
early_stop_callback=True,
early_stop_callback=None,
default_save_path=None,
gradient_clip_val=0,
gradient_clip=None, # backward compatible, todo: remove in v0.8.0
Expand Down Expand Up @@ -121,15 +121,22 @@ def __init__(
)

trainer = Trainer(checkpoint_callback=checkpoint_callback)
early_stop_callback (:class:`.EarlyStopping`): Callback for early stopping
early_stop_callback (:class:`.EarlyStopping`): Callback for early stopping. If
set to ``True``, then the default callback monitoring ``'val_loss'`` is created.
Will raise an error if ``'val_loss'`` is not found.
If set to ``False``, then early stopping will be disabled.
If set to ``None``, then the default callback monitoring ``'val_loss'`` is created.
If ``'val_loss'`` is not found will work as if early stopping is disabled.
Default: ``None``.
Example::
from pytorch_lightning.callbacks import EarlyStopping

# default used by the Trainer
early_stop_callback = EarlyStopping(
monitor='val_loss',
patience=3,
verbose=True,
strict=False,
verbose=False,
mode='min'
)

Expand Down Expand Up @@ -809,12 +816,17 @@ def run_pretrain_routine(self, model):
# dummy validation progress bar
self.val_progress_bar = tqdm.tqdm(disable=True)

self.evaluate(model, self.get_val_dataloaders(), self.num_sanity_val_steps, self.testing)
eval_results = self.evaluate(model, self.get_val_dataloaders(),
self.num_sanity_val_steps, False)
_, _, _, callback_metrics, _ = self.process_output(eval_results)

# close progress bars
self.main_progress_bar.close()
self.val_progress_bar.close()

if self.enable_early_stop:
self.early_stop_callback.check_metrics(callback_metrics)

# init progress bar
pbar = tqdm.tqdm(leave=True, position=2 * self.process_position,
disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch',
Expand Down
6 changes: 5 additions & 1 deletion pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,8 @@ def train(self):

# early stopping
met_min_epochs = epoch >= self.min_epochs - 1
if self.enable_early_stop and (met_min_epochs or self.fast_dev_run):
if (self.enable_early_stop and not self.disable_validation and
(met_min_epochs or self.fast_dev_run)):
should_stop = self.early_stop_callback.on_epoch_end(epoch=epoch,
logs=self.callback_metrics)
# stop training
Expand Down Expand Up @@ -401,6 +402,9 @@ def run_training_epoch(self):
if self.fast_dev_run or should_check_val:
self.run_evaluation(test=self.testing)

if self.enable_early_stop:
self.early_stop_callback.check_metrics(self.callback_metrics)

# when logs should be saved
should_save_log = (batch_idx + 1) % self.log_save_interval == 0 or early_stop_epoch
if should_save_log or self.fast_dev_run:
Expand Down
4 changes: 3 additions & 1 deletion tests/test_cpu_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ class CurrentTestModel(LightningTestMixin, LightningTestModelBase):
val_percent_check=0.2,
test_percent_check=0.2,
checkpoint_callback=checkpoint,
logger=logger
logger=logger,
early_stop_callback=False
)

# fit model
Expand Down Expand Up @@ -318,6 +319,7 @@ def train_dataloader(self):
truncated_bptt_steps=truncated_bptt_steps,
val_percent_check=0,
weights_summary=None,
early_stop_callback=False
)

hparams = tutils.get_hparams()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ class CurrentTestModel(
default_save_path=tmpdir,
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.2,
train_percent_check=0.2
)

# fit model
Expand Down