Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Learning rate log callback #1498

Merged
merged 33 commits into from
Apr 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
d6d49c1
base implementation
Apr 13, 2020
5f5164a
docs + implementation
Apr 13, 2020
4b5921c
fix styling
Apr 13, 2020
b5e7122
add lr string
Apr 15, 2020
619a2df
renaming
Apr 15, 2020
ad1cd93
CHANGELOG.md
Apr 15, 2020
5a99e8b
add tests
Apr 15, 2020
589d95d
rebase
Apr 15, 2020
1a3fe9e
Apply suggestions from code review
Borda Apr 16, 2020
8d22184
Apply suggestions from code review
Borda Apr 16, 2020
36622e2
Update pytorch_lightning/callbacks/lr_logger.py
Borda Apr 16, 2020
f265ff7
Update pytorch_lightning/callbacks/lr_logger.py
Borda Apr 16, 2020
258ceaf
add test for naming
Apr 17, 2020
b73a1b9
fix merge conflict
Apr 21, 2020
4f02e06
base implementation
Apr 13, 2020
b1ad2a5
docs + implementation
Apr 13, 2020
935b239
fix styling
Apr 13, 2020
ab87b2d
add lr string
Apr 15, 2020
2a592e1
renaming
Apr 15, 2020
60c5361
CHANGELOG.md
Apr 15, 2020
db7cb77
add tests
Apr 15, 2020
65be2bb
Apply suggestions from code review
Borda Apr 16, 2020
a39a5bf
Apply suggestions from code review
Borda Apr 16, 2020
1047c15
Update pytorch_lightning/callbacks/lr_logger.py
Borda Apr 16, 2020
90eab28
Update pytorch_lightning/callbacks/lr_logger.py
Borda Apr 16, 2020
af3624b
add test for naming
Apr 17, 2020
9554095
Update pytorch_lightning/callbacks/lr_logger.py
Borda Apr 26, 2020
ab8a26c
rebase
Apr 27, 2020
8fe113c
suggestions from code review
Apr 27, 2020
8c7f945
fix styling
Apr 27, 2020
c18ce73
rebase
Apr 27, 2020
3c2b12e
Merge remote-tracking branch 'upstream/master' into feature/lr_log_ca…
Apr 27, 2020
d0a25ba
fix tests
Apr 27, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added callback for logging learning rates ([#1498](https://github.com/PyTorchLightning/pytorch-lightning/pull/1498))

### Changed

- Allow logging of metrics togther with hparams ([#1630](https://github.com/PyTorchLightning/pytorch-lightning/pull/1630))
Expand Down
8 changes: 8 additions & 0 deletions docs/source/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,11 @@ We successfully extended functionality without polluting our super clean
.. automodule:: pytorch_lightning.callbacks.progress
:noindex:
:exclude-members:

---------

.. automodule:: pytorch_lightning.callbacks.lr_logger
:noindex:
:exclude-members:
_extract_lr,
_find_names
2 changes: 2 additions & 0 deletions pytorch_lightning/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.callbacks.lr_logger import LearningRateLogger
from pytorch_lightning.callbacks.progress import ProgressBarBase, ProgressBar

__all__ = [
'Callback',
'EarlyStopping',
'ModelCheckpoint',
'GradientAccumulationScheduler',
'LearningRateLogger',
'ProgressBarBase',
'ProgressBar',
]
118 changes: 118 additions & 0 deletions pytorch_lightning/callbacks/lr_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
r"""

Logging of learning rates
=========================

Log learning rate for lr schedulers during training

"""

from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities.exceptions import MisconfigurationException


class LearningRateLogger(Callback):
r"""
Automatically logs learning rate for learning rate schedulers during training.

Example::

>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import LearningRateLogger
>>> lr_logger = LearningRateLogger()
>>> trainer = Trainer(callbacks=[lr_logger])

Logging names are automatically determined based on optimizer class name.
In case of multiple optimizers of same type, they will be named `Adam`,
`Adam-1` etc. If a optimizer has multiple parameter groups they will
be named `Adam/pg1`, `Adam/pg2` etc. To control naming, pass in a
`name` keyword in the construction of the learning rate schdulers

Example::

def configure_optimizer(self):
optimizer = torch.optim.Adam(...)
lr_scheduler = {'scheduler': torch.optim.lr_schedulers.LambdaLR(optimizer, ...)
'name': 'my_logging_name'}
return [optimizer], [lr_scheduler]
"""
def __init__(self):
self.lrs = None
self.lr_sch_names = []

def on_train_start(self, trainer, pl_module):
""" Called before training, determines unique names for all lr
schedulers in the case of multiple of the same type or in
the case of multiple parameter groups
"""
if trainer.lr_schedulers == []:
raise MisconfigurationException(
'Cannot use LearningRateLogger callback with models that have no'
' learning rate schedulers. Please see documentation for'
' `configure_optimizers` method.')

if not trainer.logger:
raise MisconfigurationException(
'Cannot use LearningRateLogger callback with Trainer that has no logger.')

# Find names for schedulers
names = self._find_names(trainer.lr_schedulers)

# Initialize for storing values
self.lrs = dict.fromkeys(names, [])

def on_batch_start(self, trainer, pl_module):
latest_stat = self._extract_lr(trainer, 'step')
if trainer.logger and latest_stat:
trainer.logger.log_metrics(latest_stat, step=trainer.global_step)

def on_epoch_start(self, trainer, pl_module):
latest_stat = self._extract_lr(trainer, 'epoch')
if trainer.logger and latest_stat:
trainer.logger.log_metrics(latest_stat, step=trainer.global_step)

def _extract_lr(self, trainer, interval):
""" Extracts learning rates for lr schedulers and saves information
into dict structure. """
latest_stat = {}
for name, scheduler in zip(self.lr_sch_names, trainer.lr_schedulers):
if scheduler['interval'] == interval:
param_groups = scheduler['scheduler'].optimizer.param_groups
if len(param_groups) != 1:
for i, pg in enumerate(param_groups):
lr, key = pg['lr'], f'{name}/{i + 1}'
self.lrs[key].append(lr)
latest_stat[key] = lr
else:
self.lrs[name].append(param_groups[0]['lr'])
latest_stat[name] = param_groups[0]['lr']
return latest_stat

def _find_names(self, lr_schedulers):
# Create uniqe names in the case we have multiple of the same learning
# rate schduler + multiple parameter groups
names = []
for scheduler in lr_schedulers:
sch = scheduler['scheduler']
if 'name' in scheduler:
name = scheduler['name']
else:
opt_name = 'lr-' + sch.optimizer.__class__.__name__
i, name = 1, opt_name
# Multiple schduler of the same type
while True:
if name not in names:
break
i, name = i + 1, f'{opt_name}-{i}'

# Multiple param groups for the same schduler
param_groups = sch.optimizer.param_groups
if len(param_groups) != 1:
for i, pg in enumerate(param_groups):
temp = name + '/pg' + str(i + 1)
names.append(temp)
else:
names.append(name)

self.lr_sch_names.append(name)
return names
59 changes: 58 additions & 1 deletion tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import tests.base.utils as tutils
from pytorch_lightning import Callback
from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping, LearningRateLogger, ModelCheckpoint
from tests.base import (
LightTrainDataloader,
LightTestMixin,
LightValidationMixin,
LightTestOptimizersWithMixedSchedulingMixin,
TestModelBase
)

Expand Down Expand Up @@ -271,3 +272,59 @@ class CurrentTestModel(LightTrainDataloader, TestModelBase):

# These should be different if the dirpath has be overridden
assert trainer.ckpt_path != trainer.default_root_dir


def test_lr_logger_single_lr(tmpdir):
""" Test that learning rates are extracted and logged for single lr scheduler"""
tutils.reset_seed()

class CurrentTestModel(LightTrainDataloader, TestModelBase):
pass

hparams = tutils.get_default_hparams()
model = CurrentTestModel(hparams)

lr_logger = LearningRateLogger()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=5,
val_percent_check=0.1,
train_percent_check=0.5,
callbacks=[lr_logger]
)
results = trainer.fit(model)

assert lr_logger.lrs, 'No learning rates logged'
assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \
'Number of learning rates logged does not match number of lr schedulers'
assert all([k in ['lr-Adam'] for k in lr_logger.lrs.keys()]), \
'Names of learning rates not set correctly'


def test_lr_logger_multi_lrs(tmpdir):
""" Test that learning rates are extracted and logged for multi lr schedulers """
tutils.reset_seed()

class CurrentTestModel(LightTestOptimizersWithMixedSchedulingMixin,
LightTrainDataloader,
TestModelBase):
pass

hparams = tutils.get_default_hparams()
model = CurrentTestModel(hparams)

lr_logger = LearningRateLogger()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.5,
callbacks=[lr_logger]
)
results = trainer.fit(model)

assert lr_logger.lrs, 'No learning rates logged'
assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \
'Number of learning rates logged does not match number of lr schedulers'
assert all([k in ['lr-Adam', 'lr-Adam-1'] for k in lr_logger.lrs.keys()]), \
'Names of learning rates not set correctly'