Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

new way of passing dataloaders #759

Merged
merged 24 commits into from
Feb 19, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions pytorch_lightning/core/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@ def data_loader(fn):
:return:
"""

wraps(fn)
attr_name = '_lazy_' + fn.__name__

@wraps(fn)
def _get_data_loader(self):
try:
value = getattr(self, attr_name)
Expand Down
64 changes: 58 additions & 6 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
parse_gpu_ids,
determine_root_gpu_device
)

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.trainer.evaluation_loop import TrainerEvaluationLoopMixin
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
Expand Down Expand Up @@ -713,17 +713,38 @@ def tng_tqdm_dic(self):
# -----------------------------
# MODEL TRAINING
# -----------------------------
def fit(self, model):
def fit(self, model, train_dataloader=None, val_dataloader=None, test_dataloader=None):
r"""
Runs the full optimization routine.

Example::
Args:
model (LightningModule): Model to fit.
Example::

trainer = Trainer()
model = LightningModule()
trainer = Trainer()
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
model = LightningModule()
trainer.fit(model)

train_dataloader (:class:`.torch.utils.data.DataLoader`): A Pytorch
DataLoader with training samples. If the model has
a predefined train_dataloader method this will be skipped.

val_dataloader (:class:`.torch.utils.data.DataLoader`): Either a single
Pytorch Dataloader or a list of them, specifying validation samples. If the model has
a predefined val_dataloader method this will be skipped

test_dataloader (:class:`.torch.utils.data.DataLoader`): Either a single
Pytorch Dataloader or a list of them, specifying validation samples. If the model has
a predefined val_dataloader method this will be skipped

trainer.fit()
"""

# Update the dataloader attributes of the model with the ones supplied here,
# if they are not already defined in model
_set_dataloader(model, train_dataloader, 'train_dataloader')
_set_dataloader(model, val_dataloader, 'val_dataloader')
_set_dataloader(model, test_dataloader, 'test_dataloader')

# when using multi-node or DDP within a node start each module in a separate process
if self.use_ddp2:
task = int(os.environ['SLURM_LOCALID'])
Expand Down Expand Up @@ -911,3 +932,34 @@ def test(self, model=None):
self.fit(model)
else:
self.run_evaluation(test=True)


def _set_dataloader(model, dataloader, attribute):
r''' Check dataloaders passed to .fit() method if they are pytorch DataLoader
objects and whether or not we should overright the corresponding dataloader
in the model '''
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved

# Check if attribute comes directly from base class or
# derived in user subclass
if LightningModule.__qualname__ in getattr(model, attribute).__qualname__:
# Val and test should be list of dataloaders
dataloader = dataloader if attribute == 'train_dataloader' or \
(attribute != 'train_dataloader' and isinstance(dataloader, list)) else [dataloader]
Borda marked this conversation as resolved.
Show resolved Hide resolved

# Check if input is correct
if isinstance(dataloader, torch.utils.data.DataLoader) or \
isinstance(dataloader, list) and all(isinstance(d, torch.utils.data.DataLoader) for d in dataloader):

# Overwrite abstract methods
dl = lambda: dataloader
dl.__name__ = attribute
Borda marked this conversation as resolved.
Show resolved Hide resolved
setattr(model, attribute, dl)

elif dataloader:
raise ValueError(f'`{attribute}` needs to be an instance of'
'`torch.utils.data.DataLoader` or a list of, instead got'
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
f'`{dataloader}`')

else:
logging.info(f'Model has predefined `{attribute}`,'
f'will skip `{attribute}` passed to fit method')