diff --git a/CHANGELOG.md b/CHANGELOG.md index 51ad97decd867..5a5d824bde66c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `LightningEnvironment` for Lightning-specific DDP ([#5915](https://github.com/PyTorchLightning/pytorch-lightning/pull/5915)) +- Added `teardown()` hook to LightningDataModule ([#4673](https://github.com/PyTorchLightning/pytorch-lightning/pull/4673)) + + - Added `auto_insert_metric_name` parameter to `ModelCheckpoint` ([#6277](https://github.com/PyTorchLightning/pytorch-lightning/pull/6277)) diff --git a/docs/source/common/lightning_module.rst b/docs/source/common/lightning_module.rst index d826e062ffdbf..7b2c2bb9519d1 100644 --- a/docs/source/common/lightning_module.rst +++ b/docs/source/common/lightning_module.rst @@ -1256,7 +1256,7 @@ prepare_data setup ~~~~~ -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.setup +.. automethod:: pytorch_lightning.core.hooks.DataHooks.setup :noindex: tbptt_split_batch @@ -1268,7 +1268,7 @@ tbptt_split_batch teardown ~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.teardown +.. automethod:: pytorch_lightning.core.hooks.DataHooks.teardown :noindex: train_dataloader diff --git a/docs/source/extensions/datamodules.rst b/docs/source/extensions/datamodules.rst index 85134fda06fa2..881febe21316d 100644 --- a/docs/source/extensions/datamodules.rst +++ b/docs/source/extensions/datamodules.rst @@ -94,6 +94,10 @@ The equivalent DataModule just organizes the same exact code, but makes it reusa def test_dataloader(self): return DataLoader(self.mnist_test, batch_size=self.batch_size) + def teardown(self, stage: Optional[str] = None): + # Used to clean-up when the run is finished + ... + But now, as the complexity of your processing grows (transforms, multiple-GPU training), you can let Lightning handle those details for you while making this dataset reusable so you can share with colleagues or use in different projects. @@ -243,7 +247,10 @@ There are also data operations you might want to perform on every GPU. Use setup self.dims = getattr(self, 'dims', self.mnist_test[0][0].shape) -.. warning:: `setup` is called from every process. Setting state here is okay. +.. warning:: ``setup`` is called from every process. Setting state here is okay. + + +.. note:: ``teardown`` can be used to clean up the state. It is also called from every process train_dataloader @@ -411,10 +418,14 @@ You can of course use DataModules in plain PyTorch code as well. for batch in dm.val_dataloader(): ... + dm.teardown(stage='fit') + # lazy load test data dm.setup(stage='test') for batch in dm.test_dataloader(): ... + dm.teardown(stage='test') + But overall, DataModules encourage reproducibility by allowing all details of a dataset to be specified in a unified structure. diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 994c259f48964..4178c9eeacd50 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -14,7 +14,6 @@ """LightningDataModule for loading DataLoaders with ease.""" import functools -from abc import abstractmethod from argparse import ArgumentParser, Namespace from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union @@ -44,6 +43,8 @@ def __call__(cls, *args, **kwargs): cls.prepare_data = track_data_hook_calls(rank_zero_only(cls.prepare_data)) # Track setup calls cls.setup = track_data_hook_calls(cls.setup) + # Track teardown calls + cls.teardown = track_data_hook_calls(cls.teardown) # Get instance of LightningDataModule by mocking its __init__ via __call__ obj = type.__call__(cls, *args, **kwargs) @@ -52,12 +53,13 @@ def __call__(cls, *args, **kwargs): def track_data_hook_calls(fn): - """A decorator that checks if prepare_data/setup have been called. + """A decorator that checks if prepare_data/setup/teardown has been called. - When ``dm.prepare_data()`` is called, ``dm.has_prepared_data`` gets set to True - When ``dm.setup()``, ``dm.has_setup_{fit,validate,test}`` get set to True - When ``dm.setup(stage)`` is called, where stage is any of ``{fit,validate,test,predict}``. Its corresponding `dm_has_setup_{stage}` attribute gets set to True + - ``dm.teardown()`` and ``dm.teardown(stage)`` act exactly like ``dm.setup`` Args: fn (function): Function that will be tracked to see if it has been called. @@ -71,9 +73,10 @@ def wrapped_fn(*args, **kwargs): # The object instance from which setup or prepare_data was called obj = args[0] + name = fn.__name__ # If calling setup, we check the stage and assign stage-specific bool args - if fn.__name__ == "setup": + if name in ("setup", "teardown"): # Get stage either by grabbing from args or checking kwargs. # If not provided, set call status of 'fit', 'validate', and 'test' to True. @@ -82,11 +85,11 @@ def wrapped_fn(*args, **kwargs): if stage is None: for s in ("fit", "validate", "test"): - setattr(obj, f"_has_setup_{s}", True) + setattr(obj, f"_has_{name}_{s}", True) else: - setattr(obj, f"_has_setup_{stage}", True) + setattr(obj, f"_has_{name}_{stage}", True) - if fn.__name__ == "prepare_data": + elif name == "prepare_data": obj._has_prepared_data = True return fn(*args, **kwargs) @@ -119,14 +122,18 @@ def val_dataloader(self): def test_dataloader(self): test_split = Dataset(...) return DataLoader(test_split) + def teardown(self): + # clean up after fit or test + # called on every process in DDP - A DataModule implements 5 key methods: + A DataModule implements 6 key methods: * **prepare_data** (things to do on 1 GPU/TPU not on every GPU/TPU in distributed mode). * **setup** (things to do on every accelerator in distributed mode). * **train_dataloader** the training dataloader. * **val_dataloader** the val dataloader(s). * **test_dataloader** the test dataloader(s). + * **teardown** (things to do on every accelerator in distributed mode when finished) This allows you to share a full dataset without explaining how to download, @@ -154,11 +161,17 @@ def __init__( # Private attrs to keep track of whether or not data hooks have been called yet self._has_prepared_data = False + self._has_setup_fit = False self._has_setup_validate = False self._has_setup_test = False self._has_setup_predict = False + self._has_teardown_fit = False + self._has_teardown_validate = False + self._has_teardown_test = False + self._has_teardown_predict = False + @property def train_transforms(self): """ @@ -259,13 +272,41 @@ def has_setup_predict(self) -> bool: """ return self._has_setup_predict - @abstractmethod - def prepare_data(self, *args, **kwargs): - pass + @property + def has_teardown_fit(self) -> bool: + """Return bool letting you know if ``datamodule.teardown(stage='fit')`` has been called or not. - @abstractmethod - def setup(self, stage: Optional[str] = None): - pass + Returns: + bool: True ``if datamodule.teardown(stage='fit')`` has been called. False by default. + """ + return self._has_teardown_fit + + @property + def has_teardown_validate(self) -> bool: + """Return bool letting you know if ``datamodule.teardown(stage='validate')`` has been called or not. + + Returns: + bool: True if ``datamodule.teardown(stage='validate')`` has been called. False by default. + """ + return self._has_teardown_validate + + @property + def has_teardown_test(self) -> bool: + """Return bool letting you know if ``datamodule.teardown(stage='test')`` has been called or not. + + Returns: + bool: True if ``datamodule.teardown(stage='test')`` has been called. False by default. + """ + return self._has_teardown_test + + @property + def has_teardown_predict(self) -> bool: + """Return bool letting you know if ``datamodule.teardown(stage='predict')`` has been called or not. + + Returns: + bool: True if ``datamodule.teardown(stage='predict')`` has been called. False by default. + """ + return self._has_teardown_predict @classmethod def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentParser: diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 9624f94652713..c74de98a3bd6c 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -25,42 +25,6 @@ class ModelHooks: """Hooks to be used in LightningModule.""" - def setup(self, stage: Optional[str] = None) -> None: - """ - Called at the beginning of fit (train + validate), validate, test, predict, or tune. - This is a good hook when you need to build models dynamically or adjust something about them. - This hook is called on every process when using DDP. - - Args: - stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'`` - - Example:: - - class LitModel(...): - def __init__(self): - self.l1 = None - - def prepare_data(self): - download_data() - tokenize() - - # don't do this - self.something = else - - def setup(stage): - data = Load_data(...) - self.l1 = nn.Linear(28, data.num_classes) - - """ - - def teardown(self, stage: Optional[str] = None) -> None: - """ - Called at the end of fit (train + validate), validate, test, predict, or tune. - - Args: - stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'`` - """ - def on_fit_start(self) -> None: """ Called at the very beginning of fit. @@ -383,6 +347,42 @@ def prepare_data(self): model.test_dataloader() """ + def setup(self, stage: Optional[str] = None) -> None: + """ + Called at the beginning of fit (train + validate), validate, test, predict, or tune. + This is a good hook when you need to build models dynamically or adjust something about them. + This hook is called on every process when using DDP. + + Args: + stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'`` + + Example:: + + class LitModel(...): + def __init__(self): + self.l1 = None + + def prepare_data(self): + download_data() + tokenize() + + # don't do this + self.something = else + + def setup(stage): + data = Load_data(...) + self.l1 = nn.Linear(28, data.num_classes) + + """ + + def teardown(self, stage: Optional[str] = None) -> None: + """ + Called at the end of fit (train + validate), validate, test, predict, or tune. + + Args: + stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'`` + """ + def train_dataloader(self) -> Any: """ Implement one or more PyTorch DataLoaders for training. diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f7bd1757b9bc2..7e68ff4e8ab10 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1077,6 +1077,12 @@ def call_setup_hook(self, model: LightningModule) -> None: def call_teardown_hook(self, model: LightningModule) -> None: state = self._teardown_state + + if self.datamodule is not None: + called = getattr(self.datamodule, f'has_teardown_{state}') + if not called: + self.datamodule.teardown(stage=state) + self.profiler.teardown(stage=state) self.teardown(stage=state) model.teardown(stage=state) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 2118fec6c207b..c8808ec37326c 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -128,6 +128,10 @@ def test_data_hooks_called(tmpdir): assert not dm.has_setup_test assert not dm.has_setup_validate assert not dm.has_setup_predict + assert not dm.has_teardown_fit + assert not dm.has_teardown_test + assert not dm.has_teardown_validate + assert not dm.has_teardown_predict dm.prepare_data() assert dm.has_prepared_data @@ -135,6 +139,10 @@ def test_data_hooks_called(tmpdir): assert not dm.has_setup_test assert not dm.has_setup_validate assert not dm.has_setup_predict + assert not dm.has_teardown_fit + assert not dm.has_teardown_test + assert not dm.has_teardown_validate + assert not dm.has_teardown_predict dm.setup() assert dm.has_prepared_data @@ -142,49 +150,84 @@ def test_data_hooks_called(tmpdir): assert dm.has_setup_test assert dm.has_setup_validate assert not dm.has_setup_predict + assert not dm.has_teardown_fit + assert not dm.has_teardown_test + assert not dm.has_teardown_validate + assert not dm.has_teardown_predict + + dm.teardown() + assert dm.has_prepared_data + assert dm.has_setup_fit + assert dm.has_setup_test + assert dm.has_setup_validate + assert not dm.has_setup_predict + assert dm.has_teardown_fit + assert dm.has_teardown_test + assert dm.has_teardown_validate + assert not dm.has_teardown_predict @pytest.mark.parametrize("use_kwarg", (False, True)) def test_data_hooks_called_verbose(tmpdir, use_kwarg): dm = BoringDataModule() - assert not dm.has_prepared_data - assert not dm.has_setup_fit - assert not dm.has_setup_test - dm.prepare_data() - assert dm.has_prepared_data assert not dm.has_setup_fit assert not dm.has_setup_test + assert not dm.has_setup_validate assert not dm.has_setup_predict + assert not dm.has_teardown_fit + assert not dm.has_teardown_test + assert not dm.has_teardown_validate + assert not dm.has_teardown_predict dm.setup(stage='fit') if use_kwarg else dm.setup('fit') - assert dm.has_prepared_data assert dm.has_setup_fit assert not dm.has_setup_validate assert not dm.has_setup_test assert not dm.has_setup_predict dm.setup(stage='validate') if use_kwarg else dm.setup('validate') - assert dm.has_prepared_data assert dm.has_setup_fit assert dm.has_setup_validate assert not dm.has_setup_test assert not dm.has_setup_predict dm.setup(stage='test') if use_kwarg else dm.setup('test') - assert dm.has_prepared_data assert dm.has_setup_fit assert dm.has_setup_validate assert dm.has_setup_test assert not dm.has_setup_predict dm.setup(stage='predict') if use_kwarg else dm.setup('predict') - assert dm.has_prepared_data assert dm.has_setup_fit assert dm.has_setup_validate assert dm.has_setup_test assert dm.has_setup_predict + dm.teardown(stage='fit') if use_kwarg else dm.teardown('fit') + assert dm.has_teardown_fit + assert not dm.has_teardown_validate + assert not dm.has_teardown_test + assert not dm.has_teardown_predict + + dm.teardown(stage='validate') if use_kwarg else dm.teardown('validate') + assert dm.has_teardown_fit + assert dm.has_teardown_validate + assert not dm.has_teardown_test + assert not dm.has_teardown_predict + + dm.teardown(stage='test') if use_kwarg else dm.teardown('test') + assert dm.has_teardown_fit + assert dm.has_teardown_validate + assert dm.has_teardown_test + assert not dm.has_teardown_predict + + dm.teardown(stage='predict') if use_kwarg else dm.teardown('predict') + assert dm.has_teardown_fit + assert dm.has_teardown_validate + assert dm.has_teardown_test + assert dm.has_teardown_predict + def test_dm_add_argparse_args(tmpdir): parser = ArgumentParser() diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 69859547f4a1f..4ead0d1e14e78 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import inspect from unittest import mock from unittest.mock import PropertyMock @@ -20,7 +19,7 @@ from pytorch_lightning import Callback, Trainer from pytorch_lightning.trainer.states import TrainerState -from tests.helpers import BoringModel, RandomDataset +from tests.helpers import BoringModel, RandomDataset, BoringDataModule from tests.helpers.runif import RunIf @@ -260,7 +259,7 @@ def on_train_batch_start(self, batch, batch_idx, dataloader_idx): def test_trainer_model_hook_system(tmpdir): - """Test the hooks system.""" + """Test the LightningModule hook system.""" class HookedModel(BoringModel): @@ -269,149 +268,151 @@ def __init__(self): self.called = [] def on_after_backward(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_after_backward") super().on_after_backward() - def on_before_zero_grad(self, optimizer): - self.called.append(inspect.currentframe().f_code.co_name) - super().on_before_zero_grad(optimizer) + def on_before_zero_grad(self, *args, **kwargs): + self.called.append("on_before_zero_grad") + super().on_before_zero_grad(*args, **kwargs) def on_epoch_start(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_epoch_start") super().on_epoch_start() def on_epoch_end(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_epoch_end") super().on_epoch_end() def on_fit_start(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_fit_start") super().on_fit_start() def on_fit_end(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_fit_end") super().on_fit_end() - def on_hpc_load(self, checkpoint): - self.called.append(inspect.currentframe().f_code.co_name) - super().on_hpc_load(checkpoint) + def on_hpc_load(self, *args, **kwargs): + self.called.append("on_hpc_load") + super().on_hpc_load(*args, **kwargs) - def on_hpc_save(self, checkpoint): - self.called.append(inspect.currentframe().f_code.co_name) - super().on_hpc_save(checkpoint) + def on_hpc_save(self, *args, **kwargs): + self.called.append("on_hpc_save") + super().on_hpc_save(*args, **kwargs) - def on_load_checkpoint(self, checkpoint): - self.called.append(inspect.currentframe().f_code.co_name) - super().on_load_checkpoint(checkpoint) + def on_load_checkpoint(self, *args, **kwargs): + self.called.append("on_load_checkpoint") + super().on_load_checkpoint(*args, **kwargs) - def on_save_checkpoint(self, checkpoint): - self.called.append(inspect.currentframe().f_code.co_name) - super().on_save_checkpoint(checkpoint) + def on_save_checkpoint(self, *args, **kwargs): + self.called.append("on_save_checkpoint") + super().on_save_checkpoint(*args, **kwargs) def on_pretrain_routine_start(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_pretrain_routine_start") super().on_pretrain_routine_start() def on_pretrain_routine_end(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_pretrain_routine_end") super().on_pretrain_routine_end() def on_train_start(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_train_start") super().on_train_start() def on_train_end(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_train_end") super().on_train_end() - def on_train_batch_start(self, batch, batch_idx, dataloader_idx): - self.called.append(inspect.currentframe().f_code.co_name) - super().on_train_batch_start(batch, batch_idx, dataloader_idx) + def on_train_batch_start(self, *args, **kwargs): + self.called.append("on_train_batch_start") + super().on_train_batch_start(*args, **kwargs) - def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): - self.called.append(inspect.currentframe().f_code.co_name) - super().on_train_batch_end(outputs, batch, batch_idx, dataloader_idx) + def on_train_batch_end(self, *args, **kwargs): + self.called.append("on_train_batch_end") + super().on_train_batch_end(*args, **kwargs) def on_train_epoch_start(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_train_epoch_start") super().on_train_epoch_start() def on_train_epoch_end(self, outputs): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_train_epoch_end") super().on_train_epoch_end(outputs) def on_validation_start(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_validation_start") super().on_validation_start() def on_validation_end(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_validation_end") super().on_validation_end() - def on_validation_batch_start(self, batch, batch_idx, dataloader_idx): - self.called.append(inspect.currentframe().f_code.co_name) - super().on_validation_batch_start(batch, batch_idx, dataloader_idx) + def on_validation_batch_start(self, *args, **kwargs): + self.called.append("on_validation_batch_start") + super().on_validation_batch_start(*args, **kwargs) - def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx): - self.called.append(inspect.currentframe().f_code.co_name) - super().on_validation_batch_end(outputs, batch, batch_idx, dataloader_idx) + def on_validation_batch_end(self, *args, **kwargs): + self.called.append("on_validation_batch_end") + super().on_validation_batch_end(*args, **kwargs) def on_validation_epoch_start(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_validation_epoch_start") super().on_validation_epoch_start() - def on_validation_epoch_end(self, outputs): - self.called.append(inspect.currentframe().f_code.co_name) - super().on_validation_epoch_end(outputs) + def on_validation_epoch_end(self, *args, **kwargs): + self.called.append("on_validation_epoch_end") + super().on_validation_epoch_end(*args, **kwargs) def on_test_start(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_test_start") super().on_test_start() - def on_test_batch_start(self, batch, batch_idx, dataloader_idx): - self.called.append(inspect.currentframe().f_code.co_name) - super().on_test_batch_start(batch, batch_idx, dataloader_idx) + def on_test_batch_start(self, *args, **kwargs): + self.called.append("on_test_batch_start") + super().on_test_batch_start(*args, **kwargs) - def on_test_batch_end(self, outputs, batch, batch_idx, dataloader_idx): - self.called.append(inspect.currentframe().f_code.co_name) - super().on_test_batch_end(outputs, batch, batch_idx, dataloader_idx) + def on_test_batch_end(self, *args, **kwargs): + self.called.append("on_test_batch_end") + super().on_test_batch_end(*args, **kwargs) def on_test_epoch_start(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_test_epoch_start") super().on_test_epoch_start() - def on_test_epoch_end(self, outputs): - self.called.append(inspect.currentframe().f_code.co_name) - super().on_test_epoch_end(outputs) + def on_test_epoch_end(self, *args, **kwargs): + self.called.append("on_test_epoch_end") + super().on_test_epoch_end(*args, **kwargs) def on_validation_model_eval(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_validation_model_eval") super().on_validation_model_eval() def on_validation_model_train(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_validation_model_train") super().on_validation_model_train() def on_test_model_eval(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_test_model_eval") super().on_test_model_eval() def on_test_model_train(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_test_model_train") super().on_test_model_train() def on_test_end(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_test_end") super().on_test_end() + def setup(self, stage=None): + self.called.append(f"setup_{stage}") + super().setup(stage=stage) + def teardown(self, stage=None): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append(f"teardown_{stage}") super().teardown(stage) model = HookedModel() - assert model.called == [] - # fit model trainer = Trainer( default_root_dir=tmpdir, @@ -427,6 +428,7 @@ def teardown(self, stage=None): trainer.fit(model) expected = [ + 'setup_fit', 'on_fit_start', 'on_pretrain_routine_start', 'on_pretrain_routine_end', @@ -464,7 +466,7 @@ def teardown(self, stage=None): 'on_validation_model_train', 'on_train_end', 'on_fit_end', - 'teardown', + 'teardown_fit', ] assert model.called == expected @@ -472,6 +474,7 @@ def teardown(self, stage=None): trainer.validate(model, verbose=False) expected = [ + 'setup_validate', 'on_validation_model_eval', 'on_validation_start', 'on_validation_epoch_start', @@ -481,14 +484,15 @@ def teardown(self, stage=None): 'on_epoch_end', 'on_validation_end', 'on_validation_model_train', - 'teardown', + 'teardown_validate', ] assert model.called == expected model = HookedModel() - trainer.test(model, verbose=False) + expected = [ + 'setup_test', 'on_test_model_eval', 'on_test_start', 'on_test_epoch_start', @@ -498,6 +502,119 @@ def teardown(self, stage=None): 'on_epoch_end', 'on_test_end', 'on_test_model_train', - 'teardown', + 'teardown_test', ] assert model.called == expected + + +def test_trainer_datamodule_hook_system(tmpdir): + """Test the LightningDataModule hook system.""" + + class HookedDataModule(BoringDataModule): + def __init__(self): + super().__init__() + self.called = [] + + def prepare_data(self): + self.called.append("prepare_data") + super().prepare_data() + + def setup(self, stage=None): + self.called.append(f"setup_{stage}") + super().setup(stage=stage) + + def teardown(self, stage=None): + self.called.append(f"teardown_{stage}") + super().teardown(stage=stage) + + def train_dataloader(self): + self.called.append("train_dataloader") + return super().train_dataloader() + + def test_dataloader(self): + self.called.append("test_dataloader") + return super().test_dataloader() + + def val_dataloader(self): + self.called.append("val_dataloader") + return super().val_dataloader() + + def predict_dataloader(self): + self.called.append("predict_dataloader") + + def transfer_batch_to_device(self, *args, **kwargs): + self.called.append("transfer_batch_to_device") + return super().transfer_batch_to_device(*args, **kwargs) + + def on_before_batch_transfer(self, *args, **kwargs): + self.called.append("on_before_batch_transfer") + return super().on_before_batch_transfer(*args, **kwargs) + + def on_after_batch_transfer(self, *args, **kwargs): + self.called.append("on_after_batch_transfer") + return super().on_after_batch_transfer(*args, **kwargs) + + model = BoringModel() + dm = HookedDataModule() + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_val_batches=1, + limit_train_batches=2, + limit_test_batches=1, + progress_bar_refresh_rate=0, + weights_summary=None, + reload_dataloaders_every_epoch=True, + ) + trainer.fit(model, datamodule=dm) + + expected = [ + 'prepare_data', + 'setup_fit', + 'val_dataloader', + 'on_before_batch_transfer', + 'transfer_batch_to_device', + 'on_after_batch_transfer', + 'train_dataloader', + 'on_before_batch_transfer', + 'transfer_batch_to_device', + 'on_after_batch_transfer', + 'on_before_batch_transfer', + 'transfer_batch_to_device', + 'on_after_batch_transfer', + 'val_dataloader', + 'on_before_batch_transfer', + 'transfer_batch_to_device', + 'on_after_batch_transfer', + 'teardown_fit' + ] + assert dm.called == expected + + dm = HookedDataModule() + trainer.validate(model, datamodule=dm, verbose=False) + + expected = [ + 'prepare_data', + 'setup_validate', + 'val_dataloader', + 'on_before_batch_transfer', + 'transfer_batch_to_device', + 'on_after_batch_transfer', + 'teardown_validate' + ] + assert dm.called == expected + + dm = HookedDataModule() + trainer.test(model, datamodule=dm, verbose=False) + + expected = [ + 'prepare_data', + 'setup_test', + 'test_dataloader', + 'on_before_batch_transfer', + 'transfer_batch_to_device', + 'on_after_batch_transfer', + 'teardown_test' + ] + assert dm.called == expected