Skip to content

Commit

Permalink
bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Mar 6, 2021
1 parent 4f391bc commit 3dfe832
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 1 deletion.
6 changes: 6 additions & 0 deletions adf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from unittest.mock import Mock

from tests.helpers import BoringModel


print(model.train())
3 changes: 3 additions & 0 deletions pytorch_lightning/trainer/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@

class TrainerOptimizersMixin(ABC):

_lightning_optimizers: Optional[List[LightningOptimizer]]

def init_optimizers(self, model: LightningModule) -> Tuple[List, List, List]:
self._lightning_optimizers = None
optim_conf = model.configure_optimizers()
if optim_conf is None:
rank_zero_warn(
Expand Down
32 changes: 31 additions & 1 deletion tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from copy import deepcopy
from distutils.version import LooseVersion
from pathlib import Path
from unittest.mock import ANY, call, patch
from unittest.mock import ANY, call, patch, Mock

import cloudpickle
import pytest
Expand Down Expand Up @@ -1785,3 +1785,33 @@ def backward(self, *args, **kwargs):
"training_step",
"backward",
]


def test_init_optimizers_resets_lightning_optimizers(tmpdir):
""" Test that the Trainer resets the `lightning_optimizers` list everytime new optimizers get initialized. """

def compare_optimizers():
assert trainer.lightning_optimizers[0].optimizer is trainer.optimizers[0]

class OptimizerSpy(Callback):
def on_fit_start(self, *args, **kwargs):
compare_optimizers()

model = BoringModel()
model.lr = 0.2
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
auto_lr_find=True,
callbacks=[OptimizerSpy()]
)

trainer.tune(model)
compare_optimizers()

trainer.fit(model)
compare_optimizers()

trainer.max_epochs = 2 # simulate multiple fit calls
trainer.fit(model)
compare_optimizers()

0 comments on commit 3dfe832

Please sign in to comment.