Skip to content

Commit

Permalink
🚧 .
Browse files Browse the repository at this point in the history
  • Loading branch information
nateraw committed Jul 23, 2020
1 parent e912951 commit caa3016
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 8 deletions.
38 changes: 30 additions & 8 deletions pytorch_lightning/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ def __init__(self):
super().__init__()
def prepare_data(self):
# download, split, etc...
# only called on rank 0
def setup(self):
# make assignments here
# called on every process in DDP
def train_dataloader(self):
train_split = Dataset(...)
return DataLoader(train_split)
Expand Down Expand Up @@ -72,6 +76,9 @@ def __init__(

@property
def train_transforms(self):
"""
Optional transforms you can apply to train dataset
"""
return self._train_transforms

@train_transforms.setter
Expand All @@ -80,6 +87,9 @@ def train_transforms(self, t):

@property
def val_transforms(self):
"""
Optional transforms you can apply to validation dataset
"""
return self._val_transforms

@val_transforms.setter
Expand All @@ -88,6 +98,9 @@ def val_transforms(self, t):

@property
def test_transforms(self):
"""
Optional transforms you can apply to test dataset
"""
return self._test_transforms

@test_transforms.setter
Expand All @@ -96,9 +109,9 @@ def test_transforms(self, t):

def size(self, dim=None) -> Union[Tuple, int]:
"""
Return the dimension of each input
Either as a tuple or list of tuples
Return the dimension of each input either as a tuple or list of tuples.
"""

if dim is not None:
return self.dims[dim]

Expand All @@ -109,20 +122,29 @@ def prepare_data(self, *args, **kwargs):
"""
Use this to download and prepare data.
In distributed (GPU, TPU), this will only be called once.
This is called before requesting the dataloaders:
.. warning:: Do not assign anything to the model in this step since this will only be called on 1 GPU.
.. warning:: Do not assign anything to the datamodule in this step since this will only be called on 1 GPU.
Pseudocode::
model.prepare_data()
model.train_dataloader()
model.val_dataloader()
model.test_dataloader()
dm.prepare_data()
dm.setup()
Example::
def prepare_data(self):
download_imagenet()
clean_imagenet()
cache_imagenet()
"""

@abstractmethod
def setup(self, *args, **kwargs):
"""
Use this to load your data from file, split it, etc. You are safe to make state assignments here.
This hook is called on every process when using DDP.
Example::
def setup(self):
data = load_data(...)
self.train_ds, self.val_ds, self.test_ds = split_data(data)
"""

@abstractmethod
def train_dataloader(self, *args, **kwargs) -> DataLoader:
"""
Expand Down
12 changes: 12 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,6 +939,12 @@ def fit(
if hasattr(model, 'hparams'):
parsing.clean_namespace(model.hparams)

# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
if (train_dataloader or val_dataloaders) and datamodule:
raise MisconfigurationException(
'You cannot pass train_dataloader or val_dataloaders to trainer.fit if you supply a datamodule'
)

# set up the passed in dataloaders (if needed)
self.__attach_dataloaders(model, train_dataloader, val_dataloaders)
self.__attach_datamodule(model, datamodule)
Expand Down Expand Up @@ -1323,6 +1329,12 @@ def test(
if self.global_rank != 0:
return

# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
if test_dataloaders and datamodule:
raise MisconfigurationException(
'You cannot pass test_dataloaders to trainer.test if you supply a datamodule'
)

# Attach datamodule to get setup/prepare_data added to model before the call to it below
self.__attach_datamodule(model or self.get_model(), datamodule)

Expand Down

0 comments on commit caa3016

Please sign in to comment.