-
Notifications
You must be signed in to change notification settings - Fork 3.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Call DataModule hooks implicitly in trainer #2755
Merged
Merged
Changes from all commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
8dc681a
:sparkles: call dm hooks in trainer implicitly
nateraw 5064b7d
:white_check_mark: update tests
nateraw 0e43c0b
:pencil: remove unused stage arg from dm docs
nateraw 8550fb3
:white_check_mark: update tests
nateraw 94c1eb1
:white_check_mark: update tests
nateraw 05a16d7
:construction: include stage in datamodule.setup
nateraw d55bcd7
:pencil: docs
nateraw 981378c
:pencil: docs
nateraw 9331b60
added more dm tests
williamFalcon 6be261b
added more dm tests
williamFalcon 5233ac7
:bug: call dm.setup everywhere
nateraw cb1b848
:fire: pickle tests now implied by accelerator tests
nateraw 1b77442
:art: set dm as attr of trainer
nateraw f059cc4
:bug: .
nateraw a3be9e7
:construction: wip
nateraw ecc2875
add can prepare test
williamFalcon c54ac9d
add can prepare test
williamFalcon 2bdb10e
verified setup in fit
williamFalcon 8107ef9
fixed setup call
williamFalcon 57e942a
fixed setup call
williamFalcon c138110
fixed setup call
williamFalcon File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,33 +11,63 @@ Data preparation in PyTorch follows 5 steps: | |
A DataModule is simply a collection of a train_dataloader, val_dataloader(s), test_dataloader(s) along with the | ||
matching transforms and data processing/downloads steps required. | ||
|
||
.. code-block:: python | ||
|
||
import pytorch_lightning as pl | ||
from torch.utils.data import random_split, DataLoader | ||
|
||
# Note - you must have torchvision installed for this example | ||
from torchvision.datasets import MNIST | ||
from torchvision import transforms | ||
|
||
|
||
class MNISTDataModule(pl.LightningDataModule): | ||
|
||
def __init__(self, data_dir: str = './'): | ||
super().__init__() | ||
self.data_dir = data_dir | ||
self.transform = transforms.Compose([ | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.1307,), (0.3081,)) | ||
]) | ||
|
||
# self.dims is returned when you call dm.size() | ||
# Setting default dims here because we know them. | ||
# Could optionally be assigned dynamically in dm.setup() | ||
self.dims = (1, 28, 28) | ||
|
||
def prepare_data(self): | ||
# download | ||
MNIST(self.data_dir, train=True, download=True) | ||
MNIST(self.data_dir, train=False, download=True) | ||
|
||
def setup(self, stage=None): | ||
|
||
# Assign train/val datasets for use in dataloaders | ||
if stage == 'fit' or stage is None: | ||
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform) | ||
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000]) | ||
|
||
# Optionally... | ||
# self.dims = tuple(self.mnist_train[0][0].shape) | ||
|
||
>>> import pytorch_lightning as pl | ||
>>> class MNISTDataModule(pl.LightningDataModule): | ||
... def prepare_data(self): | ||
... # download | ||
... MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()) | ||
... MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()) | ||
... | ||
... def setup(self, stage): | ||
nateraw marked this conversation as resolved.
Show resolved
Hide resolved
|
||
... mnist_train = MNIST(os.getcwd(), train=True, download=False, transform=transforms.ToTensor()) | ||
... mnist_test = MNIST(os.getcwd(), train=False, download=False, transform=transforms.ToTensor()) | ||
... # train/val split | ||
... mnist_train, mnist_val = random_split(mnist_train, [55000, 5000]) | ||
... | ||
... # assign to use in dataloaders | ||
... self.train_dataset = mnist_train | ||
... self.val_dataset = mnist_val | ||
... self.test_dataset = mnist_test | ||
... | ||
... def train_dataloader(self): | ||
... return DataLoader(self.train_dataset, batch_size=64) | ||
... | ||
... def val_dataloader(self): | ||
... return DataLoader(self.val_dataset, batch_size=64) | ||
... | ||
... def test_dataloader(self): | ||
... return DataLoader(self.test_dataset, batch_size=64) | ||
# Assign test dataset for use in dataloader(s) | ||
if stage == 'test' or stage is None: | ||
self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform) | ||
|
||
# Optionally... | ||
# self.dims = tuple(self.mnist_test[0][0].shape) | ||
|
||
def train_dataloader(self): | ||
return DataLoader(self.mnist_train, batch_size=32) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe add batch_size as a param? |
||
|
||
def val_dataloader(self): | ||
return DataLoader(self.mnist_val, batch_size=32) | ||
|
||
def test_dataloader(self): | ||
return DataLoader(self.mnist_test, batch_size=32) | ||
|
||
.. note:: ``setup`` expects a string arg ``stage``. It is used to separate setup logic for ``trainer.fit`` and ``trainer.test``. | ||
|
||
--------------- | ||
|
||
|
@@ -60,11 +90,13 @@ settings. | |
- tokenize | ||
- etc... | ||
|
||
>>> class MNISTDataModule(pl.LightningDataModule): | ||
... def prepare_data(self): | ||
... # download | ||
... MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()) | ||
... MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()) | ||
.. code-block:: python | ||
|
||
class MNISTDataModule(pl.LightningDataModule): | ||
def prepare_data(self): | ||
# download | ||
MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()) | ||
MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()) | ||
|
||
.. warning:: `prepare_data` is called from a single GPU. Do not use it to assign state (`self.x = y`). | ||
|
||
|
@@ -77,33 +109,46 @@ There are also data operations you might want to perform on every GPU. Use setup | |
- perform train/val/test splits | ||
- etc... | ||
|
||
>>> import pytorch_lightning as pl | ||
>>> class MNISTDataModule(pl.LightningDataModule): | ||
... def setup(self, stage): | ||
... mnist_train = MNIST(os.getcwd(), train=True, download=False, transform=transforms.ToTensor()) | ||
... mnist_test = MNIST(os.getcwd(), train=False, download=False, transform=transforms.ToTensor()) | ||
... # train/val split | ||
... mnist_train, mnist_val = random_split(mnist_train, [55000, 5000]) | ||
... | ||
... # assign to use in dataloaders | ||
... self.train_dataset = mnist_train | ||
... self.val_dataset = mnist_val | ||
... self.test_dataset = mnist_test | ||
.. code-block:: python | ||
|
||
import pytorch_lightning as pl | ||
|
||
|
||
class MNISTDataModule(pl.LightningDataModule): | ||
|
||
def setup(self, stage: Optional[str] = None): | ||
|
||
# Assign Train/val split(s) for use in Dataloaders | ||
if stage == 'fit' or stage is None: | ||
mnist_full = MNIST(self.data_dir, train=True, download=True) | ||
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000]) | ||
self.dims = self.mnist_train[0][0].shape | ||
|
||
# Assign Test split(s) for use in Dataloaders | ||
if stage == 'test' or stage is None: | ||
self.mnist_test = MNIST(self.data_dir, train=False, download=True) | ||
self.dims = getattr(self, 'dims', self.mnist_test[0][0].shape) | ||
|
||
|
||
.. warning:: `setup` is called from every GPU. Setting state here is okay. | ||
|
||
|
||
train_dataloader | ||
^^^^^^^^^^^^^^^^ | ||
Use this method to generate the train dataloader. This is also a good place to place default transformations. | ||
|
||
>>> import pytorch_lightning as pl | ||
>>> class MNISTDataModule(pl.LightningDataModule): | ||
... def train_dataloader(self): | ||
... transforms = transform_lib.Compose([ | ||
... transform_lib.ToTensor(), | ||
... transform_lib.Normalize(mean=(0.5,), std=(0.5,)), | ||
... ]) | ||
... return DataLoader(self.train_dataset, transform=transforms, batch_size=64) | ||
.. code-block:: python | ||
|
||
import pytorch_lightning as pl | ||
|
||
|
||
class MNISTDataModule(pl.LightningDataModule): | ||
def train_dataloader(self): | ||
transforms = transform_lib.Compose([ | ||
transform_lib.ToTensor(), | ||
transform_lib.Normalize(mean=(0.5,), std=(0.5,)), | ||
]) | ||
return DataLoader(self.train_dataset, transform=transforms, batch_size=64) | ||
|
||
However, to decouple your data from transforms you can parametrize them via `__init__`. | ||
|
||
|
@@ -119,32 +164,41 @@ val_dataloader | |
^^^^^^^^^^^^^^ | ||
Use this method to generate the val dataloader. This is also a good place to place default transformations. | ||
|
||
>>> import pytorch_lightning as pl | ||
>>> class MNISTDataModule(pl.LightningDataModule): | ||
... def val_dataloader(self): | ||
... transforms = transform_lib.Compose([ | ||
... transform_lib.ToTensor(), | ||
... transform_lib.Normalize(mean=(0.5,), std=(0.5,)), | ||
... ]) | ||
... return DataLoader(self.val_dataset, transform=transforms, batch_size=64) | ||
.. code-block:: python | ||
|
||
import pytorch_lightning as pl | ||
|
||
|
||
class MNISTDataModule(pl.LightningDataModule): | ||
def val_dataloader(self): | ||
transforms = transform_lib.Compose([ | ||
transform_lib.ToTensor(), | ||
transform_lib.Normalize(mean=(0.5,), std=(0.5,)), | ||
]) | ||
return DataLoader(self.val_dataset, transform=transforms, batch_size=64) | ||
|
||
test_dataloader | ||
^^^^^^^^^^^^^^^ | ||
Use this method to generate the test dataloader. This is also a good place to place default transformations. | ||
|
||
>>> import pytorch_lightning as pl | ||
>>> class MNISTDataModule(pl.LightningDataModule): | ||
... def test_dataloader(self): | ||
... transforms = transform_lib.Compose([ | ||
... transform_lib.ToTensor(), | ||
... transform_lib.Normalize(mean=(0.5,), std=(0.5,)), | ||
... ]) | ||
... return DataLoader(self.test_dataset, transform=transforms, batch_size=64) | ||
.. code-block:: python | ||
|
||
import pytorch_lightning as pl | ||
|
||
|
||
class MNISTDataModule(pl.LightningDataModule): | ||
def test_dataloader(self): | ||
transforms = transform_lib.Compose([ | ||
transform_lib.ToTensor(), | ||
transform_lib.Normalize(mean=(0.5,), std=(0.5,)), | ||
]) | ||
return DataLoader(self.test_dataset, transform=transforms, batch_size=64) | ||
|
||
------------------ | ||
|
||
Using a DataModule | ||
------------------ | ||
|
||
The recommended way to use a DataModule is simply: | ||
|
||
.. code-block:: python | ||
|
@@ -162,12 +216,13 @@ still ensures the method runs on the correct devices) | |
|
||
dm = MNISTDataModule() | ||
dm.prepare_data() | ||
dm.setup() | ||
dm.setup('fit') | ||
|
||
model = Model(num_classes=dm.num_classes, width=dm.width, vocab=dm.vocab) | ||
trainer.fit(model, dm) | ||
|
||
trainer.test(model, datamodule=dm) | ||
dm.setup('test') | ||
trainer.test(datamodule=dm) | ||
|
||
---------------- | ||
|
||
|
@@ -184,12 +239,14 @@ DataModules have a few key advantages: | |
|
||
dm = MNISTDataModule() | ||
dm.prepare_data() | ||
dm.setup() | ||
|
||
dm.setup('fit') | ||
for batch in dm.train_dataloader(): | ||
... | ||
for batch in dm.val_dataloader(): | ||
... | ||
|
||
dm.setup('test') | ||
for batch in dm.test_dataloader(): | ||
... | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not test code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there are a ton of unnecessary doctests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What you mean by unnecessary, that there is no need for examples? Otherwise all examples shall be tested that they are aligned with actual code base...
cc: @awaelchli
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agree, this code block would be perfect for a doctest, and it is as simple as adding
.. testcode
. Even if we don't make any assertions here, Python will parse the code and run the import statements. It would help us keep the docs up-to-date with the api.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah we do not need anything extra, but it checks syntax and eventually nb of passed arguments or kwargs