Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Learning rate log callback #1498

Merged
merged 33 commits into from
Apr 30, 2020
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
d6d49c1
base implementation
Apr 13, 2020
5f5164a
docs + implementation
Apr 13, 2020
4b5921c
fix styling
Apr 13, 2020
b5e7122
add lr string
Apr 15, 2020
619a2df
renaming
Apr 15, 2020
ad1cd93
CHANGELOG.md
Apr 15, 2020
5a99e8b
add tests
Apr 15, 2020
589d95d
rebase
Apr 15, 2020
1a3fe9e
Apply suggestions from code review
Borda Apr 16, 2020
8d22184
Apply suggestions from code review
Borda Apr 16, 2020
36622e2
Update pytorch_lightning/callbacks/lr_logger.py
Borda Apr 16, 2020
f265ff7
Update pytorch_lightning/callbacks/lr_logger.py
Borda Apr 16, 2020
258ceaf
add test for naming
Apr 17, 2020
b73a1b9
fix merge conflict
Apr 21, 2020
4f02e06
base implementation
Apr 13, 2020
b1ad2a5
docs + implementation
Apr 13, 2020
935b239
fix styling
Apr 13, 2020
ab87b2d
add lr string
Apr 15, 2020
2a592e1
renaming
Apr 15, 2020
60c5361
CHANGELOG.md
Apr 15, 2020
db7cb77
add tests
Apr 15, 2020
65be2bb
Apply suggestions from code review
Borda Apr 16, 2020
a39a5bf
Apply suggestions from code review
Borda Apr 16, 2020
1047c15
Update pytorch_lightning/callbacks/lr_logger.py
Borda Apr 16, 2020
90eab28
Update pytorch_lightning/callbacks/lr_logger.py
Borda Apr 16, 2020
af3624b
add test for naming
Apr 17, 2020
9554095
Update pytorch_lightning/callbacks/lr_logger.py
Borda Apr 26, 2020
ab8a26c
rebase
Apr 27, 2020
8fe113c
suggestions from code review
Apr 27, 2020
8c7f945
fix styling
Apr 27, 2020
c18ce73
rebase
Apr 27, 2020
3c2b12e
Merge remote-tracking branch 'upstream/master' into feature/lr_log_ca…
Apr 27, 2020
d0a25ba
fix tests
Apr 27, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added

- Added `auto_select_gpus` flag to trainer that enables automatic selection of available GPUs on exclusive mode systems.
- Added learning rate finder ([#1347](https://github.com/PyTorchLightning/pytorch-lightning/pull/1347))

- Added learining rate finder ([#1347](https://github.com/PyTorchLightning/pytorch-lightning/pull/1347))
- Added `terminate_on_nan` flag to trainer that performs a NaN check with each training iteration when set to `True`. ([#1475](https://github.com/PyTorchLightning/pytorch-lightning/pull/1475))
- Added callback for logging learning rates ([#1498](https://github.com/PyTorchLightning/pytorch-lightning/pull/1498))

### Changed

Expand Down
7 changes: 7 additions & 0 deletions docs/source/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,10 @@ We successfully extended functionality without polluting our super clean
_save_model,
_abc_impl,
check_monitor_top_k,

---------

.. automodule:: pytorch_lightning.callbacks.lr_logger
:noindex:
:exclude-members:
_extract_lr
2 changes: 2 additions & 0 deletions pytorch_lightning/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.callbacks.lr_logger import LearningRateLogger

__all__ = [
'Callback',
'EarlyStopping',
'ModelCheckpoint',
'GradientAccumulationScheduler',
'LearningRateLogger'
]
117 changes: 117 additions & 0 deletions pytorch_lightning/callbacks/lr_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
r"""

Logging of learning rates
=========================

Log learning rate for lr schedulers during training

"""

from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities.exceptions import MisconfigurationException


class LearningRateLogger(Callback):
r"""
Automatically logs learning rate for learning rate schedulers during training.

Example::

>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import LearningRateLogger
>>> lr_logger = LearningRateLogger()
>>> trainer = Trainer(callbacks=[lr_logger])

Logging names are automatically determined based on optimizer class name.
In case of multiple optimizers of same type, they will be named `Adam`,
`Adam-1` etc. If a optimizer has multiple parameter groups they will
be named `Adam/pg1`, `Adam/pg2` etc. To control naming, pass in a
`name` keyword in the construction of the learning rate schdulers

Example::

def configure_optimizer(self):
optimizer = torch.optim.Adam(...)
lr_scheduler = {'scheduler': torch.optim.lr_schedulers.LambdaLR(optimizer, ...)
'name': 'my_logging_name'}
return [optimizer], [lr_scheduler]
"""
def __init__(self):
self.lrs = {}
self.names = []

def on_train_start(self, trainer, pl_module):
""" Called before training, determines unique names for all lr
schedulers in the case of multiple of the same type or in
the case of multiple parameter groups
"""
if trainer.lr_schedulers == []:
raise MisconfigurationException(
'Cannot use LearningRateLogger callback with models that have no'
' learning rate schedulers. Please see documentation for'
' `configure_optimizers` method.')

if not trainer.logger:
raise MisconfigurationException(
'Cannot use LearningRateLogger callback with Trainer that have no logger.')
Borda marked this conversation as resolved.
Show resolved Hide resolved

# Create uniqe names in the case we have multiple of the same learning
# rate schduler + multiple parameter groups
names = []
for scheduler in trainer.lr_schedulers:
sch = scheduler['scheduler']
if 'name' in scheduler:
name = scheduler['name']
else:
opt_name = 'lr-' + sch.optimizer.__class__.__name__
name = opt_name
counter = 0
# Multiple schduler of the same type
while True:
counter += 1
if name not in names:
break
name = opt_name + '-' + str(counter)
Borda marked this conversation as resolved.
Show resolved Hide resolved

# Multiple param groups for the same schduler
param_groups = sch.optimizer.param_groups
if len(param_groups) != 1:
for i, pg in enumerate(param_groups):
temp = name + '/pg' + str(i + 1)
names.append(temp)
else:
names.append(name)

self.names.append(name)

# Initialize for storing values
for name in names:
self.lrs[name] = []
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved

def on_batch_start(self, trainer, pl_module):
latest_stat = self._extract_lr(trainer, 'step')
if trainer.logger and latest_stat:
trainer.logger.log_metrics(latest_stat, step=trainer.global_step)

def on_epoch_start(self, trainer, pl_module):
latest_stat = self._extract_lr(trainer, 'epoch')
if trainer.logger and latest_stat:
trainer.logger.log_metrics(latest_stat, step=trainer.global_step)

def _extract_lr(self, trainer, interval):
""" Extracts learning rates for lr schedulers and saves information
into dict structure. """
latest_stat = {}
for name, scheduler in zip(self.names, trainer.lr_schedulers):
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
if scheduler['interval'] == interval:
param_groups = scheduler['scheduler'].optimizer.param_groups
if len(param_groups) != 1:
for i, pg in enumerate(param_groups):
lr = pg['lr']
key = f'{name}/{i + 1}'
Borda marked this conversation as resolved.
Show resolved Hide resolved
self.lrs[key].append(lr)
latest_stat[key] = lr
else:
self.lrs[name].append(param_groups[0]['lr'])
latest_stat[name] = param_groups[0]['lr']
return latest_stat
55 changes: 54 additions & 1 deletion tests/trainer/test_callbacks.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import tests.base.utils as tutils
from pytorch_lightning import Callback
from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.callbacks import EarlyStopping, LearningRateLogger
from tests.base import (
LightTrainDataloader,
LightTestMixin,
LightValidationMixin,
LightTestOptimizersWithMixedSchedulingMixin,
TestModelBase
)

Expand Down Expand Up @@ -181,3 +182,55 @@ def training_step(self, *args, **kwargs):

assert result == 1, 'training failed to complete'
assert trainer.current_epoch < trainer.max_epochs


def test_lr_logger_single_lr(tmpdir):
""" Test that learning rates are extracted and logged for single lr scheduler"""
tutils.reset_seed()

class CurrentTestModel(LightTrainDataloader, TestModelBase):
pass

hparams = tutils.get_default_hparams()
model = CurrentTestModel(hparams)

lr_logger = LearningRateLogger()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=5,
val_percent_check=0.1,
train_percent_check=0.5,
callbacks=[lr_logger]
)
results = trainer.fit(model)

assert lr_logger.lrs, 'No learning rates logged'
assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \
'Number of learning rates logged does not match number of lr schedulers'


def test_lr_logger_multi_lrs(tmpdir):
""" Test that learning rates are extracted and logged for multi lr schedulers """
tutils.reset_seed()

class CurrentTestModel(LightTestOptimizersWithMixedSchedulingMixin,
LightTrainDataloader,
TestModelBase):
pass

hparams = tutils.get_default_hparams()
model = CurrentTestModel(hparams)

lr_logger = LearningRateLogger()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.5,
callbacks=[lr_logger]
)
results = trainer.fit(model)

assert lr_logger.lrs, 'No learning rates logged'
assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \
'Number of learning rates logged does not match number of lr schedulers'