Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test_dataloaders to test method #1434

Merged
merged 7 commits into from
Apr 10, 2020
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,7 @@ def __set_random_port(self):
default_port = random.randint(10000, 19000)
os.environ['MASTER_PORT'] = str(default_port)

def __attach_dataloaders(self, model, train_dataloader, val_dataloaders, test_dataloaders):
def __attach_dataloaders(self, model, train_dataloader=None, val_dataloaders=None, test_dataloaders=None):
Borda marked this conversation as resolved.
Show resolved Hide resolved
# when dataloader is passed via fit, patch the train_dataloader
# functions to overwrite with these implementations
if train_dataloader is not None:
Expand Down Expand Up @@ -863,14 +863,20 @@ def run_pretrain_routine(self, model: LightningModule):
# CORE TRAINING LOOP
self.train()

def test(self, model: Optional[LightningModule] = None):
def test(self,
model: Optional[LightningModule] = None,
test_dataloaders: Optional[DataLoader] = None
):
r"""

Separates from fit to make sure you never run on your test set until you want to.

Args:
model: The model to test.

test_dataloaders: Either a single
Pytorch Dataloader or a list of them, specifying validation samples.

Example::

# Option 1
Expand All @@ -889,9 +895,13 @@ def test(self, model: Optional[LightningModule] = None):
"""

self.testing = True

if (test_dataloaders is not None) and (self.model is not None):
self.__attach_dataloaders(self.model, test_dataloaders=test_dataloaders)

if model is not None:
self.model = model
self.fit(model)
self.fit(model, test_dataloaders=test_dataloaders)
elif self.use_ddp or self.use_tpu: # pragma: no-cover
# attempt to load weights from a spawn
path = os.path.join(self.default_save_path, '__temp_weight_ddp_end.ckpt')
Expand Down