Skip to content

Commit

Permalink
a test for log length consistency
Browse files Browse the repository at this point in the history
  • Loading branch information
ivannz committed May 25, 2020
1 parent 21b14a7 commit 2ec5688
Showing 1 changed file with 26 additions and 0 deletions.
26 changes: 26 additions & 0 deletions tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,3 +355,29 @@ def test_lr_logger_param_groups(tmpdir):
'Number of learning rates logged does not match number of param groups'
assert all([k in ['lr-Adam/pg1', 'lr-Adam/pg2'] for k in lr_logger.lrs.keys()]), \
'Names of learning rates not set correctly'


def test_lr_logger_multiple_schedulers(tmpdir):
""" Test that learning rates are extracted and logged for single lr scheduler"""
tutils.reset_seed()

model = EvalModelTemplate()
model.configure_optimizers = model.configure_optimizers__multiple_schedulers

lr_logger = LearningRateLogger()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=5,
val_percent_check=0.1,
train_percent_check=0.5,
callbacks=[lr_logger]
)
results = trainer.fit(model)

assert lr_logger.lrs, 'No learning rates logged'
assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \
'Number of learning rates logged does not match number of param groups'
assert all([k in ['lr-Adam', 'lr-Adam-1'] for k in lr_logger.lrs.keys()]), \
'Names of learning rates not set correctly'
assert all(len(lr) == 5 for k, lr in lr_logger.lrs.items()), \
'Inconsistent learning rate log'

0 comments on commit 2ec5688

Please sign in to comment.