diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 5bd4f24659bf6f..c962e6cdb0a559 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -1,14 +1,13 @@ +from pathlib import Path + import pytest import tests.base.utils as tutils from pytorch_lightning import Callback from pytorch_lightning import Trainer, LightningModule -from pytorch_lightning.utilities.exceptions import MisconfigurationException - -from pytorch_lightning.callbacks import EarlyStopping, LearningRateLogger, ModelCheckpoint +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.loggers import TensorBoardLogger from tests.base import EvalModelTemplate -from pathlib import Path def test_trainer_callback_system(tmpdir): @@ -283,97 +282,3 @@ def test_model_checkpoint_path(tmpdir, logger_version, expected): ckpt_version = Path(trainer.ckpt_path).parent.name assert ckpt_version == expected - - -def test_lr_logger_single_lr(tmpdir): - """ Test that learning rates are extracted and logged for single lr scheduler""" - tutils.reset_seed() - - model = EvalModelTemplate() - model.configure_optimizers = model.configure_optimizers__single_scheduler - - lr_logger = LearningRateLogger() - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=5, - val_percent_check=0.1, - train_percent_check=0.5, - callbacks=[lr_logger] - ) - results = trainer.fit(model) - - assert results == 1 - assert lr_logger.lrs, 'No learning rates logged' - assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \ - 'Number of learning rates logged does not match number of lr schedulers' - assert all([k in ['lr-Adam'] for k in lr_logger.lrs.keys()]), \ - 'Names of learning rates not set correctly' - - -def test_lr_logger_no_lr(tmpdir): - tutils.reset_seed() - - model = EvalModelTemplate() - - lr_logger = LearningRateLogger() - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=5, - val_percent_check=0.1, - train_percent_check=0.5, - callbacks=[lr_logger] - ) - - with pytest.warns(RuntimeWarning): - results = trainer.fit(model) - - -def test_lr_logger_multi_lrs(tmpdir): - """ Test that learning rates are extracted and logged for multi lr schedulers """ - tutils.reset_seed() - - model = EvalModelTemplate() - model.configure_optimizers = model.configure_optimizers__multiple_schedulers - - lr_logger = LearningRateLogger() - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=10, - val_percent_check=0.1, - train_percent_check=0.5, - callbacks=[lr_logger] - ) - results = trainer.fit(model) - - assert results == 1 - assert lr_logger.lrs, 'No learning rates logged' - assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \ - 'Number of learning rates logged does not match number of lr schedulers' - assert all([k in ['lr-Adam', 'lr-Adam-1'] for k in lr_logger.lrs.keys()]), \ - 'Names of learning rates not set correctly' - assert all(len(lr) == trainer.max_epochs for k, lr in lr_logger.lrs.items()), \ - 'Length of logged learning rates exceeds the number of epochs' - - -def test_lr_logger_param_groups(tmpdir): - """ Test that learning rates are extracted and logged for single lr scheduler""" - tutils.reset_seed() - - model = EvalModelTemplate() - model.configure_optimizers = model.configure_optimizers__param_groups - - lr_logger = LearningRateLogger() - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=5, - val_percent_check=0.1, - train_percent_check=0.5, - callbacks=[lr_logger] - ) - results = trainer.fit(model) - - assert lr_logger.lrs, 'No learning rates logged' - assert len(lr_logger.lrs) == 2 * len(trainer.lr_schedulers), \ - 'Number of learning rates logged does not match number of param groups' - assert all([k in ['lr-Adam/pg1', 'lr-Adam/pg2'] for k in lr_logger.lrs.keys()]), \ - 'Names of learning rates not set correctly' diff --git a/tests/callbacks/test_lr.py b/tests/callbacks/test_lr.py new file mode 100644 index 00000000000000..e8c914ef2a0848 --- /dev/null +++ b/tests/callbacks/test_lr.py @@ -0,0 +1,102 @@ +import pytest + +import tests.base.utils as tutils +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import LearningRateLogger +from tests.base import EvalModelTemplate + + +def test_lr_logger_single_lr(tmpdir): + """ Test that learning rates are extracted and logged for single lr scheduler. """ + tutils.reset_seed() + + model = EvalModelTemplate() + model.configure_optimizers = model.configure_optimizers__single_scheduler + + lr_logger = LearningRateLogger() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=5, + val_percent_check=0.1, + train_percent_check=0.5, + callbacks=[lr_logger] + ) + result = trainer.fit(model) + assert result + + assert lr_logger.lrs, 'No learning rates logged' + assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \ + 'Number of learning rates logged does not match number of lr schedulers' + assert all([k in ['lr-Adam'] for k in lr_logger.lrs.keys()]), \ + 'Names of learning rates not set correctly' + + +def test_lr_logger_no_lr(tmpdir): + tutils.reset_seed() + + model = EvalModelTemplate() + + lr_logger = LearningRateLogger() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=5, + val_percent_check=0.1, + train_percent_check=0.5, + callbacks=[lr_logger] + ) + + with pytest.warns(RuntimeWarning): + result = trainer.fit(model) + assert result + + +def test_lr_logger_multi_lrs(tmpdir): + """ Test that learning rates are extracted and logged for multi lr schedulers. """ + tutils.reset_seed() + + model = EvalModelTemplate() + model.configure_optimizers = model.configure_optimizers__multiple_schedulers + + lr_logger = LearningRateLogger() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=10, + val_percent_check=0.1, + train_percent_check=0.5, + callbacks=[lr_logger] + ) + result = trainer.fit(model) + assert result + + assert lr_logger.lrs, 'No learning rates logged' + assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \ + 'Number of learning rates logged does not match number of lr schedulers' + assert all([k in ['lr-Adam', 'lr-Adam-1'] for k in lr_logger.lrs.keys()]), \ + 'Names of learning rates not set correctly' + assert all(len(lr) == trainer.max_epochs for k, lr in lr_logger.lrs.items()), \ + 'Length of logged learning rates exceeds the number of epochs' + + +def test_lr_logger_param_groups(tmpdir): + """ Test that learning rates are extracted and logged for single lr scheduler. """ + tutils.reset_seed() + + model = EvalModelTemplate() + model.configure_optimizers = model.configure_optimizers__param_groups + + lr_logger = LearningRateLogger() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=5, + val_percent_check=0.1, + train_percent_check=0.5, + callbacks=[lr_logger] + ) + result = trainer.fit(model) + assert result + + assert lr_logger.lrs, 'No learning rates logged' + assert len(lr_logger.lrs) == 2 * len(trainer.lr_schedulers), \ + 'Number of learning rates logged does not match number of param groups' + assert all([k in ['lr-Adam/pg1', 'lr-Adam/pg2'] for k in lr_logger.lrs.keys()]), \ + 'Names of learning rates not set correctly'