Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix trainer not resetting lightning_optimizers #6372

Merged
merged 5 commits into from
Mar 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed PyTorch Profiler with `emit_nvtx` ([#6260](https://github.com/PyTorchLightning/pytorch-lightning/pull/6260))


- Fixed `trainer.test` from `best_path` hangs after calling `trainer.fit` ([#6272](https://github.com/PyTorchLightning/pytorch-lightning/pull/6272))
- Fixed `Trainer` not resetting `lightning_optimizers` when calling `Trainer.fit()` multiple times ([#6372](https://github.com/PyTorchLightning/pytorch-lightning/pull/6372))


## [1.2.2] - 2021-03-02
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/trainer/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@

class TrainerOptimizersMixin(ABC):

_lightning_optimizers: Optional[List[LightningOptimizer]]

def init_optimizers(self, model: LightningModule) -> Tuple[List, List, List]:
self._lightning_optimizers = None
optim_conf = model.configure_optimizers()
if optim_conf is None:
rank_zero_warn(
Expand Down
25 changes: 25 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1785,3 +1785,28 @@ def backward(self, *args, **kwargs):
"training_step",
"backward",
]


def test_init_optimizers_resets_lightning_optimizers(tmpdir):
""" Test that the Trainer resets the `lightning_optimizers` list everytime new optimizers get initialized. """

def compare_optimizers():
assert trainer.lightning_optimizers[0].optimizer is trainer.optimizers[0]

model = BoringModel()
model.lr = 0.2
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
auto_lr_find=True,
)

trainer.tune(model)
compare_optimizers()

trainer.fit(model)
compare_optimizers()

trainer.max_epochs = 2 # simulate multiple fit calls
trainer.fit(model)
compare_optimizers()