Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

new way of passing dataloaders #759

Merged
merged 24 commits into from
Feb 19, 2020
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytorch_lightning/core/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def data_loader(fn):
"""
wraps(fn)
attr_name = '_lazy_' + fn.__name__

@wraps(fn)
def _get_data_loader(self):
try:
value = getattr(self, attr_name)
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,7 +868,6 @@ def tbptt_split_batch(self, batch, split_size):
return splits

@data_loader
@abstractmethod
def train_dataloader(self):
"""Implement a PyTorch DataLoader

Expand All @@ -894,8 +893,8 @@ def train_dataloader(self):
)
return loader


"""
return None

@data_loader
def tng_dataloader(self): # todo: remove in v0.8.0
Expand Down
73 changes: 70 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
parse_gpu_ids,
determine_root_gpu_device
)

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.trainer.evaluation_loop import TrainerEvaluationLoopMixin
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
Expand Down Expand Up @@ -734,17 +734,53 @@ def tng_tqdm_dic(self):
# -----------------------------
# MODEL TRAINING
# -----------------------------
def fit(self, model):
def fit(self, model, train_dataloader=None, val_dataloader=None, test_dataloader=None):
r"""
Runs the full optimization routine.

Args:
model (LightningModule): Model to fit.

train_dataloader (:class:`.torch.utils.data.DataLoader`): A Pytorch
DataLoader with training samples. If the model has
a predefined train_dataloader method this will be skipped.

val_dataloader (:class:`.torch.utils.data.DataLoader`): Either a single
Pytorch Dataloader or a list of them, specifying validation samples. If the model has
a predefined val_dataloader method this will be skipped

test_dataloader (:class:`.torch.utils.data.DataLoader`): Either a single
Pytorch Dataloader or a list of them, specifying validation samples. If the model has
a predefined val_dataloader method this will be skipped

Example::

# Option 1,
# Basic usecase, dataloaders defined as part of the model
trainer = Trainer()
model = LightningModule()
trainer.fit(model)

# Option 2
# Dataloaders passed to fit method
train, val, test = DataLoader(...), DataLoader(...), DataLoader(...)
trainer = Trainer()
model = LightningModule()
trainer.fit(model, train_dataloader=train,
val_dataloader=val, test_dataloader=test)

# Option 1 & 2 can be mixed, for example the training set can be
Borda marked this conversation as resolved.
Show resolved Hide resolved
# defined as part of the model, and validation/test can then be
# feed to .fit()

trainer.fit()
"""

# Update the dataloader attributes of the model with the ones supplied here,
# if they are not already defined in model
_set_dataloader(model, train_dataloader, 'train_dataloader')
_set_dataloader(model, val_dataloader, 'val_dataloader')
_set_dataloader(model, test_dataloader, 'test_dataloader')

# when using multi-node or DDP within a node start each module in a separate process
if self.use_ddp2:
task = int(os.environ['SLURM_LOCALID'])
Expand Down Expand Up @@ -935,3 +971,34 @@ def test(self, model=None):
self.fit(model)
else:
self.run_evaluation(test=True)


def _set_dataloader(model, dataloader, attribute):
r''' Check dataloaders passed to .fit() method if they are pytorch DataLoader
objects and whether or not we should overright the corresponding dataloader
in the model '''
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved

# Check if attribute comes directly from base class or
# derived in user subclass
if LightningModule.__qualname__ in getattr(model, attribute).__qualname__:
# Val and test should be list of dataloaders
dataloader = dataloader if attribute == 'train_dataloader' or \
(attribute != 'train_dataloader' and isinstance(dataloader, list)) else [dataloader]
Borda marked this conversation as resolved.
Show resolved Hide resolved

# Check if input is correct
if isinstance(dataloader, torch.utils.data.DataLoader) or \
(isinstance(dataloader, list) and all(isinstance(d, torch.utils.data.DataLoader) for d in dataloader)):
Borda marked this conversation as resolved.
Show resolved Hide resolved

# Overwrite abstract methods
dl = lambda: dataloader
dl.__name__ = attribute
Borda marked this conversation as resolved.
Show resolved Hide resolved
setattr(model, attribute, dl)

elif not dataloader and dataloader != [None]:
raise ValueError(f'`{attribute}` needs to be an instance of'
'`torch.utils.data.DataLoader` or a list of, instead got'
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
f'`{dataloader}`')

else:
warnings.warn(f'Model has predefined `{attribute}`,'
f'will skip `{attribute}` passed to fit method')
3 changes: 2 additions & 1 deletion tests/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import torch

from .base import LightningTestModelBase
from .base import (LightningTestModelBase,
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
LightningTestModelBaseWithoutDataloader)
from .mixins import (
LightningValidationStepMixin,
LightningValidationMixin,
Expand Down
20 changes: 14 additions & 6 deletions tests/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, root, train=True, transform=None, target_transform=None,
self.targets = self.targets[:num_samples]


class LightningTestModelBase(LightningModule):
class _LightningTestModelBase(LightningModule):
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
"""
Base LightningModule for testing. Implements only the required
interface
Expand All @@ -48,7 +48,7 @@ def __init__(self, hparams, force_remove_distributed_sampler=False):
:param hparams:
"""
# init superclass
super(LightningTestModelBase, self).__init__()
super(_LightningTestModelBase, self).__init__()
self.hparams = hparams

self.batch_size = hparams.batch_size
Expand Down Expand Up @@ -178,10 +178,6 @@ def _dataloader(self, train):

return loader

@data_loader
def train_dataloader(self):
return self._dataloader(train=True)

@staticmethod
def add_model_specific_args(parent_parser, root_dir): # pragma: no cover
"""
Expand Down Expand Up @@ -218,3 +214,15 @@ def add_model_specific_args(parent_parser, root_dir): # pragma: no cover
options=[32, 64, 128, 256], tunable=False,
help='batch size will be divided over all gpus being used across all nodes')
return parser


class LightningTestModelBase(_LightningTestModelBase):
""" with pre-defined train dataloader """
@data_loader
def train_dataloader(self):
return self._dataloader(train=True)


class LightningTestModelBaseWithoutDataloader(_LightningTestModelBase):
""" without pre-defined train dataloader """
pass
101 changes: 100 additions & 1 deletion tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from tests.models import (
LightningTestModel,
LightningTestModelBase,
LightningTestModelBaseWithoutDataloader,
LightningValidationStepMixin,
LightningValidationMultipleDataloadersMixin,
LightningTestMultipleDataloadersMixin,
Expand Down Expand Up @@ -367,7 +368,7 @@ class CurrentTestModel(

# verify there are 2 val loaders
assert len(trainer.get_val_dataloaders()) == 2, \
'Multiple val_dataloaders not initiated properly'
'Multiple val_dataloade not initiated properly'
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved

# make sure predictions are good for each val set
for dataloader in trainer.get_val_dataloaders():
Expand Down Expand Up @@ -411,5 +412,103 @@ class CurrentTestModel(
trainer.test()


def test_dataloaders_passed_to_fit(tmpdir):
"""Verify that dataloaders can be passed to fit"""
tutils.reset_seed()

class CurrentTestModel(
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
LightningTestModelBaseWithoutDataloader
):
pass

hparams = tutils.get_hparams()

# logger file to get meta
trainer_options = dict(
default_save_path=tmpdir,
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.2
)

# fit model
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
# only train passed to fit
model = CurrentTestModel(hparams)
trainer = Trainer(**trainer_options)
fit_options = dict(train_dataloader=model._dataloader(train=True))
results = trainer.fit(model, **fit_options)

# train, val passed to fit
model = CurrentTestModel(hparams)
trainer = Trainer(**trainer_options)
fit_options = dict(train_dataloader=model._dataloader(train=True),
val_dataloader=model._dataloader(train=False))
results = trainer.fit(model, **fit_options)
assert len(trainer.get_val_dataloaders()) == 1, \
'val_dataloaders not initiated properly'

# train, val and test passed to fit
model = CurrentTestModel(hparams)
trainer = Trainer(**trainer_options)
fit_options = dict(train_dataloader=model._dataloader(train=True),
val_dataloader=model._dataloader(train=False),
test_dataloader=model._dataloader(train=False))
results = trainer.fit(model, **fit_options)
assert len(trainer.get_val_dataloaders()) == 1, \
'val_dataloaders not initiated properly'
assert len(trainer.get_test_dataloaders()) == 1, \
'test_dataloaders not initiated properly'

# train, multiple val and multiple test passed to fit
model = CurrentTestModel(hparams)
trainer = Trainer(**trainer_options)
fit_options = dict(train_dataloader=model._dataloader(train=True),
val_dataloader=[model._dataloader(train=False),
model._dataloader(train=False)],
test_dataloader=[model._dataloader(train=False),
model._dataloader(train=False)])
results = trainer.fit(model, **fit_options)

assert len(trainer.get_val_dataloaders()) == 2, \
'Multiple val_dataloaders not initiated properly'
assert len(trainer.get_test_dataloaders()) == 2, \
'Multiple test_dataloaders not initiated properly'


def test_mixing_of_dataloader_options(tmpdir):
"""Verify that dataloaders can be passed to fit"""
tutils.reset_seed()

class CurrentTestModel(
LightningTestModelBase
):
pass

hparams = tutils.get_hparams()
model = CurrentTestModel(hparams)

# logger file to get meta
trainer_options = dict(
default_save_path=tmpdir,
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.2
)

# fit model
trainer = Trainer(**trainer_options)
fit_options = dict(val_dataloader=model._dataloader(train=False))
results = trainer.fit(model, **fit_options)

# fit model
trainer = Trainer(**trainer_options)
fit_options = dict(val_dataloader=model._dataloader(train=False),
test_dataloader=model._dataloader(train=False))
results = trainer.fit(model, **fit_options)
assert len(trainer.get_val_dataloaders()) == 1, \
'val_dataloaders not initiated properly'
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
assert len(trainer.get_test_dataloaders()) == 1, \
'test_dataloaders not initiated properly'

# if __name__ == '__main__':
# pytest.main([__file__])