Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix/_has_len #2293

Merged
merged 5 commits into from
Jun 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def _has_len(dataloader: DataLoader) -> bool:
return True
except TypeError:
return False
except NotImplementedError: # e.g. raised by torchtext if a batch_size_fn is used
return False


class TrainerDataLoadingMixin(ABC):
Expand Down
26 changes: 26 additions & 0 deletions tests/base/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,29 @@ def __next__(self):
except StopIteration:
self.iter = iter(self.dataloader)
return next(self.iter)


class CustomNotImplementedErrorDataloader:

def __init__(self, dataloader):
self.dataloader = dataloader
self.iter = iter(dataloader)
self.count = 0

def __len__(self):
"""raise NotImplementedError"""
raise NotImplementedError
thschaaf marked this conversation as resolved.
Show resolved Hide resolved

def __iter__(self):
self.count = 0
return self

def __next__(self):
if self.count >= 50:
raise StopIteration
self.count = self.count + 1
try:
return next(self.iter)
except StopIteration:
self.iter = iter(self.dataloader)
return next(self.iter)
4 changes: 4 additions & 0 deletions tests/base/model_test_dataloaders.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod

from tests.base.dataloaders import CustomInfDataloader
from tests.base.dataloaders import CustomNotImplementedErrorDataloader


class TestDataloaderVariations(ABC):
Expand All @@ -15,6 +16,9 @@ def test_dataloader(self):
def test_dataloader__infinite(self):
return CustomInfDataloader(self.dataloader(train=False))

def test_dataloader__not_implemented_error(self):
return CustomNotImplementedErrorDataloader(self.dataloader(train=False))

def test_dataloader__empty(self):
return None

Expand Down
4 changes: 4 additions & 0 deletions tests/base/model_train_dataloaders.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod

from tests.base.dataloaders import CustomInfDataloader
from tests.base.dataloaders import CustomNotImplementedErrorDataloader


class TrainDataloaderVariations(ABC):
Expand All @@ -15,6 +16,9 @@ def train_dataloader(self):
def train_dataloader__infinite(self):
return CustomInfDataloader(self.dataloader(train=True))

def train_dataloader__not_implemented_error(self):
return CustomNotImplementedErrorDataloader(self.dataloader(train=True))

def train_dataloader__zero_length(self):
dataloader = self.dataloader(train=True)
dataloader.dataset.data = dataloader.dataset.data[:0]
Expand Down
4 changes: 4 additions & 0 deletions tests/base/model_valid_dataloaders.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod

from tests.base.dataloaders import CustomInfDataloader
from tests.base.dataloaders import CustomNotImplementedErrorDataloader


class ValDataloaderVariations(ABC):
Expand All @@ -18,3 +19,6 @@ def val_dataloader__multiple(self):

def val_dataloader__infinite(self):
return CustomInfDataloader(self.dataloader(train=False))

def val_dataloader__not_implemented_error(self):
return CustomNotImplementedErrorDataloader(self.dataloader(train=False))
74 changes: 74 additions & 0 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,18 @@ def test_train_inf_dataloader_error(tmpdir):
trainer.fit(model)


@pytest.mark.skip('TODO: speed up this test')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has to be resolved before merge

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the "CustomNotImplementedErrorDataloader" is a version of 'CustomInfDataloader' with a minor change the behavior should be very similar.

The tests derived by copy and paste of the corresponding tests for 'CustomInfDataloader'. The decorators were copied under the assumption that they are correct or useful. Given that I am not sure about them, it makes sense to remove them. @Borda Please advise if that is a reasonable solution.

On a different note alternatively CustomNotImplementedErrorDataloader could be derived from CustomInfDataloader with the len method added. This would reduce duplicated code. What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Borda After removing the decorators pytest failed. This is because I also changed the match string from "infinite DataLoader" to "not_implemented_error DataLoader" when I copied the test. This does not match the MisconfigurationException exception string raised from pytorch_lightning/trainer/data_loading.py. I am changing the match strings to "infinite DataLoader" and comment skipping the tests out.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Borda I created a new PR #2307 with the changes to make the tests work. You will see a lot of commits there. It seems difficult for the test to pass on all machines. It seems that the machines cancel if it takes more than 15v minutes total time. When I enabled all the tests sometimes they finished in time on a few machine, but never on all of them. On my laptop I have observed when running the local test that they sometimes just hang (a single test >15 minutes). Maybe this is a problem of pytest in my environment (MacOS), or a more general issue. This merge certainly fixes the issue I observed, and the test in the new PR are technically working. I suggest to continue with the commits of the new PR.

def test_train_not_implemented_error_dataloader_error(tmpdir):
"""Test not_implemented_error train data loader (e.g. IterableDataset)"""
model = EvalModelTemplate()
model.train_dataloader = model.train_dataloader__not_implemented_error

trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_check_interval=0.5)

with pytest.raises(MisconfigurationException, match='not_implemented_error DataLoader'):
trainer.fit(model)


@pytest.mark.skip('TODO: speed up this test')
def test_val_inf_dataloader_error(tmpdir):
"""Test inf train data loader (e.g. IterableDataset)"""
Expand All @@ -307,6 +319,18 @@ def test_val_inf_dataloader_error(tmpdir):
trainer.fit(model)


@pytest.mark.skip('TODO: speed up this test')
def test_val_not_implemented_error_dataloader_error(tmpdir):
"""Test not_implemented_error train data loader (e.g. IterableDataset)"""
model = EvalModelTemplate()
model.val_dataloader = model.val_dataloader__not_implemented_error

trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.5)

with pytest.raises(MisconfigurationException, match='not_implemented_error DataLoader'):
trainer.fit(model)


@pytest.mark.skip('TODO: speed up this test')
def test_test_inf_dataloader_error(tmpdir):
"""Test inf train data loader (e.g. IterableDataset)"""
Expand All @@ -319,6 +343,18 @@ def test_test_inf_dataloader_error(tmpdir):
trainer.test(model)


@pytest.mark.skip('TODO: speed up this test')
def test_test_not_implemented_error_dataloader_error(tmpdir):
"""Test not_implemented_error train data loader (e.g. IterableDataset)"""
model = EvalModelTemplate()
model.test_dataloader = model.test_dataloader__not_implemented_error

trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_test_batches=0.5)

with pytest.raises(MisconfigurationException, match='not_implemented_error DataLoader'):
trainer.test(model)


@pytest.mark.parametrize('check_interval', [50, 1.0])
@pytest.mark.skip('TODO: speed up this test')
def test_inf_train_dataloader(tmpdir, check_interval):
Expand All @@ -337,6 +373,24 @@ def test_inf_train_dataloader(tmpdir, check_interval):
assert result == 1


@pytest.mark.parametrize('check_interval', [50, 1.0])
@pytest.mark.skip('TODO: speed up this test')
def test_not_implemented_error_train_dataloader(tmpdir, check_interval):
"""Test not_implemented_error train data loader (e.g. IterableDataset)"""

model = EvalModelTemplate()
model.train_dataloader = model.train_dataloader__not_implemented_error

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
val_check_interval=check_interval
)
result = trainer.fit(model)
# verify training completed
assert result == 1


@pytest.mark.parametrize('check_interval', [1.0])
@pytest.mark.skip('TODO: speed up this test')
def test_inf_val_dataloader(tmpdir, check_interval):
Expand All @@ -357,6 +411,26 @@ def test_inf_val_dataloader(tmpdir, check_interval):
assert result == 1


@pytest.mark.parametrize('check_interval', [1.0])
@pytest.mark.skip('TODO: speed up this test')
def test_not_implemented_error_dataloader(tmpdir, check_interval):
"""Test not_implemented_error data loader (e.g. IterableDataset)"""

model = EvalModelTemplate()
model.val_dataloader = model.val_dataloader__not_implemented_error

# logger file to get meta
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
val_check_interval=check_interval,
)
result = trainer.fit(model)

# verify training completed
assert result == 1


def test_error_on_zero_len_dataloader(tmpdir):
""" Test that error is raised if a zero-length dataloader is defined """

Expand Down