Skip to content

Commit

Permalink
ReduceLROnPlateau bug fix (#1126)
Browse files Browse the repository at this point in the history
* bug fix and test

* update CHANGELOG.md

Co-authored-by: Nicki Skafte <nugginea@gmail.com>
  • Loading branch information
SkafteNicki and Nicki Skafte authored Mar 16, 2020
1 parent 774d9be commit 384e124
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 5 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

-
- Fixed bug related to type cheking of `ReduceLROnPlateau` lr schedulers([#1114](https://github.com/PyTorchLightning/pytorch-lightning/issues/1114))

## [0.7.1] - 2020-03-07

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,8 +707,8 @@ def configure_schedulers(self, schedulers: list):
if 'scheduler' not in scheduler:
raise ValueError(f'Lr scheduler should have key `scheduler`',
' with item being a lr scheduler')
scheduler['reduce_on_plateau'] = \
isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau)
scheduler['reduce_on_plateau'] = isinstance(
scheduler['scheduler'], optim.lr_scheduler.ReduceLROnPlateau)

lr_schedulers.append({**default_config, **scheduler})

Expand Down
3 changes: 2 additions & 1 deletion tests/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
LightInfTestDataloader,
LightTestOptimizerWithSchedulingMixin,
LightTestMultipleOptimizersWithSchedulingMixin,
LightTestOptimizersWithMixedSchedulingMixin
LightTestOptimizersWithMixedSchedulingMixin,
LightTestReduceLROnPlateauMixin
)


Expand Down
10 changes: 10 additions & 0 deletions tests/models/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,16 @@ def configure_optimizers(self):
[{'scheduler': lr_scheduler1, 'interval': 'step'}, lr_scheduler2]


class LightTestReduceLROnPlateauMixin:
def configure_optimizers(self):
if self.hparams.optimizer_name == 'lbfgs':
optimizer = optim.LBFGS(self.parameters(), lr=self.hparams.learning_rate)
else:
optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)
return [optimizer], [lr_scheduler]


def _get_output_metric(output, name):
if isinstance(output, dict):
val = output[name]
Expand Down
37 changes: 36 additions & 1 deletion tests/trainer/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@
from tests.models import (
TestModelBase,
LightTrainDataloader,
LightValidationStepMixin,
LightValidationMixin,
LightTestOptimizerWithSchedulingMixin,
LightTestMultipleOptimizersWithSchedulingMixin,
LightTestOptimizersWithMixedSchedulingMixin
LightTestOptimizersWithMixedSchedulingMixin,
LightTestReduceLROnPlateauMixin
)


Expand Down Expand Up @@ -144,3 +147,35 @@ class CurrentTestModel(
# Called every 3 steps, meaning for 1 epoch of 11 batches, it is called 3 times
assert init_lr * 0.1 == adjusted_lr2, \
'lr for optimizer 2 not adjusted correctly'


def test_reduce_lr_on_plateau_scheduling(tmpdir):
tutils.reset_seed()

class CurrentTestModel(
LightTestReduceLROnPlateauMixin,
LightTrainDataloader,
LightValidationMixin,
LightValidationStepMixin,
TestModelBase):
pass

hparams = tutils.get_hparams()
model = CurrentTestModel(hparams)

# logger file to get meta
trainer_options = dict(
default_save_path=tmpdir,
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.2
)

# fit model
trainer = Trainer(**trainer_options)
results = trainer.fit(model)

assert trainer.lr_schedulers[0] == \
dict(scheduler=trainer.lr_schedulers[0]['scheduler'], monitor='val_loss',
interval='epoch', frequency=1, reduce_on_plateau=True), \
'lr schduler was not correctly converted to dict'

0 comments on commit 384e124

Please sign in to comment.