Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Lightning DataModules #130

Merged
merged 16 commits into from
Jul 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions docs/source/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ Step 4, also needs special care to make sure that it's only done on 1 GPU in a m
In addition, there are other challenges such as models that are built using information from the dataset
such as needing to know image dimensions or number of classes.

A datamodule simplifies all of these parts and integrates seamlessly into Lightning.
A datamodule simplifies all of these parts and has been integrated directly into Lightning in version 0.9.0.
You can view the documentation for the datamodule in the `Pytorch Lightning docs here. <https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html>`_

.. code-block:: python

Expand Down Expand Up @@ -92,7 +93,7 @@ Use this to build your own consistent train, validation, test splits.

Example::

from pl_bolts.datamodules import LightningDataModule
from pytorch_lightning import LightningDataModule

class MyDataModule(LightningDataModule):

Expand Down Expand Up @@ -133,12 +134,6 @@ Example::
return self.dm.test_dataloader()


DataModule class
^^^^^^^^^^^^^^^^

.. autoclass:: pl_bolts.datamodules.lightning_datamodule.LightningDataModule
:noindex:

-------------

DummyDataset
Expand Down
1 change: 0 additions & 1 deletion pl_bolts/datamodules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from pl_bolts.datamodules.cifar10_datamodule import CIFAR10DataModule, TinyCIFAR10DataModule
from pl_bolts.datamodules.imagenet_datamodule import ImagenetDataModule
from pl_bolts.datamodules.lightning_datamodule import LightningDataModule
from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule
from pl_bolts.datamodules.sklearn_datamodule import SklearnDataset, SklearnDataModule, TensorDataset, TensorDataModule
from pl_bolts.datamodules.ssl_imagenet_datamodule import SSLImagenetDataModule
Expand Down
27 changes: 10 additions & 17 deletions pl_bolts/datamodules/cifar10_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Optional, Sequence

from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, random_split
from torchvision import transforms as transform_lib
from torchvision.datasets import CIFAR10

from pl_bolts.datamodules.cifar10_dataset import TrialCIFAR10
from pl_bolts.datamodules.lightning_datamodule import LightningDataModule
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization


Expand All @@ -19,6 +19,7 @@ def __init__(
data_dir,
val_split=5000,
num_workers=16,
batch_size=32,
*args,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is tricky as it tends to confuse user why some arguments were not used...

**kwargs,
):
Expand Down Expand Up @@ -54,13 +55,15 @@ def __init__(
data_dir: where to save/load the data
val_split: how many of the training images to use for the validation split
num_workers: how many workers to use for loading data
batch_size: number of examples per training/eval step
"""
super().__init__(*args, **kwargs)
self.dims = (3, 32, 32)
self.DATASET = CIFAR10
self.data_dir = data_dir
self.val_split = val_split
self.num_workers = num_workers
self.batch_size = batch_size

@property
def num_classes(self):
Expand All @@ -77,12 +80,9 @@ def prepare_data(self):
self.DATASET(self.data_dir, train=True, download=True, transform=transform_lib.ToTensor(), **self.extra_args)
self.DATASET(self.data_dir, train=False, download=True, transform=transform_lib.ToTensor(), **self.extra_args)

def train_dataloader(self, batch_size):
def train_dataloader(self):
"""
CIFAR train set removes a subset to use for validation

Args:
batch_size: size of batch
"""
transforms = self.default_transforms() if self.train_transforms is None else self.train_transforms

Expand All @@ -91,20 +91,17 @@ def train_dataloader(self, batch_size):
dataset_train, _ = random_split(dataset, [train_length - self.val_split, self.val_split])
loader = DataLoader(
dataset_train,
batch_size=batch_size,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
drop_last=True,
pin_memory=True
)
return loader

def val_dataloader(self, batch_size):
def val_dataloader(self):
"""
CIFAR10 val set uses a subset of the training set for validation

Args:
batch_size: size of batch
"""
transforms = self.default_transforms() if self.val_transforms is None else self.val_transforms

Expand All @@ -113,28 +110,24 @@ def val_dataloader(self, batch_size):
_, dataset_val = random_split(dataset, [train_length - self.val_split, self.val_split])
loader = DataLoader(
dataset_val,
batch_size=batch_size,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
pin_memory=True,
drop_last=True
)
return loader

def test_dataloader(self, batch_size):
def test_dataloader(self):
"""
CIFAR10 test set uses the test split

Args:
batch_size: size of batch
transforms: custom transforms
"""
transforms = self.default_transforms() if self.test_transforms is None else self.test_transforms

dataset = self.DATASET(self.data_dir, train=False, download=False, transform=transforms, **self.extra_args)
loader = DataLoader(
dataset,
batch_size=batch_size,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
drop_last=True,
Expand Down
3 changes: 1 addition & 2 deletions pl_bolts/datamodules/fashion_mnist_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, random_split
from torchvision import transforms as transform_lib
from torchvision.datasets import FashionMNIST

from pl_bolts.datamodules.lightning_datamodule import LightningDataModule


class FashionMNISTDataModule(LightningDataModule):

Expand Down
2 changes: 1 addition & 1 deletion pl_bolts/datamodules/imagenet_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import os

from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
from torchvision import transforms as transform_lib

from pl_bolts.datamodules.imagenet_dataset import UnlabeledImagenet
from pl_bolts.datamodules.lightning_datamodule import LightningDataModule
from pl_bolts.transforms.dataset_normalizations import imagenet_normalization


Expand Down
Loading