Skip to content

Commit

Permalink
move
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored and ivannz committed May 25, 2020
1 parent f6bebc1 commit b4f7855
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 98 deletions.
101 changes: 3 additions & 98 deletions tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from pathlib import Path

import pytest

import tests.base.utils as tutils
from pytorch_lightning import Callback
from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.utilities.exceptions import MisconfigurationException

from pytorch_lightning.callbacks import EarlyStopping, LearningRateLogger, ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from tests.base import EvalModelTemplate
from pathlib import Path


def test_trainer_callback_system(tmpdir):
Expand Down Expand Up @@ -283,97 +282,3 @@ def test_model_checkpoint_path(tmpdir, logger_version, expected):

ckpt_version = Path(trainer.ckpt_path).parent.name
assert ckpt_version == expected


def test_lr_logger_single_lr(tmpdir):
""" Test that learning rates are extracted and logged for single lr scheduler"""
tutils.reset_seed()

model = EvalModelTemplate()
model.configure_optimizers = model.configure_optimizers__single_scheduler

lr_logger = LearningRateLogger()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=5,
val_percent_check=0.1,
train_percent_check=0.5,
callbacks=[lr_logger]
)
results = trainer.fit(model)

assert results == 1
assert lr_logger.lrs, 'No learning rates logged'
assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \
'Number of learning rates logged does not match number of lr schedulers'
assert all([k in ['lr-Adam'] for k in lr_logger.lrs.keys()]), \
'Names of learning rates not set correctly'


def test_lr_logger_no_lr(tmpdir):
tutils.reset_seed()

model = EvalModelTemplate()

lr_logger = LearningRateLogger()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=5,
val_percent_check=0.1,
train_percent_check=0.5,
callbacks=[lr_logger]
)

with pytest.warns(RuntimeWarning):
results = trainer.fit(model)


def test_lr_logger_multi_lrs(tmpdir):
""" Test that learning rates are extracted and logged for multi lr schedulers """
tutils.reset_seed()

model = EvalModelTemplate()
model.configure_optimizers = model.configure_optimizers__multiple_schedulers

lr_logger = LearningRateLogger()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=10,
val_percent_check=0.1,
train_percent_check=0.5,
callbacks=[lr_logger]
)
results = trainer.fit(model)

assert results == 1
assert lr_logger.lrs, 'No learning rates logged'
assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \
'Number of learning rates logged does not match number of lr schedulers'
assert all([k in ['lr-Adam', 'lr-Adam-1'] for k in lr_logger.lrs.keys()]), \
'Names of learning rates not set correctly'
assert all(len(lr) == trainer.max_epochs for k, lr in lr_logger.lrs.items()), \
'Length of logged learning rates exceeds the number of epochs'


def test_lr_logger_param_groups(tmpdir):
""" Test that learning rates are extracted and logged for single lr scheduler"""
tutils.reset_seed()

model = EvalModelTemplate()
model.configure_optimizers = model.configure_optimizers__param_groups

lr_logger = LearningRateLogger()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=5,
val_percent_check=0.1,
train_percent_check=0.5,
callbacks=[lr_logger]
)
results = trainer.fit(model)

assert lr_logger.lrs, 'No learning rates logged'
assert len(lr_logger.lrs) == 2 * len(trainer.lr_schedulers), \
'Number of learning rates logged does not match number of param groups'
assert all([k in ['lr-Adam/pg1', 'lr-Adam/pg2'] for k in lr_logger.lrs.keys()]), \
'Names of learning rates not set correctly'
102 changes: 102 additions & 0 deletions tests/callbacks/test_lr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import pytest

import tests.base.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateLogger
from tests.base import EvalModelTemplate


def test_lr_logger_single_lr(tmpdir):
""" Test that learning rates are extracted and logged for single lr scheduler. """
tutils.reset_seed()

model = EvalModelTemplate()
model.configure_optimizers = model.configure_optimizers__single_scheduler

lr_logger = LearningRateLogger()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=5,
val_percent_check=0.1,
train_percent_check=0.5,
callbacks=[lr_logger]
)
result = trainer.fit(model)
assert result

assert lr_logger.lrs, 'No learning rates logged'
assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \
'Number of learning rates logged does not match number of lr schedulers'
assert all([k in ['lr-Adam'] for k in lr_logger.lrs.keys()]), \
'Names of learning rates not set correctly'


def test_lr_logger_no_lr(tmpdir):
tutils.reset_seed()

model = EvalModelTemplate()

lr_logger = LearningRateLogger()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=5,
val_percent_check=0.1,
train_percent_check=0.5,
callbacks=[lr_logger]
)

with pytest.warns(RuntimeWarning):
result = trainer.fit(model)
assert result


def test_lr_logger_multi_lrs(tmpdir):
""" Test that learning rates are extracted and logged for multi lr schedulers. """
tutils.reset_seed()

model = EvalModelTemplate()
model.configure_optimizers = model.configure_optimizers__multiple_schedulers

lr_logger = LearningRateLogger()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=10,
val_percent_check=0.1,
train_percent_check=0.5,
callbacks=[lr_logger]
)
result = trainer.fit(model)
assert result

assert lr_logger.lrs, 'No learning rates logged'
assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \
'Number of learning rates logged does not match number of lr schedulers'
assert all([k in ['lr-Adam', 'lr-Adam-1'] for k in lr_logger.lrs.keys()]), \
'Names of learning rates not set correctly'
assert all(len(lr) == trainer.max_epochs for k, lr in lr_logger.lrs.items()), \
'Length of logged learning rates exceeds the number of epochs'


def test_lr_logger_param_groups(tmpdir):
""" Test that learning rates are extracted and logged for single lr scheduler. """
tutils.reset_seed()

model = EvalModelTemplate()
model.configure_optimizers = model.configure_optimizers__param_groups

lr_logger = LearningRateLogger()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=5,
val_percent_check=0.1,
train_percent_check=0.5,
callbacks=[lr_logger]
)
result = trainer.fit(model)
assert result

assert lr_logger.lrs, 'No learning rates logged'
assert len(lr_logger.lrs) == 2 * len(trainer.lr_schedulers), \
'Number of learning rates logged does not match number of param groups'
assert all([k in ['lr-Adam/pg1', 'lr-Adam/pg2'] for k in lr_logger.lrs.keys()]), \
'Names of learning rates not set correctly'

0 comments on commit b4f7855

Please sign in to comment.