diff --git a/pytorch_lightning/callbacks/lr_logger.py b/pytorch_lightning/callbacks/lr_logger.py index 15967bc91d8b42..43fee3477abf9b 100755 --- a/pytorch_lightning/callbacks/lr_logger.py +++ b/pytorch_lightning/callbacks/lr_logger.py @@ -10,6 +10,8 @@ from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities import rank_zero_warn + class LearningRateLogger(Callback): r""" @@ -45,16 +47,17 @@ def on_train_start(self, trainer, pl_module): schedulers in the case of multiple of the same type or in the case of multiple parameter groups """ - if trainer.lr_schedulers == []: - raise MisconfigurationException( - 'Cannot use LearningRateLogger callback with models that have no' - ' learning rate schedulers. Please see documentation for' - ' `configure_optimizers` method.') - if not trainer.logger: raise MisconfigurationException( 'Cannot use LearningRateLogger callback with Trainer that has no logger.') + if not trainer.lr_schedulers: + rank_zero_warn( + 'You are using LearningRateLogger callback with models that' + ' have no learning rate schedulers. Please see documentation' + ' for `configure_optimizers` method.', RuntimeWarning + ) + # Find names for schedulers names = self._find_names(trainer.lr_schedulers) diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index b5bc800a1e6909..5bd4f24659bf6f 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -3,6 +3,8 @@ import tests.base.utils as tutils from pytorch_lightning import Callback from pytorch_lightning import Trainer, LightningModule +from pytorch_lightning.utilities.exceptions import MisconfigurationException + from pytorch_lightning.callbacks import EarlyStopping, LearningRateLogger, ModelCheckpoint from pytorch_lightning.loggers import TensorBoardLogger from tests.base import EvalModelTemplate @@ -308,6 +310,24 @@ def test_lr_logger_single_lr(tmpdir): 'Names of learning rates not set correctly' +def test_lr_logger_no_lr(tmpdir): + tutils.reset_seed() + + model = EvalModelTemplate() + + lr_logger = LearningRateLogger() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=5, + val_percent_check=0.1, + train_percent_check=0.5, + callbacks=[lr_logger] + ) + + with pytest.warns(RuntimeWarning): + results = trainer.fit(model) + + def test_lr_logger_multi_lrs(tmpdir): """ Test that learning rates are extracted and logged for multi lr schedulers """ tutils.reset_seed()