From 5f02a1b1c61fee5435f86c12b51006162481b5e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 14 Apr 2020 22:54:13 +0200 Subject: [PATCH 1/5] call on_before_zero_grad --- pytorch_lightning/core/lightning.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 04a386d77b25d..2930b20441fe9 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1165,6 +1165,9 @@ def optimizer_step(self, current_epoch, batch_idx, optimizer, else: optimizer.step() + # model hook + self.on_before_zero_grad(optimizer) + # clear gradients optimizer.zero_grad() From b0659e8256371304c8ff2588d9a883ebe1f2d9ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 14 Apr 2020 23:02:54 +0200 Subject: [PATCH 2/5] update changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9ce083b11c250..124164fd338bc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed loggers - flushing last logged metrics even before continue, e.g. `trainer.test()` results ([#1459](https://github.com/PyTorchLightning/pytorch-lightning/pull/1459)) +- Added a missing call to the `on_before_zero_grad` model hook ([#1493](https://github.com/PyTorchLightning/pytorch-lightning/pull/1493)). + - From 9d8d303b235b577b74b1e3cbaa9582fe6df2b334 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 14 Apr 2020 23:10:22 +0200 Subject: [PATCH 3/5] add note about overriding both hooks --- pytorch_lightning/core/lightning.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 2930b20441fe9..3bf33578f0f64 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1157,6 +1157,10 @@ def optimizer_step(self, current_epoch, batch_idx, optimizer, optimizer.step() optimizer.zero_grad() + Note: + If you also override the :meth:`~pytorch_lightning.core.hooks.ModelHooks.on_before_zero_grad` + model hook don't forget to add the call to it before ``optimizer.zero_grad()`` yourself. + """ if self.trainer.use_tpu and XLA_AVAILABLE: xm.optimizer_step(optimizer) From bd4004bced2fa3176a8832b96eb8e7b22463a92b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 14 Apr 2020 23:32:45 +0200 Subject: [PATCH 4/5] added test --- tests/trainer/test_hooks.py | 39 +++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 tests/trainer/test_hooks.py diff --git a/tests/trainer/test_hooks.py b/tests/trainer/test_hooks.py new file mode 100644 index 0000000000000..1d0e55df409e0 --- /dev/null +++ b/tests/trainer/test_hooks.py @@ -0,0 +1,39 @@ +import pytest + +import tests.base.utils as tutils +from pytorch_lightning import Trainer +from tests.base import ( + LightTrainDataloader, + LightValidationMixin, + TestModelBase, + LightTestMixin) + + +@pytest.mark.parametrize('max_steps', [1, 2, 3]) +def test_on_before_zero_grad_called(max_steps): + + class CurrentTestModel( + LightTrainDataloader, + LightValidationMixin, + LightTestMixin, + TestModelBase, + ): + on_before_zero_grad_called = 0 + + def on_before_zero_grad(self, optimizer): + self.on_before_zero_grad_called += 1 + + hparams = tutils.get_default_hparams() + model = CurrentTestModel(hparams) + + trainer = Trainer( + max_steps=max_steps, + num_sanity_val_steps=5, + ) + assert 0 == model.on_before_zero_grad_called + trainer.fit(model) + assert max_steps == model.on_before_zero_grad_called + + model.on_before_zero_grad_called = 0 + trainer.test(model) + assert 0 == model.on_before_zero_grad_called From b35bb855222ae5c4e6b6eb8228800b3de36e123f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 15 Apr 2020 00:30:38 +0200 Subject: [PATCH 5/5] move test_hooks.py to models folder --- tests/{trainer => models}/test_hooks.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/{trainer => models}/test_hooks.py (100%) diff --git a/tests/trainer/test_hooks.py b/tests/models/test_hooks.py similarity index 100% rename from tests/trainer/test_hooks.py rename to tests/models/test_hooks.py