From 50881c0b31dadc8cda93701cccbd2fe4cc77aa31 Mon Sep 17 00:00:00 2001 From: Vadim Bereznyuk Date: Thu, 23 Jan 2020 19:12:51 +0300 Subject: [PATCH] Check early stopping metric in the beginning of the training (#542) * Early stopping fix * Update trainer.py * Don't force validation sanity check * fix tests * update * Added early_stopping check_metrics * Updated docs * Update docs * Do not call early stopping when validation is disabled Co-authored-by: William Falcon --- pytorch_lightning/callbacks/pt_callbacks.py | 45 +++++++++++++------- pytorch_lightning/trainer/callback_config.py | 10 +++++ pytorch_lightning/trainer/trainer.py | 20 +++++++-- pytorch_lightning/trainer/training_loop.py | 6 ++- tests/test_cpu_models.py | 4 +- tests/test_trainer.py | 2 +- 6 files changed, 65 insertions(+), 22 deletions(-) diff --git a/pytorch_lightning/callbacks/pt_callbacks.py b/pytorch_lightning/callbacks/pt_callbacks.py index 4c7d877a85bd6..20d035679fa7e 100644 --- a/pytorch_lightning/callbacks/pt_callbacks.py +++ b/pytorch_lightning/callbacks/pt_callbacks.py @@ -71,21 +71,23 @@ class EarlyStopping(Callback): Stop training when a monitored quantity has stopped improving. Args: - monitor (str): quantity to be monitored. + monitor (str): quantity to be monitored. Default: ``'val_loss'``. min_delta (float): minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute - change of less than min_delta, will count as no - improvement. + change of less than `min_delta`, will count as no + improvement. Default: ``0``. patience (int): number of epochs with no improvement - after which training will be stopped. - verbose (bool): verbosity mode. + after which training will be stopped. Default: ``0``. + verbose (bool): verbosity mode. Default: ``0``. mode (str): one of {auto, min, max}. In `min` mode, training will stop when the quantity monitored has stopped decreasing; in `max` mode it will stop when the quantity monitored has stopped increasing; in `auto` mode, the direction is automatically inferred - from the name of the monitored quantity. + from the name of the monitored quantity. Default: ``'auto'``. + strict (bool): whether to crash the training if `monitor` is + not found in the metrics. Default: ``True``. Example:: @@ -97,18 +99,20 @@ class EarlyStopping(Callback): """ def __init__(self, monitor='val_loss', - min_delta=0.0, patience=0, verbose=0, mode='auto'): + min_delta=0.0, patience=0, verbose=0, mode='auto', strict=True): super(EarlyStopping, self).__init__() self.monitor = monitor self.patience = patience self.verbose = verbose + self.strict = strict self.min_delta = min_delta self.wait = 0 self.stopped_epoch = 0 if mode not in ['auto', 'min', 'max']: - logging.info(f'EarlyStopping mode {mode} is unknown, fallback to auto mode.') + if self.verbose > 0: + logging.info(f'EarlyStopping mode {mode} is unknown, fallback to auto mode.') mode = 'auto' if mode == 'min': @@ -128,6 +132,22 @@ def __init__(self, monitor='val_loss', self.on_train_begin() + def check_metrics(self, logs): + monitor_val = logs.get(self.monitor) + error_msg = (f'Early stopping conditioned on metric `{self.monitor}`' + f' which is not available. Available metrics are:' + f' `{"`, `".join(list(logs.keys()))}`') + + if monitor_val is None: + if self.strict: + raise RuntimeError(error_msg) + elif self.verbose > 0: + warnings.warn(error_msg, RuntimeWarning) + + return False + + return True + def on_train_begin(self, logs=None): # Allow instances to be re-used self.wait = 0 @@ -135,16 +155,11 @@ def on_train_begin(self, logs=None): self.best = np.Inf if self.monitor_op == np.less else -np.Inf def on_epoch_end(self, epoch, logs=None): - current = logs.get(self.monitor) stop_training = False - if current is None: - warnings.warn( - f'Early stopping conditioned on metric `{self.monitor}`' - f' which is not available. Available metrics are: {",".join(list(logs.keys()))}', - RuntimeWarning) - stop_training = True + if not self.check_metrics(logs): return stop_training + current = logs.get(self.monitor) if self.monitor_op(current - self.min_delta, self.best): self.best = current self.wait = 0 diff --git a/pytorch_lightning/trainer/callback_config.py b/pytorch_lightning/trainer/callback_config.py index d2d042e51f6d0..c3334ab542c58 100644 --- a/pytorch_lightning/trainer/callback_config.py +++ b/pytorch_lightning/trainer/callback_config.py @@ -55,10 +55,20 @@ def configure_early_stopping(self, early_stop_callback, logger): self.early_stop_callback = EarlyStopping( monitor='val_loss', patience=3, + strict=True, verbose=True, mode='min' ) self.enable_early_stop = True + elif early_stop_callback is None: + self.early_stop_callback = EarlyStopping( + monitor='val_loss', + patience=3, + strict=False, + verbose=False, + mode='min' + ) + self.enable_early_stop = True elif not early_stop_callback: self.early_stop_callback = None self.enable_early_stop = False diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 820b2d8384858..c2add79416021 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -52,7 +52,7 @@ def __init__( self, logger=True, checkpoint_callback=True, - early_stop_callback=True, + early_stop_callback=None, default_save_path=None, gradient_clip_val=0, gradient_clip=None, # backward compatible, todo: remove in v0.8.0 @@ -121,7 +121,13 @@ def __init__( ) trainer = Trainer(checkpoint_callback=checkpoint_callback) - early_stop_callback (:class:`.EarlyStopping`): Callback for early stopping + early_stop_callback (:class:`.EarlyStopping`): Callback for early stopping. If + set to ``True``, then the default callback monitoring ``'val_loss'`` is created. + Will raise an error if ``'val_loss'`` is not found. + If set to ``False``, then early stopping will be disabled. + If set to ``None``, then the default callback monitoring ``'val_loss'`` is created. + If ``'val_loss'`` is not found will work as if early stopping is disabled. + Default: ``None``. Example:: from pytorch_lightning.callbacks import EarlyStopping @@ -129,7 +135,8 @@ def __init__( early_stop_callback = EarlyStopping( monitor='val_loss', patience=3, - verbose=True, + strict=False, + verbose=False, mode='min' ) @@ -809,12 +816,17 @@ def run_pretrain_routine(self, model): # dummy validation progress bar self.val_progress_bar = tqdm.tqdm(disable=True) - self.evaluate(model, self.get_val_dataloaders(), self.num_sanity_val_steps, self.testing) + eval_results = self.evaluate(model, self.get_val_dataloaders(), + self.num_sanity_val_steps, False) + _, _, _, callback_metrics, _ = self.process_output(eval_results) # close progress bars self.main_progress_bar.close() self.val_progress_bar.close() + if self.enable_early_stop: + self.early_stop_callback.check_metrics(callback_metrics) + # init progress bar pbar = tqdm.tqdm(leave=True, position=2 * self.process_position, disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch', diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index d5f16c9462bf1..484983d7b62be 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -346,7 +346,8 @@ def train(self): # early stopping met_min_epochs = epoch >= self.min_epochs - 1 - if self.enable_early_stop and (met_min_epochs or self.fast_dev_run): + if (self.enable_early_stop and not self.disable_validation and + (met_min_epochs or self.fast_dev_run)): should_stop = self.early_stop_callback.on_epoch_end(epoch=epoch, logs=self.callback_metrics) # stop training @@ -401,6 +402,9 @@ def run_training_epoch(self): if self.fast_dev_run or should_check_val: self.run_evaluation(test=self.testing) + if self.enable_early_stop: + self.early_stop_callback.check_metrics(self.callback_metrics) + # when logs should be saved should_save_log = (batch_idx + 1) % self.log_save_interval == 0 or early_stop_epoch if should_save_log or self.fast_dev_run: diff --git a/tests/test_cpu_models.py b/tests/test_cpu_models.py index 475e6ec5f2be5..fa643c64ac5fc 100644 --- a/tests/test_cpu_models.py +++ b/tests/test_cpu_models.py @@ -140,7 +140,8 @@ class CurrentTestModel(LightningTestMixin, LightningTestModelBase): val_percent_check=0.2, test_percent_check=0.2, checkpoint_callback=checkpoint, - logger=logger + logger=logger, + early_stop_callback=False ) # fit model @@ -318,6 +319,7 @@ def train_dataloader(self): truncated_bptt_steps=truncated_bptt_steps, val_percent_check=0, weights_summary=None, + early_stop_callback=False ) hparams = tutils.get_hparams() diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 2301104531cec..ddc74b40f8ca4 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -392,7 +392,7 @@ class CurrentTestModel( default_save_path=tmpdir, max_epochs=1, val_percent_check=0.1, - train_percent_check=0.2, + train_percent_check=0.2 ) # fit model