Skip to content

Commit

Permalink
🚧 include stage in datamodule.setup
Browse files Browse the repository at this point in the history
  • Loading branch information
nateraw committed Jul 30, 2020
1 parent 7cd23cc commit 7b49a56
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 35 deletions.
85 changes: 69 additions & 16 deletions pytorch_lightning/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import inspect
from abc import abstractmethod
from argparse import ArgumentParser, Namespace
from typing import Any, List, Tuple, Union
from typing import Any, List, Optional, Tuple, Union

from torch.utils.data import DataLoader

Expand All @@ -32,35 +32,56 @@ def __call__(cls, *args, **kwargs):
3. Lets you check prepare_data and setup to see if they've been called
"""

# Wrap cls's prepare_data function with rank_zero_only
cls.prepare_data = rank_zero_only(cls.prepare_data)

# prepare_data and setup wrapped w/ function to track if they've been called.
# Usage: your_dm.setup.has_been_called & your_dm.prepare_data.has_been_called
cls.prepare_data = track_func_calls(cls.prepare_data)
cls.setup = track_func_calls(cls.setup)
# Track prepare_data calls and make sure it runs on rank zero
cls.prepare_data = track_data_hook_calls(rank_zero_only(cls.prepare_data))
# Track setup calls
cls.setup = track_data_hook_calls(cls.setup)

# Get instance of LightningDataModule by mocking its __init__ via __call__
obj = type.__call__(cls, *args, **kwargs)

return obj


def track_func_calls(fn):
"""A decorator that checks if a function has been called.
def track_data_hook_calls(fn):
"""A decorator that checks if prepare_data/setup have been called.
- When dm.prepare_data() is called, dm.has_prepared_data gets set to True
- When dm.setup('fit') is called, dm.has_setup_fit gets set to True
- When dm.setup('test') is called, dm.has_setup_test gets set to True
- When dm.setup() is called without stage arg, both dm.has_setup_fit and dm.has_setup_test get set to True
Args:
fn (function): Function that will be tracked to see if it has been called.
Returns:
callable: Your function with an added bool attr fn.has_been_called.
function: Decorated function that tracks its call status and saves it to private attrs in its obj instance.
"""

@functools.wraps(fn)
def wrapped_fn(*args, **kwargs):
wrapped_fn.has_been_called = True
return fn(*args, **kwargs)

wrapped_fn.has_been_called = False
# The object instance from which setup or prepare_data was called
obj = args[0]

# If calling setup, we check the stage and assign stage-specific bool args
if fn.__name__ == 'setup':

# Get stage either by grabbing from args or checking kwargs.
# If not provided, set call status of 'fit' and 'test' to True.
# We do this so __attach_datamodule in trainer.py doesn't mistakenly call setup('test') on trainer.test()
stage = args[1] if len(args) > 1 else kwargs.get('stage', None)

if stage == 'fit' or stage is None:
obj._has_setup_fit = True

if stage == 'test' or stage is None:
obj._has_setup_test = True

if fn.__name__ == 'prepare_data':
obj._has_prepared_data = True

return fn(*args, **kwargs)

return wrapped_fn

Expand Down Expand Up @@ -116,6 +137,11 @@ def __init__(
self._test_transforms = test_transforms
self.dims = ()

# Private attrs to keep track of whether or not data hooks have been called yet
self._has_prepared_data = False
self._has_setup_fit = False
self._has_setup_test = False

@property
def train_transforms(self):
"""
Expand Down Expand Up @@ -159,6 +185,33 @@ def size(self, dim=None) -> Union[Tuple, int]:

return self.dims

@property
def has_prepared_data(self):
"""Return bool letting you know if datamodule.prepare_data() has been called or not.
Returns:
bool: True if datamodule.prepare_data() has been called. False by default.
"""
return self._has_prepared_data

@property
def has_setup_fit(self):
"""Return bool letting you know if datamodule.setup('fit') has been called or not.
Returns:
bool: True if datamodule.setup('fit') has been called. False by default.
"""
return self._has_setup_fit

@property
def has_setup_test(self):
"""Return bool letting you know if datamodule.setup('test') has been called or not.
Returns:
bool: True if datamodule.setup('test') has been called. False by default.
"""
return self._has_setup_test

@abstractmethod
def prepare_data(self, *args, **kwargs):
"""
Expand All @@ -181,14 +234,14 @@ def prepare_data(self):
"""

@abstractmethod
def setup(self, *args, **kwargs):
def setup(self, stage: Optional[str] = None):
"""
Use this to load your data from file, split it, etc. You are safe to make state assignments here.
This hook is called on every process when using DDP.
Example::
def setup(self):
def setup(self, stage):
data = load_data(...)
self.train_ds, self.val_ds, self.test_ds = split_data(data)
"""
Expand Down
24 changes: 18 additions & 6 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,7 +1024,7 @@ def fit(

# set up the passed in dataloaders (if needed)
self.__attach_dataloaders(model, train_dataloader, val_dataloaders)
self.__attach_datamodule(model, datamodule)
self.__attach_datamodule(model, datamodule, 'fit')

# check that model is configured correctly
self.config_validator.verify_loop_configurations(model)
Expand Down Expand Up @@ -1152,17 +1152,29 @@ def __attach_dataloaders(self, model, train_dataloader=None, val_dataloaders=Non
if test_dataloaders is not None:
model.test_dataloader = _PatchDataLoader(test_dataloaders)

def __attach_datamodule(self, model, datamodule=None):
def __attach_datamodule(self, model, datamodule, stage):

# We use datamodule if it's been provided on .fit or .test, otherwise we check model for it
datamodule = datamodule or getattr(model, 'datamodule', None)

# If we have a datamodule, attach necessary hooks + dataloaders
if datamodule:
if self.is_overridden('prepare_data', datamodule) and not datamodule.prepare_data.has_been_called:

# If datamodule.prepare_data() has not been called yet, call it
if self.is_overridden('prepare_data', datamodule) and not datamodule.has_prepared_data:
datamodule.prepare_data()
if self.is_overridden('setup', datamodule) and not datamodule.setup.has_been_called:
datamodule.setup()

# If datamodule.setup('fit') has not been called yet, call it
if stage == 'fit':
if self.is_overridden('setup', datamodule) and not datamodule.has_setup_fit:
datamodule.setup('fit')

# If datamodule.setup('test') has not been called yet, call it
if stage == 'test':
if self.is_overridden('setup', datamodule) and not datamodule.has_setup_test:
datamodule.setup('test')

# Override loader hooks
if self.is_overridden('train_dataloader', datamodule):
model.train_dataloader = datamodule.train_dataloader
if self.is_overridden('val_dataloader', datamodule):
Expand Down Expand Up @@ -1367,7 +1379,7 @@ def test(
)

# Attach datamodule to get setup/prepare_data added to model before the call to it below
self.__attach_datamodule(model or self.get_model(), datamodule)
self.__attach_datamodule(model or self.get_model(), datamodule, 'test')

self.setup('test')

Expand Down
17 changes: 11 additions & 6 deletions tests/base/datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,17 @@ def __init__(self, data_dir: str = './'):
def prepare_data(self):
TrialMNIST(self.data_dir, train=True, download=True)
TrialMNIST(self.data_dir, train=False, download=True)

def setup(self):
mnist_full = TrialMNIST(root=self.data_dir, train=True, num_samples=64, download=True)
self.mnist_train, self.mnist_val = random_split(mnist_full, [128, 64])
self.dims = tuple(self.mnist_train[0][0].shape)
self.mnist_test = TrialMNIST(root=self.data_dir, train=False, num_samples=32, download=True)

def setup(self, stage: str = None):

if stage == 'fit' or stage is None:
mnist_full = TrialMNIST(root=self.data_dir, train=True, num_samples=64, download=True)
self.mnist_train, self.mnist_val = random_split(mnist_full, [128, 64])
self.dims = self.mnist_train[0][0].shape

if stage == 'test' or stage is None:
self.mnist_test = TrialMNIST(root=self.data_dir, train=False, num_samples=32, download=True)
self.dims = getattr(self, 'dims', self.mnist_test[0][0].shape)

def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=32)
Expand Down
68 changes: 61 additions & 7 deletions tests/core/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,64 @@ def test_base_datamodule(tmpdir):
dm.setup()


def test_dm_has_been_called(tmpdir):
def test_base_datamodule_with_verbose_setup(tmpdir):
dm = TrialMNISTDataModule()
assert dm.prepare_data.has_been_called is False
assert dm.setup.has_been_called is False
dm.prepare_data()
dm.setup('fit')
dm.setup('test')


def test_data_hooks_called(tmpdir):
dm = TrialMNISTDataModule()
assert dm.has_prepared_data is False
assert dm.has_setup_fit is False
assert dm.has_setup_test is False

dm.prepare_data()
assert dm.prepare_data.has_been_called is True
assert dm.setup.has_been_called is False
assert dm.has_prepared_data is True
assert dm.has_setup_fit is False
assert dm.has_setup_test is False

dm.setup()
assert dm.prepare_data.has_been_called is True
assert dm.setup.has_been_called is True
assert dm.has_prepared_data is True
assert dm.has_setup_fit is True
assert dm.has_setup_test is True


def test_data_hooks_called_verbose(tmpdir):
dm = TrialMNISTDataModule()
assert dm.has_prepared_data is False
assert dm.has_setup_fit is False
assert dm.has_setup_test is False

dm.prepare_data()
assert dm.has_prepared_data is True
assert dm.has_setup_fit is False
assert dm.has_setup_test is False

dm.setup('fit')
assert dm.has_prepared_data is True
assert dm.has_setup_fit is True
assert dm.has_setup_test is False

dm.setup('test')
assert dm.has_prepared_data is True
assert dm.has_setup_fit is True
assert dm.has_setup_test is True


def test_data_hooks_called_with_stage_kwarg(tmpdir):
dm = TrialMNISTDataModule()
dm.prepare_data()
assert dm.has_prepared_data is True

dm.setup(stage='fit')
assert dm.has_setup_fit is True
assert dm.has_setup_test is False

dm.setup(stage='test')
assert dm.has_setup_fit is True
assert dm.has_setup_test is True


def test_dm_add_argparse_args(tmpdir):
Expand Down Expand Up @@ -58,6 +104,14 @@ def test_dm_pickle_after_setup(tmpdir):
pickle.dumps(dm)


def test_dm_pickle_after_setup_verbose(tmpdir):
dm = TrialMNISTDataModule()
dm.prepare_data()
dm.setup('fit')
dm.setup('test')
pickle.dumps(dm)


def test_train_loop_only(tmpdir):
dm = TrialMNISTDataModule(tmpdir)

Expand Down

0 comments on commit 7b49a56

Please sign in to comment.