diff --git a/docs/source/datamodules.rst b/docs/source/datamodules.rst index 12b26b4a23857d..8bc7cfc56447bd 100644 --- a/docs/source/datamodules.rst +++ b/docs/source/datamodules.rst @@ -13,27 +13,50 @@ matching transforms and data processing/downloads steps required. .. code-block:: python - class MNISTDataModule(LightningDataModule): + import pytorch_lightning as pl + from torch.utils.data import random_split, DataLoader + + # Note - you must have torchvision installed for this example + from torchvision.datasets import MNIST + from torchvision import transforms + + + class MNISTDataModule(pl.LightningDataModule): def __init__(self, data_dir: str = './'): super().__init__() self.data_dir = data_dir + self.transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ]) + + # self.dims is returned when you call dm.size() + # Setting default dims here because we know them. + # Could optionally be assigned dynamically in dm.setup() + self.dims = (1, 28, 28) def prepare_data(self): # download MNIST(self.data_dir, train=True, download=True) MNIST(self.data_dir, train=False, download=True) - def setup(self, stage): + def setup(self, stage=None): # Assign train/val datasets for use in dataloaders - if stage == 'fit': - mnist_full = MNIST(self.data_dir, train=True, download=True) + if stage == 'fit' or stage is None: + mnist_full = MNIST(self.data_dir, train=True, transform=self.transform) self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000]) + # Optionally... + # self.dims = tuple(self.mnist_train[0][0].shape) + # Assign test dataset for use in dataloader(s) - if stage == 'test': - self.mnist_test = MNIST(self.data_dir, train=False, download=True) + if stage == 'test' or stage is None: + self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform) + + # Optionally... + # self.dims = tuple(self.mnist_test[0][0].shape) def train_dataloader(self): return DataLoader(self.mnist_train, batch_size=32) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 929c5c3d2c805f..b525a2b344c73d 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -159,6 +159,21 @@ def test_train_val_loop_only(tmpdir): assert trainer.callback_metrics['loss'] < 0.6 +def test_test_loop_only(tmpdir): + reset_seed() + + dm = TrialMNISTDataModule(tmpdir) + + model = EvalModelTemplate() + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=3, + weights_summary=None, + ) + trainer.test(model, datamodule=dm) + + def test_full_loop(tmpdir): reset_seed()