Skip to content

Commit

Permalink
Add useful errors when model is not configured correctly (#1199)
Browse files Browse the repository at this point in the history
* add check_model_configuration method

* trying to fix errors

* trying to fix tests

* added test_epoch_end to lightning template

* fix tests

* fix new test after rebase

* fix spelling

* added more checks

* updated formating

* added tests

* fixed CHANGELOG

* Apply suggestions from code review

* move test to new module

* change check on configure_optimizers

Co-authored-by: Nicki Skafte <nugginea@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
3 people committed Apr 2, 2020
1 parent ddb5913 commit 2912239
Show file tree
Hide file tree
Showing 7 changed files with 271 additions and 12 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `summary` method to Profilers. ([#1259](https://github.com/PyTorchLightning/pytorch-lightning/pull/1259))
- Added informative errors if user defined dataloader has zero length ([#1280](https://github.com/PyTorchLightning/pytorch-lightning/pull/1280))
- Allow to upload models on W&B ([#1339](https://github.com/PyTorchLightning/pytorch-lightning/pull/1339))
- Added model configuration checking ([#1199](https://github.com/PyTorchLightning/pytorch-lightning/pull/1199))

### Changed

Expand Down
31 changes: 31 additions & 0 deletions pl_examples/basic_examples/lightning_module_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,37 @@ def test_dataloader(self):
log.info('Test data loader called.')
return self.__dataloader(train=False)

def test_step(self, batch, batch_idx):
"""
Lightning calls this during testing, similar to val_step
:param batch:
:return:val
"""
output = self.validation_step(batch, batch_idx)
# Rename output keys
output['test_loss'] = output.pop('val_loss')
output['test_acc'] = output.pop('val_acc')

return output

def test_epoch_end(self, outputs):
"""
Called at the end of test to aggregate outputs, similar to validation_epoch_end
:param outputs: list of individual outputs of each validation step
:return:
"""
results = self.validation_step_end(outputs)

# rename some keys
results['progress_bar'].update({
'test_loss': results['progress_bar'].pop('val_loss'),
'test_acc': results['progress_bar'].pop('val_acc'),
})
results['log'] = results['progress_bar']
results['test_loss'] = results.pop('val_loss')

return results

@staticmethod
def add_model_specific_args(parent_parser, root_dir): # pragma: no-cover
"""
Expand Down
74 changes: 62 additions & 12 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,9 @@ def fit(
# set up the passed in dataloaders (if needed)
self.__attach_dataloaders(model, train_dataloader, val_dataloaders, test_dataloaders)

# check that model is configured correctly
self.check_model_configuration(model)

# download the data and do whatever transforms we need
# do before any spawn calls so that the model can assign properties
# only on proc 0 because no spawn has happened yet
Expand Down Expand Up @@ -736,24 +739,12 @@ def __attach_dataloaders(self, model, train_dataloader, val_dataloaders, test_da
# when dataloader is passed via fit, patch the train_dataloader
# functions to overwrite with these implementations
if train_dataloader is not None:
if not self.is_overriden('training_step', model):
raise MisconfigurationException(
'You called `.fit()` with a `train_dataloader` but did not define `training_step()`')

model.train_dataloader = _PatchDataLoader(train_dataloader)

if val_dataloaders is not None:
if not self.is_overriden('validation_step', model):
raise MisconfigurationException(
'You called `.fit()` with a `val_dataloaders` but did not define `validation_step()`')

model.val_dataloader = _PatchDataLoader(val_dataloaders)

if test_dataloaders is not None:
if not self.is_overriden('test_step', model):
raise MisconfigurationException(
'You called `.fit()` with a `test_dataloaders` but did not define `test_step()`')

model.test_dataloader = _PatchDataLoader(test_dataloaders)

def run_pretrain_routine(self, model: LightningModule):
Expand Down Expand Up @@ -902,6 +893,62 @@ def test(self, model: Optional[LightningModule] = None):

self.testing = False

def check_model_configuration(self, model: LightningModule):
r"""
Checks that the model is configured correctly before training is started.
Args:
model: The model to test.
"""
# Check training_step, train_dataloader, configure_optimizer methods
if not self.is_overriden('training_step', model):
raise MisconfigurationException(
'No `training_step()` method defined. Lightning `Trainer` expects as minimum a'
' `training_step()`, `training_dataloader()` and `configure_optimizers()` to be defined.')

if not self.is_overriden('train_dataloader', model):
raise MisconfigurationException(
'No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a'
' `training_step()`, `training_dataloader()` and `configure_optimizers()` to be defined.')

if not self.is_overriden('configure_optimizers', model):
raise MisconfigurationException(
'No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a'
' `training_step()`, `training_dataloader()` and `configure_optimizers()` to be defined.')

# Check val_dataloader, validation_step and validation_epoch_end
if self.is_overriden('val_dataloader', model):
if not self.is_overriden('validation_step', model):
raise MisconfigurationException('You have passed in a `val_dataloader()`'
' but have not defined `validation_step()`.')
else:
if not self.is_overriden('validation_epoch_end', model):
warnings.warn('You have defined a `val_dataloader()` and have'
' defined a `validation_step()`, you may also want to'
' define `validation_epoch_end()` for accumulating stats.',
RuntimeWarning)
else:
if self.is_overriden('validation_step', model):
raise MisconfigurationException('You have defined `validation_step()`,'
' but have not passed in a val_dataloader().')

# Check test_dataloader, test_step and test_epoch_end
if self.is_overriden('test_dataloader', model):
if not self.is_overriden('test_step', model):
raise MisconfigurationException('You have passed in a `test_dataloader()`'
' but have not defined `test_step()`.')
else:
if not self.is_overriden('test_epoch_end', model):
warnings.warn('You have defined a `test_dataloader()` and'
' have defined a `test_step()`, you may also want to'
' define `test_epoch_end()` for accumulating stats.',
RuntimeWarning)
else:
if self.is_overriden('test_step', model):
raise MisconfigurationException('You have defined `test_step()`,'
' but have not passed in a `test_dataloader()`.')


class _PatchDataLoader(object):
r"""
Expand All @@ -916,5 +963,8 @@ class _PatchDataLoader(object):
def __init__(self, dataloader: Union[List[DataLoader], DataLoader]):
self.dataloader = dataloader

# Assign __code__, needed for checking if method has been overriden
self.__code__ = self.__call__.__code__

def __call__(self) -> Union[List[DataLoader], DataLoader]:
return self.dataloader
1 change: 1 addition & 0 deletions tests/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
LightValStepFitSingleDataloaderMixin,
LightValStepFitMultipleDataloadersMixin,
LightTrainDataloader,
LightValidationDataloader,
LightTestDataloader,
LightInfTrainDataloader,
LightInfValDataloader,
Expand Down
10 changes: 10 additions & 0 deletions tests/base/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,13 @@ def train_dataloader(self):
return self._dataloader(train=True)


class LightValidationDataloader:
"""Simple validation dataloader."""

def val_dataloader(self):
return self._dataloader(train=False)


class LightTestDataloader:
"""Simple test dataloader."""

Expand Down Expand Up @@ -412,6 +419,9 @@ def test_step(self, batch, batch_idx, dataloader_idx, **kwargs):
class LightTestFitSingleTestDataloadersMixin:
"""Test fit single test dataloaders mixin."""

def test_dataloader(self):
return self._dataloader(train=False)

def test_step(self, batch, batch_idx, *args, **kwargs):
"""
Lightning calls this inside the validation loop
Expand Down
12 changes: 12 additions & 0 deletions tests/test_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,15 @@ class ModelVer0_6(LightTrainDataloader, LightEmptyTestStep, TestModelBase):
def val_dataloader(self):
return self._dataloader(train=False)

def validation_step(self, batch, batch_idx, *args, **kwargs):
return {'val_loss': 0.6}

def validation_end(self, outputs):
return {'val_loss': 0.6}

def test_dataloader(self):
return self._dataloader(train=False)

def test_end(self, outputs):
return {'test_loss': 0.6}

Expand All @@ -79,9 +85,15 @@ class ModelVer0_7(LightTrainDataloader, LightEmptyTestStep, TestModelBase):
def val_dataloader(self):
return self._dataloader(train=False)

def validation_step(self, batch, batch_idx, *args, **kwargs):
return {'val_loss': 0.7}

def validation_end(self, outputs):
return {'val_loss': 0.7}

def test_dataloader(self):
return self._dataloader(train=False)

def test_end(self, outputs):
return {'test_loss': 0.7}

Expand Down
154 changes: 154 additions & 0 deletions tests/trainer/test_checks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import pytest

import tests.base.utils as tutils
from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import (
TestModelBase,
LightValidationDataloader,
LightTestDataloader,
LightValidationStepMixin,
LightValStepFitSingleDataloaderMixin,
LightTrainDataloader,
LightTestStepMixin,
LightTestFitMultipleTestDataloadersMixin,
)


def test_error_on_no_train_step(tmpdir):
""" Test that an error is thrown when no `training_step()` is defined """
tutils.reset_seed()

class CurrentTestModel(LightningModule):
def forward(self, x):
pass

trainer_options = dict(default_save_path=tmpdir, max_epochs=1)
trainer = Trainer(**trainer_options)

with pytest.raises(MisconfigurationException):
model = CurrentTestModel()
trainer.fit(model)


def test_error_on_no_train_dataloader(tmpdir):
""" Test that an error is thrown when no `training_dataloader()` is defined """
tutils.reset_seed()
hparams = tutils.get_default_hparams()

class CurrentTestModel(TestModelBase):
pass

trainer_options = dict(default_save_path=tmpdir, max_epochs=1)
trainer = Trainer(**trainer_options)

with pytest.raises(MisconfigurationException):
model = CurrentTestModel(hparams)
trainer.fit(model)


def test_error_on_no_configure_optimizers(tmpdir):
""" Test that an error is thrown when no `configure_optimizers()` is defined """
tutils.reset_seed()

class CurrentTestModel(LightTrainDataloader, LightningModule):
def forward(self, x):
pass

def training_step(self, batch, batch_idx, optimizer_idx=None):
pass

trainer_options = dict(default_save_path=tmpdir, max_epochs=1)
trainer = Trainer(**trainer_options)

with pytest.raises(MisconfigurationException):
model = CurrentTestModel()
trainer.fit(model)


def test_warning_on_wrong_validation_settings(tmpdir):
""" Test the following cases related to validation configuration of model:
* error if `val_dataloader()` is overriden but `validation_step()` is not
* if both `val_dataloader()` and `validation_step()` is overriden,
throw warning if `val_epoch_end()` is not defined
* error if `validation_step()` is overriden but `val_dataloader()` is not
"""
tutils.reset_seed()
hparams = tutils.get_default_hparams()

trainer_options = dict(default_save_path=tmpdir, max_epochs=1)
trainer = Trainer(**trainer_options)

class CurrentTestModel(LightTrainDataloader,
LightValidationDataloader,
TestModelBase):
pass

# check val_dataloader -> val_step
with pytest.raises(MisconfigurationException):
model = CurrentTestModel(hparams)
trainer.fit(model)

class CurrentTestModel(LightTrainDataloader,
LightValidationStepMixin,
TestModelBase):
pass

# check val_dataloader + val_step -> val_epoch_end
with pytest.warns(RuntimeWarning):
model = CurrentTestModel(hparams)
trainer.fit(model)

class CurrentTestModel(LightTrainDataloader,
LightValStepFitSingleDataloaderMixin,
TestModelBase):
pass

# check val_step -> val_dataloader
with pytest.raises(MisconfigurationException):
model = CurrentTestModel(hparams)
trainer.fit(model)


def test_warning_on_wrong_test_settigs(tmpdir):
""" Test the following cases related to test configuration of model:
* error if `test_dataloader()` is overriden but `test_step()` is not
* if both `test_dataloader()` and `test_step()` is overriden,
throw warning if `test_epoch_end()` is not defined
* error if `test_step()` is overriden but `test_dataloader()` is not
"""
tutils.reset_seed()
hparams = tutils.get_default_hparams()

trainer_options = dict(default_save_path=tmpdir, max_epochs=1)
trainer = Trainer(**trainer_options)

class CurrentTestModel(LightTrainDataloader,
LightTestDataloader,
TestModelBase):
pass

# check test_dataloader -> test_step
with pytest.raises(MisconfigurationException):
model = CurrentTestModel(hparams)
trainer.fit(model)

class CurrentTestModel(LightTrainDataloader,
LightTestStepMixin,
TestModelBase):
pass

# check test_dataloader + test_step -> test_epoch_end
with pytest.warns(RuntimeWarning):
model = CurrentTestModel(hparams)
trainer.fit(model)

class CurrentTestModel(LightTrainDataloader,
LightTestFitMultipleTestDataloadersMixin,
TestModelBase):
pass

# check test_step -> test_dataloader
with pytest.raises(MisconfigurationException):
model = CurrentTestModel(hparams)
trainer.fit(model)

0 comments on commit 2912239

Please sign in to comment.