From 7acbd65bcb76b39a0f8d7a6f6985168d767ab23a Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Thu, 4 Mar 2021 20:11:59 +0000 Subject: [PATCH] [bugfix] Check LightningOptimizer doesn't delete optimizer hooks (#6305) * update * resolve bug --- pytorch_lightning/core/optimizer.py | 2 +- tests/core/test_lightning_optimizer.py | 77 ++++++++++++++++++++++++++ 2 files changed, 78 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index 8b6548f438756..162e17ca47bf5 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -38,7 +38,7 @@ class LightningOptimizer: def __init__(self, optimizer: Optimizer): - self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k != 'step'} + self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k not in ('step', "__del__")} # For Horovod if hasattr(optimizer, "skip_synchronize"): diff --git a/tests/core/test_lightning_optimizer.py b/tests/core/test_lightning_optimizer.py index 4fc6a06157ab0..3c6e34df8d5e3 100644 --- a/tests/core/test_lightning_optimizer.py +++ b/tests/core/test_lightning_optimizer.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import gc +from typing import Any from unittest.mock import DEFAULT, patch import torch @@ -303,3 +305,78 @@ def configure_optimizers(self): lbfgs = model.optimizers() max_iter = lbfgs.param_groups[0]["max_iter"] assert zero_grad.call_count == max_iter + + +class OptimizerWithHooks(Optimizer): + + def __init__(self, model): + self._fwd_handles = [] + self._bwd_handles = [] + self.params = [] + for _, mod in model.named_modules(): + mod_class = mod.__class__.__name__ + if mod_class != 'Linear': + continue + + handle = mod.register_forward_pre_hook(self._save_input) # save the inputs + self._fwd_handles.append(handle) # collect forward-save-input hooks in list + handle = mod.register_backward_hook(self._save_grad_output) # save the gradients + self._bwd_handles.append(handle) # collect backward-save-grad hook in list + + # save the parameters + params = [mod.weight] + if mod.bias is not None: + params.append(mod.bias) + + # save a param_group for each module + d = {'params': params, 'mod': mod, 'layer_type': mod_class} + self.params.append(d) + + super(OptimizerWithHooks, self).__init__(self.params, {"lr": 0.01}) + + def _save_input(self, mod, i): + """Saves input of layer""" + if mod.training: + self.state[mod]['x'] = i[0] + + def _save_grad_output(self, mod, _, grad_output): + """ + Saves grad on output of layer to + grad is scaled with batch_size since gradient is spread over samples in mini batch + """ + batch_size = grad_output[0].shape[0] + if mod.training: + self.state[mod]['grad'] = grad_output[0] * batch_size + + def step(self, closure=None): + closure() + for group in self.param_groups: + _ = self.state[group['mod']]['x'] + _ = self.state[group['mod']]['grad'] + return True + + +def test_lightning_optimizer_keeps_hooks(tmpdir): + + class TestModel(BoringModel): + count_on_train_batch_start = 0 + count_on_train_batch_end = 0 + + def configure_optimizers(self): + return OptimizerWithHooks(self) + + def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + self.count_on_train_batch_start += 1 + optimizer = self.optimizers(use_pl_optimizer=False) + assert len(optimizer._fwd_handles) == 1 + + def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + self.count_on_train_batch_end += 1 + del self.trainer._lightning_optimizers + gc.collect() # not necessary, just in case + + trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=4, limit_val_batches=1, max_epochs=1) + model = TestModel() + trainer.fit(model) + assert model.count_on_train_batch_start == 4 + assert model.count_on_train_batch_end == 4