diff --git a/CHANGELOG.md b/CHANGELOG.md index 4d8d2a96ca10d..b932c7bece594 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added informative errors if user defined dataloader has zero length ([#1280](https://github.com/PyTorchLightning/pytorch-lightning/pull/1280)) - Allow to upload models on W&B ([#1339](https://github.com/PyTorchLightning/pytorch-lightning/pull/1339)) - Added support for `IterableDataset` when `val_check_interval=1.0` (default), this will trigger validation at the end of each epoch. ([#1283](https://github.com/PyTorchLightning/pytorch-lightning/pull/1283)) +- Added model configuration checking ([#1199](https://github.com/PyTorchLightning/pytorch-lightning/pull/1199)) ### Changed diff --git a/pl_examples/basic_examples/lightning_module_template.py b/pl_examples/basic_examples/lightning_module_template.py index 9c65bac9ac7e0..b6444cee7fb6e 100644 --- a/pl_examples/basic_examples/lightning_module_template.py +++ b/pl_examples/basic_examples/lightning_module_template.py @@ -232,6 +232,37 @@ def test_dataloader(self): log.info('Test data loader called.') return self.__dataloader(train=False) + def test_step(self, batch, batch_idx): + """ + Lightning calls this during testing, similar to val_step + :param batch: + :return:val + """ + output = self.validation_step(batch, batch_idx) + # Rename output keys + output['test_loss'] = output.pop('val_loss') + output['test_acc'] = output.pop('val_acc') + + return output + + def test_epoch_end(self, outputs): + """ + Called at the end of test to aggregate outputs, similar to validation_epoch_end + :param outputs: list of individual outputs of each validation step + :return: + """ + results = self.validation_step_end(outputs) + + # rename some keys + results['progress_bar'].update({ + 'test_loss': results['progress_bar'].pop('val_loss'), + 'test_acc': results['progress_bar'].pop('val_acc'), + }) + results['log'] = results['progress_bar'] + results['test_loss'] = results.pop('val_loss') + + return results + @staticmethod def add_model_specific_args(parent_parser, root_dir): # pragma: no-cover """ diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 7dbcfac467b9e..9ba9c177bd7c6 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -653,6 +653,9 @@ def fit( # set up the passed in dataloaders (if needed) self.__attach_dataloaders(model, train_dataloader, val_dataloaders, test_dataloaders) + # check that model is configured correctly + self.check_model_configuration(model) + # download the data and do whatever transforms we need # do before any spawn calls so that the model can assign properties # only on proc 0 because no spawn has happened yet @@ -737,24 +740,12 @@ def __attach_dataloaders(self, model, train_dataloader, val_dataloaders, test_da # when dataloader is passed via fit, patch the train_dataloader # functions to overwrite with these implementations if train_dataloader is not None: - if not self.is_overriden('training_step', model): - raise MisconfigurationException( - 'You called `.fit()` with a `train_dataloader` but did not define `training_step()`') - model.train_dataloader = _PatchDataLoader(train_dataloader) if val_dataloaders is not None: - if not self.is_overriden('validation_step', model): - raise MisconfigurationException( - 'You called `.fit()` with a `val_dataloaders` but did not define `validation_step()`') - model.val_dataloader = _PatchDataLoader(val_dataloaders) if test_dataloaders is not None: - if not self.is_overriden('test_step', model): - raise MisconfigurationException( - 'You called `.fit()` with a `test_dataloaders` but did not define `test_step()`') - model.test_dataloader = _PatchDataLoader(test_dataloaders) def run_pretrain_routine(self, model: LightningModule): @@ -903,6 +894,62 @@ def test(self, model: Optional[LightningModule] = None): self.testing = False + def check_model_configuration(self, model: LightningModule): + r""" + Checks that the model is configured correctly before training is started. + + Args: + model: The model to test. + + """ + # Check training_step, train_dataloader, configure_optimizer methods + if not self.is_overriden('training_step', model): + raise MisconfigurationException( + 'No `training_step()` method defined. Lightning `Trainer` expects as minimum a' + ' `training_step()`, `training_dataloader()` and `configure_optimizers()` to be defined.') + + if not self.is_overriden('train_dataloader', model): + raise MisconfigurationException( + 'No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a' + ' `training_step()`, `training_dataloader()` and `configure_optimizers()` to be defined.') + + if not self.is_overriden('configure_optimizers', model): + raise MisconfigurationException( + 'No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a' + ' `training_step()`, `training_dataloader()` and `configure_optimizers()` to be defined.') + + # Check val_dataloader, validation_step and validation_epoch_end + if self.is_overriden('val_dataloader', model): + if not self.is_overriden('validation_step', model): + raise MisconfigurationException('You have passed in a `val_dataloader()`' + ' but have not defined `validation_step()`.') + else: + if not self.is_overriden('validation_epoch_end', model): + warnings.warn('You have defined a `val_dataloader()` and have' + ' defined a `validation_step()`, you may also want to' + ' define `validation_epoch_end()` for accumulating stats.', + RuntimeWarning) + else: + if self.is_overriden('validation_step', model): + raise MisconfigurationException('You have defined `validation_step()`,' + ' but have not passed in a val_dataloader().') + + # Check test_dataloader, test_step and test_epoch_end + if self.is_overriden('test_dataloader', model): + if not self.is_overriden('test_step', model): + raise MisconfigurationException('You have passed in a `test_dataloader()`' + ' but have not defined `test_step()`.') + else: + if not self.is_overriden('test_epoch_end', model): + warnings.warn('You have defined a `test_dataloader()` and' + ' have defined a `test_step()`, you may also want to' + ' define `test_epoch_end()` for accumulating stats.', + RuntimeWarning) + else: + if self.is_overriden('test_step', model): + raise MisconfigurationException('You have defined `test_step()`,' + ' but have not passed in a `test_dataloader()`.') + class _PatchDataLoader(object): r""" @@ -917,5 +964,8 @@ class _PatchDataLoader(object): def __init__(self, dataloader: Union[List[DataLoader], DataLoader]): self.dataloader = dataloader + # Assign __code__, needed for checking if method has been overriden + self.__code__ = self.__call__.__code__ + def __call__(self) -> Union[List[DataLoader], DataLoader]: return self.dataloader diff --git a/tests/base/__init__.py b/tests/base/__init__.py index a5728c0f77d85..638663442f834 100644 --- a/tests/base/__init__.py +++ b/tests/base/__init__.py @@ -18,6 +18,7 @@ LightValStepFitSingleDataloaderMixin, LightValStepFitMultipleDataloadersMixin, LightTrainDataloader, + LightValidationDataloader, LightTestDataloader, LightInfTrainDataloader, LightInfValDataloader, diff --git a/tests/base/mixins.py b/tests/base/mixins.py index 02a9c16cfaa6b..afffd3768c41c 100644 --- a/tests/base/mixins.py +++ b/tests/base/mixins.py @@ -203,6 +203,13 @@ def train_dataloader(self): return self._dataloader(train=True) +class LightValidationDataloader: + """Simple validation dataloader.""" + + def val_dataloader(self): + return self._dataloader(train=False) + + class LightTestDataloader: """Simple test dataloader.""" @@ -412,6 +419,9 @@ def test_step(self, batch, batch_idx, dataloader_idx, **kwargs): class LightTestFitSingleTestDataloadersMixin: """Test fit single test dataloaders mixin.""" + def test_dataloader(self): + return self._dataloader(train=False) + def test_step(self, batch, batch_idx, *args, **kwargs): """ Lightning calls this inside the validation loop diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py index 378c0f915b78d..9bde9437bda06 100644 --- a/tests/test_deprecated.py +++ b/tests/test_deprecated.py @@ -66,9 +66,15 @@ class ModelVer0_6(LightTrainDataloader, LightEmptyTestStep, TestModelBase): def val_dataloader(self): return self._dataloader(train=False) + def validation_step(self, batch, batch_idx, *args, **kwargs): + return {'val_loss': 0.6} + def validation_end(self, outputs): return {'val_loss': 0.6} + def test_dataloader(self): + return self._dataloader(train=False) + def test_end(self, outputs): return {'test_loss': 0.6} @@ -79,9 +85,15 @@ class ModelVer0_7(LightTrainDataloader, LightEmptyTestStep, TestModelBase): def val_dataloader(self): return self._dataloader(train=False) + def validation_step(self, batch, batch_idx, *args, **kwargs): + return {'val_loss': 0.7} + def validation_end(self, outputs): return {'val_loss': 0.7} + def test_dataloader(self): + return self._dataloader(train=False) + def test_end(self, outputs): return {'test_loss': 0.7} diff --git a/tests/trainer/test_checks.py b/tests/trainer/test_checks.py new file mode 100755 index 0000000000000..1dc9819de5e0d --- /dev/null +++ b/tests/trainer/test_checks.py @@ -0,0 +1,154 @@ +import pytest + +import tests.base.utils as tutils +from pytorch_lightning import Trainer, LightningModule +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.base import ( + TestModelBase, + LightValidationDataloader, + LightTestDataloader, + LightValidationStepMixin, + LightValStepFitSingleDataloaderMixin, + LightTrainDataloader, + LightTestStepMixin, + LightTestFitMultipleTestDataloadersMixin, +) + + +def test_error_on_no_train_step(tmpdir): + """ Test that an error is thrown when no `training_step()` is defined """ + tutils.reset_seed() + + class CurrentTestModel(LightningModule): + def forward(self, x): + pass + + trainer_options = dict(default_save_path=tmpdir, max_epochs=1) + trainer = Trainer(**trainer_options) + + with pytest.raises(MisconfigurationException): + model = CurrentTestModel() + trainer.fit(model) + + +def test_error_on_no_train_dataloader(tmpdir): + """ Test that an error is thrown when no `training_dataloader()` is defined """ + tutils.reset_seed() + hparams = tutils.get_default_hparams() + + class CurrentTestModel(TestModelBase): + pass + + trainer_options = dict(default_save_path=tmpdir, max_epochs=1) + trainer = Trainer(**trainer_options) + + with pytest.raises(MisconfigurationException): + model = CurrentTestModel(hparams) + trainer.fit(model) + + +def test_error_on_no_configure_optimizers(tmpdir): + """ Test that an error is thrown when no `configure_optimizers()` is defined """ + tutils.reset_seed() + + class CurrentTestModel(LightTrainDataloader, LightningModule): + def forward(self, x): + pass + + def training_step(self, batch, batch_idx, optimizer_idx=None): + pass + + trainer_options = dict(default_save_path=tmpdir, max_epochs=1) + trainer = Trainer(**trainer_options) + + with pytest.raises(MisconfigurationException): + model = CurrentTestModel() + trainer.fit(model) + + +def test_warning_on_wrong_validation_settings(tmpdir): + """ Test the following cases related to validation configuration of model: + * error if `val_dataloader()` is overriden but `validation_step()` is not + * if both `val_dataloader()` and `validation_step()` is overriden, + throw warning if `val_epoch_end()` is not defined + * error if `validation_step()` is overriden but `val_dataloader()` is not + """ + tutils.reset_seed() + hparams = tutils.get_default_hparams() + + trainer_options = dict(default_save_path=tmpdir, max_epochs=1) + trainer = Trainer(**trainer_options) + + class CurrentTestModel(LightTrainDataloader, + LightValidationDataloader, + TestModelBase): + pass + + # check val_dataloader -> val_step + with pytest.raises(MisconfigurationException): + model = CurrentTestModel(hparams) + trainer.fit(model) + + class CurrentTestModel(LightTrainDataloader, + LightValidationStepMixin, + TestModelBase): + pass + + # check val_dataloader + val_step -> val_epoch_end + with pytest.warns(RuntimeWarning): + model = CurrentTestModel(hparams) + trainer.fit(model) + + class CurrentTestModel(LightTrainDataloader, + LightValStepFitSingleDataloaderMixin, + TestModelBase): + pass + + # check val_step -> val_dataloader + with pytest.raises(MisconfigurationException): + model = CurrentTestModel(hparams) + trainer.fit(model) + + +def test_warning_on_wrong_test_settigs(tmpdir): + """ Test the following cases related to test configuration of model: + * error if `test_dataloader()` is overriden but `test_step()` is not + * if both `test_dataloader()` and `test_step()` is overriden, + throw warning if `test_epoch_end()` is not defined + * error if `test_step()` is overriden but `test_dataloader()` is not + """ + tutils.reset_seed() + hparams = tutils.get_default_hparams() + + trainer_options = dict(default_save_path=tmpdir, max_epochs=1) + trainer = Trainer(**trainer_options) + + class CurrentTestModel(LightTrainDataloader, + LightTestDataloader, + TestModelBase): + pass + + # check test_dataloader -> test_step + with pytest.raises(MisconfigurationException): + model = CurrentTestModel(hparams) + trainer.fit(model) + + class CurrentTestModel(LightTrainDataloader, + LightTestStepMixin, + TestModelBase): + pass + + # check test_dataloader + test_step -> test_epoch_end + with pytest.warns(RuntimeWarning): + model = CurrentTestModel(hparams) + trainer.fit(model) + + class CurrentTestModel(LightTrainDataloader, + LightTestFitMultipleTestDataloadersMixin, + TestModelBase): + pass + + # check test_step -> test_dataloader + with pytest.raises(MisconfigurationException): + model = CurrentTestModel(hparams) + trainer.fit(model)