Skip to content

Commit

Permalink
📝 docs
Browse files Browse the repository at this point in the history
  • Loading branch information
nateraw committed Jul 30, 2020
1 parent d1924be commit bcc8720
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 6 deletions.
35 changes: 29 additions & 6 deletions docs/source/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,50 @@ matching transforms and data processing/downloads steps required.

.. code-block:: python
class MNISTDataModule(LightningDataModule):
import pytorch_lightning as pl
from torch.utils.data import random_split, DataLoader
# Note - you must have torchvision installed for this example
from torchvision.datasets import MNIST
from torchvision import transforms
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, data_dir: str = './'):
super().__init__()
self.data_dir = data_dir
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
# self.dims is returned when you call dm.size()
# Setting default dims here because we know them.
# Could optionally be assigned dynamically in dm.setup()
self.dims = (1, 28, 28)
def prepare_data(self):
# download
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
def setup(self, stage):
def setup(self, stage=None):
# Assign train/val datasets for use in dataloaders
if stage == 'fit':
mnist_full = MNIST(self.data_dir, train=True, download=True)
if stage == 'fit' or stage is None:
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
# Optionally...
# self.dims = tuple(self.mnist_train[0][0].shape)
# Assign test dataset for use in dataloader(s)
if stage == 'test':
self.mnist_test = MNIST(self.data_dir, train=False, download=True)
if stage == 'test' or stage is None:
self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
# Optionally...
# self.dims = tuple(self.mnist_test[0][0].shape)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=32)
Expand Down
15 changes: 15 additions & 0 deletions tests/core/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,21 @@ def test_train_val_loop_only(tmpdir):
assert trainer.callback_metrics['loss'] < 0.6


def test_test_loop_only(tmpdir):
reset_seed()

dm = TrialMNISTDataModule(tmpdir)

model = EvalModelTemplate()

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=3,
weights_summary=None,
)
trainer.test(model, datamodule=dm)


def test_full_loop(tmpdir):
reset_seed()

Expand Down

0 comments on commit bcc8720

Please sign in to comment.