Skip to content

Commit

Permalink
adds setup+teardown hook (#2229)
Browse files Browse the repository at this point in the history
* allow regression metrics to import

* allow regression metrics to import

* allow regression metrics to import

* allow regression metrics to import

* allow regression metrics to import

* allow regression metrics to import

* allow regression metrics to import

* allow regression metrics to import

* allow regression metrics to import
  • Loading branch information
williamFalcon committed Jun 17, 2020
1 parent 1635ba1 commit 34816e9
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 1 deletion.
4 changes: 4 additions & 0 deletions docs/source/hooks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ To enable a hook, simply override the method in your LightningModule and the tra

3. Add it in the correct place in :mod:`pytorch_lightning.trainer` where it should be called.

---

Hooks lifecycle
---------------
Expand Down Expand Up @@ -71,7 +72,10 @@ Test loop
- ``torch.set_grad_enabled(True)``
- :meth:`~pytorch_lightning.core.hooks.ModelHooks.on_post_performance_check`

---

General hooks
-------------

.. automodule:: pytorch_lightning.core.hooks
:noindex:
8 changes: 8 additions & 0 deletions pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@ class Callback(abc.ABC):
Abstract base class used to build new callbacks.
"""

def setup(self, trainer, step: str):
"""Called when fit or test begins"""
pass

def teardown(self, trainer, step: str):
"""Called when fit or test ends"""
pass

def on_init_start(self, trainer):
"""Called when the trainer initialization begins, model has not yet been set."""
pass
Expand Down
16 changes: 16 additions & 0 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,22 @@

class ModelHooks(Module):

def setup(self, step: str):
"""
Called at the beginning of fit and test.
Args:
step: either 'fit' or 'test'
"""

def teardown(self, step: str):
"""
Called at the end of fit and test.
Args:
step: either 'fit' or 'test'
"""

def on_fit_start(self):
"""
Called at the very beginning of fit.
Expand Down
10 changes: 10 additions & 0 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,16 @@ class TrainerCallbackHookMixin(ABC):
callbacks: List[Callback] = []
get_model: Callable = ...

def setup(self, step: str):
"""Called in the beginning of fit and test"""
for callback in self.callbacks:
callback.setup(self, step)

def teardown(self, step: str):
"""Called at the end of fit and test"""
for callback in self.callbacks:
callback.teardown(self, step)

def on_init_start(self):
"""Called when the trainer initialization begins, model has not yet been set."""
for callback in self.callbacks:
Expand Down
17 changes: 17 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,10 @@ def fit(
if self.is_function_implemented('on_fit_start'):
model.on_fit_start()

self.setup('fit')
if self.is_function_implemented('setup'):
model.setup('fit')

# on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0
# or in the case where each node needs to do its own manipulation in which case just local_rank=0
if self.can_prepare_data():
Expand Down Expand Up @@ -945,6 +949,10 @@ def fit(
if self.is_function_implemented('on_fit_end'):
model.on_fit_end()

self.teardown('fit')
if self.is_function_implemented('teardown'):
model.teardown('fit')

# return 1 when finished
# used for testing or when we need to know that training succeeded
return 1
Expand Down Expand Up @@ -1128,6 +1136,11 @@ def test(
trainer = Trainer()
trainer.test(model, test_dataloaders=test)
"""
self.setup('test')
if self.is_function_implemented('setup'):
model_ref = self.model if model is None else model
model_ref.setup('test')

if model is None and ckpt_path == 'best' and self.checkpoint_callback.save_top_k <= 0:
raise MisconfigurationException(
'ckpt_path is "best", but ModelCheckpoint is not configured to save the best model.')
Expand Down Expand Up @@ -1167,6 +1180,10 @@ def test(

self.testing = False

self.teardown('test')
if self.is_function_implemented('teardown'):
self.model.teardown('test')

def check_model_configuration(self, model: LightningModule):
r"""
Checks that the model is configured correctly before training or testing is started.
Expand Down
1 change: 0 additions & 1 deletion pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,6 @@ def run_training_teardown(self):

# summarize profile results
self.profiler.describe()

self._teardown_already_run = True

def training_forward(self, batch, batch_idx, opt_idx, hiddens):
Expand Down
22 changes: 22 additions & 0 deletions tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def _check_args(trainer, pl_module):
class TestCallback(Callback):
def __init__(self):
super().__init__()
self.setup_called = False
self.teardown_called = False
self.on_init_start_called = False
self.on_init_end_called = False
self.on_fit_start_called = False
Expand All @@ -67,6 +69,14 @@ def __init__(self):
self.on_test_start_called = False
self.on_test_end_called = False

def setup(self, trainer, step: str):
assert isinstance(trainer, Trainer)
self.setup_called = True

def teardown(self, trainer, step: str):
assert isinstance(trainer, Trainer)
self.teardown_called = True

def on_init_start(self, trainer):
assert isinstance(trainer, Trainer)
self.on_init_start_called = True
Expand Down Expand Up @@ -157,6 +167,8 @@ def on_test_end(self, trainer, pl_module):
progress_bar_refresh_rate=0,
)

assert not test_callback.setup_called
assert not test_callback.teardown_called
assert not test_callback.on_init_start_called
assert not test_callback.on_init_end_called
assert not test_callback.on_fit_start_called
Expand Down Expand Up @@ -184,6 +196,8 @@ def on_test_end(self, trainer, pl_module):
assert trainer.callbacks[0] == test_callback
assert test_callback.on_init_start_called
assert test_callback.on_init_end_called
assert not test_callback.setup_called
assert not test_callback.teardown_called
assert not test_callback.on_fit_start_called
assert not test_callback.on_fit_end_called
assert not test_callback.on_sanity_check_start_called
Expand All @@ -205,6 +219,8 @@ def on_test_end(self, trainer, pl_module):

trainer.fit(model)

assert test_callback.setup_called
assert test_callback.teardown_called
assert test_callback.on_init_start_called
assert test_callback.on_init_end_called
assert test_callback.on_fit_start_called
Expand All @@ -226,11 +242,17 @@ def on_test_end(self, trainer, pl_module):
assert not test_callback.on_test_start_called
assert not test_callback.on_test_end_called

# reset setup teardown callback
test_callback.teardown_called = False
test_callback.setup_called = False

test_callback = TestCallback()
trainer_options.update(callbacks=[test_callback])
trainer = Trainer(**trainer_options)
trainer.test(model)

assert test_callback.setup_called
assert test_callback.teardown_called
assert test_callback.on_test_batch_start_called
assert test_callback.on_test_batch_end_called
assert test_callback.on_test_start_called
Expand Down

0 comments on commit 34816e9

Please sign in to comment.