diff --git a/adf.py b/adf.py new file mode 100644 index 00000000000000..41bb94a5685a07 --- /dev/null +++ b/adf.py @@ -0,0 +1,6 @@ +from unittest.mock import Mock + +from tests.helpers import BoringModel + + +print(model.train()) \ No newline at end of file diff --git a/pytorch_lightning/trainer/optimizers.py b/pytorch_lightning/trainer/optimizers.py index 5cafa438cffcc2..a247fb92cd22f8 100644 --- a/pytorch_lightning/trainer/optimizers.py +++ b/pytorch_lightning/trainer/optimizers.py @@ -27,7 +27,10 @@ class TrainerOptimizersMixin(ABC): + _lightning_optimizers: Optional[List[LightningOptimizer]] + def init_optimizers(self, model: LightningModule) -> Tuple[List, List, List]: + self._lightning_optimizers = None optim_conf = model.configure_optimizers() if optim_conf is None: rank_zero_warn( diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 1cd979c863d373..632428e6f1874f 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -19,7 +19,7 @@ from copy import deepcopy from distutils.version import LooseVersion from pathlib import Path -from unittest.mock import ANY, call, patch +from unittest.mock import ANY, call, patch, Mock import cloudpickle import pytest @@ -1785,3 +1785,33 @@ def backward(self, *args, **kwargs): "training_step", "backward", ] + + +def test_init_optimizers_resets_lightning_optimizers(tmpdir): + """ Test that the Trainer resets the `lightning_optimizers` list everytime new optimizers get initialized. """ + + def compare_optimizers(): + assert trainer.lightning_optimizers[0].optimizer is trainer.optimizers[0] + + class OptimizerSpy(Callback): + def on_fit_start(self, *args, **kwargs): + compare_optimizers() + + model = BoringModel() + model.lr = 0.2 + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + auto_lr_find=True, + callbacks=[OptimizerSpy()] + ) + + trainer.tune(model) + compare_optimizers() + + trainer.fit(model) + compare_optimizers() + + trainer.max_epochs = 2 # simulate multiple fit calls + trainer.fit(model) + compare_optimizers()