Skip to content

Commit

Permalink
Support teardown hook on DataModule (#4673)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
Co-authored-by: chaton <thomas@grid.ai>
  • Loading branch information
3 people authored Mar 25, 2021
1 parent 92a1671 commit 40976e4
Show file tree
Hide file tree
Showing 8 changed files with 352 additions and 131 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `LightningEnvironment` for Lightning-specific DDP ([#5915](https://github.com/PyTorchLightning/pytorch-lightning/pull/5915))


- Added `teardown()` hook to LightningDataModule ([#4673](https://github.com/PyTorchLightning/pytorch-lightning/pull/4673))


- Added `auto_insert_metric_name` parameter to `ModelCheckpoint` ([#6277](https://github.com/PyTorchLightning/pytorch-lightning/pull/6277))


Expand Down
4 changes: 2 additions & 2 deletions docs/source/common/lightning_module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1256,7 +1256,7 @@ prepare_data
setup
~~~~~

.. automethod:: pytorch_lightning.core.hooks.ModelHooks.setup
.. automethod:: pytorch_lightning.core.hooks.DataHooks.setup
:noindex:

tbptt_split_batch
Expand All @@ -1268,7 +1268,7 @@ tbptt_split_batch
teardown
~~~~~~~~

.. automethod:: pytorch_lightning.core.hooks.ModelHooks.teardown
.. automethod:: pytorch_lightning.core.hooks.DataHooks.teardown
:noindex:

train_dataloader
Expand Down
13 changes: 12 additions & 1 deletion docs/source/extensions/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ The equivalent DataModule just organizes the same exact code, but makes it reusa
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size)
def teardown(self, stage: Optional[str] = None):
# Used to clean-up when the run is finished
...
But now, as the complexity of your processing grows (transforms, multiple-GPU training), you can
let Lightning handle those details for you while making this dataset reusable so you can share with
colleagues or use in different projects.
Expand Down Expand Up @@ -243,7 +247,10 @@ There are also data operations you might want to perform on every GPU. Use setup
self.dims = getattr(self, 'dims', self.mnist_test[0][0].shape)
.. warning:: `setup` is called from every process. Setting state here is okay.
.. warning:: ``setup`` is called from every process. Setting state here is okay.


.. note:: ``teardown`` can be used to clean up the state. It is also called from every process


train_dataloader
Expand Down Expand Up @@ -411,10 +418,14 @@ You can of course use DataModules in plain PyTorch code as well.
for batch in dm.val_dataloader():
...
dm.teardown(stage='fit')
# lazy load test data
dm.setup(stage='test')
for batch in dm.test_dataloader():
...
dm.teardown(stage='test')
But overall, DataModules encourage reproducibility by allowing all details of a dataset to be specified in a unified
structure.
67 changes: 54 additions & 13 deletions pytorch_lightning/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"""LightningDataModule for loading DataLoaders with ease."""

import functools
from abc import abstractmethod
from argparse import ArgumentParser, Namespace
from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union

Expand Down Expand Up @@ -44,6 +43,8 @@ def __call__(cls, *args, **kwargs):
cls.prepare_data = track_data_hook_calls(rank_zero_only(cls.prepare_data))
# Track setup calls
cls.setup = track_data_hook_calls(cls.setup)
# Track teardown calls
cls.teardown = track_data_hook_calls(cls.teardown)

# Get instance of LightningDataModule by mocking its __init__ via __call__
obj = type.__call__(cls, *args, **kwargs)
Expand All @@ -52,12 +53,13 @@ def __call__(cls, *args, **kwargs):


def track_data_hook_calls(fn):
"""A decorator that checks if prepare_data/setup have been called.
"""A decorator that checks if prepare_data/setup/teardown has been called.
- When ``dm.prepare_data()`` is called, ``dm.has_prepared_data`` gets set to True
- When ``dm.setup()``, ``dm.has_setup_{fit,validate,test}`` get set to True
- When ``dm.setup(stage)`` is called, where stage is any of ``{fit,validate,test,predict}``.
Its corresponding `dm_has_setup_{stage}` attribute gets set to True
- ``dm.teardown()`` and ``dm.teardown(stage)`` act exactly like ``dm.setup``
Args:
fn (function): Function that will be tracked to see if it has been called.
Expand All @@ -71,9 +73,10 @@ def wrapped_fn(*args, **kwargs):

# The object instance from which setup or prepare_data was called
obj = args[0]
name = fn.__name__

# If calling setup, we check the stage and assign stage-specific bool args
if fn.__name__ == "setup":
if name in ("setup", "teardown"):

# Get stage either by grabbing from args or checking kwargs.
# If not provided, set call status of 'fit', 'validate', and 'test' to True.
Expand All @@ -82,11 +85,11 @@ def wrapped_fn(*args, **kwargs):

if stage is None:
for s in ("fit", "validate", "test"):
setattr(obj, f"_has_setup_{s}", True)
setattr(obj, f"_has_{name}_{s}", True)
else:
setattr(obj, f"_has_setup_{stage}", True)
setattr(obj, f"_has_{name}_{stage}", True)

if fn.__name__ == "prepare_data":
elif name == "prepare_data":
obj._has_prepared_data = True

return fn(*args, **kwargs)
Expand Down Expand Up @@ -119,14 +122,18 @@ def val_dataloader(self):
def test_dataloader(self):
test_split = Dataset(...)
return DataLoader(test_split)
def teardown(self):
# clean up after fit or test
# called on every process in DDP
A DataModule implements 5 key methods:
A DataModule implements 6 key methods:
* **prepare_data** (things to do on 1 GPU/TPU not on every GPU/TPU in distributed mode).
* **setup** (things to do on every accelerator in distributed mode).
* **train_dataloader** the training dataloader.
* **val_dataloader** the val dataloader(s).
* **test_dataloader** the test dataloader(s).
* **teardown** (things to do on every accelerator in distributed mode when finished)
This allows you to share a full dataset without explaining how to download,
Expand Down Expand Up @@ -154,11 +161,17 @@ def __init__(

# Private attrs to keep track of whether or not data hooks have been called yet
self._has_prepared_data = False

self._has_setup_fit = False
self._has_setup_validate = False
self._has_setup_test = False
self._has_setup_predict = False

self._has_teardown_fit = False
self._has_teardown_validate = False
self._has_teardown_test = False
self._has_teardown_predict = False

@property
def train_transforms(self):
"""
Expand Down Expand Up @@ -259,13 +272,41 @@ def has_setup_predict(self) -> bool:
"""
return self._has_setup_predict

@abstractmethod
def prepare_data(self, *args, **kwargs):
pass
@property
def has_teardown_fit(self) -> bool:
"""Return bool letting you know if ``datamodule.teardown(stage='fit')`` has been called or not.
@abstractmethod
def setup(self, stage: Optional[str] = None):
pass
Returns:
bool: True ``if datamodule.teardown(stage='fit')`` has been called. False by default.
"""
return self._has_teardown_fit

@property
def has_teardown_validate(self) -> bool:
"""Return bool letting you know if ``datamodule.teardown(stage='validate')`` has been called or not.
Returns:
bool: True if ``datamodule.teardown(stage='validate')`` has been called. False by default.
"""
return self._has_teardown_validate

@property
def has_teardown_test(self) -> bool:
"""Return bool letting you know if ``datamodule.teardown(stage='test')`` has been called or not.
Returns:
bool: True if ``datamodule.teardown(stage='test')`` has been called. False by default.
"""
return self._has_teardown_test

@property
def has_teardown_predict(self) -> bool:
"""Return bool letting you know if ``datamodule.teardown(stage='predict')`` has been called or not.
Returns:
bool: True if ``datamodule.teardown(stage='predict')`` has been called. False by default.
"""
return self._has_teardown_predict

@classmethod
def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentParser:
Expand Down
72 changes: 36 additions & 36 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,42 +25,6 @@
class ModelHooks:
"""Hooks to be used in LightningModule."""

def setup(self, stage: Optional[str] = None) -> None:
"""
Called at the beginning of fit (train + validate), validate, test, predict, or tune.
This is a good hook when you need to build models dynamically or adjust something about them.
This hook is called on every process when using DDP.
Args:
stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'``
Example::
class LitModel(...):
def __init__(self):
self.l1 = None
def prepare_data(self):
download_data()
tokenize()
# don't do this
self.something = else
def setup(stage):
data = Load_data(...)
self.l1 = nn.Linear(28, data.num_classes)
"""

def teardown(self, stage: Optional[str] = None) -> None:
"""
Called at the end of fit (train + validate), validate, test, predict, or tune.
Args:
stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'``
"""

def on_fit_start(self) -> None:
"""
Called at the very beginning of fit.
Expand Down Expand Up @@ -395,6 +359,42 @@ def prepare_data(self):
model.test_dataloader()
"""

def setup(self, stage: Optional[str] = None) -> None:
"""
Called at the beginning of fit (train + validate), validate, test, predict, or tune.
This is a good hook when you need to build models dynamically or adjust something about them.
This hook is called on every process when using DDP.
Args:
stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'``
Example::
class LitModel(...):
def __init__(self):
self.l1 = None
def prepare_data(self):
download_data()
tokenize()
# don't do this
self.something = else
def setup(stage):
data = Load_data(...)
self.l1 = nn.Linear(28, data.num_classes)
"""

def teardown(self, stage: Optional[str] = None) -> None:
"""
Called at the end of fit (train + validate), validate, test, predict, or tune.
Args:
stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'``
"""

def train_dataloader(self) -> Any:
"""
Implement one or more PyTorch DataLoaders for training.
Expand Down
6 changes: 6 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1086,6 +1086,12 @@ def call_setup_hook(self, model: LightningModule) -> None:

def call_teardown_hook(self, model: LightningModule) -> None:
state = self._teardown_state

if self.datamodule is not None:
called = getattr(self.datamodule, f'has_teardown_{state}')
if not called:
self.datamodule.teardown(stage=state)

self.profiler.teardown(stage=state)
self.teardown(stage=state)
model.teardown(stage=state)
Expand Down
Loading

0 comments on commit 40976e4

Please sign in to comment.