Skip to content

Commit

Permalink
Fix misuse of transforms in docs (#3546)
Browse files Browse the repository at this point in the history
* 📝 docs

* 📝 docs

* 📝 docs

* 📝 docs

* Apply suggestions from code review

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
  • Loading branch information
3 people authored Sep 18, 2020
1 parent a9c0ed9 commit c46de8a
Showing 1 changed file with 34 additions and 29 deletions.
63 changes: 34 additions & 29 deletions docs/source/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ Here's a more realistic, complex DataModule that shows how much more reusable th
.. note:: ``setup`` expects a string arg ``stage``. It is used to separate setup logic for ``trainer.fit`` and ``trainer.test``.


---------------

LightningDataModule API
Expand Down Expand Up @@ -203,6 +204,7 @@ There are also data operations you might want to perform on every GPU. Use setup
- count number of classes
- build vocabulary
- perform train/val/test splits
- apply transforms (defined explicitly in your datamodule or assigned in init)
- etc...

.. code-block:: python
Expand All @@ -216,13 +218,23 @@ There are also data operations you might want to perform on every GPU. Use setup
# Assign Train/val split(s) for use in Dataloaders
if stage == 'fit' or stage is None:
mnist_full = MNIST(self.data_dir, train=True, download=True)
mnist_full = MNIST(
self.data_dir,
train=True,
download=True,
transform=self.transform
)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
self.dims = self.mnist_train[0][0].shape
# Assign Test split(s) for use in Dataloaders
if stage == 'test' or stage is None:
self.mnist_test = MNIST(self.data_dir, train=False, download=True)
self.mnist_test = MNIST(
self.data_dir,
train=False,
download=True,
transform=self.transform
)
self.dims = getattr(self, 'dims', self.mnist_test[0][0].shape)
Expand All @@ -231,7 +243,7 @@ There are also data operations you might want to perform on every GPU. Use setup

train_dataloader
^^^^^^^^^^^^^^^^
Use this method to generate the train dataloader. This is also a good place to place default transformations.
Use this method to generate the train dataloader. Usually you just wrap the dataset you defined in ``setup``.

.. code-block:: python
Expand All @@ -240,25 +252,12 @@ Use this method to generate the train dataloader. This is also a good place to p
class MNISTDataModule(pl.LightningDataModule):
def train_dataloader(self):
transforms = transform_lib.Compose([
transform_lib.ToTensor(),
transform_lib.Normalize(mean=(0.5,), std=(0.5,)),
])
return DataLoader(self.train_dataset, transform=transforms, batch_size=64)
return DataLoader(self.mnist_train, batch_size=64)
However, to decouple your data from transforms you can parametrize them via `__init__`.

.. code-block:: python
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, train_transforms, val_transforms, test_transforms):
self.train_transforms = train_transforms
self.val_transforms = val_transforms
self.test_transforms = test_transforms
val_dataloader
^^^^^^^^^^^^^^
Use this method to generate the val dataloader. This is also a good place to place default transformations.
Use this method to generate the val dataloader. Usually you just wrap the dataset you defined in ``setup``.

.. code-block:: python
Expand All @@ -267,15 +266,12 @@ Use this method to generate the val dataloader. This is also a good place to pla
class MNISTDataModule(pl.LightningDataModule):
def val_dataloader(self):
transforms = transform_lib.Compose([
transform_lib.ToTensor(),
transform_lib.Normalize(mean=(0.5,), std=(0.5,)),
])
return DataLoader(self.val_dataset, transform=transforms, batch_size=64)
return DataLoader(self.mnist_val, batch_size=64)
test_dataloader
^^^^^^^^^^^^^^^
Use this method to generate the test dataloader. This is also a good place to place default transformations.
Use this method to generate the test dataloader. Usually you just wrap the dataset you defined in ``setup``.

.. code-block:: python
Expand All @@ -284,11 +280,7 @@ Use this method to generate the test dataloader. This is also a good place to pl
class MNISTDataModule(pl.LightningDataModule):
def test_dataloader(self):
transforms = transform_lib.Compose([
transform_lib.ToTensor(),
transform_lib.Normalize(mean=(0.5,), std=(0.5,)),
])
return DataLoader(self.test_dataset, transform=transforms, batch_size=64)
return DataLoader(self.mnist_test, batch_size=64)
transfer_batch_to_device
^^^^^^^^^^^^^^^^^^^^^^^^
Expand All @@ -306,6 +298,19 @@ Override to define how you want to move an arbitrary batch to a device
batch['x'].to(device)
return batch
.. note:: To decouple your data from transforms you can parametrize them via `__init__`.

.. code-block:: python
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, train_transforms, val_transforms, test_transforms):
super().__init__()
self.train_transforms = train_transforms
self.val_transforms = val_transforms
self.test_transforms = test_transforms
------------------

Using a DataModule
Expand Down

0 comments on commit c46de8a

Please sign in to comment.