Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Call DataModule hooks implicitly in trainer #2755

Merged
merged 21 commits into from
Aug 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 127 additions & 70 deletions docs/source/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,33 +11,63 @@ Data preparation in PyTorch follows 5 steps:
A DataModule is simply a collection of a train_dataloader, val_dataloader(s), test_dataloader(s) along with the
matching transforms and data processing/downloads steps required.

.. code-block:: python
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not test code?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are a ton of unnecessary doctests.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What you mean by unnecessary, that there is no need for examples? Otherwise all examples shall be tested that they are aligned with actual code base...
cc: @awaelchli

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree, this code block would be perfect for a doctest, and it is as simple as adding .. testcode. Even if we don't make any assertions here, Python will parse the code and run the import statements. It would help us keep the docs up-to-date with the api.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah we do not need anything extra, but it checks syntax and eventually nb of passed arguments or kwargs


import pytorch_lightning as pl
from torch.utils.data import random_split, DataLoader

# Note - you must have torchvision installed for this example
from torchvision.datasets import MNIST
from torchvision import transforms


class MNISTDataModule(pl.LightningDataModule):

def __init__(self, data_dir: str = './'):
super().__init__()
self.data_dir = data_dir
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])

# self.dims is returned when you call dm.size()
# Setting default dims here because we know them.
# Could optionally be assigned dynamically in dm.setup()
self.dims = (1, 28, 28)

def prepare_data(self):
# download
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)

def setup(self, stage=None):

# Assign train/val datasets for use in dataloaders
if stage == 'fit' or stage is None:
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

# Optionally...
# self.dims = tuple(self.mnist_train[0][0].shape)

>>> import pytorch_lightning as pl
>>> class MNISTDataModule(pl.LightningDataModule):
... def prepare_data(self):
... # download
... MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
... MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())
...
... def setup(self, stage):
nateraw marked this conversation as resolved.
Show resolved Hide resolved
... mnist_train = MNIST(os.getcwd(), train=True, download=False, transform=transforms.ToTensor())
... mnist_test = MNIST(os.getcwd(), train=False, download=False, transform=transforms.ToTensor())
... # train/val split
... mnist_train, mnist_val = random_split(mnist_train, [55000, 5000])
...
... # assign to use in dataloaders
... self.train_dataset = mnist_train
... self.val_dataset = mnist_val
... self.test_dataset = mnist_test
...
... def train_dataloader(self):
... return DataLoader(self.train_dataset, batch_size=64)
...
... def val_dataloader(self):
... return DataLoader(self.val_dataset, batch_size=64)
...
... def test_dataloader(self):
... return DataLoader(self.test_dataset, batch_size=64)
# Assign test dataset for use in dataloader(s)
if stage == 'test' or stage is None:
self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

# Optionally...
# self.dims = tuple(self.mnist_test[0][0].shape)

def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add batch_size as a param?


def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=32)

def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=32)

.. note:: ``setup`` expects a string arg ``stage``. It is used to separate setup logic for ``trainer.fit`` and ``trainer.test``.

---------------

Expand All @@ -60,11 +90,13 @@ settings.
- tokenize
- etc...

>>> class MNISTDataModule(pl.LightningDataModule):
... def prepare_data(self):
... # download
... MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
... MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())
.. code-block:: python

class MNISTDataModule(pl.LightningDataModule):
def prepare_data(self):
# download
MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())

.. warning:: `prepare_data` is called from a single GPU. Do not use it to assign state (`self.x = y`).

Expand All @@ -77,33 +109,46 @@ There are also data operations you might want to perform on every GPU. Use setup
- perform train/val/test splits
- etc...

>>> import pytorch_lightning as pl
>>> class MNISTDataModule(pl.LightningDataModule):
... def setup(self, stage):
... mnist_train = MNIST(os.getcwd(), train=True, download=False, transform=transforms.ToTensor())
... mnist_test = MNIST(os.getcwd(), train=False, download=False, transform=transforms.ToTensor())
... # train/val split
... mnist_train, mnist_val = random_split(mnist_train, [55000, 5000])
...
... # assign to use in dataloaders
... self.train_dataset = mnist_train
... self.val_dataset = mnist_val
... self.test_dataset = mnist_test
.. code-block:: python

import pytorch_lightning as pl


class MNISTDataModule(pl.LightningDataModule):

def setup(self, stage: Optional[str] = None):

# Assign Train/val split(s) for use in Dataloaders
if stage == 'fit' or stage is None:
mnist_full = MNIST(self.data_dir, train=True, download=True)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
self.dims = self.mnist_train[0][0].shape

# Assign Test split(s) for use in Dataloaders
if stage == 'test' or stage is None:
self.mnist_test = MNIST(self.data_dir, train=False, download=True)
self.dims = getattr(self, 'dims', self.mnist_test[0][0].shape)


.. warning:: `setup` is called from every GPU. Setting state here is okay.


train_dataloader
^^^^^^^^^^^^^^^^
Use this method to generate the train dataloader. This is also a good place to place default transformations.

>>> import pytorch_lightning as pl
>>> class MNISTDataModule(pl.LightningDataModule):
... def train_dataloader(self):
... transforms = transform_lib.Compose([
... transform_lib.ToTensor(),
... transform_lib.Normalize(mean=(0.5,), std=(0.5,)),
... ])
... return DataLoader(self.train_dataset, transform=transforms, batch_size=64)
.. code-block:: python

import pytorch_lightning as pl


class MNISTDataModule(pl.LightningDataModule):
def train_dataloader(self):
transforms = transform_lib.Compose([
transform_lib.ToTensor(),
transform_lib.Normalize(mean=(0.5,), std=(0.5,)),
])
return DataLoader(self.train_dataset, transform=transforms, batch_size=64)

However, to decouple your data from transforms you can parametrize them via `__init__`.

Expand All @@ -119,32 +164,41 @@ val_dataloader
^^^^^^^^^^^^^^
Use this method to generate the val dataloader. This is also a good place to place default transformations.

>>> import pytorch_lightning as pl
>>> class MNISTDataModule(pl.LightningDataModule):
... def val_dataloader(self):
... transforms = transform_lib.Compose([
... transform_lib.ToTensor(),
... transform_lib.Normalize(mean=(0.5,), std=(0.5,)),
... ])
... return DataLoader(self.val_dataset, transform=transforms, batch_size=64)
.. code-block:: python

import pytorch_lightning as pl


class MNISTDataModule(pl.LightningDataModule):
def val_dataloader(self):
transforms = transform_lib.Compose([
transform_lib.ToTensor(),
transform_lib.Normalize(mean=(0.5,), std=(0.5,)),
])
return DataLoader(self.val_dataset, transform=transforms, batch_size=64)

test_dataloader
^^^^^^^^^^^^^^^
Use this method to generate the test dataloader. This is also a good place to place default transformations.

>>> import pytorch_lightning as pl
>>> class MNISTDataModule(pl.LightningDataModule):
... def test_dataloader(self):
... transforms = transform_lib.Compose([
... transform_lib.ToTensor(),
... transform_lib.Normalize(mean=(0.5,), std=(0.5,)),
... ])
... return DataLoader(self.test_dataset, transform=transforms, batch_size=64)
.. code-block:: python

import pytorch_lightning as pl


class MNISTDataModule(pl.LightningDataModule):
def test_dataloader(self):
transforms = transform_lib.Compose([
transform_lib.ToTensor(),
transform_lib.Normalize(mean=(0.5,), std=(0.5,)),
])
return DataLoader(self.test_dataset, transform=transforms, batch_size=64)

------------------

Using a DataModule
------------------

The recommended way to use a DataModule is simply:

.. code-block:: python
Expand All @@ -162,12 +216,13 @@ still ensures the method runs on the correct devices)

dm = MNISTDataModule()
dm.prepare_data()
dm.setup()
dm.setup('fit')

model = Model(num_classes=dm.num_classes, width=dm.width, vocab=dm.vocab)
trainer.fit(model, dm)

trainer.test(model, datamodule=dm)
dm.setup('test')
trainer.test(datamodule=dm)

----------------

Expand All @@ -184,12 +239,14 @@ DataModules have a few key advantages:

dm = MNISTDataModule()
dm.prepare_data()
dm.setup()

dm.setup('fit')
for batch in dm.train_dataloader():
...
for batch in dm.val_dataloader():
...

dm.setup('test')
for batch in dm.test_dataloader():
...

Expand Down
4 changes: 1 addition & 3 deletions pytorch_lightning/accelerator_backends/cpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ def setup(self, model):
raise MisconfigurationException('amp + cpu is not supported. Please use a GPU option')

# call setup after the ddp process has connected
if not self.trainer.testing:
self.trainer.setup('fit')
model.setup('fit')
self.trainer.call_setup_hook(model)

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,7 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0
)

# call setup after the ddp process has connected
if not self.trainer.testing:
self.trainer.setup('fit')
model.setup('fit')
self.trainer.call_setup_hook(model)

# on world_size=0 let everyone know training is starting
if self.trainer.is_global_zero:
Expand Down
4 changes: 1 addition & 3 deletions pytorch_lightning/accelerator_backends/dp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,7 @@ def __init__(self, trainer):

def setup(self, model):
# call setup after the ddp process has connected
if not self.trainer.testing:
self.trainer.setup('fit')
model.setup('fit')
self.trainer.call_setup_hook(model)

# put model on correct device
model.cuda(self.trainer.root_gpu)
Expand Down
4 changes: 1 addition & 3 deletions pytorch_lightning/accelerator_backends/gpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@ def __init__(self, trainer):
def setup(self, model):

# call setup
if not self.trainer.testing:
self.trainer.setup('fit')
model.setup('fit')
self.trainer.call_setup_hook(model)

model.cuda(self.trainer.root_gpu)

Expand Down
5 changes: 2 additions & 3 deletions pytorch_lightning/accelerator_backends/tpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,8 @@ def tpu_train_in_process(self, tpu_core_idx: int, model: LightningModule, traine
"""
if not trainer:
trainer = self.trainer
if not trainer.testing:
trainer.setup('fit')
model.setup('fit')

trainer.call_setup_hook(model)

# setup TPU training
self.__setup_tpu_training(model, trainer)
Expand Down
Loading