diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 7902d6fc769f65..1aa0ca573ec996 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -14,6 +14,7 @@ from pytorch_lightning import _logger as log from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.utilities import rank_zero_warn +import os torch_inf = torch.tensor(np.Inf) @@ -72,6 +73,7 @@ def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience: self.wait_count = 0 self.stopped_epoch = 0 self.mode = mode + self.warned_result_obj = False if mode not in self.mode_dict: if self.verbose > 0: @@ -154,12 +156,26 @@ def on_train_epoch_end(self, trainer, pl_module): if should_check_early_stop: self._run_early_stopping_check(trainer, pl_module) + def __warn_deprecated_monitor_key(self): + using_result_obj = os.environ.get('PL_USING_RESULT_OBJ', None) + invalid_key = self.monitor not in ['val_loss', 'early_stop_on', 'val_early_step_on'] + if using_result_obj and not self.warned_result_obj and invalid_key: + self.warned_result_obj = True + m = f""" + When using EvalResult(early_stop_on=X) or TrainResult(early_stop_on=X) the + 'monitor' key of EarlyStopping has no effect. + Remove EarlyStopping(monitor='{self.monitor}) to fix') + """ + rank_zero_warn(m) + def _run_early_stopping_check(self, trainer, pl_module): logs = trainer.callback_metrics if not self._validate_condition_metric(logs): return # short circuit if metric not present + self.__warn_deprecated_monitor_key() + current = logs.get(self.monitor) # when in dev debugging diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 0346a1e8575bdf..32e643c11d51d7 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -138,6 +138,7 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve self.best_model_score = 0 self.best_model_path = '' self.save_function = None + self.warned_result_obj = False torch_inf = torch.tensor(np.Inf) mode_dict = { @@ -297,12 +298,27 @@ def on_train_start(self, trainer, pl_module): if not gfile.exists(self.dirpath): makedirs(self.dirpath) + def __warn_deprecated_monitor_key(self): + using_result_obj = os.environ.get('PL_USING_RESULT_OBJ', None) + invalid_key = self.monitor not in ['val_loss', 'checkpoint_on'] + if using_result_obj and not self.warned_result_obj and invalid_key: + self.warned_result_obj = True + m = f""" + When using EvalResult(early_stop_on=X) or TrainResult(early_stop_on=X) the + 'monitor' key of ModelCheckpoint has no effect. + Remove ModelCheckpoint(monitor='{self.monitor}) to fix') + """ + rank_zero_warn(m) + @rank_zero_only def on_validation_end(self, trainer, pl_module): # only run on main process if trainer.global_rank != 0: return + # TODO: remove when dict results are deprecated + self.__warn_deprecated_monitor_key() + metrics = trainer.callback_metrics epoch = trainer.current_epoch diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index e47e186c817142..bbe81853d06e67 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -4,6 +4,7 @@ import torch from torch import Tensor +import os from pytorch_lightning.metrics.converters import _sync_ddp_if_available @@ -20,6 +21,9 @@ def __init__( super().__init__() + # temporary until dict results are deprecated + os.environ['PL_USING_RESULT_OBJ'] = '1' + if early_stop_on is not None: self.early_stop_on = early_stop_on if checkpoint_on is not None and checkpoint_on: diff --git a/tests/trainer/test_trainer_steps_result_return.py b/tests/trainer/test_trainer_steps_result_return.py index 7b9fc080b07656..ed3632b968c9fd 100644 --- a/tests/trainer/test_trainer_steps_result_return.py +++ b/tests/trainer/test_trainer_steps_result_return.py @@ -7,6 +7,7 @@ import torch from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.core.step_result import TrainResult from tests.base import EvalModelTemplate from tests.base.deterministic_model import DeterministicModel @@ -543,3 +544,43 @@ def test_result_map(tmpdir): assert 'x2' not in result assert 'y1' in result assert 'y2' in result + + +def test_result_monitor_warnings(tmpdir): + """ + Tests that we warn when the monitor key is changed and we use Results obj + """ + model = EvalModelTemplate() + model.test_step = None + model.training_step = model.training_step_result_obj + model.training_step_end = None + model.training_epoch_end = None + model.validation_step = model.validation_step_result_obj + model.validation_step_end = None + model.validation_epoch_end = None + model.test_dataloader = None + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=2, + early_stop_callback=True, + row_log_interval=2, + limit_train_batches=2, + weights_summary=None, + checkpoint_callback=ModelCheckpoint(monitor='not_val_loss') + ) + + with pytest.warns(UserWarning, match='key of ModelCheckpoint has no effect'): + trainer.fit(model) + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=2, + row_log_interval=2, + limit_train_batches=2, + weights_summary=None, + early_stop_callback=EarlyStopping(monitor='not_val_loss') + ) + + with pytest.warns(UserWarning, match='key of EarlyStopping has no effec'): + trainer.fit(model)