diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index d6b30541906616..9e9f6414098434 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -16,7 +16,7 @@ import inspect from abc import abstractmethod from argparse import ArgumentParser, Namespace -from typing import Any, List, Tuple, Union +from typing import Any, List, Optional, Tuple, Union from torch.utils.data import DataLoader @@ -32,13 +32,10 @@ def __call__(cls, *args, **kwargs): 3. Lets you check prepare_data and setup to see if they've been called """ - # Wrap cls's prepare_data function with rank_zero_only - cls.prepare_data = rank_zero_only(cls.prepare_data) - - # prepare_data and setup wrapped w/ function to track if they've been called. - # Usage: your_dm.setup.has_been_called & your_dm.prepare_data.has_been_called - cls.prepare_data = track_func_calls(cls.prepare_data) - cls.setup = track_func_calls(cls.setup) + # Track prepare_data calls and make sure it runs on rank zero + cls.prepare_data = track_data_hook_calls(rank_zero_only(cls.prepare_data)) + # Track setup calls + cls.setup = track_data_hook_calls(cls.setup) # Get instance of LightningDataModule by mocking its __init__ via __call__ obj = type.__call__(cls, *args, **kwargs) @@ -46,21 +43,45 @@ def __call__(cls, *args, **kwargs): return obj -def track_func_calls(fn): - """A decorator that checks if a function has been called. +def track_data_hook_calls(fn): + """A decorator that checks if prepare_data/setup have been called. + + - When dm.prepare_data() is called, dm.has_prepared_data gets set to True + - When dm.setup('fit') is called, dm.has_setup_fit gets set to True + - When dm.setup('test') is called, dm.has_setup_test gets set to True + - When dm.setup() is called without stage arg, both dm.has_setup_fit and dm.has_setup_test get set to True Args: fn (function): Function that will be tracked to see if it has been called. Returns: - callable: Your function with an added bool attr fn.has_been_called. + function: Decorated function that tracks its call status and saves it to private attrs in its obj instance. """ + @functools.wraps(fn) def wrapped_fn(*args, **kwargs): - wrapped_fn.has_been_called = True - return fn(*args, **kwargs) - wrapped_fn.has_been_called = False + # The object instance from which setup or prepare_data was called + obj = args[0] + + # If calling setup, we check the stage and assign stage-specific bool args + if fn.__name__ == 'setup': + + # Get stage either by grabbing from args or checking kwargs. + # If not provided, set call status of 'fit' and 'test' to True. + # We do this so __attach_datamodule in trainer.py doesn't mistakenly call setup('test') on trainer.test() + stage = args[1] if len(args) > 1 else kwargs.get('stage', None) + + if stage == 'fit' or stage is None: + obj._has_setup_fit = True + + if stage == 'test' or stage is None: + obj._has_setup_test = True + + if fn.__name__ == 'prepare_data': + obj._has_prepared_data = True + + return fn(*args, **kwargs) return wrapped_fn @@ -116,6 +137,11 @@ def __init__( self._test_transforms = test_transforms self.dims = () + # Private attrs to keep track of whether or not data hooks have been called yet + self._has_prepared_data = False + self._has_setup_fit = False + self._has_setup_test = False + @property def train_transforms(self): """ @@ -159,6 +185,33 @@ def size(self, dim=None) -> Union[Tuple, int]: return self.dims + @property + def has_prepared_data(self): + """Return bool letting you know if datamodule.prepare_data() has been called or not. + + Returns: + bool: True if datamodule.prepare_data() has been called. False by default. + """ + return self._has_prepared_data + + @property + def has_setup_fit(self): + """Return bool letting you know if datamodule.setup('fit') has been called or not. + + Returns: + bool: True if datamodule.setup('fit') has been called. False by default. + """ + return self._has_setup_fit + + @property + def has_setup_test(self): + """Return bool letting you know if datamodule.setup('test') has been called or not. + + Returns: + bool: True if datamodule.setup('test') has been called. False by default. + """ + return self._has_setup_test + @abstractmethod def prepare_data(self, *args, **kwargs): """ @@ -181,14 +234,14 @@ def prepare_data(self): """ @abstractmethod - def setup(self, *args, **kwargs): + def setup(self, stage: Optional[str] = None): """ Use this to load your data from file, split it, etc. You are safe to make state assignments here. This hook is called on every process when using DDP. Example:: - def setup(self): + def setup(self, stage): data = load_data(...) self.train_ds, self.val_ds, self.test_ds = split_data(data) """ diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 3f9a99cd7b9b58..916a47b3eebdaa 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1024,7 +1024,7 @@ def fit( # set up the passed in dataloaders (if needed) self.__attach_dataloaders(model, train_dataloader, val_dataloaders) - self.__attach_datamodule(model, datamodule) + self.__attach_datamodule(model, datamodule, 'fit') # check that model is configured correctly self.config_validator.verify_loop_configurations(model) @@ -1152,17 +1152,29 @@ def __attach_dataloaders(self, model, train_dataloader=None, val_dataloaders=Non if test_dataloaders is not None: model.test_dataloader = _PatchDataLoader(test_dataloaders) - def __attach_datamodule(self, model, datamodule=None): + def __attach_datamodule(self, model, datamodule, stage): # We use datamodule if it's been provided on .fit or .test, otherwise we check model for it datamodule = datamodule or getattr(model, 'datamodule', None) # If we have a datamodule, attach necessary hooks + dataloaders if datamodule: - if self.is_overridden('prepare_data', datamodule) and not datamodule.prepare_data.has_been_called: + + # If datamodule.prepare_data() has not been called yet, call it + if self.is_overridden('prepare_data', datamodule) and not datamodule.has_prepared_data: datamodule.prepare_data() - if self.is_overridden('setup', datamodule) and not datamodule.setup.has_been_called: - datamodule.setup() + + # If datamodule.setup('fit') has not been called yet, call it + if stage == 'fit': + if self.is_overridden('setup', datamodule) and not datamodule.has_setup_fit: + datamodule.setup('fit') + + # If datamodule.setup('test') has not been called yet, call it + if stage == 'test': + if self.is_overridden('setup', datamodule) and not datamodule.has_setup_test: + datamodule.setup('test') + + # Override loader hooks if self.is_overridden('train_dataloader', datamodule): model.train_dataloader = datamodule.train_dataloader if self.is_overridden('val_dataloader', datamodule): @@ -1367,7 +1379,7 @@ def test( ) # Attach datamodule to get setup/prepare_data added to model before the call to it below - self.__attach_datamodule(model or self.get_model(), datamodule) + self.__attach_datamodule(model or self.get_model(), datamodule, 'test') self.setup('test') diff --git a/tests/base/datamodules.py b/tests/base/datamodules.py index 23c07f93d46976..27705ad424676d 100644 --- a/tests/base/datamodules.py +++ b/tests/base/datamodules.py @@ -13,12 +13,17 @@ def __init__(self, data_dir: str = './'): def prepare_data(self): TrialMNIST(self.data_dir, train=True, download=True) TrialMNIST(self.data_dir, train=False, download=True) - - def setup(self): - mnist_full = TrialMNIST(root=self.data_dir, train=True, num_samples=64, download=True) - self.mnist_train, self.mnist_val = random_split(mnist_full, [128, 64]) - self.dims = tuple(self.mnist_train[0][0].shape) - self.mnist_test = TrialMNIST(root=self.data_dir, train=False, num_samples=32, download=True) + + def setup(self, stage: str = None): + + if stage == 'fit' or stage is None: + mnist_full = TrialMNIST(root=self.data_dir, train=True, num_samples=64, download=True) + self.mnist_train, self.mnist_val = random_split(mnist_full, [128, 64]) + self.dims = self.mnist_train[0][0].shape + + if stage == 'test' or stage is None: + self.mnist_test = TrialMNIST(root=self.data_dir, train=False, num_samples=32, download=True) + self.dims = getattr(self, 'dims', self.mnist_test[0][0].shape) def train_dataloader(self): return DataLoader(self.mnist_train, batch_size=32) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index c3bcc67f27b20b..929c5c3d2c805f 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -16,18 +16,64 @@ def test_base_datamodule(tmpdir): dm.setup() -def test_dm_has_been_called(tmpdir): +def test_base_datamodule_with_verbose_setup(tmpdir): dm = TrialMNISTDataModule() - assert dm.prepare_data.has_been_called is False - assert dm.setup.has_been_called is False + dm.prepare_data() + dm.setup('fit') + dm.setup('test') + + +def test_data_hooks_called(tmpdir): + dm = TrialMNISTDataModule() + assert dm.has_prepared_data is False + assert dm.has_setup_fit is False + assert dm.has_setup_test is False dm.prepare_data() - assert dm.prepare_data.has_been_called is True - assert dm.setup.has_been_called is False + assert dm.has_prepared_data is True + assert dm.has_setup_fit is False + assert dm.has_setup_test is False dm.setup() - assert dm.prepare_data.has_been_called is True - assert dm.setup.has_been_called is True + assert dm.has_prepared_data is True + assert dm.has_setup_fit is True + assert dm.has_setup_test is True + + +def test_data_hooks_called_verbose(tmpdir): + dm = TrialMNISTDataModule() + assert dm.has_prepared_data is False + assert dm.has_setup_fit is False + assert dm.has_setup_test is False + + dm.prepare_data() + assert dm.has_prepared_data is True + assert dm.has_setup_fit is False + assert dm.has_setup_test is False + + dm.setup('fit') + assert dm.has_prepared_data is True + assert dm.has_setup_fit is True + assert dm.has_setup_test is False + + dm.setup('test') + assert dm.has_prepared_data is True + assert dm.has_setup_fit is True + assert dm.has_setup_test is True + + +def test_data_hooks_called_with_stage_kwarg(tmpdir): + dm = TrialMNISTDataModule() + dm.prepare_data() + assert dm.has_prepared_data is True + + dm.setup(stage='fit') + assert dm.has_setup_fit is True + assert dm.has_setup_test is False + + dm.setup(stage='test') + assert dm.has_setup_fit is True + assert dm.has_setup_test is True def test_dm_add_argparse_args(tmpdir): @@ -58,6 +104,14 @@ def test_dm_pickle_after_setup(tmpdir): pickle.dumps(dm) +def test_dm_pickle_after_setup_verbose(tmpdir): + dm = TrialMNISTDataModule() + dm.prepare_data() + dm.setup('fit') + dm.setup('test') + pickle.dumps(dm) + + def test_train_loop_only(tmpdir): dm = TrialMNISTDataModule(tmpdir)