Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

new way of passing dataloaders #759

Merged
merged 24 commits into from
Feb 19, 2020
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 52 additions & 5 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,17 +713,64 @@ def tng_tqdm_dic(self):
# -----------------------------
# MODEL TRAINING
# -----------------------------
def fit(self, model):
def fit(self, model, train_dataloader=None, val_dataloader=None, test_dataloader=None):
r"""
Runs the full optimization routine.

Example::
Args:
model (LightningModule): Model to fit.
Example::

trainer = Trainer()
model = LightningModule()
trainer = Trainer()
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
model = LightningModule()
trainer.fit(model)

train_dataloader (:class:`.torch.utils.data.DataLoader`): A Pytorch
DataLoader with training samples. If the model has
a predefined train_dataloader method this will be skipped.

val_dataloader (:class:`.torch.utils.data.DataLoader`): A Pytorch
DataLoader with validation samples. If the model has
a predefined val_dataloader method this will be skipped

test_dataloader (:class:`.torch.utils.data.DataLoader`): A Pytorch
DataLoader with test samples. If the model has a
a predefined test_dataloader method this will be skipped

trainer.fit()
"""
if train_dataloader:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would make it a function since the three are almost the same:

def _set_dataloader(dataloader, attribute):
    if isinstance(dataloader, torch.utils.data.DataLoader):
        if getattr(model, attribute) is None:
            setattr(model, attribute, lambda: dataloader)
        else:
            logging.info(f'Model has predefined `{attribute}`, '
                         'will skip the `{attribute}` passed to fit method.')
    elif dataloader:
        raise ValueError(f'`{attribute}` needs to be an instance of `torch.utils.data.DataLoader`')
    
_set_dataloader(train_dataloader, 'train_dataloader')
...

if isinstance(train_dataloader, torch.utils.data.DataLoader):
if model.train_dataloader() is None:
model.train_dataloader = lambda: train_dataloader
else:
logging.info('Model has predefined train_dataloader, '
'will skip the train_dataloader passed to fit method')
else:
raise ValueError('train_dataloader needs to be an instance'
'of torch.utils.data.DataLoader')

if val_dataloader:
if isinstance(val_dataloader, torch.utils.data.DataLoader):
if model.val_dataloader() is None:
model.val_dataloader = lambda: val_dataloader
else:
logging.info('Model has predefined val_dataloader, '
'will skip the val_dataloader passed to fit method ')
else:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you may merge this with the first if, so elif val_dataloader:

raise ValueError('val_dataloader needs to be an instance '
'of torch.utils.data.DataLoader')

if test_dataloader:
if isinstance(test_dataloader, torch.utils.data.DataLoader):
if model.test_dataloader() is None:
model.test_dataloader = lambda: test_dataloader
else:
logging.info('Model has predefined test_dataloader,'
'will skip the test_dataloader passed to fit method ')
else:
raise ValueError('test_dataloader needs to be an instance'
'of torch.utils.data.DataLoader')

# when using multi-node or DDP within a node start each module in a separate process
if self.use_ddp2:
task = int(os.environ['SLURM_LOCALID'])
Expand Down