diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 57b0b537b4dcaa..b5bc800a1e6909 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -318,7 +318,7 @@ def test_lr_logger_multi_lrs(tmpdir): lr_logger = LearningRateLogger() trainer = Trainer( default_root_dir=tmpdir, - max_epochs=1, + max_epochs=10, val_percent_check=0.1, train_percent_check=0.5, callbacks=[lr_logger] @@ -331,6 +331,8 @@ def test_lr_logger_multi_lrs(tmpdir): 'Number of learning rates logged does not match number of lr schedulers' assert all([k in ['lr-Adam', 'lr-Adam-1'] for k in lr_logger.lrs.keys()]), \ 'Names of learning rates not set correctly' + assert all(len(lr) == trainer.max_epochs for k, lr in lr_logger.lrs.items()), \ + 'Length of logged learning rates exceeds the number of epochs' def test_lr_logger_param_groups(tmpdir):