Skip to content

Commit

Permalink
Error on zero length dataloaders (Lightning-AI#1280)
Browse files Browse the repository at this point in the history
* error_on_zero_length

* update CHANGELOG.md

* added test

* Update pytorch_lightning/trainer/data_loading.py

Co-authored-by: Nicki Skafte <nugginea@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
3 people authored and akarnachev committed Apr 3, 2020
1 parent 4a92684 commit a683186
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a check that stops the training when loss or weights contain `NaN` or `inf` values. ([#1097](https://github.com/PyTorchLightning/pytorch-lightning/pull/1097))
- Updated references to self.forward() to instead use the `__call__` interface. ([#1211](https://github.com/PyTorchLightning/pytorch-lightning/pull/1211))
- Added support for `IterableDataset` when `val_check_interval=1.0` (default), this will trigger validation at the end of each epoch. ([#1283](https://github.com/PyTorchLightning/pytorch-lightning/pull/1283))
- Added informative errors if user defined dataloader has zero length ([#1280](https://github.com/PyTorchLightning/pytorch-lightning/pull/1280))

### Changed

Expand Down
6 changes: 5 additions & 1 deletion pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,13 @@


def _has_len(dataloader: DataLoader) -> bool:
""" Checks if a given Dataloader has __len__ method implemented i.e. if
it is a finite dataloader or infinite dataloader """
try:
# try getting the length
_ = len(dataloader)
if len(dataloader) == 0:
raise ValueError('Dataloader returned 0 length. Please make sure'
' that your Dataloader atleast returns 1 batch')
return True
except TypeError:
return False
Expand Down
3 changes: 2 additions & 1 deletion tests/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
LightTestOptimizerWithSchedulingMixin,
LightTestMultipleOptimizersWithSchedulingMixin,
LightTestOptimizersWithMixedSchedulingMixin,
LightTestReduceLROnPlateauMixin
LightTestReduceLROnPlateauMixin,
LightZeroLenDataloader
)


Expand Down
10 changes: 10 additions & 0 deletions tests/base/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,16 @@ def test_dataloader(self):
return CustomInfDataloader(self._dataloader(train=False))


class LightZeroLenDataloader:
""" Simple dataloader that has zero length. """

def train_dataloader(self):
dataloader = self._dataloader(train=True)
dataloader.dataset.data = dataloader.dataset.data[:0]
dataloader.dataset.targets = dataloader.dataset.targets[:0]
return dataloader


class LightEmptyTestStep:
"""Empty test step."""

Expand Down
26 changes: 25 additions & 1 deletion tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
LightTrainDataloader,
LightInfTrainDataloader,
LightInfValDataloader,
LightInfTestDataloader
LightInfTestDataloader,
LightZeroLenDataloader
)


Expand Down Expand Up @@ -458,3 +459,26 @@ class CurrentTestModel(

# verify training completed
assert result == 1


def test_error_on_zero_len_dataloader(tmpdir):
""" Test that error is raised if a zero-length dataloader is defined """
tutils.reset_seed()

class CurrentTestModel(
LightZeroLenDataloader,
LightningTestModel
):
pass

hparams = tutils.get_default_hparams()
model = CurrentTestModel(hparams)

# fit model
with pytest.raises(ValueError):
trainer = Trainer(
default_save_path=tmpdir,
max_epochs=1,
test_percent_check=0.5
)
trainer.fit(model)

0 comments on commit a683186

Please sign in to comment.