Skip to content

Commit

Permalink
Call DataModule hooks implicitly in trainer (#2755)
Browse files Browse the repository at this point in the history
* ✨ call dm hooks in trainer implicitly

* ✅ update tests

* 📝 remove unused stage arg from dm docs

* ✅ update tests

* ✅ update tests

* 🚧 include stage in datamodule.setup

* 📝 docs

* 📝 docs

* added more dm tests

* added more dm tests

* 🐛 call dm.setup everywhere

* 🔥 pickle tests now implied by accelerator tests

* 🎨 set dm as attr of trainer

* 🐛 .

* 🚧 wip

* add can prepare test

* add can prepare test

* verified setup in fit

* fixed setup call

* fixed setup call

* fixed setup call

Co-authored-by: William Falcon <waf2107@columbia.edu>
  • Loading branch information
nateraw and williamFalcon committed Aug 2, 2020
1 parent f9ccb0f commit 036bcea
Show file tree
Hide file tree
Showing 12 changed files with 427 additions and 146 deletions.
197 changes: 127 additions & 70 deletions docs/source/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,33 +11,63 @@ Data preparation in PyTorch follows 5 steps:
A DataModule is simply a collection of a train_dataloader, val_dataloader(s), test_dataloader(s) along with the
matching transforms and data processing/downloads steps required.

.. code-block:: python
import pytorch_lightning as pl
from torch.utils.data import random_split, DataLoader
# Note - you must have torchvision installed for this example
from torchvision.datasets import MNIST
from torchvision import transforms
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, data_dir: str = './'):
super().__init__()
self.data_dir = data_dir
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
# self.dims is returned when you call dm.size()
# Setting default dims here because we know them.
# Could optionally be assigned dynamically in dm.setup()
self.dims = (1, 28, 28)
def prepare_data(self):
# download
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
def setup(self, stage=None):
# Assign train/val datasets for use in dataloaders
if stage == 'fit' or stage is None:
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
# Optionally...
# self.dims = tuple(self.mnist_train[0][0].shape)
>>> import pytorch_lightning as pl
>>> class MNISTDataModule(pl.LightningDataModule):
... def prepare_data(self):
... # download
... MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
... MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())
...
... def setup(self, stage):
... mnist_train = MNIST(os.getcwd(), train=True, download=False, transform=transforms.ToTensor())
... mnist_test = MNIST(os.getcwd(), train=False, download=False, transform=transforms.ToTensor())
... # train/val split
... mnist_train, mnist_val = random_split(mnist_train, [55000, 5000])
...
... # assign to use in dataloaders
... self.train_dataset = mnist_train
... self.val_dataset = mnist_val
... self.test_dataset = mnist_test
...
... def train_dataloader(self):
... return DataLoader(self.train_dataset, batch_size=64)
...
... def val_dataloader(self):
... return DataLoader(self.val_dataset, batch_size=64)
...
... def test_dataloader(self):
... return DataLoader(self.test_dataset, batch_size=64)
# Assign test dataset for use in dataloader(s)
if stage == 'test' or stage is None:
self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
# Optionally...
# self.dims = tuple(self.mnist_test[0][0].shape)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=32)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=32)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=32)
.. note:: ``setup`` expects a string arg ``stage``. It is used to separate setup logic for ``trainer.fit`` and ``trainer.test``.

---------------

Expand All @@ -60,11 +90,13 @@ settings.
- tokenize
- etc...

>>> class MNISTDataModule(pl.LightningDataModule):
... def prepare_data(self):
... # download
... MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
... MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())
.. code-block:: python
class MNISTDataModule(pl.LightningDataModule):
def prepare_data(self):
# download
MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())
.. warning:: `prepare_data` is called from a single GPU. Do not use it to assign state (`self.x = y`).

Expand All @@ -77,33 +109,46 @@ There are also data operations you might want to perform on every GPU. Use setup
- perform train/val/test splits
- etc...

>>> import pytorch_lightning as pl
>>> class MNISTDataModule(pl.LightningDataModule):
... def setup(self, stage):
... mnist_train = MNIST(os.getcwd(), train=True, download=False, transform=transforms.ToTensor())
... mnist_test = MNIST(os.getcwd(), train=False, download=False, transform=transforms.ToTensor())
... # train/val split
... mnist_train, mnist_val = random_split(mnist_train, [55000, 5000])
...
... # assign to use in dataloaders
... self.train_dataset = mnist_train
... self.val_dataset = mnist_val
... self.test_dataset = mnist_test
.. code-block:: python
import pytorch_lightning as pl
class MNISTDataModule(pl.LightningDataModule):
def setup(self, stage: Optional[str] = None):
# Assign Train/val split(s) for use in Dataloaders
if stage == 'fit' or stage is None:
mnist_full = MNIST(self.data_dir, train=True, download=True)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
self.dims = self.mnist_train[0][0].shape
# Assign Test split(s) for use in Dataloaders
if stage == 'test' or stage is None:
self.mnist_test = MNIST(self.data_dir, train=False, download=True)
self.dims = getattr(self, 'dims', self.mnist_test[0][0].shape)
.. warning:: `setup` is called from every GPU. Setting state here is okay.


train_dataloader
^^^^^^^^^^^^^^^^
Use this method to generate the train dataloader. This is also a good place to place default transformations.

>>> import pytorch_lightning as pl
>>> class MNISTDataModule(pl.LightningDataModule):
... def train_dataloader(self):
... transforms = transform_lib.Compose([
... transform_lib.ToTensor(),
... transform_lib.Normalize(mean=(0.5,), std=(0.5,)),
... ])
... return DataLoader(self.train_dataset, transform=transforms, batch_size=64)
.. code-block:: python
import pytorch_lightning as pl
class MNISTDataModule(pl.LightningDataModule):
def train_dataloader(self):
transforms = transform_lib.Compose([
transform_lib.ToTensor(),
transform_lib.Normalize(mean=(0.5,), std=(0.5,)),
])
return DataLoader(self.train_dataset, transform=transforms, batch_size=64)
However, to decouple your data from transforms you can parametrize them via `__init__`.

Expand All @@ -119,32 +164,41 @@ val_dataloader
^^^^^^^^^^^^^^
Use this method to generate the val dataloader. This is also a good place to place default transformations.

>>> import pytorch_lightning as pl
>>> class MNISTDataModule(pl.LightningDataModule):
... def val_dataloader(self):
... transforms = transform_lib.Compose([
... transform_lib.ToTensor(),
... transform_lib.Normalize(mean=(0.5,), std=(0.5,)),
... ])
... return DataLoader(self.val_dataset, transform=transforms, batch_size=64)
.. code-block:: python
import pytorch_lightning as pl
class MNISTDataModule(pl.LightningDataModule):
def val_dataloader(self):
transforms = transform_lib.Compose([
transform_lib.ToTensor(),
transform_lib.Normalize(mean=(0.5,), std=(0.5,)),
])
return DataLoader(self.val_dataset, transform=transforms, batch_size=64)
test_dataloader
^^^^^^^^^^^^^^^
Use this method to generate the test dataloader. This is also a good place to place default transformations.

>>> import pytorch_lightning as pl
>>> class MNISTDataModule(pl.LightningDataModule):
... def test_dataloader(self):
... transforms = transform_lib.Compose([
... transform_lib.ToTensor(),
... transform_lib.Normalize(mean=(0.5,), std=(0.5,)),
... ])
... return DataLoader(self.test_dataset, transform=transforms, batch_size=64)
.. code-block:: python
import pytorch_lightning as pl
class MNISTDataModule(pl.LightningDataModule):
def test_dataloader(self):
transforms = transform_lib.Compose([
transform_lib.ToTensor(),
transform_lib.Normalize(mean=(0.5,), std=(0.5,)),
])
return DataLoader(self.test_dataset, transform=transforms, batch_size=64)
------------------

Using a DataModule
------------------

The recommended way to use a DataModule is simply:

.. code-block:: python
Expand All @@ -162,12 +216,13 @@ still ensures the method runs on the correct devices)
dm = MNISTDataModule()
dm.prepare_data()
dm.setup()
dm.setup('fit')
model = Model(num_classes=dm.num_classes, width=dm.width, vocab=dm.vocab)
trainer.fit(model, dm)
trainer.test(model, datamodule=dm)
dm.setup('test')
trainer.test(datamodule=dm)
----------------

Expand All @@ -184,12 +239,14 @@ DataModules have a few key advantages:
dm = MNISTDataModule()
dm.prepare_data()
dm.setup()
dm.setup('fit')
for batch in dm.train_dataloader():
...
for batch in dm.val_dataloader():
...
dm.setup('test')
for batch in dm.test_dataloader():
...
Expand Down
4 changes: 1 addition & 3 deletions pytorch_lightning/accelerator_backends/cpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ def setup(self, model):
raise MisconfigurationException('amp + cpu is not supported. Please use a GPU option')

# call setup after the ddp process has connected
if not self.trainer.testing:
self.trainer.setup('fit')
model.setup('fit')
self.trainer.call_setup_hook(model)

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
Expand Down
4 changes: 1 addition & 3 deletions pytorch_lightning/accelerator_backends/ddp_spawn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,7 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0
)

# call setup after the ddp process has connected
if not self.trainer.testing:
self.trainer.setup('fit')
model.setup('fit')
self.trainer.call_setup_hook(model)

# on world_size=0 let everyone know training is starting
if self.trainer.is_global_zero:
Expand Down
4 changes: 1 addition & 3 deletions pytorch_lightning/accelerator_backends/dp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,7 @@ def __init__(self, trainer):

def setup(self, model):
# call setup after the ddp process has connected
if not self.trainer.testing:
self.trainer.setup('fit')
model.setup('fit')
self.trainer.call_setup_hook(model)

# put model on correct device
model.cuda(self.trainer.root_gpu)
Expand Down
4 changes: 1 addition & 3 deletions pytorch_lightning/accelerator_backends/gpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@ def __init__(self, trainer):
def setup(self, model):

# call setup
if not self.trainer.testing:
self.trainer.setup('fit')
model.setup('fit')
self.trainer.call_setup_hook(model)

model.cuda(self.trainer.root_gpu)

Expand Down
5 changes: 2 additions & 3 deletions pytorch_lightning/accelerator_backends/tpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,8 @@ def tpu_train_in_process(self, tpu_core_idx: int, model: LightningModule, traine
"""
if not trainer:
trainer = self.trainer
if not trainer.testing:
trainer.setup('fit')
model.setup('fit')

trainer.call_setup_hook(model)

# setup TPU training
self.__setup_tpu_training(model, trainer)
Expand Down
Loading

0 comments on commit 036bcea

Please sign in to comment.