Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add useful errors when model is not configured correctly #1199

Merged
merged 21 commits into from
Apr 2, 2020
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for IterableDataset in validation and testing ([#1104](https://github.com/PyTorchLightning/pytorch-lightning/pull/1104))
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
- Added support for non-primitive types in hparams for TensorboardLogger ([#1130](https://github.com/PyTorchLightning/pytorch-lightning/pull/1130))
- Added a check that stops the training when loss or weights contain NaN or inf values. ([#1097](https://github.com/PyTorchLightning/pytorch-lightning/pull/1097))
- Added model configuration checking ([#1199](https://github.com/PyTorchLightning/pytorch-lightning/pull/1199))

### Changed

Expand Down
32 changes: 32 additions & 0 deletions pl_examples/basic_examples/lightning_module_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,38 @@ def test_dataloader(self):
log.info('Test data loader called.')
return self.__dataloader(train=False)

def test_step(self, batch, batch_idx):
"""
Lightning calls this during testing, similar to val_step
:param batch:
:return:val
"""
output = self.validation_step(batch, batch_idx)
# Rename output keys
output['test_loss'] = output.pop('val_loss')
output['test_acc'] = output.pop('val_acc')

return output

def test_epoch_end(self, outputs):
"""
Called at the end of test to aggregate outputs, similar to validation_epoch_end
:param outputs: list of individual outputs of each validation step
:return:
"""
results = self.validation_step_end(outputs)

# Rename output
tqdm_dict = results['progress_bar']
tqdm_dict['test_loss'] = tqdm_dict.pop('val_loss')
tqdm_dict['test_acc'] = tqdm_dict.pop('val_acc')

results['progress_bar'] = tqdm_dict
results['log'] = tqdm_dict
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
results['test_loss'] = results.pop('val_loss')

return results

@staticmethod
def add_model_specific_args(parent_parser, root_dir): # pragma: no-cover
"""
Expand Down
78 changes: 66 additions & 12 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,9 @@ def fit(
# set up the passed in dataloaders (if needed)
self.__attach_dataloaders(model, train_dataloader, val_dataloaders, test_dataloaders)

# check that model is configured correctly
self.check_model_configuration(model)

# download the data and do whatever transforms we need
# do before any spawn calls so that the model can assign properties
# only on proc 0 because no spawn has happened yet
Expand Down Expand Up @@ -664,24 +667,12 @@ def __attach_dataloaders(self, model, train_dataloader, val_dataloaders, test_da
# when dataloader is passed via fit, patch the train_dataloader
# functions to overwrite with these implementations
if train_dataloader is not None:
if not self.is_overriden('training_step', model):
m = 'You called .fit() with a train_dataloader but did not define training_step()'
raise MisconfigurationException(m)

model.train_dataloader = _PatchDataLoader(train_dataloader)

if val_dataloaders is not None:
if not self.is_overriden('validation_step', model):
m = 'You called .fit() with a val_dataloaders but did not define validation_step()'
raise MisconfigurationException(m)

model.val_dataloader = _PatchDataLoader(val_dataloaders)

if test_dataloaders is not None:
if not self.is_overriden('test_step', model):
m = 'You called .fit() with a test_dataloaders but did not define test_step()'
raise MisconfigurationException(m)

model.test_dataloader = _PatchDataLoader(test_dataloaders)

def init_optimizers(
Expand Down Expand Up @@ -886,6 +877,66 @@ def test(self, model: Optional[LightningModule] = None):

self.testing = False

def check_model_configuration(self, model: LightningModule):
r"""
Checks that the model is configured correctly before training is started.

Args:
model: The model to test.

"""
# Check training_step, train_dataloader, configure_optimizer methods
if not self.is_overriden('training_step', model):
m = ('No training_step() method defined. Lightning expects as minimum '
'a training_step() and training_dataloader() to be defined.')
raise MisconfigurationException(m)
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved

if not self.is_overriden('train_dataloader', model):
m = ('No train_dataloader() defined. Lightning expects as minimum '
'a training_step() and training_dataloader() to be defined.')
raise MisconfigurationException(m)

if not self.is_overriden('configure_optimizers', model):
m = ('configure_optimizers() method not defined by user, will default '
'to Adam optimizer with learning rate set to 0.0001.')
warnings.warn(m)
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved

# Check val_dataloader, validation_step and validation_epoch_end
if self.is_overriden('val_dataloader', model):
if not self.is_overriden('validation_step', model):
m = ('You have passed in a val_dataloader() but have not defined '
'validation_step()')
raise MisconfigurationException(m)
else:
if not self.is_overriden('validation_epoch_end', model):
m = ('You have defined a val_dataloader() and have defined '
'a validation_step(), you may also want to define '
'validation_epoch_end() for accumulating stats')
warnings.warn(m)
else:
if self.is_overriden('validation_step', model):
m = ('You have defined validation_step(), but have not passed '
'in a val_dataloader().')
raise MisconfigurationException(m)

# Check test_dataloader, test_step and test_epoch_end
if self.is_overriden('test_dataloader', model):
if not self.is_overriden('test_step', model):
m = ('You have passed in a test_dataloader() but have not defined '
'test_step()')
raise MisconfigurationException(m)
else:
if not self.is_overriden('test_epoch_end', model):
m = ('You have defined a test_dataloader() and have defined '
'a test_step(), you may also want to define '
'test_epoch_end() for accumulating stats')
warnings.warn(m)
else:
if self.is_overriden('test_step', model):
m = ('You have defined test_step, but have not passed in '
'a test_dataloader().')
raise MisconfigurationException(m)


class _PatchDataLoader(object):
r"""
Expand All @@ -899,5 +950,8 @@ class _PatchDataLoader(object):
def __init__(self, dataloader: Union[List[DataLoader], DataLoader]):
self.dataloader = dataloader

# Assign __code__, needed for checking if method has been overriden
self.__code__ = self.__call__.__code__

def __call__(self) -> Union[List[DataLoader], DataLoader]:
return self.dataloader
3 changes: 3 additions & 0 deletions tests/models/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,9 @@ def test_step(self, batch, batch_idx, dataloader_idx, **kwargs):
class LightTestFitSingleTestDataloadersMixin:
"""Test fit single test dataloaders mixin."""

def test_dataloader(self):
return self._dataloader(train=False)

def test_step(self, batch, batch_idx, *args, **kwargs):
"""
Lightning calls this inside the validation loop
Expand Down
12 changes: 12 additions & 0 deletions tests/test_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,15 @@ class ModelVer0_6(LightTrainDataloader, LightEmptyTestStep, TestModelBase):
def val_dataloader(self):
return self._dataloader(train=False)

def validation_step(self, batch, batch_idx, *args, **kwargs):
return {'val_loss': 0.6}

def validation_end(self, outputs):
return {'val_loss': 0.6}

def test_dataloader(self):
return self._dataloader(train=False)

def test_end(self, outputs):
return {'test_loss': 0.6}

Expand All @@ -77,9 +83,15 @@ class ModelVer0_7(LightTrainDataloader, LightEmptyTestStep, TestModelBase):
def val_dataloader(self):
return self._dataloader(train=False)

def validation_step(self, batch, batch_idx, *args, **kwargs):
return {'val_loss': 0.7}

def validation_end(self, outputs):
return {'val_loss': 0.7}

def test_dataloader(self):
return self._dataloader(train=False)

def test_end(self, outputs):
return {'test_loss': 0.7}

Expand Down
4 changes: 2 additions & 2 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
LightValidationStepMixin,
LightValidationMultipleDataloadersMixin,
LightTrainDataloader,
LightTestDataloader,
LightTestStepMixin,
)


Expand Down Expand Up @@ -601,7 +601,7 @@ def test_testpass_overrides(tmpdir):
class LocalModel(LightTrainDataloader, TestModelBase):
pass

class LocalModelNoEnd(LightTrainDataloader, LightTestDataloader, LightEmptyTestStep, TestModelBase):
class LocalModelNoEnd(LightTrainDataloader, LightTestStepMixin, LightEmptyTestStep, TestModelBase):
pass

class LocalModelNoStep(LightTrainDataloader, TestModelBase):
Expand Down