Skip to content

Commit

Permalink
Fix lr key name in case of param groups (#1719)
Browse files Browse the repository at this point in the history
* Fix lr key name in case of param groups

* Add tests

* Update test and added configure_optimizers__param_groups

* Update CHANGELOG
  • Loading branch information
rohitgr7 committed May 10, 2020
1 parent 7f64ad7 commit d962ab5
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 2 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug in Trainer that prepended the checkpoint path with `version_` when it shouldn't ([#1748](https://github.com/PyTorchLightning/pytorch-lightning/pull/1748))

- Fixed lr key name in case of param groups in LearningRateLogger ([#1719](https://github.com/PyTorchLightning/pytorch-lightning/pull/1719))

## [0.7.5] - 2020-04-27

### Changed
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/callbacks/lr_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _extract_lr(self, trainer, interval):
param_groups = scheduler['scheduler'].optimizer.param_groups
if len(param_groups) != 1:
for i, pg in enumerate(param_groups):
lr, key = pg['lr'], f'{name}/{i + 1}'
lr, key = pg['lr'], f'{name}/pg{i + 1}'
self.lrs[key].append(lr)
latest_stat[key] = lr
else:
Expand Down Expand Up @@ -109,7 +109,7 @@ def _find_names(self, lr_schedulers):
param_groups = sch.optimizer.param_groups
if len(param_groups) != 1:
for i, pg in enumerate(param_groups):
temp = name + '/pg' + str(i + 1)
temp = f'{name}/pg{i + 1}'
names.append(temp)
else:
names.append(name)
Expand Down
Empty file added tests/base/mixins.py
Empty file.
10 changes: 10 additions & 0 deletions tests/base/model_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,13 @@ def configure_optimizers__reduce_lr_on_plateau(self):
optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)
return [optimizer], [lr_scheduler]

def configure_optimizers__param_groups(self):
param_groups = [
{'params': list(self.parameters())[:2], 'lr': self.hparams.learning_rate * 0.1},
{'params': list(self.parameters())[2:], 'lr': self.hparams.learning_rate}
]

optimizer = optim.Adam(param_groups)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.1)
return [optimizer], [lr_scheduler]
24 changes: 24 additions & 0 deletions tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,3 +331,27 @@ def test_lr_logger_multi_lrs(tmpdir):
'Number of learning rates logged does not match number of lr schedulers'
assert all([k in ['lr-Adam', 'lr-Adam-1'] for k in lr_logger.lrs.keys()]), \
'Names of learning rates not set correctly'


def test_lr_logger_param_groups(tmpdir):
""" Test that learning rates are extracted and logged for single lr scheduler"""
tutils.reset_seed()

model = EvalModelTemplate()
model.configure_optimizers = model.configure_optimizers__param_groups

lr_logger = LearningRateLogger()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=5,
val_percent_check=0.1,
train_percent_check=0.5,
callbacks=[lr_logger]
)
results = trainer.fit(model)

assert lr_logger.lrs, 'No learning rates logged'
assert len(lr_logger.lrs) == 2 * len(trainer.lr_schedulers), \
'Number of learning rates logged does not match number of param groups'
assert all([k in ['lr-Adam/pg1', 'lr-Adam/pg2'] for k in lr_logger.lrs.keys()]), \
'Names of learning rates not set correctly'

0 comments on commit d962ab5

Please sign in to comment.