Skip to content

Commit

Permalink
Datamodule (#2668)
Browse files Browse the repository at this point in the history
* ✨ Add copy of pl_bolts datamodule to lightning

* ✨ add datamodule to necessary init files

* 🚧 add datamodule property to LightningModule

* 🚧 .

* 🎨 Let DataModule do its own thing

* 🚧 add back setup and run both hooks implicitly

* 🚧 .

* 🐛 fix add_argparse_args

* 💄 apply black formatting and isort

* 📝 docstrings

* 📝 .

* 📝 .

* 🐛 overwrite cls prepare_data instead of instance

* 📝 .

* ✅ add some tests

* Update datamodule.py

* Update datamodule.py

* Update datamodule.py

Co-authored-by: William Falcon <waf2107@columbia.edu>
  • Loading branch information
nateraw and williamFalcon committed Jul 24, 2020
1 parent 938ec5a commit 1caf8be
Show file tree
Hide file tree
Showing 8 changed files with 589 additions and 145 deletions.
12 changes: 8 additions & 4 deletions pytorch_lightning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
__copyright__ = 'Copyright (c) 2018-2020, %s.' % __author__
__homepage__ = 'https://github.com/PyTorchLightning/pytorch-lightning'
# this has to be simple string, see: https://github.com/pypa/twine/issues/522
__docs__ = "PyTorch Lightning is the lightweight PyTorch wrapper for ML researchers." \
" Scale your models. Write less boilerplate."
__docs__ = (
"PyTorch Lightning is the lightweight PyTorch wrapper for ML researchers."
" Scale your models. Write less boilerplate."
)
__long_docs__ = """
Lightning is a way to organize your PyTorch code to decouple the science code from the engineering.
It's more of a style-guide than a framework.
Expand Down Expand Up @@ -47,10 +49,11 @@

if __LIGHTNING_SETUP__:
import sys # pragma: no-cover

sys.stdout.write(f'Partial import of `{__name__}` during the build process.\n') # pragma: no-cover
# We are not importing the rest of the lightning during the build process, as it may not be compiled yet
else:
from pytorch_lightning.core import LightningModule, data_loader
from pytorch_lightning.core import LightningDataModule, LightningModule, data_loader
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.utilities.seed import seed_everything
Expand All @@ -59,13 +62,14 @@

__all__ = [
'Trainer',
'LightningDataModule',
'LightningModule',
'Callback',
'data_loader',
'seed_everything',
'metrics',
'EvalResult',
'TrainResult'
'TrainResult',
]

# necessary for regular bolts imports. Skip exception since bolts is not always installed
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,8 +336,9 @@ def training_step(self, batch, batch_idx):
"""

from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.core.decorators import data_loader
from pytorch_lightning.core.lightning import LightningModule

__all__ = ['LightningModule', 'data_loader']
__all__ = ['LightningDataModule', 'LightningModule', 'data_loader']
# __call__ = __all__
314 changes: 314 additions & 0 deletions pytorch_lightning/core/datamodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,314 @@
import inspect
from abc import abstractmethod
from argparse import ArgumentParser, Namespace
from typing import Any, List, Tuple, Union

from torch.utils.data import DataLoader

from pytorch_lightning.utilities import parsing, rank_zero_only, rank_zero_warn


class _DataModuleWrapper(type):
def __call__(cls, *args, **kwargs):
"""A wrapper for LightningDataModule that:
1. Runs user defined subclass's __init__
2. Assures prepare_data() runs on rank 0
"""

# Wrap cls's prepare_data function with rank_zero_only
cls.prepare_data = rank_zero_only(cls.prepare_data)

# Get instance of LightningDataModule by mocking its __init__ via __call__
obj = type.__call__(cls, *args, **kwargs)

return obj


class LightningDataModule(object, metaclass=_DataModuleWrapper): # pragma: no cover
"""
A DataModule standardizes the training, val, test splits, data preparation and transforms.
The main advantage is consistent data splits, data preparation and transforms across models.
Example::
class MyDataModule(LightningDataModule):
def __init__(self):
super().__init__()
def prepare_data(self):
# download, split, etc...
# only called on 1 GPU/TPU in distributed
def setup(self):
# make assignments here (val/train/test split)
# called on every process in DDP
def train_dataloader(self):
train_split = Dataset(...)
return DataLoader(train_split)
def val_dataloader(self):
val_split = Dataset(...)
return DataLoader(val_split)
def test_dataloader(self):
test_split = Dataset(...)
return DataLoader(test_split)
A DataModule implements 5 key methods:
* **prepare_data** (things to do on 1 GPU/TPU not on every GPU/TPU in distributed mode).
* **setup** (things to do on every accelerator in distributed mode).
* **train_dataloader** the training dataloader.
* **val_dataloader** the val dataloader(s).
* **test_dataloader** the test dataloader(s).
This allows you to share a full dataset without explaining how to download,
split transform and process the data
"""

name: str = ...

def __init__(
self, train_transforms=None, val_transforms=None, test_transforms=None,
):
super().__init__()
self._train_transforms = train_transforms
self._val_transforms = val_transforms
self._test_transforms = test_transforms
self.dims = ()

@property
def train_transforms(self):
"""
Optional transforms (or collection of transforms) you can apply to train dataset
"""
return self._train_transforms

@train_transforms.setter
def train_transforms(self, t):
self._train_transforms = t

@property
def val_transforms(self):
"""
Optional transforms (or collection of transforms) you can apply to validation dataset
"""
return self._val_transforms

@val_transforms.setter
def val_transforms(self, t):
self._val_transforms = t

@property
def test_transforms(self):
"""
Optional transforms (or collection of transforms) you can apply to test dataset
"""
return self._test_transforms

@test_transforms.setter
def test_transforms(self, t):
self._test_transforms = t

def size(self, dim=None) -> Union[Tuple, int]:
"""
Return the dimension of each input either as a tuple or list of tuples.
"""

if dim is not None:
return self.dims[dim]

return self.dims

@abstractmethod
def prepare_data(self, *args, **kwargs):
"""
Use this to download and prepare data.
In distributed (GPU, TPU), this will only be called once.
.. warning:: Do not assign anything to the datamodule in this step since this will only be called on 1 GPU.
Pseudocode::
dm.prepare_data()
dm.setup()
Example::
def prepare_data(self):
download_imagenet()
clean_imagenet()
cache_imagenet()
"""

@abstractmethod
def setup(self, *args, **kwargs):
"""
Use this to load your data from file, split it, etc. You are safe to make state assignments here.
This hook is called on every process when using DDP.
Example::
def setup(self):
data = load_data(...)
self.train_ds, self.val_ds, self.test_ds = split_data(data)
"""

@abstractmethod
def train_dataloader(self, *args, **kwargs) -> DataLoader:
"""
Implement a PyTorch DataLoader for training.
Return:
Single PyTorch :class:`~torch.utils.data.DataLoader`.
Note:
Lightning adds the correct sampler for distributed and arbitrary hardware.
There is no need to set it yourself.
Example::
def train_dataloader(self):
dataset = MNIST(root=PATH, train=True, transform=transforms.ToTensor(), download=False)
loader = torch.utils.data.DataLoader(dataset=dataset)
return loader
"""
rank_zero_warn('`train_dataloader` must be implemented to be used with the Lightning Trainer')

@abstractmethod
def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
r"""
Implement a PyTorch DataLoader for training.
Return:
Single PyTorch :class:`~torch.utils.data.DataLoader`.
Note:
Lightning adds the correct sampler for distributed and arbitrary hardware.
There is no need to set it yourself.
Note:
You can also return a list of DataLoaders
Example::
def val_dataloader(self):
dataset = MNIST(root=PATH, train=False, transform=transforms.ToTensor(), download=False)
loader = torch.utils.data.DataLoader(dataset=dataset, shuffle=False)
return loader
"""

@abstractmethod
def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
r"""
Implement a PyTorch DataLoader for training.
Return:
Single PyTorch :class:`~torch.utils.data.DataLoader`.
Note:
Lightning adds the correct sampler for distributed and arbitrary hardware.
There is no need to set it yourself.
Note:
You can also return a list of DataLoaders
Example::
def test_dataloader(self):
dataset = MNIST(root=PATH, train=False, transform=transforms.ToTensor(), download=False)
loader = torch.utils.data.DataLoader(dataset=dataset, shuffle=False)
return loader
"""

@classmethod
def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
r"""Extends existing argparse by default `LightningDataModule` attributes.
"""
parser = ArgumentParser(parents=[parent_parser], add_help=False,)
added_args = [x.dest for x in parser._actions]

blacklist = ['kwargs']
depr_arg_names = blacklist + added_args
depr_arg_names = set(depr_arg_names)

allowed_types = (str, float, int, bool)

# TODO: get "help" from docstring :)
for arg, arg_types, arg_default in (
at for at in cls.get_init_arguments_and_types() if at[0] not in depr_arg_names
):
arg_types = [at for at in allowed_types if at in arg_types]
if not arg_types:
# skip argument with not supported type
continue
arg_kwargs = {}
if bool in arg_types:
arg_kwargs.update(nargs="?")
# if the only arg type is bool
if len(arg_types) == 1:
# redefine the type for ArgParser needed
def use_type(x):
return bool(parsing.str_to_bool(x))

else:
# filter out the bool as we need to use more general
use_type = [at for at in arg_types if at is not bool][0]
else:
use_type = arg_types[0]

if arg_default == inspect._empty:
arg_default = None

parser.add_argument(
f'--{arg}',
dest=arg,
default=arg_default,
type=use_type,
help=f'autogenerated by plb.{cls.__name__}',
**arg_kwargs,
)

return parser

@classmethod
def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs):
"""
Create an instance from CLI arguments.
Args:
args: The parser or namespace to take arguments from. Only known arguments will be
parsed and passed to the :class:`LightningDataModule`.
**kwargs: Additional keyword arguments that may override ones in the parser or namespace.
These must be valid DataModule arguments.
Example::
parser = ArgumentParser(add_help=False)
parser = LightningDataModule.add_argparse_args(parser)
module = LightningDataModule.from_argparse_args(args)
"""
if isinstance(args, ArgumentParser):
args = cls.parse_argparser(args)
params = vars(args)

# we only want to pass in valid DataModule args, the rest may be user specific
valid_kwargs = inspect.signature(cls.__init__).parameters
datamodule_kwargs = dict((name, params[name]) for name in valid_kwargs if name in params)
datamodule_kwargs.update(**kwargs)

return cls(**datamodule_kwargs)

@classmethod
def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
r"""Scans the DataModule signature and returns argument names, types and default values.
Returns:
List with tuples of 3 values:
(argument name, set with argument types, argument default value).
"""
datamodule_default_params = inspect.signature(cls.__init__).parameters
name_type_default = []
for arg in datamodule_default_params:
arg_type = datamodule_default_params[arg].annotation
arg_default = datamodule_default_params[arg].default
try:
arg_types = tuple(arg_type.__args__)
except AttributeError:
arg_types = (arg_type,)

name_type_default.append((arg, arg_types, arg_default))

return name_type_default
Loading

0 comments on commit 1caf8be

Please sign in to comment.