diff --git a/CHANGELOG.md b/CHANGELOG.md index f78569c1b7a0b4..39b9a6a3a6a3f0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `LightningEnvironment` for Lightning-specific DDP ([#5915](https://github.com/PyTorchLightning/pytorch-lightning/pull/5915)) +- Added `teardown()` hook to LightningDataModule ([#4673](https://github.com/PyTorchLightning/pytorch-lightning/pull/4673)) + + - Added arg to `self.log` that enables users to give custom names when dealing with multiple dataloaders ([#6274](https://github.com/PyTorchLightning/pytorch-lightning/pull/6274)) diff --git a/docs/source/common/lightning_module.rst b/docs/source/common/lightning_module.rst index ec257bf444f5c8..82b96c09acef82 100644 --- a/docs/source/common/lightning_module.rst +++ b/docs/source/common/lightning_module.rst @@ -1246,7 +1246,7 @@ prepare_data setup ~~~~~ -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.setup +.. automethod:: pytorch_lightning.core.hooks.DataHooks.setup :noindex: tbptt_split_batch @@ -1258,7 +1258,7 @@ tbptt_split_batch teardown ~~~~~~~~ -.. automethod:: pytorch_lightning.core.hooks.ModelHooks.teardown +.. automethod:: pytorch_lightning.core.hooks.DataHooks.teardown :noindex: train_dataloader diff --git a/docs/source/extensions/datamodules.rst b/docs/source/extensions/datamodules.rst index 85134fda06fa2f..881febe21316dc 100644 --- a/docs/source/extensions/datamodules.rst +++ b/docs/source/extensions/datamodules.rst @@ -94,6 +94,10 @@ The equivalent DataModule just organizes the same exact code, but makes it reusa def test_dataloader(self): return DataLoader(self.mnist_test, batch_size=self.batch_size) + def teardown(self, stage: Optional[str] = None): + # Used to clean-up when the run is finished + ... + But now, as the complexity of your processing grows (transforms, multiple-GPU training), you can let Lightning handle those details for you while making this dataset reusable so you can share with colleagues or use in different projects. @@ -243,7 +247,10 @@ There are also data operations you might want to perform on every GPU. Use setup self.dims = getattr(self, 'dims', self.mnist_test[0][0].shape) -.. warning:: `setup` is called from every process. Setting state here is okay. +.. warning:: ``setup`` is called from every process. Setting state here is okay. + + +.. note:: ``teardown`` can be used to clean up the state. It is also called from every process train_dataloader @@ -411,10 +418,14 @@ You can of course use DataModules in plain PyTorch code as well. for batch in dm.val_dataloader(): ... + dm.teardown(stage='fit') + # lazy load test data dm.setup(stage='test') for batch in dm.test_dataloader(): ... + dm.teardown(stage='test') + But overall, DataModules encourage reproducibility by allowing all details of a dataset to be specified in a unified structure. diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 1b6852c071fe11..4514b140c46adf 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -15,7 +15,6 @@ import functools import inspect -from abc import abstractmethod from argparse import ArgumentParser, Namespace from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union @@ -45,6 +44,8 @@ def __call__(cls, *args, **kwargs): cls.prepare_data = track_data_hook_calls(rank_zero_only(cls.prepare_data)) # Track setup calls cls.setup = track_data_hook_calls(cls.setup) + # Track teardown calls + cls.teardown = track_data_hook_calls(cls.teardown) # Get instance of LightningDataModule by mocking its __init__ via __call__ obj = type.__call__(cls, *args, **kwargs) @@ -53,12 +54,13 @@ def __call__(cls, *args, **kwargs): def track_data_hook_calls(fn): - """A decorator that checks if prepare_data/setup have been called. + """A decorator that checks if prepare_data/setup/teardown has been called. - When ``dm.prepare_data()`` is called, ``dm.has_prepared_data`` gets set to True - When ``dm.setup()``, ``dm.has_setup_{fit,validate,test}`` get set to True - When ``dm.setup(stage)`` is called, where stage is any of ``{fit,validate,test,predict}``. Its corresponding `dm_has_setup_{stage}` attribute gets set to True + - ``dm.teardown()`` and ``dm.teardown(stage)`` act exactly like ``dm.setup`` Args: fn (function): Function that will be tracked to see if it has been called. @@ -72,9 +74,10 @@ def wrapped_fn(*args, **kwargs): # The object instance from which setup or prepare_data was called obj = args[0] + name = fn.__name__ # If calling setup, we check the stage and assign stage-specific bool args - if fn.__name__ == "setup": + if name in ("setup", "teardown"): # Get stage either by grabbing from args or checking kwargs. # If not provided, set call status of 'fit', 'validate', and 'test' to True. @@ -83,11 +86,11 @@ def wrapped_fn(*args, **kwargs): if stage is None: for s in ("fit", "validate", "test"): - setattr(obj, f"_has_setup_{s}", True) + setattr(obj, f"_has_{name}_{s}", True) else: - setattr(obj, f"_has_setup_{stage}", True) + setattr(obj, f"_has_{name}_{stage}", True) - if fn.__name__ == "prepare_data": + elif name == "prepare_data": obj._has_prepared_data = True return fn(*args, **kwargs) @@ -120,14 +123,18 @@ def val_dataloader(self): def test_dataloader(self): test_split = Dataset(...) return DataLoader(test_split) + def teardown(self): + # clean up after fit or test + # called on every process in DDP - A DataModule implements 5 key methods: + A DataModule implements 6 key methods: * **prepare_data** (things to do on 1 GPU/TPU not on every GPU/TPU in distributed mode). * **setup** (things to do on every accelerator in distributed mode). * **train_dataloader** the training dataloader. * **val_dataloader** the val dataloader(s). * **test_dataloader** the test dataloader(s). + * **teardown** (things to do on every accelerator in distributed mode when finished) This allows you to share a full dataset without explaining how to download, @@ -155,11 +162,17 @@ def __init__( # Private attrs to keep track of whether or not data hooks have been called yet self._has_prepared_data = False + self._has_setup_fit = False self._has_setup_validate = False self._has_setup_test = False self._has_setup_predict = False + self._has_teardown_fit = False + self._has_teardown_validate = False + self._has_teardown_test = False + self._has_teardown_predict = False + @property def train_transforms(self): """ @@ -260,13 +273,41 @@ def has_setup_predict(self) -> bool: """ return self._has_setup_predict - @abstractmethod - def prepare_data(self, *args, **kwargs): - pass + @property + def has_teardown_fit(self) -> bool: + """Return bool letting you know if ``datamodule.teardown(stage='fit')`` has been called or not. - @abstractmethod - def setup(self, stage: Optional[str] = None): - pass + Returns: + bool: True ``if datamodule.teardown(stage='fit')`` has been called. False by default. + """ + return self._has_teardown_fit + + @property + def has_teardown_validate(self) -> bool: + """Return bool letting you know if ``datamodule.teardown(stage='validate')`` has been called or not. + + Returns: + bool: True if ``datamodule.teardown(stage='validate')`` has been called. False by default. + """ + return self._has_teardown_validate + + @property + def has_teardown_test(self) -> bool: + """Return bool letting you know if ``datamodule.teardown(stage='test')`` has been called or not. + + Returns: + bool: True if ``datamodule.teardown(stage='test')`` has been called. False by default. + """ + return self._has_teardown_test + + @property + def has_teardown_predict(self) -> bool: + """Return bool letting you know if ``datamodule.teardown(stage='predict')`` has been called or not. + + Returns: + bool: True if ``datamodule.teardown(stage='predict')`` has been called. False by default. + """ + return self._has_teardown_predict @classmethod def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser: diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 9826f9d44ac2c0..40788324ade73a 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -25,42 +25,6 @@ class ModelHooks: """Hooks to be used in LightningModule.""" - def setup(self, stage: Optional[str] = None) -> None: - """ - Called at the beginning of fit (train + validate), validate, test, predict, or tune. - This is a good hook when you need to build models dynamically or adjust something about them. - This hook is called on every process when using DDP. - - Args: - stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'`` - - Example:: - - class LitModel(...): - def __init__(self): - self.l1 = None - - def prepare_data(self): - download_data() - tokenize() - - # don't do this - self.something = else - - def setup(stage): - data = Load_data(...) - self.l1 = nn.Linear(28, data.num_classes) - - """ - - def teardown(self, stage: Optional[str] = None) -> None: - """ - Called at the end of fit (train + validate), validate, test, predict, or tune. - - Args: - stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'`` - """ - def on_fit_start(self) -> None: """ Called at the very beginning of fit. @@ -383,6 +347,42 @@ def prepare_data(self): model.test_dataloader() """ + def setup(self, stage: Optional[str] = None) -> None: + """ + Called at the beginning of fit (train + validate), validate, test, predict, or tune. + This is a good hook when you need to build models dynamically or adjust something about them. + This hook is called on every process when using DDP. + + Args: + stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'`` + + Example:: + + class LitModel(...): + def __init__(self): + self.l1 = None + + def prepare_data(self): + download_data() + tokenize() + + # don't do this + self.something = else + + def setup(stage): + data = Load_data(...) + self.l1 = nn.Linear(28, data.num_classes) + + """ + + def teardown(self, stage: Optional[str] = None) -> None: + """ + Called at the end of fit (train + validate), validate, test, predict, or tune. + + Args: + stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'`` + """ + def train_dataloader(self) -> Any: """ Implement one or more PyTorch DataLoaders for training. diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 253d60c2858345..3970fc01e5a339 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1044,6 +1044,11 @@ def call_teardown_hook(self, model: LightningModule) -> None: else: state = None + if self.datamodule is not None: + called = getattr(self.datamodule, f'has_teardown_{state}') + if not called: + self.datamodule.teardown(stage=state) + self.teardown(stage=state) model.teardown(stage=state) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index ab51a87329e2fe..9e79cc5097e43d 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -129,6 +129,10 @@ def test_data_hooks_called(tmpdir): assert not dm.has_setup_test assert not dm.has_setup_validate assert not dm.has_setup_predict + assert not dm.has_teardown_fit + assert not dm.has_teardown_test + assert not dm.has_teardown_validate + assert not dm.has_teardown_predict dm.prepare_data() assert dm.has_prepared_data @@ -136,6 +140,10 @@ def test_data_hooks_called(tmpdir): assert not dm.has_setup_test assert not dm.has_setup_validate assert not dm.has_setup_predict + assert not dm.has_teardown_fit + assert not dm.has_teardown_test + assert not dm.has_teardown_validate + assert not dm.has_teardown_predict dm.setup() assert dm.has_prepared_data @@ -143,49 +151,84 @@ def test_data_hooks_called(tmpdir): assert dm.has_setup_test assert dm.has_setup_validate assert not dm.has_setup_predict + assert not dm.has_teardown_fit + assert not dm.has_teardown_test + assert not dm.has_teardown_validate + assert not dm.has_teardown_predict + + dm.teardown() + assert dm.has_prepared_data + assert dm.has_setup_fit + assert dm.has_setup_test + assert dm.has_setup_validate + assert not dm.has_setup_predict + assert dm.has_teardown_fit + assert dm.has_teardown_test + assert dm.has_teardown_validate + assert not dm.has_teardown_predict @pytest.mark.parametrize("use_kwarg", (False, True)) def test_data_hooks_called_verbose(tmpdir, use_kwarg): dm = BoringDataModule() - assert not dm.has_prepared_data - assert not dm.has_setup_fit - assert not dm.has_setup_test - dm.prepare_data() - assert dm.has_prepared_data assert not dm.has_setup_fit assert not dm.has_setup_test + assert not dm.has_setup_validate assert not dm.has_setup_predict + assert not dm.has_teardown_fit + assert not dm.has_teardown_test + assert not dm.has_teardown_validate + assert not dm.has_teardown_predict dm.setup(stage='fit') if use_kwarg else dm.setup('fit') - assert dm.has_prepared_data assert dm.has_setup_fit assert not dm.has_setup_validate assert not dm.has_setup_test assert not dm.has_setup_predict dm.setup(stage='validate') if use_kwarg else dm.setup('validate') - assert dm.has_prepared_data assert dm.has_setup_fit assert dm.has_setup_validate assert not dm.has_setup_test assert not dm.has_setup_predict dm.setup(stage='test') if use_kwarg else dm.setup('test') - assert dm.has_prepared_data assert dm.has_setup_fit assert dm.has_setup_validate assert dm.has_setup_test assert not dm.has_setup_predict dm.setup(stage='predict') if use_kwarg else dm.setup('predict') - assert dm.has_prepared_data assert dm.has_setup_fit assert dm.has_setup_validate assert dm.has_setup_test assert dm.has_setup_predict + dm.teardown(stage='fit') if use_kwarg else dm.teardown('fit') + assert dm.has_teardown_fit + assert not dm.has_teardown_validate + assert not dm.has_teardown_test + assert not dm.has_teardown_predict + + dm.teardown(stage='validate') if use_kwarg else dm.teardown('validate') + assert dm.has_teardown_fit + assert dm.has_teardown_validate + assert not dm.has_teardown_test + assert not dm.has_teardown_predict + + dm.teardown(stage='test') if use_kwarg else dm.teardown('test') + assert dm.has_teardown_fit + assert dm.has_teardown_validate + assert dm.has_teardown_test + assert not dm.has_teardown_predict + + dm.teardown(stage='predict') if use_kwarg else dm.teardown('predict') + assert dm.has_teardown_fit + assert dm.has_teardown_validate + assert dm.has_teardown_test + assert dm.has_teardown_predict + def test_dm_add_argparse_args(tmpdir): parser = ArgumentParser() diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 7c53925bd7cc4b..b6400ec5063f28 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import inspect from unittest import mock from unittest.mock import PropertyMock @@ -20,7 +19,7 @@ from pytorch_lightning import Callback, Trainer from pytorch_lightning.trainer.states import TrainerState -from tests.helpers import BoringModel, RandomDataset +from tests.helpers import BoringModel, RandomDataset, BoringDataModule from tests.helpers.runif import RunIf @@ -260,7 +259,7 @@ def on_train_batch_start(self, batch, batch_idx, dataloader_idx): def test_trainer_model_hook_system(tmpdir): - """Test the hooks system.""" + """Test the LightningModule hook system.""" class HookedModel(BoringModel): @@ -269,149 +268,151 @@ def __init__(self): self.called = [] def on_after_backward(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_after_backward") super().on_after_backward() - def on_before_zero_grad(self, optimizer): - self.called.append(inspect.currentframe().f_code.co_name) - super().on_before_zero_grad(optimizer) + def on_before_zero_grad(self, *args, **kwargs): + self.called.append("on_before_zero_grad") + super().on_before_zero_grad(*args, **kwargs) def on_epoch_start(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_epoch_start") super().on_epoch_start() def on_epoch_end(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_epoch_end") super().on_epoch_end() def on_fit_start(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_fit_start") super().on_fit_start() def on_fit_end(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_fit_end") super().on_fit_end() - def on_hpc_load(self, checkpoint): - self.called.append(inspect.currentframe().f_code.co_name) - super().on_hpc_load(checkpoint) + def on_hpc_load(self, *args, **kwargs): + self.called.append("on_hpc_load") + super().on_hpc_load(*args, **kwargs) - def on_hpc_save(self, checkpoint): - self.called.append(inspect.currentframe().f_code.co_name) - super().on_hpc_save(checkpoint) + def on_hpc_save(self, *args, **kwargs): + self.called.append("on_hpc_save") + super().on_hpc_save(*args, **kwargs) - def on_load_checkpoint(self, checkpoint): - self.called.append(inspect.currentframe().f_code.co_name) - super().on_load_checkpoint(checkpoint) + def on_load_checkpoint(self, *args, **kwargs): + self.called.append("on_load_checkpoint") + super().on_load_checkpoint(*args, **kwargs) - def on_save_checkpoint(self, checkpoint): - self.called.append(inspect.currentframe().f_code.co_name) - super().on_save_checkpoint(checkpoint) + def on_save_checkpoint(self, *args, **kwargs): + self.called.append("on_save_checkpoint") + super().on_save_checkpoint(*args, **kwargs) def on_pretrain_routine_start(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_pretrain_routine_start") super().on_pretrain_routine_start() def on_pretrain_routine_end(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_pretrain_routine_end") super().on_pretrain_routine_end() def on_train_start(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_train_start") super().on_train_start() def on_train_end(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_train_end") super().on_train_end() - def on_train_batch_start(self, batch, batch_idx, dataloader_idx): - self.called.append(inspect.currentframe().f_code.co_name) - super().on_train_batch_start(batch, batch_idx, dataloader_idx) + def on_train_batch_start(self, *args, **kwargs): + self.called.append("on_train_batch_start") + super().on_train_batch_start(*args, **kwargs) - def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): - self.called.append(inspect.currentframe().f_code.co_name) - super().on_train_batch_end(outputs, batch, batch_idx, dataloader_idx) + def on_train_batch_end(self, *args, **kwargs): + self.called.append("on_train_batch_end") + super().on_train_batch_end(*args, **kwargs) def on_train_epoch_start(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_train_epoch_start") super().on_train_epoch_start() def on_train_epoch_end(self, outputs): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_train_epoch_end") super().on_train_epoch_end(outputs) def on_validation_start(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_validation_start") super().on_validation_start() def on_validation_end(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_validation_end") super().on_validation_end() - def on_validation_batch_start(self, batch, batch_idx, dataloader_idx): - self.called.append(inspect.currentframe().f_code.co_name) - super().on_validation_batch_start(batch, batch_idx, dataloader_idx) + def on_validation_batch_start(self, *args, **kwargs): + self.called.append("on_validation_batch_start") + super().on_validation_batch_start(*args, **kwargs) - def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx): - self.called.append(inspect.currentframe().f_code.co_name) - super().on_validation_batch_end(outputs, batch, batch_idx, dataloader_idx) + def on_validation_batch_end(self, *args, **kwargs): + self.called.append("on_validation_batch_end") + super().on_validation_batch_end(*args, **kwargs) def on_validation_epoch_start(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_validation_epoch_start") super().on_validation_epoch_start() def on_validation_epoch_end(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_validation_epoch_end") super().on_validation_epoch_end() def on_test_start(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_test_start") super().on_test_start() - def on_test_batch_start(self, batch, batch_idx, dataloader_idx): - self.called.append(inspect.currentframe().f_code.co_name) - super().on_test_batch_start(batch, batch_idx, dataloader_idx) + def on_test_batch_start(self, *args, **kwargs): + self.called.append("on_test_batch_start") + super().on_test_batch_start(*args, **kwargs) - def on_test_batch_end(self, outputs, batch, batch_idx, dataloader_idx): - self.called.append(inspect.currentframe().f_code.co_name) - super().on_test_batch_end(outputs, batch, batch_idx, dataloader_idx) + def on_test_batch_end(self, *args, **kwargs): + self.called.append("on_test_batch_end") + super().on_test_batch_end(*args, **kwargs) def on_test_epoch_start(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_test_epoch_start") super().on_test_epoch_start() def on_test_epoch_end(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_test_epoch_end") super().on_test_epoch_end() def on_validation_model_eval(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_validation_model_eval") super().on_validation_model_eval() def on_validation_model_train(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_validation_model_train") super().on_validation_model_train() def on_test_model_eval(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_test_model_eval") super().on_test_model_eval() def on_test_model_train(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_test_model_train") super().on_test_model_train() def on_test_end(self): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append("on_test_end") super().on_test_end() + def setup(self, stage=None): + self.called.append(f"setup_{stage}") + super().setup(stage=stage) + def teardown(self, stage=None): - self.called.append(inspect.currentframe().f_code.co_name) + self.called.append(f"teardown_{stage}") super().teardown(stage) model = HookedModel() - assert model.called == [] - # fit model trainer = Trainer( default_root_dir=tmpdir, @@ -427,6 +428,7 @@ def teardown(self, stage=None): trainer.fit(model) expected = [ + 'setup_fit', 'on_fit_start', 'on_pretrain_routine_start', 'on_pretrain_routine_end', @@ -464,15 +466,15 @@ def teardown(self, stage=None): 'on_validation_model_train', 'on_train_end', 'on_fit_end', - 'teardown', + 'teardown_fit', ] - assert model.called == expected model = HookedModel() - trainer.test(model, verbose=False) + expected = [ + 'setup_test', 'on_test_model_eval', 'on_test_start', 'on_test_epoch_start', @@ -482,6 +484,103 @@ def teardown(self, stage=None): 'on_epoch_end', 'on_test_end', 'on_test_model_train', - 'teardown', + 'teardown_test', ] assert model.called == expected + + +def test_trainer_datamodule_hook_system(tmpdir): + """Test the LightningDataModule hook system.""" + + class HookedDataModule(BoringDataModule): + def __init__(self): + super().__init__() + self.called = [] + + def prepare_data(self): + self.called.append("prepare_data") + super().prepare_data() + + def setup(self, stage=None): + self.called.append(f"setup_{stage}") + super().setup(stage=stage) + + def teardown(self, stage=None): + self.called.append(f"teardown_{stage}") + super().teardown(stage=stage) + + def train_dataloader(self): + self.called.append("train_dataloader") + return super().train_dataloader() + + def test_dataloader(self): + self.called.append("test_dataloader") + return super().test_dataloader() + + def val_dataloader(self): + self.called.append("val_dataloader") + return super().val_dataloader() + + def predict_dataloader(self): + self.called.append("predict_dataloader") + + def transfer_batch_to_device(self, *args, **kwargs): + self.called.append("transfer_batch_to_device") + return super().transfer_batch_to_device(*args, **kwargs) + + def on_before_batch_transfer(self, *args, **kwargs): + self.called.append("on_before_batch_transfer") + return super().on_before_batch_transfer(*args, **kwargs) + + def on_after_batch_transfer(self, *args, **kwargs): + self.called.append("on_after_batch_transfer") + return super().on_after_batch_transfer(*args, **kwargs) + + model = BoringModel() + dm = HookedDataModule() + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_val_batches=1, + limit_train_batches=2, + limit_test_batches=1, + progress_bar_refresh_rate=0, + weights_summary=None, + ) + trainer.fit(model, datamodule=dm) + + expected = [ + 'prepare_data', + 'setup_fit', + 'val_dataloader', + 'on_before_batch_transfer', + 'transfer_batch_to_device', + 'on_after_batch_transfer', + 'train_dataloader', + 'on_before_batch_transfer', + 'transfer_batch_to_device', + 'on_after_batch_transfer', + 'on_before_batch_transfer', + 'transfer_batch_to_device', + 'on_after_batch_transfer', + 'on_before_batch_transfer', + 'transfer_batch_to_device', + 'on_after_batch_transfer', + 'teardown_fit' + ] + assert dm.called == expected + + dm = HookedDataModule() + trainer.test(model, datamodule=dm, verbose=False) + + expected = [ + 'prepare_data', + 'setup_test', + 'test_dataloader', + 'on_before_batch_transfer', + 'transfer_batch_to_device', + 'on_after_batch_transfer', + 'teardown_test' + ] + assert dm.called == expected