Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

new way of passing dataloaders #759

Merged
merged 24 commits into from
Feb 19, 2020
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytorch_lightning/core/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def data_loader(fn):
"""
wraps(fn)
attr_name = '_lazy_' + fn.__name__

@wraps(fn)
def _get_data_loader(self):
try:
value = getattr(self, attr_name)
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,7 +868,6 @@ def tbptt_split_batch(self, batch, split_size):
return splits

@data_loader
@abstractmethod
def train_dataloader(self):
"""Implement a PyTorch DataLoader

Expand All @@ -894,8 +893,8 @@ def train_dataloader(self):
)
return loader


"""
return None

@data_loader
def tng_dataloader(self): # todo: remove in v0.8.0
Expand Down
91 changes: 88 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
parse_gpu_ids,
determine_root_gpu_device
)

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.trainer.evaluation_loop import TrainerEvaluationLoopMixin
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
Expand Down Expand Up @@ -734,17 +734,56 @@ def tng_tqdm_dic(self):
# -----------------------------
# MODEL TRAINING
# -----------------------------
def fit(self, model):
def fit(self, model, train_dataloader=None, val_dataloader=None, test_dataloader=None):
r"""
Runs the full optimization routine.

Args:
model (LightningModule): Model to fit.

train_dataloader (:class:`.torch.utils.data.DataLoader`): A Pytorch
DataLoader with training samples. If the model has
a predefined train_dataloader method this will be skipped.

val_dataloader (:class:`.torch.utils.data.DataLoader`): Either a single
Pytorch Dataloader or a list of them, specifying validation samples.
If the model has a predefined val_dataloader method this will be skipped

test_dataloader (:class:`.torch.utils.data.DataLoader`): Either a single
Pytorch Dataloader or a list of them, specifying validation samples.
If the model has a predefined val_dataloader method this will be skipped

Example::

# Option 1,
# Define the train_dataloader(), test_dataloader() and val_dataloader() fxs
# in the lightningModule
# RECOMMENDED FOR MOST RESEARCH AND APPLICATIONS TO MAINTAIN READABILITY
trainer = Trainer()
model = LightningModule()
trainer.fit(model)

# Option 2
# in production cases we might want to pass different datasets to the same model
# Recommended for PRODUCTION SYSTEMS
train, val, test = DataLoader(...), DataLoader(...), DataLoader(...)
trainer = Trainer()
model = LightningModule()
trainer.fit(model, train_dataloader=train,
val_dataloader=val, test_dataloader=test)

# Option 1 & 2 can be mixed, for example the training set can be
Borda marked this conversation as resolved.
Show resolved Hide resolved
# defined as part of the model, and validation/test can then be
# feed to .fit()

trainer.fit()
"""

# Update the dataloader attributes of the model with the ones supplied here,
# if they are not already defined in model
_set_dataloader(model, train_dataloader, 'train_dataloader')
_set_dataloader(model, val_dataloader, 'val_dataloader')
_set_dataloader(model, test_dataloader, 'test_dataloader')

# when using multi-node or DDP within a node start each module in a separate process
if self.use_ddp2:
task = int(os.environ['SLURM_LOCALID'])
Expand Down Expand Up @@ -935,3 +974,49 @@ def test(self, model=None):
self.fit(model)
else:
self.run_evaluation(test=True)


def _set_dataloader(model, dataloader, attribute):
r'''
Check dataloaders passed to .fit() method if they are pytorch DataLoader
objects and whether or not we should overright the corresponding dataloader
in the model

Args:
model (LightningModule): The model to check

dataloader: If a pytorch dataloader (or a list of pytorch dataloaders)
is passed, it will be incorporate into the model as model.attribute.
If attribute alreay exist it will warn the userpass. If not a
dataloader will throw an error

attribute (str): The attribute to save the dataloader under

'''
# Check if attribute comes directly from base class or
# derived in user subclass
if LightningModule.__qualname__ in getattr(model, attribute).__qualname__:
# Val and test should be list of dataloaders
dataloader = dataloader if attribute == 'train_dataloader' or \
(attribute != 'train_dataloader' and isinstance(dataloader, list)) else [dataloader]
Borda marked this conversation as resolved.
Show resolved Hide resolved

# Check we are given valid dataloaders
is_dataloader = isinstance(dataloader, torch.utils.data.DataLoader)
is_dataloader_list = isinstance(dataloader, list)
if is_dataloader_list:
valid_loaders = all(isinstance(d, torch.utils.data.DataLoader) for d in dataloader)
if is_dataloader or is_dataloader_list and valid_loaders:

# Overwrite abstract methods
dl = lambda: dataloader
dl.__name__ = attribute
Borda marked this conversation as resolved.
Show resolved Hide resolved
setattr(model, attribute, dl)

elif dataloader and dataloader != [None]:
raise ValueError(f'`{attribute}` needs to be an instance of '
'`torch.utils.data.DataLoader` or a list of '
'DataLoaders, instead got %r`' % dataloader)

elif dataloader: # if default (None) is passed, do not warn the user
warnings.warn(f'Model has predefined `{attribute}`,'
f'will skip `{attribute}={dataloader}` passed to fit method')
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
3 changes: 2 additions & 1 deletion tests/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import torch

from .base import LightningTestModelBase
from .base import (LightningTestModelBase,
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
LightningTestModelBaseWithoutDataloader)
from .mixins import (
LightningValidationStepMixin,
LightningValidationMixin,
Expand Down
20 changes: 14 additions & 6 deletions tests/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, root, train=True, transform=None, target_transform=None,
self.targets = self.targets[:num_samples]


class LightningTestModelBase(LightningModule):
class TestModelBase(LightningModule):
"""
Base LightningModule for testing. Implements only the required
interface
Expand All @@ -48,7 +48,7 @@ def __init__(self, hparams, force_remove_distributed_sampler=False):
:param hparams:
"""
# init superclass
super(LightningTestModelBase, self).__init__()
super(TestModelBase, self).__init__()
self.hparams = hparams

self.batch_size = hparams.batch_size
Expand Down Expand Up @@ -178,10 +178,6 @@ def _dataloader(self, train):

return loader

@data_loader
def train_dataloader(self):
return self._dataloader(train=True)

@staticmethod
def add_model_specific_args(parent_parser, root_dir): # pragma: no cover
"""
Expand Down Expand Up @@ -218,3 +214,15 @@ def add_model_specific_args(parent_parser, root_dir): # pragma: no cover
options=[32, 64, 128, 256], tunable=False,
help='batch size will be divided over all gpus being used across all nodes')
return parser


class LightningTestModelBase(TestModelBase):
""" with pre-defined train dataloader """
@data_loader
def train_dataloader(self):
return self._dataloader(train=True)


class LightningTestModelBaseWithoutDataloader(TestModelBase):
""" without pre-defined train dataloader """
pass
159 changes: 159 additions & 0 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from tests.models import (
LightningTestModel,
LightningTestModelBase,
LightningTestModelBaseWithoutDataloader,
LightningValidationStepMixin,
LightningValidationMultipleDataloadersMixin,
LightningTestMultipleDataloadersMixin,
Expand Down Expand Up @@ -411,5 +412,163 @@ class CurrentTestModel(
trainer.test()


def test_train_dataloaders_passed_to_fit(tmpdir):
""" Verify that train dataloader can be passed to fit """
tutils.reset_seed()

class CurrentTestModel(
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
LightningTestModelBaseWithoutDataloader
):
pass

hparams = tutils.get_hparams()

# logger file to get meta
trainer_options = dict(
default_save_path=tmpdir,
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.2
)

# only train passed to fit
model = CurrentTestModel(hparams)
trainer = Trainer(**trainer_options)
fit_options = dict(train_dataloader=model._dataloader(train=True))
results = trainer.fit(model, **fit_options)


def test_train_val_dataloaders_passed_to_fit(tmpdir):
""" Verify that train & val dataloader can be passed to fit """
tutils.reset_seed()

class CurrentTestModel(
LightningTestModelBaseWithoutDataloader
):
pass

hparams = tutils.get_hparams()

# logger file to get meta
trainer_options = dict(
default_save_path=tmpdir,
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.2
)

# train, val passed to fit
model = CurrentTestModel(hparams)
trainer = Trainer(**trainer_options)
fit_options = dict(train_dataloader=model._dataloader(train=True),
val_dataloader=model._dataloader(train=False))
results = trainer.fit(model, **fit_options)
assert len(trainer.get_val_dataloaders()) == 1, \
f'`val_dataloaders` not initiated properly, got {trainer.get_val_dataloaders()}'


def test_all_dataloaders_passed_to_fit(tmpdir):
""" Verify train, val & test dataloader can be passed to fit """
tutils.reset_seed()

class CurrentTestModel(
LightningTestModelBaseWithoutDataloader
):
pass

hparams = tutils.get_hparams()

# logger file to get meta
trainer_options = dict(
default_save_path=tmpdir,
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.2
)

# train, val and test passed to fit
model = CurrentTestModel(hparams)
trainer = Trainer(**trainer_options)
fit_options = dict(train_dataloader=model._dataloader(train=True),
val_dataloader=model._dataloader(train=False),
test_dataloader=model._dataloader(train=False))
results = trainer.fit(model, **fit_options)

assert len(trainer.get_val_dataloaders()) == 1, \
f'`val_dataloaders` not initiated properly, got {trainer.get_val_dataloaders()}'
assert len(trainer.get_test_dataloaders()) == 1, \
f'`test_dataloaders` not initiated properly, got {trainer.get_test_dataloaders()}'


def test_multiple_dataloaders_passed_to_fit(tmpdir):
""" Verify that multiple val & test dataloaders can be passed to fit """
tutils.reset_seed()

class CurrentTestModel(
LightningTestModelBaseWithoutDataloader
):
pass

hparams = tutils.get_hparams()

# logger file to get meta
trainer_options = dict(
default_save_path=tmpdir,
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.2
)

# train, multiple val and multiple test passed to fit
model = CurrentTestModel(hparams)
trainer = Trainer(**trainer_options)
fit_options = dict(train_dataloader=model._dataloader(train=True),
val_dataloader=[model._dataloader(train=False),
model._dataloader(train=False)],
test_dataloader=[model._dataloader(train=False),
model._dataloader(train=False)])
results = trainer.fit(model, **fit_options)

assert len(trainer.get_val_dataloaders()) == 2, \
f'Multiple `val_dataloaders` not initiated properly, got {trainer.get_val_dataloaders()}'
assert len(trainer.get_test_dataloaders()) == 2, \
f'Multiple `test_dataloaders` not initiated properly, got {trainer.get_test_dataloaders()}'


def test_mixing_of_dataloader_options(tmpdir):
"""Verify that dataloaders can be passed to fit"""
tutils.reset_seed()

class CurrentTestModel(
LightningTestModelBase
):
pass

hparams = tutils.get_hparams()
model = CurrentTestModel(hparams)

# logger file to get meta
trainer_options = dict(
default_save_path=tmpdir,
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.2
)

# fit model
trainer = Trainer(**trainer_options)
fit_options = dict(val_dataloader=model._dataloader(train=False))
results = trainer.fit(model, **fit_options)

# fit model
trainer = Trainer(**trainer_options)
fit_options = dict(val_dataloader=model._dataloader(train=False),
test_dataloader=model._dataloader(train=False))
results = trainer.fit(model, **fit_options)
assert len(trainer.get_val_dataloaders()) == 1, \
f'`val_dataloaders` not initiated properly, got {trainer.get_val_dataloaders()}'
assert len(trainer.get_test_dataloaders()) == 1, \
f'`test_dataloaders` not initiated properly, got {trainer.get_test_dataloaders()}'

# if __name__ == '__main__':
# pytest.main([__file__])