diff --git a/docs/source/datamodules.rst b/docs/source/datamodules.rst index 9a78158d947e8..8bc7cfc56447b 100644 --- a/docs/source/datamodules.rst +++ b/docs/source/datamodules.rst @@ -11,33 +11,63 @@ Data preparation in PyTorch follows 5 steps: A DataModule is simply a collection of a train_dataloader, val_dataloader(s), test_dataloader(s) along with the matching transforms and data processing/downloads steps required. +.. code-block:: python + + import pytorch_lightning as pl + from torch.utils.data import random_split, DataLoader + + # Note - you must have torchvision installed for this example + from torchvision.datasets import MNIST + from torchvision import transforms + + + class MNISTDataModule(pl.LightningDataModule): + + def __init__(self, data_dir: str = './'): + super().__init__() + self.data_dir = data_dir + self.transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ]) + + # self.dims is returned when you call dm.size() + # Setting default dims here because we know them. + # Could optionally be assigned dynamically in dm.setup() + self.dims = (1, 28, 28) + + def prepare_data(self): + # download + MNIST(self.data_dir, train=True, download=True) + MNIST(self.data_dir, train=False, download=True) + + def setup(self, stage=None): + + # Assign train/val datasets for use in dataloaders + if stage == 'fit' or stage is None: + mnist_full = MNIST(self.data_dir, train=True, transform=self.transform) + self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000]) + + # Optionally... + # self.dims = tuple(self.mnist_train[0][0].shape) - >>> import pytorch_lightning as pl - >>> class MNISTDataModule(pl.LightningDataModule): - ... def prepare_data(self): - ... # download - ... MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()) - ... MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()) - ... - ... def setup(self, stage): - ... mnist_train = MNIST(os.getcwd(), train=True, download=False, transform=transforms.ToTensor()) - ... mnist_test = MNIST(os.getcwd(), train=False, download=False, transform=transforms.ToTensor()) - ... # train/val split - ... mnist_train, mnist_val = random_split(mnist_train, [55000, 5000]) - ... - ... # assign to use in dataloaders - ... self.train_dataset = mnist_train - ... self.val_dataset = mnist_val - ... self.test_dataset = mnist_test - ... - ... def train_dataloader(self): - ... return DataLoader(self.train_dataset, batch_size=64) - ... - ... def val_dataloader(self): - ... return DataLoader(self.val_dataset, batch_size=64) - ... - ... def test_dataloader(self): - ... return DataLoader(self.test_dataset, batch_size=64) + # Assign test dataset for use in dataloader(s) + if stage == 'test' or stage is None: + self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform) + + # Optionally... + # self.dims = tuple(self.mnist_test[0][0].shape) + + def train_dataloader(self): + return DataLoader(self.mnist_train, batch_size=32) + + def val_dataloader(self): + return DataLoader(self.mnist_val, batch_size=32) + + def test_dataloader(self): + return DataLoader(self.mnist_test, batch_size=32) + +.. note:: ``setup`` expects a string arg ``stage``. It is used to separate setup logic for ``trainer.fit`` and ``trainer.test``. --------------- @@ -60,11 +90,13 @@ settings. - tokenize - etc... - >>> class MNISTDataModule(pl.LightningDataModule): - ... def prepare_data(self): - ... # download - ... MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()) - ... MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()) +.. code-block:: python + + class MNISTDataModule(pl.LightningDataModule): + def prepare_data(self): + # download + MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()) + MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()) .. warning:: `prepare_data` is called from a single GPU. Do not use it to assign state (`self.x = y`). @@ -77,33 +109,46 @@ There are also data operations you might want to perform on every GPU. Use setup - perform train/val/test splits - etc... - >>> import pytorch_lightning as pl - >>> class MNISTDataModule(pl.LightningDataModule): - ... def setup(self, stage): - ... mnist_train = MNIST(os.getcwd(), train=True, download=False, transform=transforms.ToTensor()) - ... mnist_test = MNIST(os.getcwd(), train=False, download=False, transform=transforms.ToTensor()) - ... # train/val split - ... mnist_train, mnist_val = random_split(mnist_train, [55000, 5000]) - ... - ... # assign to use in dataloaders - ... self.train_dataset = mnist_train - ... self.val_dataset = mnist_val - ... self.test_dataset = mnist_test +.. code-block:: python + + import pytorch_lightning as pl + + + class MNISTDataModule(pl.LightningDataModule): + + def setup(self, stage: Optional[str] = None): + + # Assign Train/val split(s) for use in Dataloaders + if stage == 'fit' or stage is None: + mnist_full = MNIST(self.data_dir, train=True, download=True) + self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000]) + self.dims = self.mnist_train[0][0].shape + + # Assign Test split(s) for use in Dataloaders + if stage == 'test' or stage is None: + self.mnist_test = MNIST(self.data_dir, train=False, download=True) + self.dims = getattr(self, 'dims', self.mnist_test[0][0].shape) + .. warning:: `setup` is called from every GPU. Setting state here is okay. + train_dataloader ^^^^^^^^^^^^^^^^ Use this method to generate the train dataloader. This is also a good place to place default transformations. - >>> import pytorch_lightning as pl - >>> class MNISTDataModule(pl.LightningDataModule): - ... def train_dataloader(self): - ... transforms = transform_lib.Compose([ - ... transform_lib.ToTensor(), - ... transform_lib.Normalize(mean=(0.5,), std=(0.5,)), - ... ]) - ... return DataLoader(self.train_dataset, transform=transforms, batch_size=64) +.. code-block:: python + + import pytorch_lightning as pl + + + class MNISTDataModule(pl.LightningDataModule): + def train_dataloader(self): + transforms = transform_lib.Compose([ + transform_lib.ToTensor(), + transform_lib.Normalize(mean=(0.5,), std=(0.5,)), + ]) + return DataLoader(self.train_dataset, transform=transforms, batch_size=64) However, to decouple your data from transforms you can parametrize them via `__init__`. @@ -119,32 +164,41 @@ val_dataloader ^^^^^^^^^^^^^^ Use this method to generate the val dataloader. This is also a good place to place default transformations. - >>> import pytorch_lightning as pl - >>> class MNISTDataModule(pl.LightningDataModule): - ... def val_dataloader(self): - ... transforms = transform_lib.Compose([ - ... transform_lib.ToTensor(), - ... transform_lib.Normalize(mean=(0.5,), std=(0.5,)), - ... ]) - ... return DataLoader(self.val_dataset, transform=transforms, batch_size=64) +.. code-block:: python + + import pytorch_lightning as pl + + + class MNISTDataModule(pl.LightningDataModule): + def val_dataloader(self): + transforms = transform_lib.Compose([ + transform_lib.ToTensor(), + transform_lib.Normalize(mean=(0.5,), std=(0.5,)), + ]) + return DataLoader(self.val_dataset, transform=transforms, batch_size=64) test_dataloader ^^^^^^^^^^^^^^^ Use this method to generate the test dataloader. This is also a good place to place default transformations. - >>> import pytorch_lightning as pl - >>> class MNISTDataModule(pl.LightningDataModule): - ... def test_dataloader(self): - ... transforms = transform_lib.Compose([ - ... transform_lib.ToTensor(), - ... transform_lib.Normalize(mean=(0.5,), std=(0.5,)), - ... ]) - ... return DataLoader(self.test_dataset, transform=transforms, batch_size=64) +.. code-block:: python + + import pytorch_lightning as pl + + + class MNISTDataModule(pl.LightningDataModule): + def test_dataloader(self): + transforms = transform_lib.Compose([ + transform_lib.ToTensor(), + transform_lib.Normalize(mean=(0.5,), std=(0.5,)), + ]) + return DataLoader(self.test_dataset, transform=transforms, batch_size=64) ------------------ Using a DataModule ------------------ + The recommended way to use a DataModule is simply: .. code-block:: python @@ -162,12 +216,13 @@ still ensures the method runs on the correct devices) dm = MNISTDataModule() dm.prepare_data() - dm.setup() + dm.setup('fit') model = Model(num_classes=dm.num_classes, width=dm.width, vocab=dm.vocab) trainer.fit(model, dm) - trainer.test(model, datamodule=dm) + dm.setup('test') + trainer.test(datamodule=dm) ---------------- @@ -184,12 +239,14 @@ DataModules have a few key advantages: dm = MNISTDataModule() dm.prepare_data() - dm.setup() + dm.setup('fit') for batch in dm.train_dataloader(): ... for batch in dm.val_dataloader(): ... + + dm.setup('test') for batch in dm.test_dataloader(): ... diff --git a/pytorch_lightning/accelerator_backends/cpu_backend.py b/pytorch_lightning/accelerator_backends/cpu_backend.py index 2446aab4ddc00..7760442a206c5 100644 --- a/pytorch_lightning/accelerator_backends/cpu_backend.py +++ b/pytorch_lightning/accelerator_backends/cpu_backend.py @@ -26,9 +26,7 @@ def setup(self, model): raise MisconfigurationException('amp + cpu is not supported. Please use a GPU option') # call setup after the ddp process has connected - if not self.trainer.testing: - self.trainer.setup('fit') - model.setup('fit') + self.trainer.call_setup_hook(model) # CHOOSE OPTIMIZER # allow for lr schedulers as well diff --git a/pytorch_lightning/accelerator_backends/ddp_spawn_backend.py b/pytorch_lightning/accelerator_backends/ddp_spawn_backend.py index 6aee68f6634f2..122355856eaf1 100644 --- a/pytorch_lightning/accelerator_backends/ddp_spawn_backend.py +++ b/pytorch_lightning/accelerator_backends/ddp_spawn_backend.py @@ -106,9 +106,7 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0 ) # call setup after the ddp process has connected - if not self.trainer.testing: - self.trainer.setup('fit') - model.setup('fit') + self.trainer.call_setup_hook(model) # on world_size=0 let everyone know training is starting if self.trainer.is_global_zero: diff --git a/pytorch_lightning/accelerator_backends/dp_backend.py b/pytorch_lightning/accelerator_backends/dp_backend.py index 2b0bfca93f7ac..efb683ff4eaa9 100644 --- a/pytorch_lightning/accelerator_backends/dp_backend.py +++ b/pytorch_lightning/accelerator_backends/dp_backend.py @@ -33,9 +33,7 @@ def __init__(self, trainer): def setup(self, model): # call setup after the ddp process has connected - if not self.trainer.testing: - self.trainer.setup('fit') - model.setup('fit') + self.trainer.call_setup_hook(model) # put model on correct device model.cuda(self.trainer.root_gpu) diff --git a/pytorch_lightning/accelerator_backends/gpu_backend.py b/pytorch_lightning/accelerator_backends/gpu_backend.py index 3b5f37671d9e8..7f15d3c25f410 100644 --- a/pytorch_lightning/accelerator_backends/gpu_backend.py +++ b/pytorch_lightning/accelerator_backends/gpu_backend.py @@ -31,9 +31,7 @@ def __init__(self, trainer): def setup(self, model): # call setup - if not self.trainer.testing: - self.trainer.setup('fit') - model.setup('fit') + self.trainer.call_setup_hook(model) model.cuda(self.trainer.root_gpu) diff --git a/pytorch_lightning/accelerator_backends/tpu_backend.py b/pytorch_lightning/accelerator_backends/tpu_backend.py index 8d1d1b271b7dc..2c0b172b9e211 100644 --- a/pytorch_lightning/accelerator_backends/tpu_backend.py +++ b/pytorch_lightning/accelerator_backends/tpu_backend.py @@ -102,9 +102,8 @@ def tpu_train_in_process(self, tpu_core_idx: int, model: LightningModule, traine """ if not trainer: trainer = self.trainer - if not trainer.testing: - trainer.setup('fit') - model.setup('fit') + + trainer.call_setup_hook(model) # setup TPU training self.__setup_tpu_training(model, trainer) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 547f9dc87a605..9e9f641409843 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import functools import inspect from abc import abstractmethod from argparse import ArgumentParser, Namespace -from typing import Any, List, Tuple, Union +from typing import Any, List, Optional, Tuple, Union from torch.utils.data import DataLoader @@ -28,10 +29,13 @@ def __call__(cls, *args, **kwargs): 1. Runs user defined subclass's __init__ 2. Assures prepare_data() runs on rank 0 + 3. Lets you check prepare_data and setup to see if they've been called """ - # Wrap cls's prepare_data function with rank_zero_only - cls.prepare_data = rank_zero_only(cls.prepare_data) + # Track prepare_data calls and make sure it runs on rank zero + cls.prepare_data = track_data_hook_calls(rank_zero_only(cls.prepare_data)) + # Track setup calls + cls.setup = track_data_hook_calls(cls.setup) # Get instance of LightningDataModule by mocking its __init__ via __call__ obj = type.__call__(cls, *args, **kwargs) @@ -39,6 +43,49 @@ def __call__(cls, *args, **kwargs): return obj +def track_data_hook_calls(fn): + """A decorator that checks if prepare_data/setup have been called. + + - When dm.prepare_data() is called, dm.has_prepared_data gets set to True + - When dm.setup('fit') is called, dm.has_setup_fit gets set to True + - When dm.setup('test') is called, dm.has_setup_test gets set to True + - When dm.setup() is called without stage arg, both dm.has_setup_fit and dm.has_setup_test get set to True + + Args: + fn (function): Function that will be tracked to see if it has been called. + + Returns: + function: Decorated function that tracks its call status and saves it to private attrs in its obj instance. + """ + + @functools.wraps(fn) + def wrapped_fn(*args, **kwargs): + + # The object instance from which setup or prepare_data was called + obj = args[0] + + # If calling setup, we check the stage and assign stage-specific bool args + if fn.__name__ == 'setup': + + # Get stage either by grabbing from args or checking kwargs. + # If not provided, set call status of 'fit' and 'test' to True. + # We do this so __attach_datamodule in trainer.py doesn't mistakenly call setup('test') on trainer.test() + stage = args[1] if len(args) > 1 else kwargs.get('stage', None) + + if stage == 'fit' or stage is None: + obj._has_setup_fit = True + + if stage == 'test' or stage is None: + obj._has_setup_test = True + + if fn.__name__ == 'prepare_data': + obj._has_prepared_data = True + + return fn(*args, **kwargs) + + return wrapped_fn + + class LightningDataModule(object, metaclass=_DataModuleWrapper): # pragma: no cover """ A DataModule standardizes the training, val, test splits, data preparation and transforms. @@ -90,6 +137,11 @@ def __init__( self._test_transforms = test_transforms self.dims = () + # Private attrs to keep track of whether or not data hooks have been called yet + self._has_prepared_data = False + self._has_setup_fit = False + self._has_setup_test = False + @property def train_transforms(self): """ @@ -133,6 +185,33 @@ def size(self, dim=None) -> Union[Tuple, int]: return self.dims + @property + def has_prepared_data(self): + """Return bool letting you know if datamodule.prepare_data() has been called or not. + + Returns: + bool: True if datamodule.prepare_data() has been called. False by default. + """ + return self._has_prepared_data + + @property + def has_setup_fit(self): + """Return bool letting you know if datamodule.setup('fit') has been called or not. + + Returns: + bool: True if datamodule.setup('fit') has been called. False by default. + """ + return self._has_setup_fit + + @property + def has_setup_test(self): + """Return bool letting you know if datamodule.setup('test') has been called or not. + + Returns: + bool: True if datamodule.setup('test') has been called. False by default. + """ + return self._has_setup_test + @abstractmethod def prepare_data(self, *args, **kwargs): """ @@ -155,14 +234,14 @@ def prepare_data(self): """ @abstractmethod - def setup(self, *args, **kwargs): + def setup(self, stage: Optional[str] = None): """ Use this to load your data from file, split it, etc. You are safe to make state assignments here. This hook is called on every process when using DDP. Example:: - def setup(self): + def setup(self, stage): data = load_data(...) self.train_ds, self.val_ds, self.test_ds = split_data(data) """ diff --git a/pytorch_lightning/trainer/distrib_data_parallel.py b/pytorch_lightning/trainer/distrib_data_parallel.py index 3ca5f6ffa68f3..c8efc55b49e37 100644 --- a/pytorch_lightning/trainer/distrib_data_parallel.py +++ b/pytorch_lightning/trainer/distrib_data_parallel.py @@ -145,6 +145,7 @@ def train_fx(trial_hparams, cluster_manager, _): from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, rank_zero_info +from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.lightning import LightningModule @@ -204,12 +205,17 @@ class TrainerDDPMixin(ABC): node_rank: int tpu_cores: int testing: bool + datamodule: Optional[LightningDataModule] @property @abstractmethod def is_global_zero(self) -> bool: """Warning: this is just empty shell for code implemented in other class.""" + @abstractmethod + def call_setup_hook(self, *args): + """Warning: this is just empty shell for code implemented in other class.""" + @property @abstractmethod def num_gpus(self) -> int: @@ -530,9 +536,7 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0 model.init_ddp_connection(self.global_rank, self.world_size, self.is_slurm_managing_tasks) # call setup after the ddp process has connected - if not self.testing: - self.setup('fit') - model.setup('fit') + self.call_setup_hook(model) # on world_size=0 let everyone know training is starting if self.is_global_zero: diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index 63db623d91f0b..7d5a00523ef9e 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -88,6 +88,10 @@ class TrainerDPMixin(ABC): def use_amp(self) -> bool: """Warning: this is just empty shell for code implemented in other class.""" + @abstractmethod + def call_setup_hook(self, *args): + """Warning: this is just empty shell for code implemented in other class.""" + @abstractmethod def run_pretrain_routine(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @@ -180,9 +184,7 @@ def __transfer_batch_to_device(self, batch: Any, device: torch.device): def horovod_train(self, model): # call setup after the ddp process has connected - if not self.testing: - self.setup('fit') - model.setup('fit') + self.call_setup_hook(model) if torch.cuda.is_available() and self.on_gpu: # Horovod: pin GPU to local rank diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 0a4c9c349fba4..59a33dad7e5dd 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -378,6 +378,7 @@ def __init__( # training state self.model = None + self.datamodule = None self.testing = False self.prepare_data_per_node = prepare_data_per_node self.lr_schedulers = [] @@ -941,7 +942,7 @@ def fit( # set up the passed in dataloaders (if needed) self.__attach_dataloaders(model, train_dataloader, val_dataloaders) - self.__attach_datamodule(model, datamodule) + self.__attach_datamodule(model, datamodule, 'fit') # check that model is configured correctly self.config_validator.verify_loop_configurations(model) @@ -954,6 +955,8 @@ def fit( # on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0 # or in the case where each node needs to do its own manipulation in which case just local_rank=0 if self.can_prepare_data(): + if datamodule is not None: + datamodule.prepare_data() model.prepare_data() self._is_data_prepared = True @@ -1052,10 +1055,14 @@ def fit( return results or 1 def can_prepare_data(self): + should_call_dm_prepare_data = True + if self.datamodule is not None and self.is_overridden('prepare_data', self.datamodule): + should_call_dm_prepare_data = not self.datamodule.has_prepared_data + if self.prepare_data_per_node: - return self.local_rank == 0 + return self.local_rank == 0 and should_call_dm_prepare_data else: - return self.node_rank == 0 and self.local_rank == 0 + return self.node_rank == 0 and self.local_rank == 0 and should_call_dm_prepare_data def __attach_dataloaders(self, model, train_dataloader=None, val_dataloaders=None, test_dataloaders=None): # when dataloader is passed via fit, patch the train_dataloader @@ -1069,13 +1076,20 @@ def __attach_dataloaders(self, model, train_dataloader=None, val_dataloaders=Non if test_dataloaders is not None: model.test_dataloader = _PatchDataLoader(test_dataloaders) - def __attach_datamodule(self, model, datamodule=None): + def __attach_datamodule(self, model, datamodule, stage): # We use datamodule if it's been provided on .fit or .test, otherwise we check model for it datamodule = datamodule or getattr(model, 'datamodule', None) # If we have a datamodule, attach necessary hooks + dataloaders if datamodule: + + # If datamodule.setup('test') has not been called yet, call it + # if stage == 'test': + # if self.is_overridden('setup', datamodule) and not datamodule.has_setup_test: + # datamodule.setup('test') + + # Override loader hooks if self.is_overridden('train_dataloader', datamodule): model.train_dataloader = datamodule.train_dataloader if self.is_overridden('val_dataloader', datamodule): @@ -1083,6 +1097,8 @@ def __attach_datamodule(self, model, datamodule=None): if self.is_overridden('test_dataloader', datamodule): model.test_dataloader = datamodule.test_dataloader + self.datamodule = datamodule + def run_pretrain_routine(self, model: LightningModule): """Sanity check a few things before starting actual training. @@ -1279,9 +1295,7 @@ def test( ) # Attach datamodule to get setup/prepare_data added to model before the call to it below - self.__attach_datamodule(model or self.get_model(), datamodule) - - self.setup('test') + self.__attach_datamodule(model or self.get_model(), datamodule, 'test') if model is not None: results = self.__test_given_model(model, test_dataloaders) @@ -1294,7 +1308,6 @@ def test( def __test_using_best_weights(self, ckpt_path, test_dataloaders): model = self.get_model() - model.setup('test') # if user requests the best checkpoint but we don't have it, error if ckpt_path == 'best' and self.checkpoint_callback.save_top_k <= 0: @@ -1340,8 +1353,6 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders): return results def __test_given_model(self, model, test_dataloaders): - # setup hook - model.setup('test') # attach data if test_dataloaders is not None: @@ -1370,6 +1381,16 @@ def barrier(self, name): # wait for all processes to catch up torch_xla.core.xla_model.rendezvous(f'pl.Trainer.{name}') + def call_setup_hook(self, model): + # call setup after the ddp process has connected + stage_name = 'test' if self.testing else 'fit' + if self.datamodule is not None: + called = self.datamodule.has_setup_test if self.testing else self.datamodule.has_setup_fit + if not called: + self.datamodule.setup(stage_name) + self.setup(stage_name) + model.setup(stage_name) + class _PatchDataLoader(object): r""" diff --git a/tests/base/datamodules.py b/tests/base/datamodules.py index d863c85605af7..a55a9a718ea9d 100644 --- a/tests/base/datamodules.py +++ b/tests/base/datamodules.py @@ -5,19 +5,28 @@ class TrialMNISTDataModule(LightningDataModule): + def __init__(self, data_dir: str = './'): super().__init__() self.data_dir = data_dir + self.non_picklable = None def prepare_data(self): TrialMNIST(self.data_dir, train=True, download=True) TrialMNIST(self.data_dir, train=False, download=True) - def setup(self): - mnist_full = TrialMNIST(root=self.data_dir, train=True, num_samples=64, download=True) - self.mnist_train, self.mnist_val = random_split(mnist_full, [128, 64]) - self.dims = tuple(self.mnist_train[0][0].shape) - self.mnist_test = TrialMNIST(root=self.data_dir, train=False, num_samples=32, download=True) + def setup(self, stage: str = None): + + if stage == 'fit' or stage is None: + mnist_full = TrialMNIST(root=self.data_dir, train=True, num_samples=64, download=True) + self.mnist_train, self.mnist_val = random_split(mnist_full, [128, 64]) + self.dims = self.mnist_train[0][0].shape + + if stage == 'test' or stage is None: + self.mnist_test = TrialMNIST(root=self.data_dir, train=False, num_samples=32, download=True) + self.dims = getattr(self, 'dims', self.mnist_test[0][0].shape) + + self.non_picklable = lambda x: x**2 def train_dataloader(self): return DataLoader(self.mnist_train, batch_size=32) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 76f62590f904a..ec66afb71ca22 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -1,10 +1,68 @@ import pickle +from argparse import ArgumentParser + import torch import pytest + from pytorch_lightning import Trainer -from tests.base.datamodules import TrialMNISTDataModule from tests.base import EvalModelTemplate -from argparse import ArgumentParser +from tests.base.datamodules import TrialMNISTDataModule +from tests.base.develop_utils import reset_seed + + +def test_can_prepare_data(tmpdir): + + dm = TrialMNISTDataModule() + trainer = Trainer() + trainer.datamodule = dm + + # 1 no DM + # prepare_data_per_node = True + # local rank = 0 (True) + trainer.prepare_data_per_node = True + trainer.local_rank = 0 + assert trainer.can_prepare_data() + + # local rank = 1 (False) + trainer.local_rank = 1 + assert not trainer.can_prepare_data() + + # prepare_data_per_node = False (prepare across all nodes) + # global rank = 0 (True) + trainer.prepare_data_per_node = False + trainer.node_rank = 0 + trainer.local_rank = 0 + assert trainer.can_prepare_data() + + # global rank = 1 (False) + trainer.node_rank = 1 + trainer.local_rank = 0 + assert not trainer.can_prepare_data() + trainer.node_rank = 0 + trainer.local_rank = 1 + assert not trainer.can_prepare_data() + + # 2 dm + # prepar per node = True + # local rank = 0 (True) + trainer.prepare_data_per_node = True + trainer.local_rank = 0 + + # is_overridden prepare data = True + # has been called + # False + dm._has_prepared_data = True + assert not trainer.can_prepare_data() + + # has not been called + # True + dm._has_prepared_data = False + assert trainer.can_prepare_data() + + # is_overridden prepare data = False + # True + dm.prepare_data = None + assert trainer.can_prepare_data() def test_base_datamodule(tmpdir): @@ -13,6 +71,66 @@ def test_base_datamodule(tmpdir): dm.setup() +def test_base_datamodule_with_verbose_setup(tmpdir): + dm = TrialMNISTDataModule() + dm.prepare_data() + dm.setup('fit') + dm.setup('test') + + +def test_data_hooks_called(tmpdir): + dm = TrialMNISTDataModule() + assert dm.has_prepared_data is False + assert dm.has_setup_fit is False + assert dm.has_setup_test is False + + dm.prepare_data() + assert dm.has_prepared_data is True + assert dm.has_setup_fit is False + assert dm.has_setup_test is False + + dm.setup() + assert dm.has_prepared_data is True + assert dm.has_setup_fit is True + assert dm.has_setup_test is True + + +def test_data_hooks_called_verbose(tmpdir): + dm = TrialMNISTDataModule() + assert dm.has_prepared_data is False + assert dm.has_setup_fit is False + assert dm.has_setup_test is False + + dm.prepare_data() + assert dm.has_prepared_data is True + assert dm.has_setup_fit is False + assert dm.has_setup_test is False + + dm.setup('fit') + assert dm.has_prepared_data is True + assert dm.has_setup_fit is True + assert dm.has_setup_test is False + + dm.setup('test') + assert dm.has_prepared_data is True + assert dm.has_setup_fit is True + assert dm.has_setup_test is True + + +def test_data_hooks_called_with_stage_kwarg(tmpdir): + dm = TrialMNISTDataModule() + dm.prepare_data() + assert dm.has_prepared_data is True + + dm.setup(stage='fit') + assert dm.has_setup_fit is True + assert dm.has_setup_test is False + + dm.setup(stage='test') + assert dm.has_setup_fit is True + assert dm.has_setup_test is True + + def test_dm_add_argparse_args(tmpdir): parser = ArgumentParser() parser = TrialMNISTDataModule.add_argparse_args(parser) @@ -34,17 +152,8 @@ def test_dm_pickle_after_init(tmpdir): pickle.dumps(dm) -def test_dm_pickle_after_setup(tmpdir): - dm = TrialMNISTDataModule() - dm.prepare_data() - dm.setup() - pickle.dumps(dm) - - def test_train_loop_only(tmpdir): dm = TrialMNISTDataModule(tmpdir) - dm.prepare_data() - dm.setup() model = EvalModelTemplate() model.validation_step = None @@ -59,18 +168,17 @@ def test_train_loop_only(tmpdir): max_epochs=3, weights_summary=None, ) - trainer.fit(model, dm) # fit model - result = trainer.fit(model) + result = trainer.fit(model, dm) assert result == 1 - assert trainer.callback_metrics['loss'] < 0.50 + assert trainer.callback_metrics['loss'] < 0.6 def test_train_val_loop_only(tmpdir): + reset_seed() + dm = TrialMNISTDataModule(tmpdir) - dm.prepare_data() - dm.setup() model = EvalModelTemplate() model.validation_step = None @@ -82,18 +190,32 @@ def test_train_val_loop_only(tmpdir): max_epochs=3, weights_summary=None, ) - trainer.fit(model, dm) # fit model - result = trainer.fit(model) + result = trainer.fit(model, dm) assert result == 1 - assert trainer.callback_metrics['loss'] < 0.50 + assert trainer.callback_metrics['loss'] < 0.6 + + +def test_test_loop_only(tmpdir): + reset_seed() + + dm = TrialMNISTDataModule(tmpdir) + + model = EvalModelTemplate() + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=3, + weights_summary=None, + ) + trainer.test(model, datamodule=dm) def test_full_loop(tmpdir): + reset_seed() + dm = TrialMNISTDataModule(tmpdir) - dm.prepare_data() - dm.setup() model = EvalModelTemplate() @@ -102,10 +224,9 @@ def test_full_loop(tmpdir): max_epochs=3, weights_summary=None, ) - trainer.fit(model, dm) # fit model - result = trainer.fit(model) + result = trainer.fit(model, dm) assert result == 1 # test @@ -116,9 +237,9 @@ def test_full_loop(tmpdir): @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="test requires multi-GPU machine") def test_full_loop_single_gpu(tmpdir): + reset_seed() + dm = TrialMNISTDataModule(tmpdir) - dm.prepare_data() - dm.setup() model = EvalModelTemplate() @@ -128,10 +249,9 @@ def test_full_loop_single_gpu(tmpdir): weights_summary=None, gpus=1 ) - trainer.fit(model, dm) # fit model - result = trainer.fit(model) + result = trainer.fit(model, dm) assert result == 1 # test @@ -142,9 +262,9 @@ def test_full_loop_single_gpu(tmpdir): @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") def test_full_loop_dp(tmpdir): + reset_seed() + dm = TrialMNISTDataModule(tmpdir) - dm.prepare_data() - dm.setup() model = EvalModelTemplate() @@ -155,10 +275,9 @@ def test_full_loop_dp(tmpdir): distributed_backend='dp', gpus=2 ) - trainer.fit(model, dm) # fit model - result = trainer.fit(model) + result = trainer.fit(model, dm) assert result == 1 # test @@ -172,9 +291,9 @@ def test_full_loop_ddp_spawn(tmpdir): import os os.environ['CUDA_VISIBLE_DEVICES'] = '0,1' + reset_seed() + dm = TrialMNISTDataModule(tmpdir) - dm.prepare_data() - dm.setup() model = EvalModelTemplate() @@ -185,10 +304,9 @@ def test_full_loop_ddp_spawn(tmpdir): distributed_backend='ddp_spawn', gpus=[0, 1] ) - trainer.fit(model, dm) # fit model - result = trainer.fit(model) + result = trainer.fit(model, dm) assert result == 1 # test