From 440db6a42fa67060b138338bf83c4f75a1652013 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 6 Mar 2021 03:41:28 +0100 Subject: [PATCH] bugfix --- pytorch_lightning/trainer/optimizers.py | 3 +++ tests/trainer/test_trainer.py | 32 ++++++++++++++++++++++++- 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/optimizers.py b/pytorch_lightning/trainer/optimizers.py index 5cafa438cffcc..a247fb92cd22f 100644 --- a/pytorch_lightning/trainer/optimizers.py +++ b/pytorch_lightning/trainer/optimizers.py @@ -27,7 +27,10 @@ class TrainerOptimizersMixin(ABC): + _lightning_optimizers: Optional[List[LightningOptimizer]] + def init_optimizers(self, model: LightningModule) -> Tuple[List, List, List]: + self._lightning_optimizers = None optim_conf = model.configure_optimizers() if optim_conf is None: rank_zero_warn( diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 1cd979c863d37..632428e6f1874 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -19,7 +19,7 @@ from copy import deepcopy from distutils.version import LooseVersion from pathlib import Path -from unittest.mock import ANY, call, patch +from unittest.mock import ANY, call, patch, Mock import cloudpickle import pytest @@ -1785,3 +1785,33 @@ def backward(self, *args, **kwargs): "training_step", "backward", ] + + +def test_init_optimizers_resets_lightning_optimizers(tmpdir): + """ Test that the Trainer resets the `lightning_optimizers` list everytime new optimizers get initialized. """ + + def compare_optimizers(): + assert trainer.lightning_optimizers[0].optimizer is trainer.optimizers[0] + + class OptimizerSpy(Callback): + def on_fit_start(self, *args, **kwargs): + compare_optimizers() + + model = BoringModel() + model.lr = 0.2 + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + auto_lr_find=True, + callbacks=[OptimizerSpy()] + ) + + trainer.tune(model) + compare_optimizers() + + trainer.fit(model) + compare_optimizers() + + trainer.max_epochs = 2 # simulate multiple fit calls + trainer.fit(model) + compare_optimizers()