diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 57b0b537b4dcaa..7cbd4177fac8c3 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -355,3 +355,29 @@ def test_lr_logger_param_groups(tmpdir): 'Number of learning rates logged does not match number of param groups' assert all([k in ['lr-Adam/pg1', 'lr-Adam/pg2'] for k in lr_logger.lrs.keys()]), \ 'Names of learning rates not set correctly' + + +def test_lr_logger_multiple_schedulers(tmpdir): + """ Test that learning rates are extracted and logged for single lr scheduler""" + tutils.reset_seed() + + model = EvalModelTemplate() + model.configure_optimizers = model.configure_optimizers__multiple_schedulers + + lr_logger = LearningRateLogger() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=5, + val_percent_check=0.1, + train_percent_check=0.5, + callbacks=[lr_logger] + ) + results = trainer.fit(model) + + assert lr_logger.lrs, 'No learning rates logged' + assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \ + 'Number of learning rates logged does not match number of param groups' + assert all([k in ['lr-Adam', 'lr-Adam-1'] for k in lr_logger.lrs.keys()]), \ + 'Names of learning rates not set correctly' + assert all(len(lr) == 5 for k, lr in lr_logger.lrs.items()), \ + 'Inconsistent learning rate log'