Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add useful errors when model is not configured correctly #1199

Merged
merged 21 commits into from
Apr 2, 2020
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Updated references to self.forward() to instead use the `__call__` interface. ([#1211](https://github.com/PyTorchLightning/pytorch-lightning/pull/1211))
- Added support for `IterableDataset` when `val_check_interval=1.0` (default), this will trigger validation at the end of each epoch. ([#1283](https://github.com/PyTorchLightning/pytorch-lightning/pull/1283))
- Added informative errors if user defined dataloader has zero length ([#1280](https://github.com/PyTorchLightning/pytorch-lightning/pull/1280))
- Added model configuration checking ([#1199](https://github.com/PyTorchLightning/pytorch-lightning/pull/1199))

### Changed

Expand Down
31 changes: 31 additions & 0 deletions pl_examples/basic_examples/lightning_module_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,37 @@ def test_dataloader(self):
log.info('Test data loader called.')
return self.__dataloader(train=False)

def test_step(self, batch, batch_idx):
"""
Lightning calls this during testing, similar to val_step
:param batch:
:return:val
"""
output = self.validation_step(batch, batch_idx)
# Rename output keys
output['test_loss'] = output.pop('val_loss')
output['test_acc'] = output.pop('val_acc')

return output

def test_epoch_end(self, outputs):
"""
Called at the end of test to aggregate outputs, similar to validation_epoch_end
:param outputs: list of individual outputs of each validation step
:return:
"""
results = self.validation_step_end(outputs)

# rename some keys
results['progress_bar'].update({
'test_loss': results['progress_bar'].pop('val_loss'),
'test_acc': results['progress_bar'].pop('val_acc'),
})
results['log'] = results['progress_bar']
results['test_loss'] = results.pop('val_loss')

return results

@staticmethod
def add_model_specific_args(parent_parser, root_dir): # pragma: no-cover
"""
Expand Down
74 changes: 62 additions & 12 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,9 @@ def fit(
# set up the passed in dataloaders (if needed)
self.__attach_dataloaders(model, train_dataloader, val_dataloaders, test_dataloaders)

# check that model is configured correctly
self.check_model_configuration(model)

# download the data and do whatever transforms we need
# do before any spawn calls so that the model can assign properties
# only on proc 0 because no spawn has happened yet
Expand Down Expand Up @@ -735,24 +738,12 @@ def __attach_dataloaders(self, model, train_dataloader, val_dataloaders, test_da
# when dataloader is passed via fit, patch the train_dataloader
# functions to overwrite with these implementations
if train_dataloader is not None:
if not self.is_overriden('training_step', model):
raise MisconfigurationException(
'You called `.fit()` with a `train_dataloader` but did not define `training_step()`')

model.train_dataloader = _PatchDataLoader(train_dataloader)

if val_dataloaders is not None:
if not self.is_overriden('validation_step', model):
raise MisconfigurationException(
'You called `.fit()` with a `val_dataloaders` but did not define `validation_step()`')

model.val_dataloader = _PatchDataLoader(val_dataloaders)

if test_dataloaders is not None:
if not self.is_overriden('test_step', model):
raise MisconfigurationException(
'You called `.fit()` with a `test_dataloaders` but did not define `test_step()`')

model.test_dataloader = _PatchDataLoader(test_dataloaders)

def init_optimizers(
Expand Down Expand Up @@ -957,6 +948,62 @@ def test(self, model: Optional[LightningModule] = None):

self.testing = False

def check_model_configuration(self, model: LightningModule):
r"""
Checks that the model is configured correctly before training is started.

Args:
model: The model to test.

"""
# Check training_step, train_dataloader, configure_optimizer methods
if not self.is_overriden('training_step', model):
raise MisconfigurationException(
'No `training_step()` method defined. Lightning expects as minimum a `training_step()` and'
' `training_dataloader()` to be defined.')

if not self.is_overriden('train_dataloader', model):
raise MisconfigurationException(
'No `train_dataloader()` defined. Lightning expects as minimum a `training_step()` and'
' `training_dataloader()` to be defined.')

if not self.is_overriden('configure_optimizers', model):
warnings.warn('`configure_optimizers()` method not defined by user,'
' will default to Adam optimizer with learning rate set to 0.0001.',
RuntimeWarning)

# Check val_dataloader, validation_step and validation_epoch_end
if self.is_overriden('val_dataloader', model):
if not self.is_overriden('validation_step', model):
raise MisconfigurationException('You have passed in a `val_dataloader()`'
' but have not defined `validation_step()`.')
else:
if not self.is_overriden('validation_epoch_end', model):
warnings.warn('You have defined a `val_dataloader()` and have'
' defined a `validation_step()`, you may also want to'
' define `validation_epoch_end()` for accumulating stats.',
RuntimeWarning)
else:
if self.is_overriden('validation_step', model):
raise MisconfigurationException('You have defined `validation_step()`,'
' but have not passed in a val_dataloader().')

# Check test_dataloader, test_step and test_epoch_end
if self.is_overriden('test_dataloader', model):
if not self.is_overriden('test_step', model):
raise MisconfigurationException('You have passed in a `test_dataloader()`'
' but have not defined `test_step()`.')
else:
if not self.is_overriden('test_epoch_end', model):
warnings.warn('You have defined a `test_dataloader()` and'
' have defined a `test_step()`, you may also want to'
' define `test_epoch_end()` for accumulating stats.',
RuntimeWarning)
else:
if self.is_overriden('test_step', model):
raise MisconfigurationException('You have defined `test_step()`,'
' but have not passed in a `test_dataloader()`.')


class _PatchDataLoader(object):
r"""
Expand All @@ -970,5 +1017,8 @@ class _PatchDataLoader(object):
def __init__(self, dataloader: Union[List[DataLoader], DataLoader]):
self.dataloader = dataloader

# Assign __code__, needed for checking if method has been overriden
self.__code__ = self.__call__.__code__

def __call__(self) -> Union[List[DataLoader], DataLoader]:
return self.dataloader
1 change: 1 addition & 0 deletions tests/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
LightValStepFitSingleDataloaderMixin,
LightValStepFitMultipleDataloadersMixin,
LightTrainDataloader,
LightValidationDataloader,
LightTestDataloader,
LightInfTrainDataloader,
LightInfValDataloader,
Expand Down
10 changes: 10 additions & 0 deletions tests/base/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,13 @@ def train_dataloader(self):
return self._dataloader(train=True)


class LightValidationDataloader:
"""Simple validation dataloader."""

def val_dataloader(self):
return self._dataloader(train=False)


class LightTestDataloader:
"""Simple test dataloader."""

Expand Down Expand Up @@ -412,6 +419,9 @@ def test_step(self, batch, batch_idx, dataloader_idx, **kwargs):
class LightTestFitSingleTestDataloadersMixin:
"""Test fit single test dataloaders mixin."""

def test_dataloader(self):
return self._dataloader(train=False)

def test_step(self, batch, batch_idx, *args, **kwargs):
"""
Lightning calls this inside the validation loop
Expand Down
12 changes: 12 additions & 0 deletions tests/test_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,15 @@ class ModelVer0_6(LightTrainDataloader, LightEmptyTestStep, TestModelBase):
def val_dataloader(self):
return self._dataloader(train=False)

def validation_step(self, batch, batch_idx, *args, **kwargs):
return {'val_loss': 0.6}

def validation_end(self, outputs):
return {'val_loss': 0.6}

def test_dataloader(self):
return self._dataloader(train=False)

def test_end(self, outputs):
return {'test_loss': 0.6}

Expand All @@ -77,9 +83,15 @@ class ModelVer0_7(LightTrainDataloader, LightEmptyTestStep, TestModelBase):
def val_dataloader(self):
return self._dataloader(train=False)

def validation_step(self, batch, batch_idx, *args, **kwargs):
return {'val_loss': 0.7}

def validation_end(self, outputs):
return {'val_loss': 0.7}

def test_dataloader(self):
return self._dataloader(train=False)

def test_end(self, outputs):
return {'test_loss': 0.7}

Expand Down
135 changes: 135 additions & 0 deletions tests/trainer/test_checks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import pytest

import tests.base.utils as tutils
from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.utilities.debugging import MisconfigurationException
from tests.base import (
TestModelBase,
LightValidationDataloader,
LightTestDataloader,
LightValidationStepMixin,
LightValStepFitSingleDataloaderMixin,
LightTrainDataloader,
LightTestStepMixin,
LightTestFitMultipleTestDataloadersMixin,
)


def test_error_on_no_train_step(tmpdir):
""" Test that an error is thrown when no `training_step()` is defined """
tutils.reset_seed()

class CurrentTestModel(LightningModule):
def forward(self, x):
pass

trainer_options = dict(default_save_path=tmpdir, max_epochs=1)
trainer = Trainer(**trainer_options)

with pytest.raises(MisconfigurationException):
model = CurrentTestModel()
trainer.fit(model)


def test_error_on_no_train_dataloader(tmpdir):
""" Test that an error is thrown when no `training_dataloader()` is defined """
tutils.reset_seed()
hparams = tutils.get_default_hparams()

class CurrentTestModel(TestModelBase):
pass

trainer_options = dict(default_save_path=tmpdir, max_epochs=1)
trainer = Trainer(**trainer_options)

with pytest.raises(MisconfigurationException):
model = CurrentTestModel(hparams)
trainer.fit(model)


def test_warning_on_wrong_validation_settings(tmpdir):
""" Test the following cases related to validation configuration of model:
* error if `val_dataloader()` is overriden but `validation_step()` is not
* if both `val_dataloader()` and `validation_step()` is overriden,
throw warning if `val_epoch_end()` is not defined
* error if `validation_step()` is overriden but `val_dataloader()` is not
"""
tutils.reset_seed()
hparams = tutils.get_default_hparams()

trainer_options = dict(default_save_path=tmpdir, max_epochs=1)
trainer = Trainer(**trainer_options)

class CurrentTestModel(LightTrainDataloader,
LightValidationDataloader,
TestModelBase):
pass

# check val_dataloader -> val_step
with pytest.raises(MisconfigurationException):
model = CurrentTestModel(hparams)
trainer.fit(model)

class CurrentTestModel(LightTrainDataloader,
LightValidationStepMixin,
TestModelBase):
pass

# check val_dataloader + val_step -> val_epoch_end
with pytest.warns(RuntimeWarning):
model = CurrentTestModel(hparams)
trainer.fit(model)

class CurrentTestModel(LightTrainDataloader,
LightValStepFitSingleDataloaderMixin,
TestModelBase):
pass

# check val_step -> val_dataloader
with pytest.raises(MisconfigurationException):
model = CurrentTestModel(hparams)
trainer.fit(model)


def test_warning_on_wrong_test_settigs(tmpdir):
""" Test the following cases related to test configuration of model:
* error if `test_dataloader()` is overriden but `test_step()` is not
* if both `test_dataloader()` and `test_step()` is overriden,
throw warning if `test_epoch_end()` is not defined
* error if `test_step()` is overriden but `test_dataloader()` is not
"""
tutils.reset_seed()
hparams = tutils.get_default_hparams()

trainer_options = dict(default_save_path=tmpdir, max_epochs=1)
trainer = Trainer(**trainer_options)

class CurrentTestModel(LightTrainDataloader,
LightTestDataloader,
TestModelBase):
pass

# check test_dataloader -> test_step
with pytest.raises(MisconfigurationException):
model = CurrentTestModel(hparams)
trainer.fit(model)

class CurrentTestModel(LightTrainDataloader,
LightTestStepMixin,
TestModelBase):
pass

# check test_dataloader + test_step -> test_epoch_end
with pytest.warns(RuntimeWarning):
model = CurrentTestModel(hparams)
trainer.fit(model)

class CurrentTestModel(LightTrainDataloader,
LightTestFitMultipleTestDataloadersMixin,
TestModelBase):
pass

# check test_step -> test_dataloader
with pytest.raises(MisconfigurationException):
model = CurrentTestModel(hparams)
trainer.fit(model)
4 changes: 2 additions & 2 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
LightValidationStepMixin,
LightValidationMultipleDataloadersMixin,
LightTrainDataloader,
LightTestDataloader,
LightTestStepMixin,
)


Expand Down Expand Up @@ -495,7 +495,7 @@ def test_testpass_overrides(tmpdir):
class LocalModel(LightTrainDataloader, TestModelBase):
pass

class LocalModelNoEnd(LightTrainDataloader, LightTestDataloader, LightEmptyTestStep, TestModelBase):
class LocalModelNoEnd(LightTrainDataloader, LightTestStepMixin, LightEmptyTestStep, TestModelBase):
pass

class LocalModelNoStep(LightTrainDataloader, TestModelBase):
Expand Down