Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add teardown hook to LightningDataModule #4673

Merged
merged 25 commits into from
Mar 25, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion pytorch_lightning/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ def track_data_hook_calls(fn):
- When dm.setup('fit') is called, dm.has_setup_fit gets set to True
- When dm.setup('test') is called, dm.has_setup_test gets set to True
- When dm.setup() is called without stage arg, both dm.has_setup_fit and dm.has_setup_test get set to True
- When dm.teardown('fit') is called, dm.has_teardown_fit gets set to True
- When dm.teardown('test') is called, dm.has_teardown_fit gets set to True
- When dm.teardown() is called without stage arg, both dm.has_teardown_fit and dm.has_teardown_test get set to True

Args:
fn (function): Function that will be tracked to see if it has been called.
Expand Down Expand Up @@ -86,6 +89,21 @@ def wrapped_fn(*args, **kwargs):
if stage == "test" or stage is None:
obj._has_setup_test = True

# If calling teardown, we check the stage and assign stage-specific bool args
if fn.__name__ == "teardown":

# Get stage either by grabbing from args or checking kwargs.
# If not provided, set call status of 'fit' and 'test' to True.
# We do this so __attach_datamodule in trainer.py doesn't mistakenly call teardown('test') on trainer.test()
stage = args[1] if len(args) > 1 else kwargs.get("stage", None)

if stage == "fit" or stage is None:
obj._has_teardown_fit = True

if stage == "test" or stage is None:
obj._has_teardown_test = True


carmocca marked this conversation as resolved.
Show resolved Hide resolved
if fn.__name__ == "prepare_data":
obj._has_prepared_data = True

Expand Down Expand Up @@ -119,14 +137,18 @@ def val_dataloader(self):
def test_dataloader(self):
test_split = Dataset(...)
return DataLoader(test_split)
def teardown(self):
# clean up after fit or test
# called on every process in DDP

A DataModule implements 5 key methods:
A DataModule implements 6 key methods:

* **prepare_data** (things to do on 1 GPU/TPU not on every GPU/TPU in distributed mode).
* **setup** (things to do on every accelerator in distributed mode).
* **train_dataloader** the training dataloader.
* **val_dataloader** the val dataloader(s).
* **test_dataloader** the test dataloader(s).
* **teardown** (things to do on every accelerator in distributed mode after fit/test)


This allows you to share a full dataset without explaining how to download,
Expand Down Expand Up @@ -156,6 +178,8 @@ def __init__(
self._has_prepared_data = False
self._has_setup_fit = False
self._has_setup_test = False
self._has_teardown_fit = False
self._has_teardown_test = False

@property
def train_transforms(self):
Expand Down Expand Up @@ -239,6 +263,25 @@ def has_setup_test(self):
"""
return self._has_setup_test


@property
def has_teardown_fit(self):
"""Return bool letting you know if datamodule.teardown('fit') has been called or not.

Returns:
bool: True if datamodule.teardown('fit') has been called. False by default.
"""
return self._has_teardown_fit

@property
def has_teardown_test(self):
"""Return bool letting you know if datamodule.teardown('test') has been called or not.

Returns:
bool: True if datamodule.teardown('test') has been called. False by default.
"""
return self._has_teardown_test

@abstractmethod
def prepare_data(self, *args, **kwargs):
pass
Expand All @@ -259,6 +302,10 @@ def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]
def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
pass

@abstractmethod
def teardown(self, stage: Optional[str] = None):
pass

@abstractmethod
def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any:
pass
Expand Down
18 changes: 10 additions & 8 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,9 @@ def fit(

# hook
self.teardown('fit')
if self.datamodule is not None:
if not self.datamodule.has_teardown_fit:
self.datamodule.teardown('fit')
if self.is_function_implemented('teardown'):
model.teardown('fit')

Expand Down Expand Up @@ -759,7 +762,14 @@ def test(
else:
results = self.__test_using_best_weights(ckpt_path, test_dataloaders)

# teardown
self.teardown('test')
if self.datamodule is not None:
if not self.datamodule.has_teardown_test:
self.datamodule.teardown('test')
if self.is_function_implemented('teardown'):
model_ref = self.get_model()
model_ref.teardown('test')

return results

Expand Down Expand Up @@ -803,11 +813,6 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders):
self.testing = False
del os.environ['PL_TESTING_MODE']

# teardown
if self.is_function_implemented('teardown'):
model_ref = self.get_model()
model_ref.teardown('test')

return results

def __test_given_model(self, model, test_dataloaders):
Expand All @@ -823,9 +828,6 @@ def __test_given_model(self, model, test_dataloaders):
results = self.fit(model)
self.testing = False

# teardown
if self.is_function_implemented('teardown'):
model.teardown('test')

return results

Expand Down
33 changes: 32 additions & 1 deletion tests/base/boring_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from pytorch_lightning import LightningModule
from pytorch_lightning import LightningDataModule, LightningModule
from torch.utils.data import Dataset


Expand Down Expand Up @@ -129,3 +129,34 @@ def val_dataloader(self):

def test_dataloader(self):
return torch.utils.data.DataLoader(RandomDataset(32, 64))


class BoringDataModule(LightningDataModule):
carmocca marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self):
"""
Testing PL DataModule

Use as follows:
- subclass
- modify the behavior for what you want

class TestDM(BoringDataModule):
def train_dataloader(...):
# do your own thing

or:

model = TestDM()
model.setup = None
"""
super().__init__()

def train_dataloader(self):
return torch.utils.data.DataLoader(RandomDataset(32, 64))

def val_dataloader(self):
return torch.utils.data.DataLoader(RandomDataset(32, 64))

def test_dataloader(self):
return torch.utils.data.DataLoader(RandomDataset(32, 64))
1 change: 1 addition & 0 deletions tests/core/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from pytorch_lightning import LightningDataModule, Trainer, seed_everything
from tests.base import EvalModelTemplate
from tests.base.boring_model import BoringDataModule
from tests.base.datasets import TrialMNIST
from tests.base.datamodules import TrialMNISTDataModule
from tests.base.develop_utils import reset_seed
Expand Down