Skip to content

Commit

Permalink
new way of passing dataloaders (#759)
Browse files Browse the repository at this point in the history
* new way of passing dataloaders

* fixed docs

* fixed codestyle to follow flake8

* allow val/test be list of dataloaders and smarter checking

* added test

* fix flake error

* fix linking to new test model

* split into multiple test

* fix naming and typo

* minor documentation changes

* remove random file

* Update trainer.py

* Update trainer.py

* Update trainer.py

* Update trainer.py

* Update trainer.py

* Update trainer.py

* better error/warning message

* final adjustments

* update CHANGELOG.md

Co-authored-by: William Falcon <waf2107@columbia.edu>
  • Loading branch information
SkafteNicki and williamFalcon committed Feb 19, 2020
1 parent b9b5a93 commit ffd6e69
Show file tree
Hide file tree
Showing 7 changed files with 267 additions and 13 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a tool for profiling training runs ([#782](https://github.com/PyTorchLightning/pytorch-lightning/pull/782))
- Improved flexibility for naming of TensorBoard logs, can now set `version` to a `str` to just save to that directory, and use `name=''` to prevent experiment-name directory ([#804](https://github.com/PyTorchLightning/pytorch-lightning/pull/804))
- Added option to specify `step` key when logging metrics ([#808](https://github.com/PyTorchLightning/pytorch-lightning/pull/808))
- Added `train_dataloader`, `val_dataloader` and `test_dataloader` arguments to `Trainer.fit()`, for alternative data parsing ([#759]([https://github.com/PyTorchLightning/pytorch-lightning/pull/759]))
### Changed
- Changed default TQDM to use `tqdm.auto` for prettier outputs in IPython notebooks ([#752](https://github.com/PyTorchLightning/pytorch-lightning/pull/752))
- Changed `pytorch_lightning.logging` to `pytorch_lightning.loggers` ([#767](https://github.com/PyTorchLightning/pytorch-lightning/pull/767))
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def data_loader(fn):
"""
wraps(fn)
attr_name = '_lazy_' + fn.__name__

@wraps(fn)
def _get_data_loader(self):
try:
value = getattr(self, attr_name)
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,7 +869,6 @@ def tbptt_split_batch(self, batch, split_size):
return splits

@data_loader
@abstractmethod
def train_dataloader(self):
"""Implement a PyTorch DataLoader
Expand All @@ -895,8 +894,8 @@ def train_dataloader(self):
)
return loader
"""
return None

@data_loader
def tng_dataloader(self): # todo: remove in v0.8.0
Expand Down
91 changes: 88 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
parse_gpu_ids,
determine_root_gpu_device
)

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.trainer.evaluation_loop import TrainerEvaluationLoopMixin
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
Expand Down Expand Up @@ -840,17 +840,56 @@ def tng_tqdm_dic(self):
# -----------------------------
# MODEL TRAINING
# -----------------------------
def fit(self, model):
def fit(self, model, train_dataloader=None, val_dataloader=None, test_dataloader=None):
r"""
Runs the full optimization routine.
Args:
model (LightningModule): Model to fit.
train_dataloader (:class:`.torch.utils.data.DataLoader`): A Pytorch
DataLoader with training samples. If the model has
a predefined train_dataloader method this will be skipped.
val_dataloader (:class:`.torch.utils.data.DataLoader`): Either a single
Pytorch Dataloader or a list of them, specifying validation samples.
If the model has a predefined val_dataloader method this will be skipped
test_dataloader (:class:`.torch.utils.data.DataLoader`): Either a single
Pytorch Dataloader or a list of them, specifying validation samples.
If the model has a predefined val_dataloader method this will be skipped
Example::
# Option 1,
# Define the train_dataloader(), test_dataloader() and val_dataloader() fxs
# in the lightningModule
# RECOMMENDED FOR MOST RESEARCH AND APPLICATIONS TO MAINTAIN READABILITY
trainer = Trainer()
model = LightningModule()
trainer.fit(model)
# Option 2
# in production cases we might want to pass different datasets to the same model
# Recommended for PRODUCTION SYSTEMS
train, val, test = DataLoader(...), DataLoader(...), DataLoader(...)
trainer = Trainer()
model = LightningModule()
trainer.fit(model, train_dataloader=train,
val_dataloader=val, test_dataloader=test)
# Option 1 & 2 can be mixed, for example the training set can be
# defined as part of the model, and validation/test can then be
# feed to .fit()
trainer.fit()
"""

# Update the dataloader attributes of the model with the ones supplied here,
# if they are not already defined in model
_set_dataloader(model, train_dataloader, 'train_dataloader')
_set_dataloader(model, val_dataloader, 'val_dataloader')
_set_dataloader(model, test_dataloader, 'test_dataloader')

# when using multi-node or DDP within a node start each module in a separate process
if self.use_ddp2:
task = int(os.environ['SLURM_LOCALID'])
Expand Down Expand Up @@ -1048,3 +1087,49 @@ def test(self, model=None):
self.fit(model)
else:
self.run_evaluation(test=True)


def _set_dataloader(model, dataloader, attribute):
r'''
Check dataloaders passed to .fit() method if they are pytorch DataLoader
objects and whether or not we should overright the corresponding dataloader
in the model
Args:
model (LightningModule): The model to check
dataloader: If a pytorch dataloader (or a list of pytorch dataloaders)
is passed, it will be incorporate into the model as model.attribute.
If attribute alreay exist it will warn the userpass. If not a
dataloader will throw an error
attribute (str): The attribute to save the dataloader under
'''
# Check if attribute comes directly from base class or
# derived in user subclass
if LightningModule.__qualname__ in getattr(model, attribute).__qualname__:
# Val and test should be list of dataloaders
dataloader = dataloader if attribute == 'train_dataloader' or \
(attribute != 'train_dataloader' and isinstance(dataloader, list)) else [dataloader]

# Check we are given valid dataloaders
is_dataloader = isinstance(dataloader, torch.utils.data.DataLoader)
is_dataloader_list = isinstance(dataloader, list)
if is_dataloader_list:
valid_loaders = all(isinstance(d, torch.utils.data.DataLoader) for d in dataloader)
if is_dataloader or is_dataloader_list and valid_loaders:

# Overwrite abstract methods
dl = lambda: dataloader
dl.__name__ = attribute
setattr(model, attribute, dl)

elif dataloader and dataloader != [None]:
raise ValueError(f'`{attribute}` needs to be an instance of '
'`torch.utils.data.DataLoader` or a list of '
'DataLoaders, instead got %r`' % dataloader)

elif dataloader: # if default (None) is passed, do not warn the user
warnings.warn(f'Model has predefined `{attribute}`,'
f' will skip `{attribute}={dataloader}` passed to fit method.')
2 changes: 1 addition & 1 deletion tests/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from .base import LightningTestModelBase
from .base import LightningTestModelBase, LightningTestModelBaseWithoutDataloader
from .mixins import (
LightningValidationStepMixin,
LightningValidationMixin,
Expand Down
20 changes: 14 additions & 6 deletions tests/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, root, train=True, transform=None, target_transform=None,
self.targets = self.targets[:num_samples]


class LightningTestModelBase(LightningModule):
class TestModelBase(LightningModule):
"""
Base LightningModule for testing. Implements only the required
interface
Expand All @@ -48,7 +48,7 @@ def __init__(self, hparams, force_remove_distributed_sampler=False):
:param hparams:
"""
# init superclass
super(LightningTestModelBase, self).__init__()
super(TestModelBase, self).__init__()
self.hparams = hparams

self.batch_size = hparams.batch_size
Expand Down Expand Up @@ -178,10 +178,6 @@ def _dataloader(self, train):

return loader

@data_loader
def train_dataloader(self):
return self._dataloader(train=True)

@staticmethod
def add_model_specific_args(parent_parser, root_dir): # pragma: no cover
"""
Expand Down Expand Up @@ -218,3 +214,15 @@ def add_model_specific_args(parent_parser, root_dir): # pragma: no cover
options=[32, 64, 128, 256], tunable=False,
help='batch size will be divided over all gpus being used across all nodes')
return parser


class LightningTestModelBase(TestModelBase):
""" with pre-defined train dataloader """
@data_loader
def train_dataloader(self):
return self._dataloader(train=True)


class LightningTestModelBaseWithoutDataloader(TestModelBase):
""" without pre-defined train dataloader """
pass
161 changes: 161 additions & 0 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from tests.models import (
LightningTestModel,
LightningTestModelBase,
LightningTestModelBaseWithoutDataloader,
LightningValidationStepMixin,
LightningValidationMultipleDataloadersMixin,
LightningTestMultipleDataloadersMixin,
Expand Down Expand Up @@ -449,6 +450,165 @@ class CurrentTestModel(
trainer.test()


def test_train_dataloaders_passed_to_fit(tmpdir):
""" Verify that train dataloader can be passed to fit """
tutils.reset_seed()

class CurrentTestModel(
LightningTestModelBaseWithoutDataloader
):
pass

hparams = tutils.get_hparams()

# logger file to get meta
trainer_options = dict(
default_save_path=tmpdir,
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.2
)

# only train passed to fit
model = CurrentTestModel(hparams)
trainer = Trainer(**trainer_options)
fit_options = dict(train_dataloader=model._dataloader(train=True))
results = trainer.fit(model, **fit_options)


def test_train_val_dataloaders_passed_to_fit(tmpdir):
""" Verify that train & val dataloader can be passed to fit """
tutils.reset_seed()

class CurrentTestModel(
LightningTestModelBaseWithoutDataloader
):
pass

hparams = tutils.get_hparams()

# logger file to get meta
trainer_options = dict(
default_save_path=tmpdir,
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.2
)

# train, val passed to fit
model = CurrentTestModel(hparams)
trainer = Trainer(**trainer_options)
fit_options = dict(train_dataloader=model._dataloader(train=True),
val_dataloader=model._dataloader(train=False))
results = trainer.fit(model, **fit_options)
assert len(trainer.get_val_dataloaders()) == 1, \
f'`val_dataloaders` not initiated properly, got {trainer.get_val_dataloaders()}'


def test_all_dataloaders_passed_to_fit(tmpdir):
""" Verify train, val & test dataloader can be passed to fit """
tutils.reset_seed()

class CurrentTestModel(
LightningTestModelBaseWithoutDataloader
):
pass

hparams = tutils.get_hparams()

# logger file to get meta
trainer_options = dict(
default_save_path=tmpdir,
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.2
)

# train, val and test passed to fit
model = CurrentTestModel(hparams)
trainer = Trainer(**trainer_options)
fit_options = dict(train_dataloader=model._dataloader(train=True),
val_dataloader=model._dataloader(train=False),
test_dataloader=model._dataloader(train=False))
results = trainer.fit(model, **fit_options)

assert len(trainer.get_val_dataloaders()) == 1, \
f'`val_dataloaders` not initiated properly, got {trainer.get_val_dataloaders()}'
assert len(trainer.get_test_dataloaders()) == 1, \
f'`test_dataloaders` not initiated properly, got {trainer.get_test_dataloaders()}'


def test_multiple_dataloaders_passed_to_fit(tmpdir):
""" Verify that multiple val & test dataloaders can be passed to fit """
tutils.reset_seed()

class CurrentTestModel(
LightningTestModelBaseWithoutDataloader
):
pass

hparams = tutils.get_hparams()

# logger file to get meta
trainer_options = dict(
default_save_path=tmpdir,
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.2
)

# train, multiple val and multiple test passed to fit
model = CurrentTestModel(hparams)
trainer = Trainer(**trainer_options)
fit_options = dict(train_dataloader=model._dataloader(train=True),
val_dataloader=[model._dataloader(train=False),
model._dataloader(train=False)],
test_dataloader=[model._dataloader(train=False),
model._dataloader(train=False)])
results = trainer.fit(model, **fit_options)

assert len(trainer.get_val_dataloaders()) == 2, \
f'Multiple `val_dataloaders` not initiated properly, got {trainer.get_val_dataloaders()}'
assert len(trainer.get_test_dataloaders()) == 2, \
f'Multiple `test_dataloaders` not initiated properly, got {trainer.get_test_dataloaders()}'


def test_mixing_of_dataloader_options(tmpdir):
"""Verify that dataloaders can be passed to fit"""
tutils.reset_seed()

class CurrentTestModel(
LightningTestModelBase
):
pass

hparams = tutils.get_hparams()
model = CurrentTestModel(hparams)

# logger file to get meta
trainer_options = dict(
default_save_path=tmpdir,
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.2
)

# fit model
trainer = Trainer(**trainer_options)
fit_options = dict(val_dataloader=model._dataloader(train=False))
results = trainer.fit(model, **fit_options)

# fit model
trainer = Trainer(**trainer_options)
fit_options = dict(val_dataloader=model._dataloader(train=False),
test_dataloader=model._dataloader(train=False))
results = trainer.fit(model, **fit_options)
assert len(trainer.get_val_dataloaders()) == 1, \
f'`val_dataloaders` not initiated properly, got {trainer.get_val_dataloaders()}'
assert len(trainer.get_test_dataloaders()) == 1, \
f'`test_dataloaders` not initiated properly, got {trainer.get_test_dataloaders()}'


def _init_steps_model():
"""private method for initializing a model with 5% train epochs"""
tutils.reset_seed()
Expand Down Expand Up @@ -533,5 +693,6 @@ def test_trainer_min_steps_and_epochs(tmpdir):
assert trainer.global_step >= math.floor(num_train_samples * 1.5) and \
trainer.current_epoch > 0, "Model did not train for at least min_steps"


# if __name__ == '__main__':
# pytest.main([__file__])

0 comments on commit ffd6e69

Please sign in to comment.