From 6e1a87cba2b7db1100e68d5d52c3b007c2579734 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 20 Mar 2020 14:48:02 +0100 Subject: [PATCH 01/14] add check_model_configuration method --- pytorch_lightning/trainer/trainer.py | 55 ++++++++++++++++++++++------ 1 file changed, 43 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 3f98bf2185929..1be3cbec5cd2d 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -571,6 +571,9 @@ def fit( # set up the passed in dataloaders (if needed) self.__attach_dataloaders(model, train_dataloader, val_dataloaders, test_dataloaders) + # check that model is configured correctly + self.check_model_configuration(model) + # download the data and do whatever transforms we need # do before any spawn calls so that the model can assign properties # only on proc 0 because no spawn has happened yet @@ -655,24 +658,12 @@ def __attach_dataloaders(self, model, train_dataloader, val_dataloaders, test_da # when dataloader is passed via fit, patch the train_dataloader # functions to overwrite with these implementations if train_dataloader is not None: - if not self.is_overriden('training_step', model): - m = 'You called .fit() with a train_dataloader but did not define training_step()' - raise MisconfigurationException(m) - model.train_dataloader = _PatchDataLoader(train_dataloader) if val_dataloaders is not None: - if not self.is_overriden('validation_step', model): - m = 'You called .fit() with a val_dataloaders but did not define validation_step()' - raise MisconfigurationException(m) - model.val_dataloader = _PatchDataLoader(val_dataloaders) if test_dataloaders is not None: - if not self.is_overriden('test_step', model): - m = 'You called .fit() with a test_dataloaders but did not define test_step()' - raise MisconfigurationException(m) - model.test_dataloader = _PatchDataLoader(test_dataloaders) def init_optimizers( @@ -877,6 +868,46 @@ def test(self, model: Optional[LightningModule] = None): self.testing = False + def check_model_configuration(self, model: LightningModule): + if not self.is_overriden('training_step', model): + m = ('No training_step() method defined. Lightning expects as minimum ' + 'a training_step() and training_dataloader() to be defined.') + raise MisconfigurationException(m) + + if not self.is_overriden('train_dataloader', model): + m = ('No train_dataloader() defined. Lightning expects as minimum ' + 'a training_step() and training_dataloader() to be defined.') + raise MisconfigurationException(m) + + if not self.is_overriden('configure_optimizers', model): + m = ('configure_optimizers() method not defined by user, will default ' + 'to Adam optimizer with learning rate set to 0.0001.') + warnings.warn(m) + + if self.is_overriden('val_dataloader', model): + if not self.is_overriden('validation_step', model): + m = ('You have passed in a val_dataloader() but have not defined ' + 'validation_step()') + raise MisconfigurationException(m) + else: + if not self.is_overriden('validation_epoch_end', model): + m = ('You have defined a val_dataloader() and have defined ' + 'a validation_step(), you may also want to define ' + 'validation_epoch_end() for accumulating stats') + warnings.warn(m) + + if self.is_overriden('test_dataloader', model): + if not self.is_overriden('test_step', model): + m = ('You have passed in a test_dataloader() but have not defined ' + 'test_step()') + raise MisconfigurationException(m) + else: + if not self.is_overriden('test_epoch_end', model): + m = ('You have defined a test_dataloader() and have defined ' + 'a test_step(), you may also want to define ' + 'test_epoch_end() for accumulating stats') + warnings.warn(m) + class _PatchDataLoader(object): r''' From 12dfc7e8ce6d4b1fe4372e3dc9f7f7734e7bd47e Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 20 Mar 2020 18:38:57 +0100 Subject: [PATCH 02/14] trying to fix errors --- .../lightning_module_template.py | 33 +++++++++++++++++++ tests/models/__init__.py | 4 +-- tests/trainer/test_trainer.py | 4 +-- 3 files changed, 37 insertions(+), 4 deletions(-) diff --git a/pl_examples/basic_examples/lightning_module_template.py b/pl_examples/basic_examples/lightning_module_template.py index 1880bffa7b7ea..00f899a604624 100644 --- a/pl_examples/basic_examples/lightning_module_template.py +++ b/pl_examples/basic_examples/lightning_module_template.py @@ -223,7 +223,40 @@ def val_dataloader(self): def test_dataloader(self): log.info('Test data loader called.') return self.__dataloader(train=False) + + def test_step(self, batch, batch_idx): + """ + Lightning calls this during testing + :param batch: + :return: + """ + x, y = batch + x = x.view(x.size(0), -1) + y_hat = self.forward(x) + + loss_val = self.loss(y, y_hat) + + # acc + labels_hat = torch.argmax(y_hat, dim=1) + val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) + val_acc = torch.tensor(val_acc) + if self.on_gpu: + val_acc = val_acc.cuda(loss_val.device.index) + + # in DP mode (default) make sure if result is scalar, there's another dim in the beginning + if self.trainer.use_dp or self.trainer.use_ddp2: + loss_val = loss_val.unsqueeze(0) + val_acc = val_acc.unsqueeze(0) + + output = OrderedDict({ + 'test_loss': loss_val, + 'test_acc': val_acc, + }) + + # can also return just a scalar instead of a dict (return loss_val) + return output + @staticmethod def add_model_specific_args(parent_parser, root_dir): # pragma: no-cover """ diff --git a/tests/models/__init__.py b/tests/models/__init__.py index 67206a63d0fe6..b256ea37c3463 100644 --- a/tests/models/__init__.py +++ b/tests/models/__init__.py @@ -30,8 +30,8 @@ class LightningTestModel(LightTrainDataloader, - LightValidationMixin, - LightTestMixin, + LightValidationStepMixin, + LightTestStepMixin, TestModelBase): """Most common test case. Validation and test dataloaders.""" diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 38112e1b6be23..54e687f4ccea2 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -24,7 +24,7 @@ LightValidationStepMixin, LightValidationMultipleDataloadersMixin, LightTrainDataloader, - LightTestDataloader, + LightTestStepMixin, ) @@ -601,7 +601,7 @@ def test_testpass_overrides(tmpdir): class LocalModel(LightTrainDataloader, TestModelBase): pass - class LocalModelNoEnd(LightTrainDataloader, LightTestDataloader, LightEmptyTestStep, TestModelBase): + class LocalModelNoEnd(LightTrainDataloader, LightTestStepMixin, LightEmptyTestStep, TestModelBase): pass class LocalModelNoStep(LightTrainDataloader, TestModelBase): From 45d9e0e8e22e254ebf07b335310e4af3c51e61bf Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 23 Mar 2020 10:50:07 +0100 Subject: [PATCH 03/14] trying to fix tests --- pytorch_lightning/trainer/trainer.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 1be3cbec5cd2d..9a0cb68d2d37e 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -920,5 +920,11 @@ class _PatchDataLoader(object): def __init__(self, dataloader: Union[List[DataLoader], DataLoader]): self.dataloader = dataloader + # Assign __code__, needed for checking if method has been overriden + if isinstance(self.dataloader, (list, tuple)): + self.__code__ = self.dataloader[0].__iter__.__code__ + else: + self.__code__ = self.dataloader.__iter__.__code__ + def __call__(self) -> Union[List[DataLoader], DataLoader]: return self.dataloader From ac609498fca15053749c760e86210a30470572b4 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 23 Mar 2020 11:14:55 +0100 Subject: [PATCH 04/14] added test_epoch_end to lightning template --- CHANGELOG.md | 1 + .../lightning_module_template.py | 47 +++++++++---------- 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0af36b4cf5408..f81a2b1a593f9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for IterableDataset in validation and testing ([#1104](https://github.com/PyTorchLightning/pytorch-lightning/pull/1104)) - Added support for non-primitive types in hparams for TensorboardLogger ([#1130](https://github.com/PyTorchLightning/pytorch-lightning/pull/1130)) - Added a check that stops the training when loss or weights contain NaN or inf values. ([#1097](https://github.com/PyTorchLightning/pytorch-lightning/pull/1097)) +- Added model configuratio checking ([#1199](https://github.com/PyTorchLightning/pytorch-lightning/pull/1199)) ### Changed diff --git a/pl_examples/basic_examples/lightning_module_template.py b/pl_examples/basic_examples/lightning_module_template.py index 00f899a604624..ecb097bb5f676 100644 --- a/pl_examples/basic_examples/lightning_module_template.py +++ b/pl_examples/basic_examples/lightning_module_template.py @@ -223,40 +223,39 @@ def val_dataloader(self): def test_dataloader(self): log.info('Test data loader called.') return self.__dataloader(train=False) - + def test_step(self, batch, batch_idx): """ - Lightning calls this during testing + Lightning calls this during testing, similar to val_step :param batch: - :return: + :return:val """ - x, y = batch - x = x.view(x.size(0), -1) - y_hat = self.forward(x) + output = self.validation_step(batch, batch_idx) + # Rename output keys + output['test_loss'] = output.pop('val_loss') + output['test_acc'] = output.pop('val_acc') - loss_val = self.loss(y, y_hat) + return output - # acc - labels_hat = torch.argmax(y_hat, dim=1) - val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) - val_acc = torch.tensor(val_acc) + def test_epoch_end(self, outputs): + """ + Called at the end of test to aggregate outputs, similar to validation_epoch_end + :param outputs: list of individual outputs of each validation step + :return: + """ + results = self.validation_step_end(outputs) - if self.on_gpu: - val_acc = val_acc.cuda(loss_val.device.index) + # Rename output + tqdm_dict = results['progress_bar'] + tqdm_dict['test_loss'] = tqdm_dict.pop('val_loss') + tqdm_dict['test_acc'] = tqdm_dict.pop('val_acc') - # in DP mode (default) make sure if result is scalar, there's another dim in the beginning - if self.trainer.use_dp or self.trainer.use_ddp2: - loss_val = loss_val.unsqueeze(0) - val_acc = val_acc.unsqueeze(0) + results['progress_bar'] = tqdm_dict + results['log'] = tqdm_dict + results['test_loss'] = results.pop('val_loss') - output = OrderedDict({ - 'test_loss': loss_val, - 'test_acc': val_acc, - }) + return results - # can also return just a scalar instead of a dict (return loss_val) - return output - @staticmethod def add_model_specific_args(parent_parser, root_dir): # pragma: no-cover """ From 1c64dd882fb42404497fec90027ec64a3162c994 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 23 Mar 2020 13:56:21 +0100 Subject: [PATCH 05/14] fix tests --- tests/models/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/__init__.py b/tests/models/__init__.py index b256ea37c3463..67206a63d0fe6 100644 --- a/tests/models/__init__.py +++ b/tests/models/__init__.py @@ -30,8 +30,8 @@ class LightningTestModel(LightTrainDataloader, - LightValidationStepMixin, - LightTestStepMixin, + LightValidationMixin, + LightTestMixin, TestModelBase): """Most common test case. Validation and test dataloaders.""" From 00c686d7af6cb7611f95d77d147308587f05068f Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 23 Mar 2020 14:13:45 +0100 Subject: [PATCH 06/14] fix new test after rebase --- tests/test_deprecated.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py index a79eb7451305f..0fef5ddf6c16b 100644 --- a/tests/test_deprecated.py +++ b/tests/test_deprecated.py @@ -64,6 +64,9 @@ class ModelVer0_6(LightTrainDataloader, LightEmptyTestStep, TestModelBase): def val_dataloader(self): return self._dataloader(train=False) + def validation_step(self, batch, batch_idx, *args, **kwargs): + return {'val_loss': 0.6} + def validation_end(self, outputs): return {'val_loss': 0.6} @@ -77,6 +80,9 @@ class ModelVer0_7(LightTrainDataloader, LightEmptyTestStep, TestModelBase): def val_dataloader(self): return self._dataloader(train=False) + def validation_step(self, batch, batch_idx, *args, **kwargs): + return {'val_loss': 0.7} + def validation_end(self, outputs): return {'val_loss': 0.7} From 210f92b230692bd44a29bcb180b64a5f185b0bb9 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 23 Mar 2020 15:07:05 +0100 Subject: [PATCH 07/14] fix spelling --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a455d82d905e9..c543ef45aa1dc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,7 +14,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for IterableDataset in validation and testing ([#1104](https://github.com/PyTorchLightning/pytorch-lightning/pull/1104)) - Added support for non-primitive types in hparams for TensorboardLogger ([#1130](https://github.com/PyTorchLightning/pytorch-lightning/pull/1130)) - Added a check that stops the training when loss or weights contain NaN or inf values. ([#1097](https://github.com/PyTorchLightning/pytorch-lightning/pull/1097)) -- Added model configuratio checking ([#1199](https://github.com/PyTorchLightning/pytorch-lightning/pull/1199)) +- Added model configuration checking ([#1199](https://github.com/PyTorchLightning/pytorch-lightning/pull/1199)) ### Changed From 3ee083f7236eb983f2266885ec1d10a1ddc6b7ec Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 24 Mar 2020 09:57:14 +0100 Subject: [PATCH 08/14] added more checks --- pytorch_lightning/trainer/trainer.py | 25 +++++++++++++++++++++---- tests/models/mixins.py | 3 +++ tests/test_deprecated.py | 6 ++++++ 3 files changed, 30 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 715fef89c6e77..d36df5ca544ef 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -878,6 +878,14 @@ def test(self, model: Optional[LightningModule] = None): self.testing = False def check_model_configuration(self, model: LightningModule): + r""" + Checks that the model is configured correctly before training is started. + + Args: + model: The model to test. + + """ + # Check training_step, train_dataloader, configure_optimizer methods if not self.is_overriden('training_step', model): m = ('No training_step() method defined. Lightning expects as minimum ' 'a training_step() and training_dataloader() to be defined.') @@ -893,6 +901,7 @@ def check_model_configuration(self, model: LightningModule): 'to Adam optimizer with learning rate set to 0.0001.') warnings.warn(m) + # Check val_dataloader, validation_step and validation_epoch_end if self.is_overriden('val_dataloader', model): if not self.is_overriden('validation_step', model): m = ('You have passed in a val_dataloader() but have not defined ' @@ -904,7 +913,13 @@ def check_model_configuration(self, model: LightningModule): 'a validation_step(), you may also want to define ' 'validation_epoch_end() for accumulating stats') warnings.warn(m) + else: + if self.is_overriden('validation_step', model): + m = ('You have defined validation_step(), but have not passed ' + 'in a val_dataloader().') + raise MisconfigurationException(m) + # Check test_dataloader, test_step and test_epoch_end if self.is_overriden('test_dataloader', model): if not self.is_overriden('test_step', model): m = ('You have passed in a test_dataloader() but have not defined ' @@ -916,6 +931,11 @@ def check_model_configuration(self, model: LightningModule): 'a test_step(), you may also want to define ' 'test_epoch_end() for accumulating stats') warnings.warn(m) + else: + if self.is_overriden('test_step', model): + m = ('You have defined test_step, but have not passed in ' + 'a test_dataloader().') + raise MisconfigurationException(m) class _PatchDataLoader(object): @@ -931,10 +951,7 @@ def __init__(self, dataloader: Union[List[DataLoader], DataLoader]): self.dataloader = dataloader # Assign __code__, needed for checking if method has been overriden - if isinstance(self.dataloader, (list, tuple)): - self.__code__ = self.dataloader[0].__iter__.__code__ - else: - self.__code__ = self.dataloader.__iter__.__code__ + self.__code__ = self.__call__.__code__ def __call__(self) -> Union[List[DataLoader], DataLoader]: return self.dataloader diff --git a/tests/models/mixins.py b/tests/models/mixins.py index 0be691726e209..7a00e539ff926 100644 --- a/tests/models/mixins.py +++ b/tests/models/mixins.py @@ -405,6 +405,9 @@ def test_step(self, batch, batch_idx, dataloader_idx, **kwargs): class LightTestFitSingleTestDataloadersMixin: """Test fit single test dataloaders mixin.""" + def test_dataloader(self): + return self._dataloader(train=False) + def test_step(self, batch, batch_idx, *args, **kwargs): """ Lightning calls this inside the validation loop diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py index 0fef5ddf6c16b..9a45e22c21e61 100644 --- a/tests/test_deprecated.py +++ b/tests/test_deprecated.py @@ -70,6 +70,9 @@ def validation_step(self, batch, batch_idx, *args, **kwargs): def validation_end(self, outputs): return {'val_loss': 0.6} + def test_dataloader(self): + return self._dataloader(train=False) + def test_end(self, outputs): return {'test_loss': 0.6} @@ -86,6 +89,9 @@ def validation_step(self, batch, batch_idx, *args, **kwargs): def validation_end(self, outputs): return {'val_loss': 0.7} + def test_dataloader(self): + return self._dataloader(train=False) + def test_end(self, outputs): return {'test_loss': 0.7} From e7abda19c258df96eb75493fc18b268ca1b76f58 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 25 Mar 2020 17:12:38 +0100 Subject: [PATCH 09/14] updated formating --- pytorch_lightning/trainer/trainer.py | 54 +++++++++++++--------------- 1 file changed, 25 insertions(+), 29 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index d36df5ca544ef..da4e6b38bd6ca 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -887,55 +887,51 @@ def check_model_configuration(self, model: LightningModule): """ # Check training_step, train_dataloader, configure_optimizer methods if not self.is_overriden('training_step', model): - m = ('No training_step() method defined. Lightning expects as minimum ' - 'a training_step() and training_dataloader() to be defined.') - raise MisconfigurationException(m) + raise MisconfigurationException('No training_step() method defined.' + ' Lightning expects as minimum a `training_step()` and' + ' `training_dataloader()` to be defined.') if not self.is_overriden('train_dataloader', model): - m = ('No train_dataloader() defined. Lightning expects as minimum ' - 'a training_step() and training_dataloader() to be defined.') - raise MisconfigurationException(m) + raise MisconfigurationException('No train_dataloader() defined.' + ' Lightning expects as minimum a `training_step()` and' + ' `training_dataloader()` to be defined.') if not self.is_overriden('configure_optimizers', model): - m = ('configure_optimizers() method not defined by user, will default ' - 'to Adam optimizer with learning rate set to 0.0001.') - warnings.warn(m) + warnings.warn('`configure_optimizers()` method not defined by user,' + ' will default to Adam optimizer with learning rate set to 0.0001.', + RuntimeWarning) # Check val_dataloader, validation_step and validation_epoch_end if self.is_overriden('val_dataloader', model): if not self.is_overriden('validation_step', model): - m = ('You have passed in a val_dataloader() but have not defined ' - 'validation_step()') - raise MisconfigurationException(m) + raise MisconfigurationException('You have passed in a `val_dataloader()`' + ' but have not defined `validation_step()`.') else: if not self.is_overriden('validation_epoch_end', model): - m = ('You have defined a val_dataloader() and have defined ' - 'a validation_step(), you may also want to define ' - 'validation_epoch_end() for accumulating stats') - warnings.warn(m) + warnings.warn('You have defined a `val_dataloader()` and have' + ' defined a `validation_step()`, you may also want to' + ' define `validation_epoch_end()` for accumulating stats.', + RuntimeWarning) else: if self.is_overriden('validation_step', model): - m = ('You have defined validation_step(), but have not passed ' - 'in a val_dataloader().') - raise MisconfigurationException(m) + raise MisconfigurationException('You have defined `validation_step()`,' + ' but have not passed in a val_dataloader().') # Check test_dataloader, test_step and test_epoch_end if self.is_overriden('test_dataloader', model): if not self.is_overriden('test_step', model): - m = ('You have passed in a test_dataloader() but have not defined ' - 'test_step()') - raise MisconfigurationException(m) + raise MisconfigurationException('You have passed in a `test_dataloader()`' + ' but have not defined `test_step()`.') else: if not self.is_overriden('test_epoch_end', model): - m = ('You have defined a test_dataloader() and have defined ' - 'a test_step(), you may also want to define ' - 'test_epoch_end() for accumulating stats') - warnings.warn(m) + warnings.warn('You have defined a `test_dataloader()` and' + ' have defined a `test_step()`, you may also want to' + ' define `test_epoch_end()` for accumulating stats.', + RuntimeWarning) else: if self.is_overriden('test_step', model): - m = ('You have defined test_step, but have not passed in ' - 'a test_dataloader().') - raise MisconfigurationException(m) + raise MisconfigurationException('You have defined `test_step()`,' + ' but have not passed in a `test_dataloader()`.') class _PatchDataLoader(object): From 7c1eee1a6069a99a9a3be4fe144a61ce5fa96772 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Sat, 28 Mar 2020 16:50:04 +0100 Subject: [PATCH 10/14] added tests --- tests/base/__init__.py | 3 +- tests/base/mixins.py | 7 ++ tests/trainer/test_trainer.py | 126 +++++++++++++++++++++++++++++++++- 3 files changed, 134 insertions(+), 2 deletions(-) diff --git a/tests/base/__init__.py b/tests/base/__init__.py index 1e68469871d25..faf36b8836e6d 100644 --- a/tests/base/__init__.py +++ b/tests/base/__init__.py @@ -18,6 +18,7 @@ LightValStepFitSingleDataloaderMixin, LightValStepFitMultipleDataloadersMixin, LightTrainDataloader, + LightValidationDataloader, LightTestDataloader, LightInfTrainDataloader, LightInfValDataloader, @@ -25,7 +26,7 @@ LightTestOptimizerWithSchedulingMixin, LightTestMultipleOptimizersWithSchedulingMixin, LightTestOptimizersWithMixedSchedulingMixin, - LightTestReduceLROnPlateauMixin + LightTestReduceLROnPlateauMixin, ) diff --git a/tests/base/mixins.py b/tests/base/mixins.py index 7a00e539ff926..d1c3331b68a76 100644 --- a/tests/base/mixins.py +++ b/tests/base/mixins.py @@ -206,6 +206,13 @@ def train_dataloader(self): return self._dataloader(train=True) +class LightValidationDataloader: + """Simple validation dataloader.""" + + def val_dataloader(self): + return self._dataloader(train=False) + + class LightTestDataloader: """Simple test dataloader.""" diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 3e9c0bf480d40..5a8df68325d3b 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -7,7 +7,7 @@ import torch import tests.base.utils as tutils -from pytorch_lightning import Trainer +from pytorch_lightning import Trainer, LightningModule from pytorch_lightning.callbacks import ( EarlyStopping, ModelCheckpoint, @@ -19,11 +19,15 @@ TestModelBase, DictHparamsModel, LightningTestModel, + LightValidationDataloader, + LightTestDataloader, LightEmptyTestStep, LightValidationStepMixin, LightValidationMultipleDataloadersMixin, + LightValStepFitSingleDataloaderMixin, LightTrainDataloader, LightTestStepMixin, + LightTestFitMultipleTestDataloadersMixin, ) @@ -624,3 +628,123 @@ def test_epoch_end(self, outputs): model = LightningTestModel(hparams) Trainer().test(model) + + +def test_error_on_no_train_step(tmpdir): + """ Test that an error is thrown when no `training_step()` is defined """ + tutils.reset_seed() + + class CurrentTestModel(LightningModule): + def forward(self, x): + pass + + trainer_options = dict(default_save_path=tmpdir, max_epochs=1) + trainer = Trainer(**trainer_options) + + with pytest.raises(MisconfigurationException) as error: + model = CurrentTestModel() + trainer.fit(model) + + +def test_error_on_no_train_dataloader(tmpdir): + """ Test that an error is thrown when no `training_dataloader()` is defined """ + tutils.reset_seed() + hparams = tutils.get_default_hparams() + + class CurrentTestModel(TestModelBase): + pass + + trainer_options = dict(default_save_path=tmpdir, max_epochs=1) + trainer = Trainer(**trainer_options) + + with pytest.raises(MisconfigurationException): + model = CurrentTestModel(hparams) + trainer.fit(model) + + +def test_warning_on_wrong_validation_settings(tmpdir): + """ Test the following cases related to validation configuration of model: + * error if `val_dataloader()` is overriden but `validation_step()` is not + * if both `val_dataloader()` and `validation_step()` is overriden, + throw warning if `val_epoch_end()` is not defined + * error if `validation_step()` is overriden but `val_dataloader()` is not + """ + tutils.reset_seed() + hparams = tutils.get_default_hparams() + + trainer_options = dict(default_save_path=tmpdir, max_epochs=1) + trainer = Trainer(**trainer_options) + + class CurrentTestModel(LightTrainDataloader, + LightValidationDataloader, + TestModelBase): + pass + + # check val_dataloader -> val_step + with pytest.raises(MisconfigurationException): + model = CurrentTestModel(hparams) + trainer.fit(model) + + class CurrentTestModel(LightTrainDataloader, + LightValidationStepMixin, + TestModelBase): + pass + + # check val_dataloader + val_step -> val_epoch_end + with pytest.warns(RuntimeWarning): + model = CurrentTestModel(hparams) + trainer.fit(model) + + class CurrentTestModel(LightTrainDataloader, + LightValStepFitSingleDataloaderMixin, + TestModelBase): + pass + + # check val_step -> val_dataloader + with pytest.raises(MisconfigurationException): + model = CurrentTestModel(hparams) + trainer.fit(model) + + +def test_warning_on_wrong_test_settigs(tmpdir): + """ Test the following cases related to test configuration of model: + * error if `test_dataloader()` is overriden but `test_step()` is not + * if both `test_dataloader()` and `test_step()` is overriden, + throw warning if `test_epoch_end()` is not defined + * error if `test_step()` is overriden but `test_dataloader()` is not + """ + tutils.reset_seed() + hparams = tutils.get_default_hparams() + + trainer_options = dict(default_save_path=tmpdir, max_epochs=1) + trainer = Trainer(**trainer_options) + + class CurrentTestModel(LightTrainDataloader, + LightTestDataloader, + TestModelBase): + pass + + # check test_dataloader -> test_step + with pytest.raises(MisconfigurationException): + model = CurrentTestModel(hparams) + trainer.fit(model) + + class CurrentTestModel(LightTrainDataloader, + LightTestStepMixin, + TestModelBase): + pass + + # check test_dataloader + test_step -> test_epoch_end + with pytest.warns(RuntimeWarning): + model = CurrentTestModel(hparams) + trainer.fit(model) + + class CurrentTestModel(LightTrainDataloader, + LightTestFitMultipleTestDataloadersMixin, + TestModelBase): + pass + + # check test_step -> test_dataloader + with pytest.raises(MisconfigurationException): + model = CurrentTestModel(hparams) + trainer.fit(model) From 2fd5abc582c61cafd805d2fb7c7ef211f185e410 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Sun, 29 Mar 2020 15:07:07 +0200 Subject: [PATCH 11/14] fixed CHANGELOG --- CHANGELOG.md | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b8cd899e113e3..9e1c9897031f0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,13 +11,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for hierarchical `dict` ([#1152](https://github.com/PyTorchLightning/pytorch-lightning/pull/1152)) - Added `TrainsLogger` class ([#1122](https://github.com/PyTorchLightning/pytorch-lightning/pull/1122)) - Added type hints to `pytorch_lightning.core` ([#946](https://github.com/PyTorchLightning/pytorch-lightning/pull/946)) -- Added support for IterableDataset in validation and testing ([#1104](https://github.com/PyTorchLightning/pytorch-lightning/pull/1104)) -- Added support for non-primitive types in hparams for TensorboardLogger ([#1130](https://github.com/PyTorchLightning/pytorch-lightning/pull/1130)) -- Added a check that stops the training when loss or weights contain NaN or inf values. ([#1097](https://github.com/PyTorchLightning/pytorch-lightning/pull/1097)) -- Added model configuration checking ([#1199](https://github.com/PyTorchLightning/pytorch-lightning/pull/1199)) - Added support for `IterableDataset` in validation and testing ([#1104](https://github.com/PyTorchLightning/pytorch-lightning/pull/1104)) - Added support for non-primitive types in `hparams` for `TensorboardLogger` ([#1130](https://github.com/PyTorchLightning/pytorch-lightning/pull/1130)) - Added a check that stops the training when loss or weights contain `NaN` or `inf` values. ([#1097](https://github.com/PyTorchLightning/pytorch-lightning/pull/1097)) +- Added model configuration checking ([#1199](https://github.com/PyTorchLightning/pytorch-lightning/pull/1199)) ### Changed From 6653a54b898b415ff7589161a75b0c1426fdbc31 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Sun, 29 Mar 2020 21:55:06 +0200 Subject: [PATCH 12/14] Apply suggestions from code review --- pytorch_lightning/trainer/trainer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 442b28eb1d313..299c63c2f33fd 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -962,14 +962,14 @@ def check_model_configuration(self, model: LightningModule): """ # Check training_step, train_dataloader, configure_optimizer methods if not self.is_overriden('training_step', model): - raise MisconfigurationException('No training_step() method defined.' - ' Lightning expects as minimum a `training_step()` and' - ' `training_dataloader()` to be defined.') + raise MisconfigurationException( + 'No `training_step()` method defined. Lightning expects as minimum a `training_step()` and' + ' `training_dataloader()` to be defined.') if not self.is_overriden('train_dataloader', model): - raise MisconfigurationException('No train_dataloader() defined.' - ' Lightning expects as minimum a `training_step()` and' - ' `training_dataloader()` to be defined.') + raise MisconfigurationException( + 'No `train_dataloader()` defined. Lightning expects as minimum a `training_step()` and' + ' `training_dataloader()` to be defined.') if not self.is_overriden('configure_optimizers', model): warnings.warn('`configure_optimizers()` method not defined by user,' From d9a22e9ddf3f8e345fd1534c0ae7552cb1537a6d Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 30 Mar 2020 13:51:21 +0200 Subject: [PATCH 13/14] move test to new module --- .../lightning_module_template.py | 13 +- tests/trainer/test_checks.py | 135 ++++++++++++++++++ tests/trainer/test_trainer.py | 126 +--------------- 3 files changed, 142 insertions(+), 132 deletions(-) create mode 100755 tests/trainer/test_checks.py diff --git a/pl_examples/basic_examples/lightning_module_template.py b/pl_examples/basic_examples/lightning_module_template.py index 8cfa7cd5f2153..13eff7ff4b333 100644 --- a/pl_examples/basic_examples/lightning_module_template.py +++ b/pl_examples/basic_examples/lightning_module_template.py @@ -262,13 +262,12 @@ def test_epoch_end(self, outputs): """ results = self.validation_step_end(outputs) - # Rename output - tqdm_dict = results['progress_bar'] - tqdm_dict['test_loss'] = tqdm_dict.pop('val_loss') - tqdm_dict['test_acc'] = tqdm_dict.pop('val_acc') - - results['progress_bar'] = tqdm_dict - results['log'] = tqdm_dict + # rename some keys + results['progress_bar'].update({ + 'test_loss': results['progress_bar'].pop('val_loss'), + 'test_acc': results['progress_bar'].pop('val_acc'), + }) + results['log'] = results['progress_bar'] results['test_loss'] = results.pop('val_loss') return results diff --git a/tests/trainer/test_checks.py b/tests/trainer/test_checks.py new file mode 100755 index 0000000000000..e9d8dfb0355fa --- /dev/null +++ b/tests/trainer/test_checks.py @@ -0,0 +1,135 @@ +import pytest + +import tests.base.utils as tutils +from pytorch_lightning import Trainer, LightningModule +from pytorch_lightning.utilities.debugging import MisconfigurationException +from tests.base import ( + TestModelBase, + LightValidationDataloader, + LightTestDataloader, + LightValidationStepMixin, + LightValStepFitSingleDataloaderMixin, + LightTrainDataloader, + LightTestStepMixin, + LightTestFitMultipleTestDataloadersMixin, +) + + +def test_error_on_no_train_step(tmpdir): + """ Test that an error is thrown when no `training_step()` is defined """ + tutils.reset_seed() + + class CurrentTestModel(LightningModule): + def forward(self, x): + pass + + trainer_options = dict(default_save_path=tmpdir, max_epochs=1) + trainer = Trainer(**trainer_options) + + with pytest.raises(MisconfigurationException): + model = CurrentTestModel() + trainer.fit(model) + + +def test_error_on_no_train_dataloader(tmpdir): + """ Test that an error is thrown when no `training_dataloader()` is defined """ + tutils.reset_seed() + hparams = tutils.get_default_hparams() + + class CurrentTestModel(TestModelBase): + pass + + trainer_options = dict(default_save_path=tmpdir, max_epochs=1) + trainer = Trainer(**trainer_options) + + with pytest.raises(MisconfigurationException): + model = CurrentTestModel(hparams) + trainer.fit(model) + + +def test_warning_on_wrong_validation_settings(tmpdir): + """ Test the following cases related to validation configuration of model: + * error if `val_dataloader()` is overriden but `validation_step()` is not + * if both `val_dataloader()` and `validation_step()` is overriden, + throw warning if `val_epoch_end()` is not defined + * error if `validation_step()` is overriden but `val_dataloader()` is not + """ + tutils.reset_seed() + hparams = tutils.get_default_hparams() + + trainer_options = dict(default_save_path=tmpdir, max_epochs=1) + trainer = Trainer(**trainer_options) + + class CurrentTestModel(LightTrainDataloader, + LightValidationDataloader, + TestModelBase): + pass + + # check val_dataloader -> val_step + with pytest.raises(MisconfigurationException): + model = CurrentTestModel(hparams) + trainer.fit(model) + + class CurrentTestModel(LightTrainDataloader, + LightValidationStepMixin, + TestModelBase): + pass + + # check val_dataloader + val_step -> val_epoch_end + with pytest.warns(RuntimeWarning): + model = CurrentTestModel(hparams) + trainer.fit(model) + + class CurrentTestModel(LightTrainDataloader, + LightValStepFitSingleDataloaderMixin, + TestModelBase): + pass + + # check val_step -> val_dataloader + with pytest.raises(MisconfigurationException): + model = CurrentTestModel(hparams) + trainer.fit(model) + + +def test_warning_on_wrong_test_settigs(tmpdir): + """ Test the following cases related to test configuration of model: + * error if `test_dataloader()` is overriden but `test_step()` is not + * if both `test_dataloader()` and `test_step()` is overriden, + throw warning if `test_epoch_end()` is not defined + * error if `test_step()` is overriden but `test_dataloader()` is not + """ + tutils.reset_seed() + hparams = tutils.get_default_hparams() + + trainer_options = dict(default_save_path=tmpdir, max_epochs=1) + trainer = Trainer(**trainer_options) + + class CurrentTestModel(LightTrainDataloader, + LightTestDataloader, + TestModelBase): + pass + + # check test_dataloader -> test_step + with pytest.raises(MisconfigurationException): + model = CurrentTestModel(hparams) + trainer.fit(model) + + class CurrentTestModel(LightTrainDataloader, + LightTestStepMixin, + TestModelBase): + pass + + # check test_dataloader + test_step -> test_epoch_end + with pytest.warns(RuntimeWarning): + model = CurrentTestModel(hparams) + trainer.fit(model) + + class CurrentTestModel(LightTrainDataloader, + LightTestFitMultipleTestDataloadersMixin, + TestModelBase): + pass + + # check test_step -> test_dataloader + with pytest.raises(MisconfigurationException): + model = CurrentTestModel(hparams) + trainer.fit(model) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 5a8df68325d3b..3e9c0bf480d40 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -7,7 +7,7 @@ import torch import tests.base.utils as tutils -from pytorch_lightning import Trainer, LightningModule +from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ( EarlyStopping, ModelCheckpoint, @@ -19,15 +19,11 @@ TestModelBase, DictHparamsModel, LightningTestModel, - LightValidationDataloader, - LightTestDataloader, LightEmptyTestStep, LightValidationStepMixin, LightValidationMultipleDataloadersMixin, - LightValStepFitSingleDataloaderMixin, LightTrainDataloader, LightTestStepMixin, - LightTestFitMultipleTestDataloadersMixin, ) @@ -628,123 +624,3 @@ def test_epoch_end(self, outputs): model = LightningTestModel(hparams) Trainer().test(model) - - -def test_error_on_no_train_step(tmpdir): - """ Test that an error is thrown when no `training_step()` is defined """ - tutils.reset_seed() - - class CurrentTestModel(LightningModule): - def forward(self, x): - pass - - trainer_options = dict(default_save_path=tmpdir, max_epochs=1) - trainer = Trainer(**trainer_options) - - with pytest.raises(MisconfigurationException) as error: - model = CurrentTestModel() - trainer.fit(model) - - -def test_error_on_no_train_dataloader(tmpdir): - """ Test that an error is thrown when no `training_dataloader()` is defined """ - tutils.reset_seed() - hparams = tutils.get_default_hparams() - - class CurrentTestModel(TestModelBase): - pass - - trainer_options = dict(default_save_path=tmpdir, max_epochs=1) - trainer = Trainer(**trainer_options) - - with pytest.raises(MisconfigurationException): - model = CurrentTestModel(hparams) - trainer.fit(model) - - -def test_warning_on_wrong_validation_settings(tmpdir): - """ Test the following cases related to validation configuration of model: - * error if `val_dataloader()` is overriden but `validation_step()` is not - * if both `val_dataloader()` and `validation_step()` is overriden, - throw warning if `val_epoch_end()` is not defined - * error if `validation_step()` is overriden but `val_dataloader()` is not - """ - tutils.reset_seed() - hparams = tutils.get_default_hparams() - - trainer_options = dict(default_save_path=tmpdir, max_epochs=1) - trainer = Trainer(**trainer_options) - - class CurrentTestModel(LightTrainDataloader, - LightValidationDataloader, - TestModelBase): - pass - - # check val_dataloader -> val_step - with pytest.raises(MisconfigurationException): - model = CurrentTestModel(hparams) - trainer.fit(model) - - class CurrentTestModel(LightTrainDataloader, - LightValidationStepMixin, - TestModelBase): - pass - - # check val_dataloader + val_step -> val_epoch_end - with pytest.warns(RuntimeWarning): - model = CurrentTestModel(hparams) - trainer.fit(model) - - class CurrentTestModel(LightTrainDataloader, - LightValStepFitSingleDataloaderMixin, - TestModelBase): - pass - - # check val_step -> val_dataloader - with pytest.raises(MisconfigurationException): - model = CurrentTestModel(hparams) - trainer.fit(model) - - -def test_warning_on_wrong_test_settigs(tmpdir): - """ Test the following cases related to test configuration of model: - * error if `test_dataloader()` is overriden but `test_step()` is not - * if both `test_dataloader()` and `test_step()` is overriden, - throw warning if `test_epoch_end()` is not defined - * error if `test_step()` is overriden but `test_dataloader()` is not - """ - tutils.reset_seed() - hparams = tutils.get_default_hparams() - - trainer_options = dict(default_save_path=tmpdir, max_epochs=1) - trainer = Trainer(**trainer_options) - - class CurrentTestModel(LightTrainDataloader, - LightTestDataloader, - TestModelBase): - pass - - # check test_dataloader -> test_step - with pytest.raises(MisconfigurationException): - model = CurrentTestModel(hparams) - trainer.fit(model) - - class CurrentTestModel(LightTrainDataloader, - LightTestStepMixin, - TestModelBase): - pass - - # check test_dataloader + test_step -> test_epoch_end - with pytest.warns(RuntimeWarning): - model = CurrentTestModel(hparams) - trainer.fit(model) - - class CurrentTestModel(LightTrainDataloader, - LightTestFitMultipleTestDataloadersMixin, - TestModelBase): - pass - - # check test_step -> test_dataloader - with pytest.raises(MisconfigurationException): - model = CurrentTestModel(hparams) - trainer.fit(model) From 8b6b0cb5d7bf67f89062dfa1ec0fcde516a0d1fb Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 1 Apr 2020 14:41:48 +0200 Subject: [PATCH 14/14] change check on configure_optimizers --- pytorch_lightning/trainer/trainer.py | 14 +++++++------- tests/trainer/test_checks.py | 21 ++++++++++++++++++++- 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 73472445de3f8..451874485e8e7 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -988,18 +988,18 @@ def check_model_configuration(self, model: LightningModule): # Check training_step, train_dataloader, configure_optimizer methods if not self.is_overriden('training_step', model): raise MisconfigurationException( - 'No `training_step()` method defined. Lightning expects as minimum a `training_step()` and' - ' `training_dataloader()` to be defined.') + 'No `training_step()` method defined. Lightning `Trainer` expects as minimum a' + ' `training_step()`, `training_dataloader()` and `configure_optimizers()` to be defined.') if not self.is_overriden('train_dataloader', model): raise MisconfigurationException( - 'No `train_dataloader()` defined. Lightning expects as minimum a `training_step()` and' - ' `training_dataloader()` to be defined.') + 'No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a' + ' `training_step()`, `training_dataloader()` and `configure_optimizers()` to be defined.') if not self.is_overriden('configure_optimizers', model): - warnings.warn('`configure_optimizers()` method not defined by user,' - ' will default to Adam optimizer with learning rate set to 0.0001.', - RuntimeWarning) + raise MisconfigurationException( + 'No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a' + ' `training_step()`, `training_dataloader()` and `configure_optimizers()` to be defined.') # Check val_dataloader, validation_step and validation_epoch_end if self.is_overriden('val_dataloader', model): diff --git a/tests/trainer/test_checks.py b/tests/trainer/test_checks.py index e9d8dfb0355fa..1dc9819de5e0d 100755 --- a/tests/trainer/test_checks.py +++ b/tests/trainer/test_checks.py @@ -2,7 +2,7 @@ import tests.base.utils as tutils from pytorch_lightning import Trainer, LightningModule -from pytorch_lightning.utilities.debugging import MisconfigurationException +from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import ( TestModelBase, LightValidationDataloader, @@ -47,6 +47,25 @@ class CurrentTestModel(TestModelBase): trainer.fit(model) +def test_error_on_no_configure_optimizers(tmpdir): + """ Test that an error is thrown when no `configure_optimizers()` is defined """ + tutils.reset_seed() + + class CurrentTestModel(LightTrainDataloader, LightningModule): + def forward(self, x): + pass + + def training_step(self, batch, batch_idx, optimizer_idx=None): + pass + + trainer_options = dict(default_save_path=tmpdir, max_epochs=1) + trainer = Trainer(**trainer_options) + + with pytest.raises(MisconfigurationException): + model = CurrentTestModel() + trainer.fit(model) + + def test_warning_on_wrong_validation_settings(tmpdir): """ Test the following cases related to validation configuration of model: * error if `val_dataloader()` is overriden but `validation_step()` is not