Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test_dataloaders to test method #1434

Merged
merged 7 commits into from
Apr 10, 2020
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 16 additions & 12 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,8 +612,7 @@ def fit(
self,
model: LightningModule,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Optional[DataLoader] = None,
test_dataloaders: Optional[DataLoader] = None
val_dataloaders: Optional[DataLoader] = None
):
r"""
Runs the full optimization routine.
Expand All @@ -629,10 +628,6 @@ def fit(
Pytorch Dataloader or a list of them, specifying validation samples.
If the model has a predefined val_dataloaders method this will be skipped

test_dataloaders: Either a single
Pytorch Dataloader or a list of them, specifying validation samples.
If the model has a predefined test_dataloaders method this will be skipped

Example::

# Option 1,
Expand All @@ -646,11 +641,10 @@ def fit(
# Option 2
# in production cases we might want to pass different datasets to the same model
# Recommended for PRODUCTION SYSTEMS
train, val, test = DataLoader(...), DataLoader(...), DataLoader(...)
train, val = DataLoader(...), DataLoader(...)
trainer = Trainer()
model = LightningModule()
trainer.fit(model, train_dataloader=train,
val_dataloader=val, test_dataloader=test)
trainer.fit(model, train_dataloader=train, val_dataloader=val)

# Option 1 & 2 can be mixed, for example the training set can be
# defined as part of the model, and validation/test can then be
Expand All @@ -662,7 +656,7 @@ def fit(
self.copy_trainer_model_properties(model)

# set up the passed in dataloaders (if needed)
self.__attach_dataloaders(model, train_dataloader, val_dataloaders, test_dataloaders)
self.__attach_dataloaders(model, train_dataloader, val_dataloaders)

# check that model is configured correctly
self.check_model_configuration(model)
Expand Down Expand Up @@ -747,7 +741,7 @@ def __set_random_port(self):
default_port = random.randint(10000, 19000)
os.environ['MASTER_PORT'] = str(default_port)

def __attach_dataloaders(self, model, train_dataloader, val_dataloaders, test_dataloaders):
def __attach_dataloaders(self, model, train_dataloader=None, val_dataloaders=None, test_dataloaders=None):
Borda marked this conversation as resolved.
Show resolved Hide resolved
# when dataloader is passed via fit, patch the train_dataloader
# functions to overwrite with these implementations
if train_dataloader is not None:
Expand Down Expand Up @@ -863,14 +857,17 @@ def run_pretrain_routine(self, model: LightningModule):
# CORE TRAINING LOOP
self.train()

def test(self, model: Optional[LightningModule] = None):
def test(self, model: Optional[LightningModule] = None, test_dataloaders: Optional[DataLoader] = None):
r"""

Separates from fit to make sure you never run on your test set until you want to.

Args:
model: The model to test.

test_dataloaders: Either a single
Pytorch Dataloader or a list of them, specifying validation samples.

Example::

# Option 1
Expand All @@ -889,6 +886,13 @@ def test(self, model: Optional[LightningModule] = None):
"""

self.testing = True

if test_dataloaders is not None:
if model is not None:
self.__attach_dataloaders(model, test_dataloaders=test_dataloaders)
else:
self.__attach_dataloaders(self.model, test_dataloaders=test_dataloaders)

if model is not None:
self.model = model
self.fit(model)
Expand Down