Skip to content

Commit

Permalink
📝 docs
Browse files Browse the repository at this point in the history
  • Loading branch information
nateraw committed Jul 30, 2020
1 parent 7b49a56 commit 7839df4
Showing 1 changed file with 104 additions and 70 deletions.
174 changes: 104 additions & 70 deletions docs/source/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,33 +11,40 @@ Data preparation in PyTorch follows 5 steps:
A DataModule is simply a collection of a train_dataloader, val_dataloader(s), test_dataloader(s) along with the
matching transforms and data processing/downloads steps required.

.. code-block:: python
class MNISTDataModule(LightningDataModule):
def __init__(self, data_dir: str = './'):
super().__init__()
self.data_dir = data_dir
def prepare_data(self):
# download
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
def setup(self, stage):
# Assign train/val datasets for use in dataloaders
if stage == 'fit':
mnist_full = MNIST(self.data_dir, train=True, download=True)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
# Assign test dataset for use in dataloader(s)
if stage == 'test':
self.mnist_test = MNIST(self.data_dir, train=False, download=True)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=32)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=32)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=32)
>>> import pytorch_lightning as pl
>>> class MNISTDataModule(pl.LightningDataModule):
... def prepare_data(self):
... # download
... MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
... MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())
...
... def setup(self):
... mnist_train = MNIST(os.getcwd(), train=True, download=False, transform=transforms.ToTensor())
... mnist_test = MNIST(os.getcwd(), train=False, download=False, transform=transforms.ToTensor())
... # train/val split
... mnist_train, mnist_val = random_split(mnist_train, [55000, 5000])
...
... # assign to use in dataloaders
... self.train_dataset = mnist_train
... self.val_dataset = mnist_val
... self.test_dataset = mnist_test
...
... def train_dataloader(self):
... return DataLoader(self.train_dataset, batch_size=64)
...
... def val_dataloader(self):
... return DataLoader(self.val_dataset, batch_size=64)
...
... def test_dataloader(self):
... return DataLoader(self.test_dataset, batch_size=64)
.. note:: ``setup`` expects a string arg ``stage``. It is used to separate setup logic for ``trainer.fit`` and ``trainer.test``.

---------------

Expand All @@ -60,11 +67,13 @@ settings.
- tokenize
- etc...

>>> class MNISTDataModule(pl.LightningDataModule):
... def prepare_data(self):
... # download
... MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
... MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())
.. code-block:: python
class MNISTDataModule(pl.LightningDataModule):
def prepare_data(self):
# download
MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())
.. warning:: `prepare_data` is called from a single GPU. Do not use it to assign state (`self.x = y`).

Expand All @@ -77,33 +86,46 @@ There are also data operations you might want to perform on every GPU. Use setup
- perform train/val/test splits
- etc...

>>> import pytorch_lightning as pl
>>> class MNISTDataModule(pl.LightningDataModule):
... def setup(self):
... mnist_train = MNIST(os.getcwd(), train=True, download=False, transform=transforms.ToTensor())
... mnist_test = MNIST(os.getcwd(), train=False, download=False, transform=transforms.ToTensor())
... # train/val split
... mnist_train, mnist_val = random_split(mnist_train, [55000, 5000])
...
... # assign to use in dataloaders
... self.train_dataset = mnist_train
... self.val_dataset = mnist_val
... self.test_dataset = mnist_test
.. code-block:: python
import pytorch_lightning as pl
class MNISTDataModule(pl.LightningDataModule):
def setup(self, stage: Optional[str] = None):
# Assign Train/val split(s) for use in Dataloaders
if stage == 'fit' or stage is None:
mnist_full = MNIST(self.data_dir, train=True, download=True)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
self.dims = self.mnist_train[0][0].shape
# Assign Test split(s) for use in Dataloaders
if stage == 'test' or stage is None:
self.mnist_test = MNIST(self.data_dir, train=False, download=True)
self.dims = getattr(self, 'dims', self.mnist_test[0][0].shape)
.. warning:: `setup` is called from every GPU. Setting state here is okay.


train_dataloader
^^^^^^^^^^^^^^^^
Use this method to generate the train dataloader. This is also a good place to place default transformations.

>>> import pytorch_lightning as pl
>>> class MNISTDataModule(pl.LightningDataModule):
... def train_dataloader(self):
... transforms = transform_lib.Compose([
... transform_lib.ToTensor(),
... transform_lib.Normalize(mean=(0.5,), std=(0.5,)),
... ])
... return DataLoader(self.train_dataset, transform=transforms, batch_size=64)
.. code-block:: python
import pytorch_lightning as pl
class MNISTDataModule(pl.LightningDataModule):
def train_dataloader(self):
transforms = transform_lib.Compose([
transform_lib.ToTensor(),
transform_lib.Normalize(mean=(0.5,), std=(0.5,)),
])
return DataLoader(self.train_dataset, transform=transforms, batch_size=64)
However, to decouple your data from transforms you can parametrize them via `__init__`.

Expand All @@ -119,32 +141,41 @@ val_dataloader
^^^^^^^^^^^^^^
Use this method to generate the val dataloader. This is also a good place to place default transformations.

>>> import pytorch_lightning as pl
>>> class MNISTDataModule(pl.LightningDataModule):
... def val_dataloader(self):
... transforms = transform_lib.Compose([
... transform_lib.ToTensor(),
... transform_lib.Normalize(mean=(0.5,), std=(0.5,)),
... ])
... return DataLoader(self.val_dataset, transform=transforms, batch_size=64)
.. code-block:: python
import pytorch_lightning as pl
class MNISTDataModule(pl.LightningDataModule):
def val_dataloader(self):
transforms = transform_lib.Compose([
transform_lib.ToTensor(),
transform_lib.Normalize(mean=(0.5,), std=(0.5,)),
])
return DataLoader(self.val_dataset, transform=transforms, batch_size=64)
test_dataloader
^^^^^^^^^^^^^^^
Use this method to generate the test dataloader. This is also a good place to place default transformations.

>>> import pytorch_lightning as pl
>>> class MNISTDataModule(pl.LightningDataModule):
... def test_dataloader(self):
... transforms = transform_lib.Compose([
... transform_lib.ToTensor(),
... transform_lib.Normalize(mean=(0.5,), std=(0.5,)),
... ])
... return DataLoader(self.test_dataset, transform=transforms, batch_size=64)
.. code-block:: python
import pytorch_lightning as pl
class MNISTDataModule(pl.LightningDataModule):
def test_dataloader(self):
transforms = transform_lib.Compose([
transform_lib.ToTensor(),
transform_lib.Normalize(mean=(0.5,), std=(0.5,)),
])
return DataLoader(self.test_dataset, transform=transforms, batch_size=64)
------------------

Using a DataModule
------------------

The recommended way to use a DataModule is simply:

.. code-block:: python
Expand All @@ -162,12 +193,13 @@ still ensures the method runs on the correct devices)
dm = MNISTDataModule()
dm.prepare_data()
dm.setup()
dm.setup('fit')
model = Model(num_classes=dm.num_classes, width=dm.width, vocab=dm.vocab)
trainer.fit(model, dm)
trainer.test(model, datamodule=dm)
dm.setup('test')
trainer.test(datamodule=dm)
----------------

Expand All @@ -184,12 +216,14 @@ DataModules have a few key advantages:
dm = MNISTDataModule()
dm.prepare_data()
dm.setup()
dm.setup('fit')
for batch in dm.train_dataloader():
...
for batch in dm.val_dataloader():
...
dm.setup('test')
for batch in dm.test_dataloader():
...
Expand Down

0 comments on commit 7839df4

Please sign in to comment.