Skip to content

Commit

Permalink
Call on_before_zero_grad model hook (#1493)
Browse files Browse the repository at this point in the history
* call on_before_zero_grad

* update changelog

* add note about overriding both hooks

* added test

* move test_hooks.py to models folder
  • Loading branch information
Adrian Wälchli committed Apr 16, 2020
1 parent 06e6ead commit 3c549e8
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 0 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed loggers - flushing last logged metrics even before continue, e.g. `trainer.test()` results ([#1459](https://github.com/PyTorchLightning/pytorch-lightning/pull/1459))

- Added a missing call to the `on_before_zero_grad` model hook ([#1493](https://github.com/PyTorchLightning/pytorch-lightning/pull/1493)).

-


Expand Down
7 changes: 7 additions & 0 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1158,6 +1158,10 @@ def optimizer_step(self, current_epoch, batch_idx, optimizer,
optimizer.step()
optimizer.zero_grad()
Note:
If you also override the :meth:`~pytorch_lightning.core.hooks.ModelHooks.on_before_zero_grad`
model hook don't forget to add the call to it before ``optimizer.zero_grad()`` yourself.
"""
if self.trainer.use_tpu and XLA_AVAILABLE:
xm.optimizer_step(optimizer)
Expand All @@ -1166,6 +1170,9 @@ def optimizer_step(self, current_epoch, batch_idx, optimizer,
else:
optimizer.step()

# model hook
self.on_before_zero_grad(optimizer)

# clear gradients
optimizer.zero_grad()

Expand Down
39 changes: 39 additions & 0 deletions tests/models/test_hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import pytest

import tests.base.utils as tutils
from pytorch_lightning import Trainer
from tests.base import (
LightTrainDataloader,
LightValidationMixin,
TestModelBase,
LightTestMixin)


@pytest.mark.parametrize('max_steps', [1, 2, 3])
def test_on_before_zero_grad_called(max_steps):

class CurrentTestModel(
LightTrainDataloader,
LightValidationMixin,
LightTestMixin,
TestModelBase,
):
on_before_zero_grad_called = 0

def on_before_zero_grad(self, optimizer):
self.on_before_zero_grad_called += 1

hparams = tutils.get_default_hparams()
model = CurrentTestModel(hparams)

trainer = Trainer(
max_steps=max_steps,
num_sanity_val_steps=5,
)
assert 0 == model.on_before_zero_grad_called
trainer.fit(model)
assert max_steps == model.on_before_zero_grad_called

model.on_before_zero_grad_called = 0
trainer.test(model)
assert 0 == model.on_before_zero_grad_called

0 comments on commit 3c549e8

Please sign in to comment.