diff --git a/docs/source/datamodules.rst b/docs/source/datamodules.rst index 5ffe5de763b698..12b26b4a23857d 100644 --- a/docs/source/datamodules.rst +++ b/docs/source/datamodules.rst @@ -11,33 +11,40 @@ Data preparation in PyTorch follows 5 steps: A DataModule is simply a collection of a train_dataloader, val_dataloader(s), test_dataloader(s) along with the matching transforms and data processing/downloads steps required. +.. code-block:: python + + class MNISTDataModule(LightningDataModule): + + def __init__(self, data_dir: str = './'): + super().__init__() + self.data_dir = data_dir + + def prepare_data(self): + # download + MNIST(self.data_dir, train=True, download=True) + MNIST(self.data_dir, train=False, download=True) + + def setup(self, stage): + + # Assign train/val datasets for use in dataloaders + if stage == 'fit': + mnist_full = MNIST(self.data_dir, train=True, download=True) + self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000]) + + # Assign test dataset for use in dataloader(s) + if stage == 'test': + self.mnist_test = MNIST(self.data_dir, train=False, download=True) + + def train_dataloader(self): + return DataLoader(self.mnist_train, batch_size=32) + + def val_dataloader(self): + return DataLoader(self.mnist_val, batch_size=32) + + def test_dataloader(self): + return DataLoader(self.mnist_test, batch_size=32) - >>> import pytorch_lightning as pl - >>> class MNISTDataModule(pl.LightningDataModule): - ... def prepare_data(self): - ... # download - ... MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()) - ... MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()) - ... - ... def setup(self): - ... mnist_train = MNIST(os.getcwd(), train=True, download=False, transform=transforms.ToTensor()) - ... mnist_test = MNIST(os.getcwd(), train=False, download=False, transform=transforms.ToTensor()) - ... # train/val split - ... mnist_train, mnist_val = random_split(mnist_train, [55000, 5000]) - ... - ... # assign to use in dataloaders - ... self.train_dataset = mnist_train - ... self.val_dataset = mnist_val - ... self.test_dataset = mnist_test - ... - ... def train_dataloader(self): - ... return DataLoader(self.train_dataset, batch_size=64) - ... - ... def val_dataloader(self): - ... return DataLoader(self.val_dataset, batch_size=64) - ... - ... def test_dataloader(self): - ... return DataLoader(self.test_dataset, batch_size=64) +.. note:: ``setup`` expects a string arg ``stage``. It is used to separate setup logic for ``trainer.fit`` and ``trainer.test``. --------------- @@ -60,11 +67,13 @@ settings. - tokenize - etc... - >>> class MNISTDataModule(pl.LightningDataModule): - ... def prepare_data(self): - ... # download - ... MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()) - ... MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()) +.. code-block:: python + + class MNISTDataModule(pl.LightningDataModule): + def prepare_data(self): + # download + MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()) + MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()) .. warning:: `prepare_data` is called from a single GPU. Do not use it to assign state (`self.x = y`). @@ -77,33 +86,46 @@ There are also data operations you might want to perform on every GPU. Use setup - perform train/val/test splits - etc... - >>> import pytorch_lightning as pl - >>> class MNISTDataModule(pl.LightningDataModule): - ... def setup(self): - ... mnist_train = MNIST(os.getcwd(), train=True, download=False, transform=transforms.ToTensor()) - ... mnist_test = MNIST(os.getcwd(), train=False, download=False, transform=transforms.ToTensor()) - ... # train/val split - ... mnist_train, mnist_val = random_split(mnist_train, [55000, 5000]) - ... - ... # assign to use in dataloaders - ... self.train_dataset = mnist_train - ... self.val_dataset = mnist_val - ... self.test_dataset = mnist_test +.. code-block:: python + + import pytorch_lightning as pl + + + class MNISTDataModule(pl.LightningDataModule): + + def setup(self, stage: Optional[str] = None): + + # Assign Train/val split(s) for use in Dataloaders + if stage == 'fit' or stage is None: + mnist_full = MNIST(self.data_dir, train=True, download=True) + self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000]) + self.dims = self.mnist_train[0][0].shape + + # Assign Test split(s) for use in Dataloaders + if stage == 'test' or stage is None: + self.mnist_test = MNIST(self.data_dir, train=False, download=True) + self.dims = getattr(self, 'dims', self.mnist_test[0][0].shape) + .. warning:: `setup` is called from every GPU. Setting state here is okay. + train_dataloader ^^^^^^^^^^^^^^^^ Use this method to generate the train dataloader. This is also a good place to place default transformations. - >>> import pytorch_lightning as pl - >>> class MNISTDataModule(pl.LightningDataModule): - ... def train_dataloader(self): - ... transforms = transform_lib.Compose([ - ... transform_lib.ToTensor(), - ... transform_lib.Normalize(mean=(0.5,), std=(0.5,)), - ... ]) - ... return DataLoader(self.train_dataset, transform=transforms, batch_size=64) +.. code-block:: python + + import pytorch_lightning as pl + + + class MNISTDataModule(pl.LightningDataModule): + def train_dataloader(self): + transforms = transform_lib.Compose([ + transform_lib.ToTensor(), + transform_lib.Normalize(mean=(0.5,), std=(0.5,)), + ]) + return DataLoader(self.train_dataset, transform=transforms, batch_size=64) However, to decouple your data from transforms you can parametrize them via `__init__`. @@ -119,32 +141,41 @@ val_dataloader ^^^^^^^^^^^^^^ Use this method to generate the val dataloader. This is also a good place to place default transformations. - >>> import pytorch_lightning as pl - >>> class MNISTDataModule(pl.LightningDataModule): - ... def val_dataloader(self): - ... transforms = transform_lib.Compose([ - ... transform_lib.ToTensor(), - ... transform_lib.Normalize(mean=(0.5,), std=(0.5,)), - ... ]) - ... return DataLoader(self.val_dataset, transform=transforms, batch_size=64) +.. code-block:: python + + import pytorch_lightning as pl + + + class MNISTDataModule(pl.LightningDataModule): + def val_dataloader(self): + transforms = transform_lib.Compose([ + transform_lib.ToTensor(), + transform_lib.Normalize(mean=(0.5,), std=(0.5,)), + ]) + return DataLoader(self.val_dataset, transform=transforms, batch_size=64) test_dataloader ^^^^^^^^^^^^^^^ Use this method to generate the test dataloader. This is also a good place to place default transformations. - >>> import pytorch_lightning as pl - >>> class MNISTDataModule(pl.LightningDataModule): - ... def test_dataloader(self): - ... transforms = transform_lib.Compose([ - ... transform_lib.ToTensor(), - ... transform_lib.Normalize(mean=(0.5,), std=(0.5,)), - ... ]) - ... return DataLoader(self.test_dataset, transform=transforms, batch_size=64) +.. code-block:: python + + import pytorch_lightning as pl + + + class MNISTDataModule(pl.LightningDataModule): + def test_dataloader(self): + transforms = transform_lib.Compose([ + transform_lib.ToTensor(), + transform_lib.Normalize(mean=(0.5,), std=(0.5,)), + ]) + return DataLoader(self.test_dataset, transform=transforms, batch_size=64) ------------------ Using a DataModule ------------------ + The recommended way to use a DataModule is simply: .. code-block:: python @@ -162,12 +193,13 @@ still ensures the method runs on the correct devices) dm = MNISTDataModule() dm.prepare_data() - dm.setup() + dm.setup('fit') model = Model(num_classes=dm.num_classes, width=dm.width, vocab=dm.vocab) trainer.fit(model, dm) - trainer.test(model, datamodule=dm) + dm.setup('test') + trainer.test(datamodule=dm) ---------------- @@ -184,12 +216,14 @@ DataModules have a few key advantages: dm = MNISTDataModule() dm.prepare_data() - dm.setup() + dm.setup('fit') for batch in dm.train_dataloader(): ... for batch in dm.val_dataloader(): ... + + dm.setup('test') for batch in dm.test_dataloader(): ...