Skip to content

Commit

Permalink
🚧 .
Browse files Browse the repository at this point in the history
  • Loading branch information
nateraw committed Jul 22, 2020
1 parent 081b2a8 commit fd80559
Show file tree
Hide file tree
Showing 4 changed files with 248 additions and 324 deletions.
243 changes: 243 additions & 0 deletions pytorch_lightning/core/data_hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
from typing import List, Union

from torch.utils.data import DataLoader

from pytorch_lightning.utilities import rank_zero_warn

class DataHooks:

def setup(self, stage: str):
"""
Called at the beginning of fit and test.
This is a good hook when you need to build models dynamically or adjust something about them.
This hook is called on every process when using DDP.
Args:
stage: either 'fit' or 'test'
Example::
class LitModel(...):
def __init__(self):
self.l1 = None
def prepare_data(self):
download_data()
tokenize()
# don't do this
self.something = else
def setup(stage):
data = Load_data(...)
self.l1 = nn.Linear(28, data.num_classes)
"""

def prepare_data(self) -> None:
"""
Use this to download and prepare data.
.. warning:: DO NOT set state to the model (use `setup` instead)
since this is NOT called on every GPU in DDP/TPU
Example::
def prepare_data(self):
# good
download_data()
tokenize()
etc()
# bad
self.split = data_split
self.some_state = some_other_state()
In DDP prepare_data can be called in two ways (using Trainer(prepare_data_per_node)):
1. Once per node. This is the default and is only called on LOCAL_RANK=0.
2. Once in total. Only called on GLOBAL_RANK=0.
Example::
# DEFAULT
# called once per node on LOCAL_RANK=0 of that node
Trainer(prepare_data_per_node=True)
# call on GLOBAL_RANK=0 (great for shared file systems)
Trainer(prepare_data_per_node=False)
This is called before requesting the dataloaders:
.. code-block:: python
model.prepare_data()
if ddp/tpu: init()
model.setup(stage)
model.train_dataloader()
model.val_dataloader()
model.test_dataloader()
"""

def train_dataloader(self) -> DataLoader:
"""
Implement a PyTorch DataLoader for training.
Return:
Single PyTorch :class:`~torch.utils.data.DataLoader`.
The dataloader you return will not be called every epoch unless you set
:paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_epoch` to ``True``.
For data processing use the following pattern:
- download in :meth:`prepare_data`
- process and split in :meth:`setup`
However, the above are only necessary for distributed processing.
.. warning:: do not assign state in prepare_data
- :meth:`~pytorch_lightning.trainer.Trainer.fit`
- ...
- :meth:`prepare_data`
- :meth:`setup`
- :meth:`train_dataloader`
Note:
Lightning adds the correct sampler for distributed and arbitrary hardware.
There is no need to set it yourself.
Example:
.. code-block:: python
def train_dataloader(self):
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (1.0,))])
dataset = MNIST(root='/path/to/mnist/', train=True, transform=transform,
download=True)
loader = torch.utils.data.DataLoader(
dataset=dataset,
batch_size=self.batch_size,
shuffle=True
)
return loader
"""
rank_zero_warn('`train_dataloader` must be implemented to be used with the Lightning Trainer')

def tng_dataloader(self): # todo: remove in v1.0.0
"""
Warnings:
Deprecated in v0.5.0. Use :meth:`train_dataloader` instead. Will be removed in 1.0.0.
"""
output = self.train_dataloader()
rank_zero_warn("`tng_dataloader` has been renamed to `train_dataloader` since v0.5.0."
" and this method will be removed in v1.0.0", DeprecationWarning)
return output

def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
r"""
Implement one or multiple PyTorch DataLoaders for testing.
The dataloader you return will not be called every epoch unless you set
:paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_epoch` to ``True``.
For data processing use the following pattern:
- download in :meth:`prepare_data`
- process and split in :meth:`setup`
However, the above are only necessary for distributed processing.
.. warning:: do not assign state in prepare_data
- :meth:`~pytorch_lightning.trainer.Trainer.fit`
- ...
- :meth:`prepare_data`
- :meth:`setup`
- :meth:`train_dataloader`
- :meth:`val_dataloader`
- :meth:`test_dataloader`
Note:
Lightning adds the correct sampler for distributed and arbitrary hardware.
There is no need to set it yourself.
Return:
Single or multiple PyTorch DataLoaders.
Example:
.. code-block:: python
def test_dataloader(self):
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (1.0,))])
dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform,
download=True)
loader = torch.utils.data.DataLoader(
dataset=dataset,
batch_size=self.batch_size,
shuffle=False
)
return loader
Note:
If you don't need a test dataset and a :meth:`test_step`, you don't need to implement
this method.
"""

def val_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
r"""
Implement one or multiple PyTorch DataLoaders for validation.
The dataloader you return will not be called every epoch unless you set
:paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_epoch` to ``True``.
It's recommended that all data downloads and preparation happen in :meth:`prepare_data`.
- :meth:`~pytorch_lightning.trainer.Trainer.fit`
- ...
- :meth:`prepare_data`
- :meth:`train_dataloader`
- :meth:`val_dataloader`
- :meth:`test_dataloader`
Note:
Lightning adds the correct sampler for distributed and arbitrary hardware
There is no need to set it yourself.
Return:
Single or multiple PyTorch DataLoaders.
Examples:
.. code-block:: python
def val_dataloader(self):
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (1.0,))])
dataset = MNIST(root='/path/to/mnist/', train=False,
transform=transform, download=True)
loader = torch.utils.data.DataLoader(
dataset=dataset,
batch_size=self.batch_size,
shuffle=False
)
return loader
# can also return multiple dataloaders
def val_dataloader(self):
return [loader_a, loader_b, ..., loader_n]
Note:
If you don't need a validation dataset and a :meth:`validation_step`, you don't need to
implement this method.
Note:
In the case where you return multiple validation dataloaders, the :meth:`validation_step`
will have an argument ``dataset_idx`` which matches the order here.
"""
83 changes: 3 additions & 80 deletions pytorch_lightning/core/datamodule.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import inspect
from abc import abstractmethod
from abc import ABC, abstractmethod
from argparse import ArgumentParser, Namespace
from typing import Union, List, Tuple, Any

from pytorch_lightning.core.data_hooks import DataHooks
from pytorch_lightning.utilities import rank_zero_warn, parsing
from torch.utils.data import DataLoader


class LightningDataModule(object): # pragma: no cover
class LightningDataModule(DataHooks): # pragma: no cover
"""
A DataModule standardizes the training, val, test splits, data preparation and transforms.
The main advantage is consistent data splits and transforms across models.
Expand Down Expand Up @@ -82,84 +83,6 @@ def size(self, dim=None) -> Union[Tuple, int]:

return self.dims

@abstractmethod
def prepare_data(self, *args, **kwargs):
"""
Use this to download and prepare data.
In distributed (GPU, TPU), this will only be called once.
This is called before requesting the dataloaders:
.. warning:: Do not assign anything to the model in this step since this will only be called on 1 GPU.
Pseudocode::
model.prepare_data()
model.train_dataloader()
model.val_dataloader()
model.test_dataloader()
Example::
def prepare_data(self):
download_imagenet()
clean_imagenet()
cache_imagenet()
"""

@abstractmethod
def setup(self, stage):
"""
Use this to make assignments to the class.
"""

@abstractmethod
def train_dataloader(self, *args, **kwargs) -> DataLoader:
"""
Implement a PyTorch DataLoader for training.
Return:
Single PyTorch :class:`~torch.utils.data.DataLoader`.
Note:
Lightning adds the correct sampler for distributed and arbitrary hardware.
There is no need to set it yourself.
Example::
def train_dataloader(self):
dataset = MNIST(root=PATH, train=True, transform=transforms.ToTensor(), download=False)
loader = torch.utils.data.DataLoader(dataset=dataset)
return loader
"""
rank_zero_warn('`train_dataloader` must be implemented to be used with the Lightning Trainer')

@abstractmethod
def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
r"""
Implement a PyTorch DataLoader for training.
Return:
Single PyTorch :class:`~torch.utils.data.DataLoader`.
Note:
Lightning adds the correct sampler for distributed and arbitrary hardware.
There is no need to set it yourself.
Note:
You can also return a list of DataLoaders
Example::
def val_dataloader(self):
dataset = MNIST(root=PATH, train=False, transform=transforms.ToTensor(), download=False)
loader = torch.utils.data.DataLoader(dataset=dataset, shuffle=False)
return loader
"""

@abstractmethod
def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
r"""
Implement a PyTorch DataLoader for training.
Return:
Single PyTorch :class:`~torch.utils.data.DataLoader`.
Note:
Lightning adds the correct sampler for distributed and arbitrary hardware.
There is no need to set it yourself.
Note:
You can also return a list of DataLoaders
Example::
def test_dataloader(self):
dataset = MNIST(root=PATH, train=False, transform=transforms.ToTensor(), download=False)
loader = torch.utils.data.DataLoader(dataset=dataset, shuffle=False)
return loader
"""

@classmethod
def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
r"""Extends existing argparse by default `LightningDataModule` attributes.
Expand Down
28 changes: 0 additions & 28 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,34 +17,6 @@

class ModelHooks(Module):

def setup(self, stage: str):
"""
Called at the beginning of fit and test.
This is a good hook when you need to build models dynamically or adjust something about them.
This hook is called on every process when using DDP.
Args:
stage: either 'fit' or 'test'
Example::
class LitModel(...):
def __init__(self):
self.l1 = None
def prepare_data(self):
download_data()
tokenize()
# don't do this
self.something = else
def setup(stage):
data = Load_data(...)
self.l1 = nn.Linear(28, data.num_classes)
"""

def teardown(self, stage: str):
"""
Called at the end of fit and test.
Expand Down
Loading

0 comments on commit fd80559

Please sign in to comment.