From 4c635ad250fd197d76dc0bead0beb36949cc14f9 Mon Sep 17 00:00:00 2001 From: Vadim Bereznyuk Date: Mon, 25 Nov 2019 11:38:32 +0300 Subject: [PATCH 1/9] Early stopping fix --- pytorch_lightning/callbacks/pt_callbacks.py | 5 +++-- pytorch_lightning/trainer/trainer.py | 16 ++++++++++++++-- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/callbacks/pt_callbacks.py b/pytorch_lightning/callbacks/pt_callbacks.py index 81b5be70f682b..555e5c67b944e 100644 --- a/pytorch_lightning/callbacks/pt_callbacks.py +++ b/pytorch_lightning/callbacks/pt_callbacks.py @@ -124,9 +124,10 @@ def on_epoch_end(self, epoch, logs=None): if current is None: warnings.warn( f'Early stopping conditioned on metric `{self.monitor}`' - f' which is not available. Available metrics are: {",".join(list(logs.keys()))}', + f' which is not available, so early stopping will not work.' + f' Available metrics are: {",".join(list(logs.keys()))}', RuntimeWarning) - stop_training = True + return stop_training if self.monitor_op(current - self.min_delta, self.best): diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 155a6a7fdb609..087cb627fcfff 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -53,7 +53,7 @@ class Trainer(TrainerIOMixin, def __init__(self, logger=True, checkpoint_callback=True, - early_stop_callback=True, + early_stop_callback=False, default_save_path=None, gradient_clip_val=0, gradient_clip=None, # backward compatible @@ -185,6 +185,8 @@ def __init__(self, # creates a default one if none passed in self.early_stop_callback = None self.configure_early_stopping(early_stop_callback, logger) + if self.enable_early_stop: + self.nb_sanity_val_steps = max(1, self.nb_sanity_val_steps) # configure checkpoint callback self.checkpoint_callback = checkpoint_callback @@ -444,6 +446,7 @@ def run_pretrain_routine(self, model): # run tiny validation (if validation defined) # to make sure program won't crash during val ref_model.on_sanity_check_start() + callback_metrics = {} if self.get_val_dataloaders() is not None and self.nb_sanity_val_steps > 0: # init progress bars for validation sanity check pbar = tqdm.tqdm(desc='Validation sanity check', total=self.nb_sanity_val_steps, @@ -453,12 +456,21 @@ def run_pretrain_routine(self, model): # dummy validation progress bar self.val_progress_bar = tqdm.tqdm(disable=True) - self.evaluate(model, self.get_val_dataloaders(), self.nb_sanity_val_steps, self.testing) + eval_results = self.evaluate(model, self.get_val_dataloaders(), + self.nb_sanity_val_steps, False) + _, _, _, callback_metrics, _ = self.process_output(eval_results) # close progress bars self.main_progress_bar.close() self.val_progress_bar.close() + if (self.enable_early_stop and + callback_metrics.get(self.early_stop_callback.monitor) is None): + raise RuntimeError(f"Early stopping was configured to monitor " + f"{self.early_stop_callback.monitor} but it is not available " + f"after validation_end. Available metrics are: " + f"{','.join(list(callback_metrics.keys()))}") + # init progress bar pbar = tqdm.tqdm(leave=True, position=2 * self.process_position, disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch', From c06bad6a9a5e0dd9bcfe2e4f512e5a80019b562b Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sat, 30 Nov 2019 14:54:21 -0500 Subject: [PATCH 2/9] Update trainer.py --- pytorch_lightning/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 087cb627fcfff..1ab4ef3d12beb 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -53,7 +53,7 @@ class Trainer(TrainerIOMixin, def __init__(self, logger=True, checkpoint_callback=True, - early_stop_callback=False, + early_stop_callback=True, default_save_path=None, gradient_clip_val=0, gradient_clip=None, # backward compatible From acda2b70ecc43a64a2c38b939ef95340b2757758 Mon Sep 17 00:00:00 2001 From: Vadim Bereznyuk Date: Sun, 1 Dec 2019 14:57:55 +0300 Subject: [PATCH 3/9] Don't force validation sanity check --- pytorch_lightning/trainer/train_loop_mixin.py | 7 +++++++ pytorch_lightning/trainer/trainer.py | 15 ++++++--------- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/trainer/train_loop_mixin.py b/pytorch_lightning/trainer/train_loop_mixin.py index 6130fd4931342..800497f7c3460 100644 --- a/pytorch_lightning/trainer/train_loop_mixin.py +++ b/pytorch_lightning/trainer/train_loop_mixin.py @@ -113,6 +113,13 @@ def run_training_epoch(self): if self.fast_dev_run or should_check_val: self.run_evaluation(test=self.testing) + if (self.enable_early_stop and + self.callback_metrics.get(self.early_stop_callback.monitor) is None): + raise RuntimeError(f"Early stopping was configured to monitor " + f"{self.early_stop_callback.monitor} but it is not available" + f" after validation_end. Available metrics are: " + f"{','.join(list(self.callback_metrics.keys()))}") + # when logs should be saved should_save_log = (batch_nb + 1) % self.log_save_interval == 0 or early_stop_epoch if should_save_log or self.fast_dev_run: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 1ab4ef3d12beb..041dea5586d4f 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -185,8 +185,6 @@ def __init__(self, # creates a default one if none passed in self.early_stop_callback = None self.configure_early_stopping(early_stop_callback, logger) - if self.enable_early_stop: - self.nb_sanity_val_steps = max(1, self.nb_sanity_val_steps) # configure checkpoint callback self.checkpoint_callback = checkpoint_callback @@ -446,7 +444,6 @@ def run_pretrain_routine(self, model): # run tiny validation (if validation defined) # to make sure program won't crash during val ref_model.on_sanity_check_start() - callback_metrics = {} if self.get_val_dataloaders() is not None and self.nb_sanity_val_steps > 0: # init progress bars for validation sanity check pbar = tqdm.tqdm(desc='Validation sanity check', total=self.nb_sanity_val_steps, @@ -464,12 +461,12 @@ def run_pretrain_routine(self, model): self.main_progress_bar.close() self.val_progress_bar.close() - if (self.enable_early_stop and - callback_metrics.get(self.early_stop_callback.monitor) is None): - raise RuntimeError(f"Early stopping was configured to monitor " - f"{self.early_stop_callback.monitor} but it is not available " - f"after validation_end. Available metrics are: " - f"{','.join(list(callback_metrics.keys()))}") + if (self.enable_early_stop and + callback_metrics.get(self.early_stop_callback.monitor) is None): + raise RuntimeError(f"Early stopping was configured to monitor " + f"{self.early_stop_callback.monitor} but it is not available " + f"after validation_end. Available metrics are: " + f"{','.join(list(callback_metrics.keys()))}") # init progress bar pbar = tqdm.tqdm(leave=True, position=2 * self.process_position, From 10e76f655838cece0591c3d5476c7e00f36e5ec3 Mon Sep 17 00:00:00 2001 From: Vadim Bereznyuk Date: Mon, 2 Dec 2019 13:41:04 +0300 Subject: [PATCH 4/9] fix tests --- tests/test_cpu_models.py | 4 +++- tests/test_trainer.py | 7 +++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/test_cpu_models.py b/tests/test_cpu_models.py index 1b7a2573d70f1..0fc4b71d2b5ff 100644 --- a/tests/test_cpu_models.py +++ b/tests/test_cpu_models.py @@ -158,7 +158,8 @@ class CurrentTestModel(LightningTestMixin, LightningTestModelBase): val_percent_check=0.2, test_percent_check=0.2, checkpoint_callback=checkpoint, - logger=logger + logger=logger, + early_stop_callback=False ) # fit model @@ -352,6 +353,7 @@ def train_dataloader(self): truncated_bptt_steps=truncated_bptt_steps, val_percent_check=0, weights_summary=None, + early_stop_callback=False ) hparams = testing_utils.get_hparams() diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 8f2d830c02ee0..c6926633ac3de 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -41,7 +41,8 @@ class CurrentTestModel(LightningTestModelBase): trainer_options = dict( max_nb_epochs=1, logger=logger, - checkpoint_callback=ModelCheckpoint(save_dir) + checkpoint_callback=ModelCheckpoint(save_dir), + early_stop_callback=False ) # fit model @@ -87,7 +88,8 @@ class CurrentTestModel(LightningValidationStepMixin, LightningTestModelBase): trainer_options = dict( max_nb_epochs=1, logger=logger, - checkpoint_callback=ModelCheckpoint(save_dir) + checkpoint_callback=ModelCheckpoint(save_dir), + early_stop_callback=False ) # fit model @@ -414,6 +416,7 @@ class CurrentTestModel( max_nb_epochs=1, val_percent_check=0.1, train_percent_check=0.1, + early_stop_callback=False ) # fit model From 25e1bdaff98cb583fc494af7d59920fbd391ed48 Mon Sep 17 00:00:00 2001 From: Vadim Bereznyuk Date: Tue, 14 Jan 2020 11:27:18 +0300 Subject: [PATCH 5/9] update --- pytorch_lightning/callbacks/pt_callbacks.py | 2 +- pytorch_lightning/trainer/trainer.py | 8 ++++---- pytorch_lightning/trainer/training_loop.py | 7 +++++++ 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/callbacks/pt_callbacks.py b/pytorch_lightning/callbacks/pt_callbacks.py index b0ba603688019..b41ea32d5eb58 100644 --- a/pytorch_lightning/callbacks/pt_callbacks.py +++ b/pytorch_lightning/callbacks/pt_callbacks.py @@ -128,7 +128,7 @@ def on_epoch_end(self, epoch, logs=None): warnings.warn( f'Early stopping conditioned on metric `{self.monitor}`' f' which is not available, so early stopping will not work.' - f' Available metrics are: {",".join(list(logs.keys()))}', + f' Available metrics are: {", ".join(list(logs.keys()))}', RuntimeWarning) return stop_training diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 9b7cf003789a5..ffdd76eb1c730 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -516,10 +516,10 @@ def run_pretrain_routine(self, model): if (self.enable_early_stop and callback_metrics.get(self.early_stop_callback.monitor) is None): - raise RuntimeError(f"Early stopping was configured to monitor " - f"{self.early_stop_callback.monitor} but it is not available " - f"after validation_end. Available metrics are: " - f"{','.join(list(callback_metrics.keys()))}") + raise RuntimeError(f"Early stopping was configured to monitor" + f" {self.early_stop_callback.monitor} but it is not available" + f" after validation_end. Available metrics are:" + f" {', '.join(list(callback_metrics.keys()))}") # init progress bar pbar = tqdm.tqdm(leave=True, position=2 * self.process_position, diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 657fe49829fbc..cd6c13949c0ee 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -400,6 +400,13 @@ def run_training_epoch(self): if self.fast_dev_run or should_check_val: self.run_evaluation(test=self.testing) + if (self.enable_early_stop and + self.callback_metrics.get(self.early_stop_callback.monitor) is None): + raise RuntimeError(f"Early stopping was configured to monitor" + f" {self.early_stop_callback.monitor} but it is not" + f" available after validation_end. Available metrics are:" + f" {', '.join(list(self.callback_metrics.keys()))}") + # when logs should be saved should_save_log = (batch_idx + 1) % self.log_save_interval == 0 or early_stop_epoch if should_save_log or self.fast_dev_run: From 748640bc93433508284db483e6453eaa36c5800a Mon Sep 17 00:00:00 2001 From: Vadim Bereznyuk Date: Tue, 14 Jan 2020 12:40:35 +0300 Subject: [PATCH 6/9] Added early_stopping check_metrics --- pytorch_lightning/callbacks/pt_callbacks.py | 32 ++++++++++++++------ pytorch_lightning/trainer/callback_config.py | 10 ++++++ pytorch_lightning/trainer/trainer.py | 10 ++---- pytorch_lightning/trainer/training_loop.py | 8 ++--- 4 files changed, 37 insertions(+), 23 deletions(-) diff --git a/pytorch_lightning/callbacks/pt_callbacks.py b/pytorch_lightning/callbacks/pt_callbacks.py index b41ea32d5eb58..6842e72204bcc 100644 --- a/pytorch_lightning/callbacks/pt_callbacks.py +++ b/pytorch_lightning/callbacks/pt_callbacks.py @@ -84,18 +84,20 @@ class EarlyStopping(Callback): """ def __init__(self, monitor='val_loss', - min_delta=0.0, patience=0, verbose=0, mode='auto'): + min_delta=0.0, patience=0, verbose=0, mode='auto', strict=True): super(EarlyStopping, self).__init__() self.monitor = monitor self.patience = patience self.verbose = verbose + self.strict = strict self.min_delta = min_delta self.wait = 0 self.stopped_epoch = 0 if mode not in ['auto', 'min', 'max']: - logging.info(f'EarlyStopping mode {mode} is unknown, fallback to auto mode.') + if self.verbose > 0: + logging.info(f'EarlyStopping mode {mode} is unknown, fallback to auto mode.') mode = 'auto' if mode == 'min': @@ -115,6 +117,22 @@ def __init__(self, monitor='val_loss', self.on_train_begin() + def check_metrics(self, logs): + monitor_val = logs.get(self.monitor) + error_msg = (f'Early stopping conditioned on metric `{self.monitor}`' + f' which is not available. Available metrics are:' + f' `{"`, `".join(list(logs.keys()))}`') + + if monitor_val is None: + if self.strict: + raise RuntimeError(error_msg) + elif self.verbose > 0: + warnings.warn(error_msg, RuntimeWarning) + + return False + + return True + def on_train_begin(self, logs=None): # Allow instances to be re-used self.wait = 0 @@ -122,17 +140,11 @@ def on_train_begin(self, logs=None): self.best = np.Inf if self.monitor_op == np.less else -np.Inf def on_epoch_end(self, epoch, logs=None): - current = logs.get(self.monitor) stop_training = False - if current is None: - warnings.warn( - f'Early stopping conditioned on metric `{self.monitor}`' - f' which is not available, so early stopping will not work.' - f' Available metrics are: {", ".join(list(logs.keys()))}', - RuntimeWarning) - + if not self.check_metrics(logs): return stop_training + current = logs.get(self.monitor) if self.monitor_op(current - self.min_delta, self.best): self.best = current self.wait = 0 diff --git a/pytorch_lightning/trainer/callback_config.py b/pytorch_lightning/trainer/callback_config.py index 02db09c31cab1..05852f53a9c0c 100644 --- a/pytorch_lightning/trainer/callback_config.py +++ b/pytorch_lightning/trainer/callback_config.py @@ -55,10 +55,20 @@ def configure_early_stopping(self, early_stop_callback, logger): self.early_stop_callback = EarlyStopping( monitor='val_loss', patience=3, + strict=True, verbose=True, mode='min' ) self.enable_early_stop = True + elif early_stop_callback is None: + self.early_stop_callback = EarlyStopping( + monitor='val_loss', + patience=3, + strict=False, + verbose=False, + mode='min' + ) + self.enable_early_stop = True elif not early_stop_callback: self.early_stop_callback = None self.enable_early_stop = False diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index ffdd76eb1c730..3ec15bb57a2db 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -55,7 +55,7 @@ def __init__( self, logger=True, checkpoint_callback=True, - early_stop_callback=True, + early_stop_callback=None, default_save_path=None, gradient_clip_val=0, gradient_clip=None, # backward compatible, todo: remove in v0.8.0 @@ -514,12 +514,8 @@ def run_pretrain_routine(self, model): self.main_progress_bar.close() self.val_progress_bar.close() - if (self.enable_early_stop and - callback_metrics.get(self.early_stop_callback.monitor) is None): - raise RuntimeError(f"Early stopping was configured to monitor" - f" {self.early_stop_callback.monitor} but it is not available" - f" after validation_end. Available metrics are:" - f" {', '.join(list(callback_metrics.keys()))}") + if self.enable_early_stop: + self.early_stop_callback.check_metrics(callback_metrics) # init progress bar pbar = tqdm.tqdm(leave=True, position=2 * self.process_position, diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index cd6c13949c0ee..8c862488fcec9 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -400,12 +400,8 @@ def run_training_epoch(self): if self.fast_dev_run or should_check_val: self.run_evaluation(test=self.testing) - if (self.enable_early_stop and - self.callback_metrics.get(self.early_stop_callback.monitor) is None): - raise RuntimeError(f"Early stopping was configured to monitor" - f" {self.early_stop_callback.monitor} but it is not" - f" available after validation_end. Available metrics are:" - f" {', '.join(list(self.callback_metrics.keys()))}") + if self.enable_early_stop: + self.early_stop_callback.check_metrics(self.callback_metrics) # when logs should be saved should_save_log = (batch_idx + 1) % self.log_save_interval == 0 or early_stop_epoch From 0f83b1aff3b5c5dae302d73ded85fb73851e36d7 Mon Sep 17 00:00:00 2001 From: Vadim Bereznyuk Date: Wed, 22 Jan 2020 11:06:50 +0300 Subject: [PATCH 7/9] Updated docs --- pytorch_lightning/callbacks/pt_callbacks.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/callbacks/pt_callbacks.py b/pytorch_lightning/callbacks/pt_callbacks.py index 1fd05c779c63a..20d035679fa7e 100644 --- a/pytorch_lightning/callbacks/pt_callbacks.py +++ b/pytorch_lightning/callbacks/pt_callbacks.py @@ -71,21 +71,23 @@ class EarlyStopping(Callback): Stop training when a monitored quantity has stopped improving. Args: - monitor (str): quantity to be monitored. + monitor (str): quantity to be monitored. Default: ``'val_loss'``. min_delta (float): minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute - change of less than min_delta, will count as no - improvement. + change of less than `min_delta`, will count as no + improvement. Default: ``0``. patience (int): number of epochs with no improvement - after which training will be stopped. - verbose (bool): verbosity mode. + after which training will be stopped. Default: ``0``. + verbose (bool): verbosity mode. Default: ``0``. mode (str): one of {auto, min, max}. In `min` mode, training will stop when the quantity monitored has stopped decreasing; in `max` mode it will stop when the quantity monitored has stopped increasing; in `auto` mode, the direction is automatically inferred - from the name of the monitored quantity. + from the name of the monitored quantity. Default: ``'auto'``. + strict (bool): whether to crash the training if `monitor` is + not found in the metrics. Default: ``True``. Example:: From 4e35bee58bdd60653ca10035502f5b15a05b0b83 Mon Sep 17 00:00:00 2001 From: Vadim Bereznyuk Date: Wed, 22 Jan 2020 11:19:55 +0300 Subject: [PATCH 8/9] Update docs --- pytorch_lightning/trainer/trainer.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 6889aecda526e..c2add79416021 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -121,7 +121,13 @@ def __init__( ) trainer = Trainer(checkpoint_callback=checkpoint_callback) - early_stop_callback (:class:`.EarlyStopping`): Callback for early stopping + early_stop_callback (:class:`.EarlyStopping`): Callback for early stopping. If + set to ``True``, then the default callback monitoring ``'val_loss'`` is created. + Will raise an error if ``'val_loss'`` is not found. + If set to ``False``, then early stopping will be disabled. + If set to ``None``, then the default callback monitoring ``'val_loss'`` is created. + If ``'val_loss'`` is not found will work as if early stopping is disabled. + Default: ``None``. Example:: from pytorch_lightning.callbacks import EarlyStopping @@ -129,7 +135,8 @@ def __init__( early_stop_callback = EarlyStopping( monitor='val_loss', patience=3, - verbose=True, + strict=False, + verbose=False, mode='min' ) From 320e4c180e252ba77cf0ffe77e364cc4addf7e9b Mon Sep 17 00:00:00 2001 From: Vadim Bereznyuk Date: Wed, 22 Jan 2020 12:10:01 +0300 Subject: [PATCH 9/9] Do not call early stopping when validation is disabled --- pytorch_lightning/trainer/training_loop.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index dd685e444fcb3..484983d7b62be 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -346,7 +346,8 @@ def train(self): # early stopping met_min_epochs = epoch >= self.min_epochs - 1 - if self.enable_early_stop and (met_min_epochs or self.fast_dev_run): + if (self.enable_early_stop and not self.disable_validation and + (met_min_epochs or self.fast_dev_run)): should_stop = self.early_stop_callback.on_epoch_end(epoch=epoch, logs=self.callback_metrics) # stop training