From e79ae18cae352505524215b6f3617f862ad024f2 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Fri, 10 Apr 2020 21:14:03 +0530 Subject: [PATCH] Add test_dataloaders to test method (#1434) * Add test_dataloaders to test method * Remove test_dataloaders from .fit() * Fix code comment * Fix tests * Add test_dataloaders to test method (#1393) * Fix failing tests * Update docs (#1393) --- CHANGELOG.md | 5 ++++ pytorch_lightning/trainer/trainer.py | 41 ++++++++++++++++------------ tests/trainer/test_dataloaders.py | 28 ++++++++++--------- 3 files changed, 43 insertions(+), 31 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fd29d715a8686..6380ec2e7d4d6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added - Added `rank_zero_warn` for warning only in rank 0 ([#1428](https://github.com/PyTorchLightning/pytorch-lightning/pull/1428)) +- Added `test_dataloaders` parameter to `Trainer.test()` ([#1393](https://github.com/PyTorchLightning/pytorch-lightning/issues/1393)) + +### Changed + +- Removed `test_dataloaders` parameter from `Trainer.fit()` ([#1393](https://github.com/PyTorchLightning/pytorch-lightning/issues/1393)) ### Fixed diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 3e3b7c9679df9..2286e1ba2c626 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -613,8 +613,7 @@ def fit( self, model: LightningModule, train_dataloader: Optional[DataLoader] = None, - val_dataloaders: Optional[DataLoader] = None, - test_dataloaders: Optional[DataLoader] = None + val_dataloaders: Optional[DataLoader] = None ): r""" Runs the full optimization routine. @@ -630,14 +629,10 @@ def fit( Pytorch Dataloader or a list of them, specifying validation samples. If the model has a predefined val_dataloaders method this will be skipped - test_dataloaders: Either a single - Pytorch Dataloader or a list of them, specifying validation samples. - If the model has a predefined test_dataloaders method this will be skipped - Example:: # Option 1, - # Define the train_dataloader(), test_dataloader() and val_dataloader() fxs + # Define the train_dataloader() and val_dataloader() fxs # in the lightningModule # RECOMMENDED FOR MOST RESEARCH AND APPLICATIONS TO MAINTAIN READABILITY trainer = Trainer() @@ -647,15 +642,13 @@ def fit( # Option 2 # in production cases we might want to pass different datasets to the same model # Recommended for PRODUCTION SYSTEMS - train, val, test = DataLoader(...), DataLoader(...), DataLoader(...) + train, val = DataLoader(...), DataLoader(...) trainer = Trainer() model = LightningModule() - trainer.fit(model, train_dataloader=train, - val_dataloader=val, test_dataloader=test) + trainer.fit(model, train_dataloader=train, val_dataloader=val) # Option 1 & 2 can be mixed, for example the training set can be - # defined as part of the model, and validation/test can then be - # feed to .fit() + # defined as part of the model, and validation can then be feed to .fit() """ # bind logger and other properties @@ -663,7 +656,7 @@ def fit( self.copy_trainer_model_properties(model) # set up the passed in dataloaders (if needed) - self.__attach_dataloaders(model, train_dataloader, val_dataloaders, test_dataloaders) + self.__attach_dataloaders(model, train_dataloader, val_dataloaders) # check that model is configured correctly self.check_model_configuration(model) @@ -748,7 +741,7 @@ def __set_random_port(self): default_port = random.randint(10000, 19000) os.environ['MASTER_PORT'] = str(default_port) - def __attach_dataloaders(self, model, train_dataloader, val_dataloaders, test_dataloaders): + def __attach_dataloaders(self, model, train_dataloader=None, val_dataloaders=None, test_dataloaders=None): # when dataloader is passed via fit, patch the train_dataloader # functions to overwrite with these implementations if train_dataloader is not None: @@ -864,7 +857,7 @@ def run_pretrain_routine(self, model: LightningModule): # CORE TRAINING LOOP self.train() - def test(self, model: Optional[LightningModule] = None): + def test(self, model: Optional[LightningModule] = None, test_dataloaders: Optional[DataLoader] = None): r""" Separates from fit to make sure you never run on your test set until you want to. @@ -872,24 +865,36 @@ def test(self, model: Optional[LightningModule] = None): Args: model: The model to test. + test_dataloaders: Either a single + Pytorch Dataloader or a list of them, specifying validation samples. + Example:: # Option 1 # run test after fitting + test = DataLoader(...) trainer = Trainer() model = LightningModule() - trainer.fit() - trainer.test() + trainer.fit(model) + trainer.test(test_dataloaders=test) # Option 2 # run test from a loaded model + test = DataLoader(...) model = LightningModule.load_from_checkpoint('path/to/checkpoint.ckpt') trainer = Trainer() - trainer.test(model) + trainer.test(model, test_dataloaders=test) """ self.testing = True + + if test_dataloaders is not None: + if model is not None: + self.__attach_dataloaders(model, test_dataloaders=test_dataloaders) + else: + self.__attach_dataloaders(self.model, test_dataloaders=test_dataloaders) + if model is not None: self.model = model self.fit(model) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 408774430c398..42d2fe7d1bd56 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -260,12 +260,12 @@ class CurrentTestModel( model = CurrentTestModel(hparams) trainer = Trainer(**trainer_options) fit_options = dict(train_dataloader=model._dataloader(train=True), - val_dataloaders=model._dataloader(train=False), - test_dataloaders=model._dataloader(train=False)) + val_dataloaders=model._dataloader(train=False)) + test_options = dict(test_dataloaders=model._dataloader(train=False)) result = trainer.fit(model, **fit_options) - trainer.test() + trainer.test(**test_options) assert result == 1 assert len(trainer.val_dataloaders) == 1, \ @@ -300,11 +300,12 @@ class CurrentTestModel( trainer = Trainer(**trainer_options) fit_options = dict(train_dataloader=model._dataloader(train=True), val_dataloaders=[model._dataloader(train=False), - model._dataloader(train=False)], - test_dataloaders=[model._dataloader(train=False), - model._dataloader(train=False)]) + model._dataloader(train=False)]) + test_options = dict(test_dataloaders=[model._dataloader(train=False), + model._dataloader(train=False)]) + results = trainer.fit(model, **fit_options) - trainer.test() + trainer.test(**test_options) assert len(trainer.val_dataloaders) == 2, \ f'Multiple `val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' @@ -342,10 +343,11 @@ class CurrentTestModel( # fit model trainer = Trainer(**trainer_options) - fit_options = dict(val_dataloaders=model._dataloader(train=False), - test_dataloaders=model._dataloader(train=False)) + fit_options = dict(val_dataloaders=model._dataloader(train=False)) + test_options = dict(test_dataloaders=model._dataloader(train=False)) + _ = trainer.fit(model, **fit_options) - trainer.test() + trainer.test(**test_options) assert len(trainer.val_dataloaders) == 1, \ f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' @@ -511,8 +513,8 @@ class CurrentTestModel( ) fit_options = dict(train_dataloader=model._dataloader(train=True), - val_dataloaders=model._dataloader(train=False), - test_dataloaders=model._dataloader(train=False)) + val_dataloaders=model._dataloader(train=False)) + test_options = dict(test_dataloaders=model._dataloader(train=False)) trainer = Trainer(**trainer_options) @@ -524,7 +526,7 @@ class CurrentTestModel( trainer.fit(model, **fit_options) with pytest.warns(UserWarning, match='test'): - trainer.test() + trainer.test(**test_options) @pytest.mark.skipif(torch.cuda.device_count() < 2, reason='Test requires multiple GPUs')