Skip to content

Commit

Permalink
added warning when changing monitor and using results obj (Lightning-…
Browse files Browse the repository at this point in the history
…AI#3014)

* added warning when changing monitor and using results obj

* added warning when changing monitor and using results obj

* added warning when changing monitor and using results obj
  • Loading branch information
williamFalcon authored and atee committed Aug 17, 2020
1 parent d18da8e commit a03a0bf
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 0 deletions.
16 changes: 16 additions & 0 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_warn
import os

torch_inf = torch.tensor(np.Inf)

Expand Down Expand Up @@ -72,6 +73,7 @@ def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience:
self.wait_count = 0
self.stopped_epoch = 0
self.mode = mode
self.warned_result_obj = False

if mode not in self.mode_dict:
if self.verbose > 0:
Expand Down Expand Up @@ -154,12 +156,26 @@ def on_train_epoch_end(self, trainer, pl_module):
if should_check_early_stop:
self._run_early_stopping_check(trainer, pl_module)

def __warn_deprecated_monitor_key(self):
using_result_obj = os.environ.get('PL_USING_RESULT_OBJ', None)
invalid_key = self.monitor not in ['val_loss', 'early_stop_on', 'val_early_step_on']
if using_result_obj and not self.warned_result_obj and invalid_key:
self.warned_result_obj = True
m = f"""
When using EvalResult(early_stop_on=X) or TrainResult(early_stop_on=X) the
'monitor' key of EarlyStopping has no effect.
Remove EarlyStopping(monitor='{self.monitor}) to fix')
"""
rank_zero_warn(m)

def _run_early_stopping_check(self, trainer, pl_module):
logs = trainer.callback_metrics

if not self._validate_condition_metric(logs):
return # short circuit if metric not present

self.__warn_deprecated_monitor_key()

current = logs.get(self.monitor)

# when in dev debugging
Expand Down
16 changes: 16 additions & 0 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve
self.best_model_score = 0
self.best_model_path = ''
self.save_function = None
self.warned_result_obj = False

torch_inf = torch.tensor(np.Inf)
mode_dict = {
Expand Down Expand Up @@ -297,12 +298,27 @@ def on_train_start(self, trainer, pl_module):
if not gfile.exists(self.dirpath):
makedirs(self.dirpath)

def __warn_deprecated_monitor_key(self):
using_result_obj = os.environ.get('PL_USING_RESULT_OBJ', None)
invalid_key = self.monitor not in ['val_loss', 'checkpoint_on']
if using_result_obj and not self.warned_result_obj and invalid_key:
self.warned_result_obj = True
m = f"""
When using EvalResult(early_stop_on=X) or TrainResult(early_stop_on=X) the
'monitor' key of ModelCheckpoint has no effect.
Remove ModelCheckpoint(monitor='{self.monitor}) to fix')
"""
rank_zero_warn(m)

@rank_zero_only
def on_validation_end(self, trainer, pl_module):
# only run on main process
if trainer.global_rank != 0:
return

# TODO: remove when dict results are deprecated
self.__warn_deprecated_monitor_key()

metrics = trainer.callback_metrics
epoch = trainer.current_epoch

Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch
from torch import Tensor
import os

from pytorch_lightning.metrics.converters import _sync_ddp_if_available

Expand All @@ -20,6 +21,9 @@ def __init__(

super().__init__()

# temporary until dict results are deprecated
os.environ['PL_USING_RESULT_OBJ'] = '1'

if early_stop_on is not None:
self.early_stop_on = early_stop_on
if checkpoint_on is not None and checkpoint_on:
Expand Down
41 changes: 41 additions & 0 deletions tests/trainer/test_trainer_steps_result_return.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.core.step_result import TrainResult
from tests.base import EvalModelTemplate
from tests.base.deterministic_model import DeterministicModel
Expand Down Expand Up @@ -543,3 +544,43 @@ def test_result_map(tmpdir):
assert 'x2' not in result
assert 'y1' in result
assert 'y2' in result


def test_result_monitor_warnings(tmpdir):
"""
Tests that we warn when the monitor key is changed and we use Results obj
"""
model = EvalModelTemplate()
model.test_step = None
model.training_step = model.training_step_result_obj
model.training_step_end = None
model.training_epoch_end = None
model.validation_step = model.validation_step_result_obj
model.validation_step_end = None
model.validation_epoch_end = None
model.test_dataloader = None

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
early_stop_callback=True,
row_log_interval=2,
limit_train_batches=2,
weights_summary=None,
checkpoint_callback=ModelCheckpoint(monitor='not_val_loss')
)

with pytest.warns(UserWarning, match='key of ModelCheckpoint has no effect'):
trainer.fit(model)

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
row_log_interval=2,
limit_train_batches=2,
weights_summary=None,
early_stop_callback=EarlyStopping(monitor='not_val_loss')
)

with pytest.warns(UserWarning, match='key of EarlyStopping has no effec'):
trainer.fit(model)

0 comments on commit a03a0bf

Please sign in to comment.