Skip to content

Commit

Permalink
Add test_dataloaders to test method (#1434)
Browse files Browse the repository at this point in the history
* Add test_dataloaders to test method

* Remove test_dataloaders from .fit()

* Fix code comment

* Fix tests

* Add test_dataloaders to test method (#1393)

* Fix failing tests

* Update docs (#1393)
  • Loading branch information
rohitgr7 committed Apr 10, 2020
1 parent 4c34d16 commit e79ae18
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 31 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added

- Added `rank_zero_warn` for warning only in rank 0 ([#1428](https://github.com/PyTorchLightning/pytorch-lightning/pull/1428))
- Added `test_dataloaders` parameter to `Trainer.test()` ([#1393](https://github.com/PyTorchLightning/pytorch-lightning/issues/1393))

### Changed

- Removed `test_dataloaders` parameter from `Trainer.fit()` ([#1393](https://github.com/PyTorchLightning/pytorch-lightning/issues/1393))

### Fixed

Expand Down
41 changes: 23 additions & 18 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,8 +613,7 @@ def fit(
self,
model: LightningModule,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Optional[DataLoader] = None,
test_dataloaders: Optional[DataLoader] = None
val_dataloaders: Optional[DataLoader] = None
):
r"""
Runs the full optimization routine.
Expand All @@ -630,14 +629,10 @@ def fit(
Pytorch Dataloader or a list of them, specifying validation samples.
If the model has a predefined val_dataloaders method this will be skipped
test_dataloaders: Either a single
Pytorch Dataloader or a list of them, specifying validation samples.
If the model has a predefined test_dataloaders method this will be skipped
Example::
# Option 1,
# Define the train_dataloader(), test_dataloader() and val_dataloader() fxs
# Define the train_dataloader() and val_dataloader() fxs
# in the lightningModule
# RECOMMENDED FOR MOST RESEARCH AND APPLICATIONS TO MAINTAIN READABILITY
trainer = Trainer()
Expand All @@ -647,23 +642,21 @@ def fit(
# Option 2
# in production cases we might want to pass different datasets to the same model
# Recommended for PRODUCTION SYSTEMS
train, val, test = DataLoader(...), DataLoader(...), DataLoader(...)
train, val = DataLoader(...), DataLoader(...)
trainer = Trainer()
model = LightningModule()
trainer.fit(model, train_dataloader=train,
val_dataloader=val, test_dataloader=test)
trainer.fit(model, train_dataloader=train, val_dataloader=val)
# Option 1 & 2 can be mixed, for example the training set can be
# defined as part of the model, and validation/test can then be
# feed to .fit()
# defined as part of the model, and validation can then be feed to .fit()
"""
# bind logger and other properties
model.logger = self.logger
self.copy_trainer_model_properties(model)

# set up the passed in dataloaders (if needed)
self.__attach_dataloaders(model, train_dataloader, val_dataloaders, test_dataloaders)
self.__attach_dataloaders(model, train_dataloader, val_dataloaders)

# check that model is configured correctly
self.check_model_configuration(model)
Expand Down Expand Up @@ -748,7 +741,7 @@ def __set_random_port(self):
default_port = random.randint(10000, 19000)
os.environ['MASTER_PORT'] = str(default_port)

def __attach_dataloaders(self, model, train_dataloader, val_dataloaders, test_dataloaders):
def __attach_dataloaders(self, model, train_dataloader=None, val_dataloaders=None, test_dataloaders=None):
# when dataloader is passed via fit, patch the train_dataloader
# functions to overwrite with these implementations
if train_dataloader is not None:
Expand Down Expand Up @@ -864,32 +857,44 @@ def run_pretrain_routine(self, model: LightningModule):
# CORE TRAINING LOOP
self.train()

def test(self, model: Optional[LightningModule] = None):
def test(self, model: Optional[LightningModule] = None, test_dataloaders: Optional[DataLoader] = None):
r"""
Separates from fit to make sure you never run on your test set until you want to.
Args:
model: The model to test.
test_dataloaders: Either a single
Pytorch Dataloader or a list of them, specifying validation samples.
Example::
# Option 1
# run test after fitting
test = DataLoader(...)
trainer = Trainer()
model = LightningModule()
trainer.fit()
trainer.test()
trainer.fit(model)
trainer.test(test_dataloaders=test)
# Option 2
# run test from a loaded model
test = DataLoader(...)
model = LightningModule.load_from_checkpoint('path/to/checkpoint.ckpt')
trainer = Trainer()
trainer.test(model)
trainer.test(model, test_dataloaders=test)
"""

self.testing = True

if test_dataloaders is not None:
if model is not None:
self.__attach_dataloaders(model, test_dataloaders=test_dataloaders)
else:
self.__attach_dataloaders(self.model, test_dataloaders=test_dataloaders)

if model is not None:
self.model = model
self.fit(model)
Expand Down
28 changes: 15 additions & 13 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,12 +260,12 @@ class CurrentTestModel(
model = CurrentTestModel(hparams)
trainer = Trainer(**trainer_options)
fit_options = dict(train_dataloader=model._dataloader(train=True),
val_dataloaders=model._dataloader(train=False),
test_dataloaders=model._dataloader(train=False))
val_dataloaders=model._dataloader(train=False))
test_options = dict(test_dataloaders=model._dataloader(train=False))

result = trainer.fit(model, **fit_options)

trainer.test()
trainer.test(**test_options)

assert result == 1
assert len(trainer.val_dataloaders) == 1, \
Expand Down Expand Up @@ -300,11 +300,12 @@ class CurrentTestModel(
trainer = Trainer(**trainer_options)
fit_options = dict(train_dataloader=model._dataloader(train=True),
val_dataloaders=[model._dataloader(train=False),
model._dataloader(train=False)],
test_dataloaders=[model._dataloader(train=False),
model._dataloader(train=False)])
model._dataloader(train=False)])
test_options = dict(test_dataloaders=[model._dataloader(train=False),
model._dataloader(train=False)])

results = trainer.fit(model, **fit_options)
trainer.test()
trainer.test(**test_options)

assert len(trainer.val_dataloaders) == 2, \
f'Multiple `val_dataloaders` not initiated properly, got {trainer.val_dataloaders}'
Expand Down Expand Up @@ -342,10 +343,11 @@ class CurrentTestModel(

# fit model
trainer = Trainer(**trainer_options)
fit_options = dict(val_dataloaders=model._dataloader(train=False),
test_dataloaders=model._dataloader(train=False))
fit_options = dict(val_dataloaders=model._dataloader(train=False))
test_options = dict(test_dataloaders=model._dataloader(train=False))

_ = trainer.fit(model, **fit_options)
trainer.test()
trainer.test(**test_options)

assert len(trainer.val_dataloaders) == 1, \
f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}'
Expand Down Expand Up @@ -511,8 +513,8 @@ class CurrentTestModel(
)

fit_options = dict(train_dataloader=model._dataloader(train=True),
val_dataloaders=model._dataloader(train=False),
test_dataloaders=model._dataloader(train=False))
val_dataloaders=model._dataloader(train=False))
test_options = dict(test_dataloaders=model._dataloader(train=False))

trainer = Trainer(**trainer_options)

Expand All @@ -524,7 +526,7 @@ class CurrentTestModel(
trainer.fit(model, **fit_options)

with pytest.warns(UserWarning, match='test'):
trainer.test()
trainer.test(**test_options)


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='Test requires multiple GPUs')
Expand Down

0 comments on commit e79ae18

Please sign in to comment.