Skip to content

Commit

Permalink
runtime-warn if no schedulers are configured
Browse files Browse the repository at this point in the history
  • Loading branch information
ivannz committed May 25, 2020
1 parent 2c57d5d commit e0498bb
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 6 deletions.
15 changes: 9 additions & 6 deletions pytorch_lightning/callbacks/lr_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities.exceptions import MisconfigurationException

from pytorch_lightning.utilities import rank_zero_warn


class LearningRateLogger(Callback):
r"""
Expand Down Expand Up @@ -45,16 +47,17 @@ def on_train_start(self, trainer, pl_module):
schedulers in the case of multiple of the same type or in
the case of multiple parameter groups
"""
if trainer.lr_schedulers == []:
raise MisconfigurationException(
'Cannot use LearningRateLogger callback with models that have no'
' learning rate schedulers. Please see documentation for'
' `configure_optimizers` method.')

if not trainer.logger:
raise MisconfigurationException(
'Cannot use LearningRateLogger callback with Trainer that has no logger.')

if not trainer.lr_schedulers:
rank_zero_warn(
'You are using LearningRateLogger callback with models that'
' have no learning rate schedulers. Please see documentation'
' for `configure_optimizers` method.', RuntimeWarning
)

# Find names for schedulers
names = self._find_names(trainer.lr_schedulers)

Expand Down
20 changes: 20 additions & 0 deletions tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import tests.base.utils as tutils
from pytorch_lightning import Callback
from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.utilities.exceptions import MisconfigurationException

from pytorch_lightning.callbacks import EarlyStopping, LearningRateLogger, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from tests.base import EvalModelTemplate
Expand Down Expand Up @@ -308,6 +310,24 @@ def test_lr_logger_single_lr(tmpdir):
'Names of learning rates not set correctly'


def test_lr_logger_no_lr(tmpdir):
tutils.reset_seed()

model = EvalModelTemplate()

lr_logger = LearningRateLogger()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=5,
val_percent_check=0.1,
train_percent_check=0.5,
callbacks=[lr_logger]
)

with pytest.warns(RuntimeWarning):
results = trainer.fit(model)


def test_lr_logger_multi_lrs(tmpdir):
""" Test that learning rates are extracted and logged for multi lr schedulers """
tutils.reset_seed()
Expand Down

0 comments on commit e0498bb

Please sign in to comment.