From 4c34d16a349bc96a717be5674606c2577fab8946 Mon Sep 17 00:00:00 2001 From: Alexey Karnachev Date: Fri, 10 Apr 2020 18:43:06 +0300 Subject: [PATCH] Fixed configure optimizer from dict without "scheduler" key (#1443) * `configure_optimizer` from dict with only "optimizer" key. bug fixed * autopep8 * pep8speaks suggested fixes * CHANGELOG.md upd --- CHANGELOG.md | 1 + pytorch_lightning/trainer/optimizers.py | 2 ++ pytorch_lightning/trainer/supporters.py | 1 + pytorch_lightning/trainer/trainer.py | 3 ++- tests/trainer/test_optimizers.py | 23 +++++++++++++++++++++++ 5 files changed, 29 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a1833c38ed993..fd29d715a8686 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed optimizer configuration when `configure_optimizers` returns dict without `lr_scheduler` ([#1443](https://github.com/PyTorchLightning/pytorch-lightning/pull/1443)) - Fixed default `DistributedSampler` for DDP training ([#1425](https://github.com/PyTorchLightning/pytorch-lightning/pull/1425)) - Fixed workers warning not on windows ([#1430](https://github.com/PyTorchLightning/pytorch-lightning/pull/1430)) - Fixed returning tuple from `run_training_batch` ([#1431](https://github.com/PyTorchLightning/pytorch-lightning/pull/1431)) diff --git a/pytorch_lightning/trainer/optimizers.py b/pytorch_lightning/trainer/optimizers.py index 2c4f0ed57e04f..8dd77f7971a48 100644 --- a/pytorch_lightning/trainer/optimizers.py +++ b/pytorch_lightning/trainer/optimizers.py @@ -39,6 +39,8 @@ def init_optimizers( lr_scheduler = optim_conf.get("lr_scheduler", []) if lr_scheduler: lr_schedulers = self.configure_schedulers([lr_scheduler]) + else: + lr_schedulers = [] return [optimizer], lr_schedulers, [] # multiple dictionaries diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 925d96ed8baa3..dc29e8d36a08d 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -20,6 +20,7 @@ class TensorRunningAccum(object): >>> accum.last(), accum.mean(), accum.min(), accum.max() (tensor(12.), tensor(10.), tensor(8.), tensor(12.)) """ + def __init__(self, window_length: int): self.window_length = window_length self.memory = torch.Tensor(self.window_length) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 9cf6678c5837f..3e3b7c9679df9 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -554,7 +554,8 @@ def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser: if at[0] not in depr_arg_names): for allowed_type in (at for at in allowed_types if at in arg_types): if isinstance(allowed_type, bool): - allowed_type = lambda x: bool(distutils.util.strtobool(x)) + def allowed_type(x): + return bool(distutils.util.strtobool(x)) parser.add_argument( f'--{arg}', default=arg_default, diff --git a/tests/trainer/test_optimizers.py b/tests/trainer/test_optimizers.py index b07cbd4fd7318..f65878cf3262b 100644 --- a/tests/trainer/test_optimizers.py +++ b/tests/trainer/test_optimizers.py @@ -275,3 +275,26 @@ class CurrentTestModel( # verify training completed assert result == 1 + + +def test_configure_optimizer_from_dict(tmpdir): + """Tests if `configure_optimizer` method could return a dictionary with + `optimizer` field only. + """ + + class CurrentTestModel(LightTrainDataloader, TestModelBase): + def configure_optimizers(self): + config = { + 'optimizer': torch.optim.SGD(params=self.parameters(), lr=1e-03) + } + return config + + hparams = tutils.get_default_hparams() + model = CurrentTestModel(hparams) + + trainer_options = dict(default_save_path=tmpdir, max_epochs=1) + + # fit model + trainer = Trainer(**trainer_options) + result = trainer.fit(model) + assert result == 1