Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Datamodule #2668

Merged
merged 18 commits into from
Jul 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions pytorch_lightning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
__copyright__ = 'Copyright (c) 2018-2020, %s.' % __author__
__homepage__ = 'https://github.com/PyTorchLightning/pytorch-lightning'
# this has to be simple string, see: https://github.com/pypa/twine/issues/522
__docs__ = "PyTorch Lightning is the lightweight PyTorch wrapper for ML researchers." \
" Scale your models. Write less boilerplate."
__docs__ = (
"PyTorch Lightning is the lightweight PyTorch wrapper for ML researchers."
" Scale your models. Write less boilerplate."
)
__long_docs__ = """
Lightning is a way to organize your PyTorch code to decouple the science code from the engineering.
It's more of a style-guide than a framework.
Expand Down Expand Up @@ -47,10 +49,11 @@

if __LIGHTNING_SETUP__:
import sys # pragma: no-cover

sys.stdout.write(f'Partial import of `{__name__}` during the build process.\n') # pragma: no-cover
# We are not importing the rest of the lightning during the build process, as it may not be compiled yet
else:
from pytorch_lightning.core import LightningModule, data_loader
from pytorch_lightning.core import LightningDataModule, LightningModule, data_loader
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.utilities.seed import seed_everything
Expand All @@ -59,13 +62,14 @@

__all__ = [
'Trainer',
'LightningDataModule',
'LightningModule',
'Callback',
'data_loader',
'seed_everything',
'metrics',
'EvalResult',
'TrainResult'
'TrainResult',
]

# necessary for regular bolts imports. Skip exception since bolts is not always installed
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,8 +336,9 @@ def training_step(self, batch, batch_idx):

"""

from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.core.decorators import data_loader
from pytorch_lightning.core.lightning import LightningModule

__all__ = ['LightningModule', 'data_loader']
__all__ = ['LightningDataModule', 'LightningModule', 'data_loader']
# __call__ = __all__
314 changes: 314 additions & 0 deletions pytorch_lightning/core/datamodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,314 @@
import inspect
from abc import abstractmethod
from argparse import ArgumentParser, Namespace
from typing import Any, List, Tuple, Union

from torch.utils.data import DataLoader

from pytorch_lightning.utilities import parsing, rank_zero_only, rank_zero_warn


class _DataModuleWrapper(type):
def __call__(cls, *args, **kwargs):
"""A wrapper for LightningDataModule that:

1. Runs user defined subclass's __init__
2. Assures prepare_data() runs on rank 0
"""

# Wrap cls's prepare_data function with rank_zero_only
cls.prepare_data = rank_zero_only(cls.prepare_data)

# Get instance of LightningDataModule by mocking its __init__ via __call__
obj = type.__call__(cls, *args, **kwargs)

return obj


class LightningDataModule(object, metaclass=_DataModuleWrapper): # pragma: no cover
"""
A DataModule standardizes the training, val, test splits, data preparation and transforms.
The main advantage is consistent data splits, data preparation and transforms across models.

Example::

class MyDataModule(LightningDataModule):
def __init__(self):
super().__init__()
def prepare_data(self):
# download, split, etc...
# only called on 1 GPU/TPU in distributed
def setup(self):
# make assignments here (val/train/test split)
# called on every process in DDP
def train_dataloader(self):
train_split = Dataset(...)
return DataLoader(train_split)
def val_dataloader(self):
val_split = Dataset(...)
return DataLoader(val_split)
def test_dataloader(self):
test_split = Dataset(...)
return DataLoader(test_split)

A DataModule implements 5 key methods:

* **prepare_data** (things to do on 1 GPU/TPU not on every GPU/TPU in distributed mode).
* **setup** (things to do on every accelerator in distributed mode).
* **train_dataloader** the training dataloader.
* **val_dataloader** the val dataloader(s).
* **test_dataloader** the test dataloader(s).


This allows you to share a full dataset without explaining how to download,
split transform and process the data

"""

name: str = ...

def __init__(
self, train_transforms=None, val_transforms=None, test_transforms=None,
):
super().__init__()
self._train_transforms = train_transforms
self._val_transforms = val_transforms
self._test_transforms = test_transforms
self.dims = ()

@property
def train_transforms(self):
nateraw marked this conversation as resolved.
Show resolved Hide resolved
"""
Optional transforms (or collection of transforms) you can apply to train dataset
"""
return self._train_transforms

@train_transforms.setter
def train_transforms(self, t):
self._train_transforms = t

@property
def val_transforms(self):
"""
Optional transforms (or collection of transforms) you can apply to validation dataset
"""
return self._val_transforms

@val_transforms.setter
def val_transforms(self, t):
self._val_transforms = t

@property
def test_transforms(self):
"""
Optional transforms (or collection of transforms) you can apply to test dataset
"""
return self._test_transforms

@test_transforms.setter
def test_transforms(self, t):
self._test_transforms = t

def size(self, dim=None) -> Union[Tuple, int]:
"""
Return the dimension of each input either as a tuple or list of tuples.
"""

if dim is not None:
return self.dims[dim]

return self.dims

@abstractmethod
def prepare_data(self, *args, **kwargs):
"""
Use this to download and prepare data.
In distributed (GPU, TPU), this will only be called once.

.. warning:: Do not assign anything to the datamodule in this step since this will only be called on 1 GPU.

Pseudocode::

dm.prepare_data()
dm.setup()

Example::

def prepare_data(self):
download_imagenet()
clean_imagenet()
cache_imagenet()
"""

@abstractmethod
def setup(self, *args, **kwargs):
"""
Use this to load your data from file, split it, etc. You are safe to make state assignments here.
This hook is called on every process when using DDP.

Example::

def setup(self):
data = load_data(...)
self.train_ds, self.val_ds, self.test_ds = split_data(data)
"""

@abstractmethod
def train_dataloader(self, *args, **kwargs) -> DataLoader:
"""
Implement a PyTorch DataLoader for training.
Return:
Single PyTorch :class:`~torch.utils.data.DataLoader`.
Note:
Lightning adds the correct sampler for distributed and arbitrary hardware.
There is no need to set it yourself.

Example::

def train_dataloader(self):
dataset = MNIST(root=PATH, train=True, transform=transforms.ToTensor(), download=False)
loader = torch.utils.data.DataLoader(dataset=dataset)
return loader

"""
rank_zero_warn('`train_dataloader` must be implemented to be used with the Lightning Trainer')

@abstractmethod
def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
r"""
Implement a PyTorch DataLoader for training.
Return:
Single PyTorch :class:`~torch.utils.data.DataLoader`.
Note:
Lightning adds the correct sampler for distributed and arbitrary hardware.
There is no need to set it yourself.
Note:
You can also return a list of DataLoaders

Example::

def val_dataloader(self):
dataset = MNIST(root=PATH, train=False, transform=transforms.ToTensor(), download=False)
loader = torch.utils.data.DataLoader(dataset=dataset, shuffle=False)
return loader
"""

@abstractmethod
def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
r"""
Implement a PyTorch DataLoader for training.
Return:
Single PyTorch :class:`~torch.utils.data.DataLoader`.
Note:
Lightning adds the correct sampler for distributed and arbitrary hardware.
There is no need to set it yourself.
Note:
You can also return a list of DataLoaders

Example::

def test_dataloader(self):
dataset = MNIST(root=PATH, train=False, transform=transforms.ToTensor(), download=False)
loader = torch.utils.data.DataLoader(dataset=dataset, shuffle=False)
return loader
"""

@classmethod
def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
r"""Extends existing argparse by default `LightningDataModule` attributes.
"""
parser = ArgumentParser(parents=[parent_parser], add_help=False,)
added_args = [x.dest for x in parser._actions]

blacklist = ['kwargs']
depr_arg_names = blacklist + added_args
depr_arg_names = set(depr_arg_names)

allowed_types = (str, float, int, bool)

# TODO: get "help" from docstring :)
for arg, arg_types, arg_default in (
at for at in cls.get_init_arguments_and_types() if at[0] not in depr_arg_names
):
arg_types = [at for at in allowed_types if at in arg_types]
if not arg_types:
# skip argument with not supported type
continue
arg_kwargs = {}
if bool in arg_types:
arg_kwargs.update(nargs="?")
# if the only arg type is bool
if len(arg_types) == 1:
# redefine the type for ArgParser needed
def use_type(x):
return bool(parsing.str_to_bool(x))

else:
# filter out the bool as we need to use more general
use_type = [at for at in arg_types if at is not bool][0]
else:
use_type = arg_types[0]

if arg_default == inspect._empty:
arg_default = None

parser.add_argument(
f'--{arg}',
dest=arg,
default=arg_default,
type=use_type,
help=f'autogenerated by plb.{cls.__name__}',
**arg_kwargs,
)

return parser

@classmethod
def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs):
"""
Create an instance from CLI arguments.

Args:
args: The parser or namespace to take arguments from. Only known arguments will be
parsed and passed to the :class:`LightningDataModule`.
**kwargs: Additional keyword arguments that may override ones in the parser or namespace.
These must be valid DataModule arguments.

Example::

parser = ArgumentParser(add_help=False)
parser = LightningDataModule.add_argparse_args(parser)
module = LightningDataModule.from_argparse_args(args)

"""
if isinstance(args, ArgumentParser):
args = cls.parse_argparser(args)
params = vars(args)

# we only want to pass in valid DataModule args, the rest may be user specific
valid_kwargs = inspect.signature(cls.__init__).parameters
datamodule_kwargs = dict((name, params[name]) for name in valid_kwargs if name in params)
datamodule_kwargs.update(**kwargs)

return cls(**datamodule_kwargs)

@classmethod
def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
r"""Scans the DataModule signature and returns argument names, types and default values.
Returns:
List with tuples of 3 values:
(argument name, set with argument types, argument default value).
"""
datamodule_default_params = inspect.signature(cls.__init__).parameters
name_type_default = []
for arg in datamodule_default_params:
arg_type = datamodule_default_params[arg].annotation
arg_default = datamodule_default_params[arg].default
try:
arg_types = tuple(arg_type.__args__)
except AttributeError:
arg_types = (arg_type,)

name_type_default.append((arg, arg_types, arg_default))

return name_type_default
Loading