Skip to content

Commit

Permalink
Bugfix/_has_len (#2293)
Browse files Browse the repository at this point in the history
* deal with NotImplementedError raised by torchtext

* deal with NotImplementedError raised by torchtext

* Added tests for dataloader which raise NotImplementedError in __len__()

* Fixed some typos

Co-authored-by: Thomas Schaaf <tschaaf@cs.cmu.edu>
  • Loading branch information
thschaaf and Thomas Schaaf committed Jun 20, 2020
1 parent 3256fe4 commit 554fb47
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 0 deletions.
2 changes: 2 additions & 0 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def _has_len(dataloader: DataLoader) -> bool:
return True
except TypeError:
return False
except NotImplementedError: # e.g. raised by torchtext if a batch_size_fn is used
return False


class TrainerDataLoadingMixin(ABC):
Expand Down
26 changes: 26 additions & 0 deletions tests/base/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,29 @@ def __next__(self):
except StopIteration:
self.iter = iter(self.dataloader)
return next(self.iter)


class CustomNotImplementedErrorDataloader:

def __init__(self, dataloader):
self.dataloader = dataloader
self.iter = iter(dataloader)
self.count = 0

def __len__(self):
"""raise NotImplementedError"""
raise NotImplementedError

def __iter__(self):
self.count = 0
return self

def __next__(self):
if self.count >= 50:
raise StopIteration
self.count = self.count + 1
try:
return next(self.iter)
except StopIteration:
self.iter = iter(self.dataloader)
return next(self.iter)
4 changes: 4 additions & 0 deletions tests/base/model_test_dataloaders.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod

from tests.base.dataloaders import CustomInfDataloader
from tests.base.dataloaders import CustomNotImplementedErrorDataloader


class TestDataloaderVariations(ABC):
Expand All @@ -15,6 +16,9 @@ def test_dataloader(self):
def test_dataloader__infinite(self):
return CustomInfDataloader(self.dataloader(train=False))

def test_dataloader__not_implemented_error(self):
return CustomNotImplementedErrorDataloader(self.dataloader(train=False))

def test_dataloader__empty(self):
return None

Expand Down
4 changes: 4 additions & 0 deletions tests/base/model_train_dataloaders.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod

from tests.base.dataloaders import CustomInfDataloader
from tests.base.dataloaders import CustomNotImplementedErrorDataloader


class TrainDataloaderVariations(ABC):
Expand All @@ -15,6 +16,9 @@ def train_dataloader(self):
def train_dataloader__infinite(self):
return CustomInfDataloader(self.dataloader(train=True))

def train_dataloader__not_implemented_error(self):
return CustomNotImplementedErrorDataloader(self.dataloader(train=True))

def train_dataloader__zero_length(self):
dataloader = self.dataloader(train=True)
dataloader.dataset.data = dataloader.dataset.data[:0]
Expand Down
4 changes: 4 additions & 0 deletions tests/base/model_valid_dataloaders.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod

from tests.base.dataloaders import CustomInfDataloader
from tests.base.dataloaders import CustomNotImplementedErrorDataloader


class ValDataloaderVariations(ABC):
Expand All @@ -18,3 +19,6 @@ def val_dataloader__multiple(self):

def val_dataloader__infinite(self):
return CustomInfDataloader(self.dataloader(train=False))

def val_dataloader__not_implemented_error(self):
return CustomNotImplementedErrorDataloader(self.dataloader(train=False))
74 changes: 74 additions & 0 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,18 @@ def test_train_inf_dataloader_error(tmpdir):
trainer.fit(model)


@pytest.mark.skip('TODO: speed up this test')
def test_train_not_implemented_error_dataloader_error(tmpdir):
"""Test not_implemented_error train data loader (e.g. IterableDataset)"""
model = EvalModelTemplate()
model.train_dataloader = model.train_dataloader__not_implemented_error

trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_check_interval=0.5)

with pytest.raises(MisconfigurationException, match='not_implemented_error DataLoader'):
trainer.fit(model)


@pytest.mark.skip('TODO: speed up this test')
def test_val_inf_dataloader_error(tmpdir):
"""Test inf train data loader (e.g. IterableDataset)"""
Expand All @@ -307,6 +319,18 @@ def test_val_inf_dataloader_error(tmpdir):
trainer.fit(model)


@pytest.mark.skip('TODO: speed up this test')
def test_val_not_implemented_error_dataloader_error(tmpdir):
"""Test not_implemented_error train data loader (e.g. IterableDataset)"""
model = EvalModelTemplate()
model.val_dataloader = model.val_dataloader__not_implemented_error

trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.5)

with pytest.raises(MisconfigurationException, match='not_implemented_error DataLoader'):
trainer.fit(model)


@pytest.mark.skip('TODO: speed up this test')
def test_test_inf_dataloader_error(tmpdir):
"""Test inf train data loader (e.g. IterableDataset)"""
Expand All @@ -319,6 +343,18 @@ def test_test_inf_dataloader_error(tmpdir):
trainer.test(model)


@pytest.mark.skip('TODO: speed up this test')
def test_test_not_implemented_error_dataloader_error(tmpdir):
"""Test not_implemented_error train data loader (e.g. IterableDataset)"""
model = EvalModelTemplate()
model.test_dataloader = model.test_dataloader__not_implemented_error

trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_test_batches=0.5)

with pytest.raises(MisconfigurationException, match='not_implemented_error DataLoader'):
trainer.test(model)


@pytest.mark.parametrize('check_interval', [50, 1.0])
@pytest.mark.skip('TODO: speed up this test')
def test_inf_train_dataloader(tmpdir, check_interval):
Expand All @@ -337,6 +373,24 @@ def test_inf_train_dataloader(tmpdir, check_interval):
assert result == 1


@pytest.mark.parametrize('check_interval', [50, 1.0])
@pytest.mark.skip('TODO: speed up this test')
def test_not_implemented_error_train_dataloader(tmpdir, check_interval):
"""Test not_implemented_error train data loader (e.g. IterableDataset)"""

model = EvalModelTemplate()
model.train_dataloader = model.train_dataloader__not_implemented_error

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
val_check_interval=check_interval
)
result = trainer.fit(model)
# verify training completed
assert result == 1


@pytest.mark.parametrize('check_interval', [1.0])
@pytest.mark.skip('TODO: speed up this test')
def test_inf_val_dataloader(tmpdir, check_interval):
Expand All @@ -357,6 +411,26 @@ def test_inf_val_dataloader(tmpdir, check_interval):
assert result == 1


@pytest.mark.parametrize('check_interval', [1.0])
@pytest.mark.skip('TODO: speed up this test')
def test_not_implemented_error_dataloader(tmpdir, check_interval):
"""Test not_implemented_error data loader (e.g. IterableDataset)"""

model = EvalModelTemplate()
model.val_dataloader = model.val_dataloader__not_implemented_error

# logger file to get meta
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
val_check_interval=check_interval,
)
result = trainer.fit(model)

# verify training completed
assert result == 1


def test_error_on_zero_len_dataloader(tmpdir):
""" Test that error is raised if a zero-length dataloader is defined """

Expand Down

0 comments on commit 554fb47

Please sign in to comment.