Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test_dataloaders to test method #1434

Merged
merged 7 commits into from
Apr 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added

- Added `rank_zero_warn` for warning only in rank 0 ([#1428](https://github.com/PyTorchLightning/pytorch-lightning/pull/1428))
- Added `test_dataloaders` parameter to `Trainer.test()` ([#1393](https://github.com/PyTorchLightning/pytorch-lightning/issues/1393))

### Changed

- Removed `test_dataloaders` parameter from `Trainer.fit()` ([#1393](https://github.com/PyTorchLightning/pytorch-lightning/issues/1393))

### Fixed

Expand Down
41 changes: 23 additions & 18 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,8 +612,7 @@ def fit(
self,
model: LightningModule,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Optional[DataLoader] = None,
test_dataloaders: Optional[DataLoader] = None
val_dataloaders: Optional[DataLoader] = None
):
r"""
Runs the full optimization routine.
Expand All @@ -629,14 +628,10 @@ def fit(
Pytorch Dataloader or a list of them, specifying validation samples.
If the model has a predefined val_dataloaders method this will be skipped

test_dataloaders: Either a single
Pytorch Dataloader or a list of them, specifying validation samples.
If the model has a predefined test_dataloaders method this will be skipped

Example::

# Option 1,
# Define the train_dataloader(), test_dataloader() and val_dataloader() fxs
# Define the train_dataloader() and val_dataloader() fxs
# in the lightningModule
# RECOMMENDED FOR MOST RESEARCH AND APPLICATIONS TO MAINTAIN READABILITY
trainer = Trainer()
Expand All @@ -646,23 +641,21 @@ def fit(
# Option 2
# in production cases we might want to pass different datasets to the same model
# Recommended for PRODUCTION SYSTEMS
train, val, test = DataLoader(...), DataLoader(...), DataLoader(...)
train, val = DataLoader(...), DataLoader(...)
trainer = Trainer()
model = LightningModule()
trainer.fit(model, train_dataloader=train,
val_dataloader=val, test_dataloader=test)
trainer.fit(model, train_dataloader=train, val_dataloader=val)

# Option 1 & 2 can be mixed, for example the training set can be
# defined as part of the model, and validation/test can then be
# feed to .fit()
# defined as part of the model, and validation can then be feed to .fit()

"""
# bind logger and other properties
model.logger = self.logger
self.copy_trainer_model_properties(model)

# set up the passed in dataloaders (if needed)
self.__attach_dataloaders(model, train_dataloader, val_dataloaders, test_dataloaders)
self.__attach_dataloaders(model, train_dataloader, val_dataloaders)

# check that model is configured correctly
self.check_model_configuration(model)
Expand Down Expand Up @@ -747,7 +740,7 @@ def __set_random_port(self):
default_port = random.randint(10000, 19000)
os.environ['MASTER_PORT'] = str(default_port)

def __attach_dataloaders(self, model, train_dataloader, val_dataloaders, test_dataloaders):
def __attach_dataloaders(self, model, train_dataloader=None, val_dataloaders=None, test_dataloaders=None):
Borda marked this conversation as resolved.
Show resolved Hide resolved
# when dataloader is passed via fit, patch the train_dataloader
# functions to overwrite with these implementations
if train_dataloader is not None:
Expand Down Expand Up @@ -863,32 +856,44 @@ def run_pretrain_routine(self, model: LightningModule):
# CORE TRAINING LOOP
self.train()

def test(self, model: Optional[LightningModule] = None):
def test(self, model: Optional[LightningModule] = None, test_dataloaders: Optional[DataLoader] = None):
r"""

Separates from fit to make sure you never run on your test set until you want to.

Args:
model: The model to test.

test_dataloaders: Either a single
Pytorch Dataloader or a list of them, specifying validation samples.

Example::

# Option 1
# run test after fitting
test = DataLoader(...)
trainer = Trainer()
model = LightningModule()

trainer.fit()
trainer.test()
trainer.fit(model)
trainer.test(test_dataloaders=test)

# Option 2
# run test from a loaded model
test = DataLoader(...)
model = LightningModule.load_from_checkpoint('path/to/checkpoint.ckpt')
trainer = Trainer()
trainer.test(model)
trainer.test(model, test_dataloaders=test)
"""

self.testing = True

if test_dataloaders is not None:
if model is not None:
self.__attach_dataloaders(model, test_dataloaders=test_dataloaders)
else:
self.__attach_dataloaders(self.model, test_dataloaders=test_dataloaders)

if model is not None:
self.model = model
self.fit(model)
Expand Down
28 changes: 15 additions & 13 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,12 +260,12 @@ class CurrentTestModel(
model = CurrentTestModel(hparams)
trainer = Trainer(**trainer_options)
fit_options = dict(train_dataloader=model._dataloader(train=True),
val_dataloaders=model._dataloader(train=False),
test_dataloaders=model._dataloader(train=False))
val_dataloaders=model._dataloader(train=False))
test_options = dict(test_dataloaders=model._dataloader(train=False))

result = trainer.fit(model, **fit_options)

trainer.test()
trainer.test(**test_options)

assert result == 1
assert len(trainer.val_dataloaders) == 1, \
Expand Down Expand Up @@ -300,11 +300,12 @@ class CurrentTestModel(
trainer = Trainer(**trainer_options)
fit_options = dict(train_dataloader=model._dataloader(train=True),
val_dataloaders=[model._dataloader(train=False),
model._dataloader(train=False)],
test_dataloaders=[model._dataloader(train=False),
model._dataloader(train=False)])
model._dataloader(train=False)])
test_options = dict(test_dataloaders=[model._dataloader(train=False),
model._dataloader(train=False)])

results = trainer.fit(model, **fit_options)
trainer.test()
trainer.test(**test_options)

assert len(trainer.val_dataloaders) == 2, \
f'Multiple `val_dataloaders` not initiated properly, got {trainer.val_dataloaders}'
Expand Down Expand Up @@ -342,10 +343,11 @@ class CurrentTestModel(

# fit model
trainer = Trainer(**trainer_options)
fit_options = dict(val_dataloaders=model._dataloader(train=False),
test_dataloaders=model._dataloader(train=False))
fit_options = dict(val_dataloaders=model._dataloader(train=False))
test_options = dict(test_dataloaders=model._dataloader(train=False))

_ = trainer.fit(model, **fit_options)
trainer.test()
trainer.test(**test_options)

assert len(trainer.val_dataloaders) == 1, \
f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}'
Expand Down Expand Up @@ -511,8 +513,8 @@ class CurrentTestModel(
)

fit_options = dict(train_dataloader=model._dataloader(train=True),
val_dataloaders=model._dataloader(train=False),
test_dataloaders=model._dataloader(train=False))
val_dataloaders=model._dataloader(train=False))
test_options = dict(test_dataloaders=model._dataloader(train=False))

trainer = Trainer(**trainer_options)

Expand All @@ -524,7 +526,7 @@ class CurrentTestModel(
trainer.fit(model, **fit_options)

with pytest.warns(UserWarning, match='test'):
trainer.test()
trainer.test(**test_options)


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='Test requires multiple GPUs')
Expand Down