From 9ec198941fcfd11d8c43c97c38860150edb2f305 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 16 Apr 2020 18:01:41 +0200 Subject: [PATCH] Call on_before_zero_grad model hook (#1493) * call on_before_zero_grad * update changelog * add note about overriding both hooks * added test * move test_hooks.py to models folder --- CHANGELOG.md | 2 ++ pytorch_lightning/core/lightning.py | 7 ++++++ tests/models/test_hooks.py | 39 +++++++++++++++++++++++++++++ 3 files changed, 48 insertions(+) create mode 100644 tests/models/test_hooks.py diff --git a/CHANGELOG.md b/CHANGELOG.md index b101cd0932592f..671503a37ad408 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed loggers - flushing last logged metrics even before continue, e.g. `trainer.test()` results ([#1459](https://github.com/PyTorchLightning/pytorch-lightning/pull/1459)) +- Added a missing call to the `on_before_zero_grad` model hook ([#1493](https://github.com/PyTorchLightning/pytorch-lightning/pull/1493)). + - diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 427febd2f81baf..71ea357ea3944a 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1158,6 +1158,10 @@ def optimizer_step(self, current_epoch, batch_idx, optimizer, optimizer.step() optimizer.zero_grad() + Note: + If you also override the :meth:`~pytorch_lightning.core.hooks.ModelHooks.on_before_zero_grad` + model hook don't forget to add the call to it before ``optimizer.zero_grad()`` yourself. + """ if self.trainer.use_tpu and XLA_AVAILABLE: xm.optimizer_step(optimizer) @@ -1166,6 +1170,9 @@ def optimizer_step(self, current_epoch, batch_idx, optimizer, else: optimizer.step() + # model hook + self.on_before_zero_grad(optimizer) + # clear gradients optimizer.zero_grad() diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py new file mode 100644 index 00000000000000..1d0e55df409e01 --- /dev/null +++ b/tests/models/test_hooks.py @@ -0,0 +1,39 @@ +import pytest + +import tests.base.utils as tutils +from pytorch_lightning import Trainer +from tests.base import ( + LightTrainDataloader, + LightValidationMixin, + TestModelBase, + LightTestMixin) + + +@pytest.mark.parametrize('max_steps', [1, 2, 3]) +def test_on_before_zero_grad_called(max_steps): + + class CurrentTestModel( + LightTrainDataloader, + LightValidationMixin, + LightTestMixin, + TestModelBase, + ): + on_before_zero_grad_called = 0 + + def on_before_zero_grad(self, optimizer): + self.on_before_zero_grad_called += 1 + + hparams = tutils.get_default_hparams() + model = CurrentTestModel(hparams) + + trainer = Trainer( + max_steps=max_steps, + num_sanity_val_steps=5, + ) + assert 0 == model.on_before_zero_grad_called + trainer.fit(model) + assert max_steps == model.on_before_zero_grad_called + + model.on_before_zero_grad_called = 0 + trainer.test(model) + assert 0 == model.on_before_zero_grad_called