Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Warn user when IterableDataset has __len__ defined #2437

Merged
merged 11 commits into from
Jul 1, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 20 additions & 12 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,33 @@
HOROVOD_AVAILABLE = True


def _has_iterable_dataset(dataloader: DataLoader):
return ITERABLE_DATASET_EXISTS and hasattr(dataloader, 'dataset') \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you encapsulate the if statement in () you do not need to use \

and isinstance(dataloader.dataset, IterableDataset)


def _has_len(dataloader: DataLoader) -> bool:
""" Checks if a given Dataloader has __len__ method implemented i.e. if
it is a finite dataloader or infinite dataloader """
it is a finite dataloader or infinite dataloader. """

try:
# try getting the length
if len(dataloader) == 0:
raise ValueError('`Dataloader` returned 0 length.'
' Please make sure that your Dataloader at least returns 1 batch')
return True
has_len = True
except TypeError:
return False
has_len = False
except NotImplementedError: # e.g. raised by torchtext if a batch_size_fn is used
return False
has_len = False
Comment on lines 61 to +64
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the resulting action is the same, lets write it in one

    except (TypeError, NotImplementedError):  # e.g. raised by torchtext if a batch_size_fn is used
        has_len = False


if has_len and _has_iterable_dataset(dataloader):
rank_zero_warn(
'Your `IterableDataset` has `__len__` defined.'
' In combination with multi-processing data loading (e.g. batch size > 1),'
' this can lead to unintended side effects since the samples will be duplicated.'
)
return has_len


class TrainerDataLoadingMixin(ABC):
Expand Down Expand Up @@ -131,9 +145,7 @@ def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:
# don't manipulate iterable datasets
is_dataloader = isinstance(dataloader, DataLoader)

is_iterable_ds = False
if ITERABLE_DATASET_EXISTS and hasattr(dataloader, 'dataset'):
is_iterable_ds = isinstance(dataloader.dataset, IterableDataset)
is_iterable_ds = _has_iterable_dataset(dataloader)

if not is_dataloader or is_iterable_ds:
return dataloader
Expand Down Expand Up @@ -285,11 +297,7 @@ def _reset_eval_dataloader(
# datasets could be none, 1 or 2+
if len(dataloaders) != 0:
for i, dataloader in enumerate(dataloaders):
try:
num_batches = len(dataloader)
except (TypeError, NotImplementedError):
num_batches = float('inf')

num_batches = len(dataloader) if _has_len(dataloader) else float('inf')
self._worker_check(dataloader, f'{mode} dataloader {i}')

# percent or num_steps
Expand Down
28 changes: 27 additions & 1 deletion tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import pytest
import torch
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Subset
from torch.utils.data.dataset import Subset, IterableDataset

import tests.base.develop_pipelines as tpipes
from pytorch_lightning import Trainer
from pytorch_lightning.trainer.data_loading import _has_len, _has_iterable_dataset
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate

Expand Down Expand Up @@ -487,6 +488,31 @@ def test_warning_with_few_workers(tmpdir, ckpt_path):
trainer.test(**test_options)


def test_warning_with_iterable_dataset_and_len(tmpdir):
""" Tests that a warning messages is shown when an IterableDataset defines `__len__`. """
model = EvalModelTemplate()
original_dataset = model.train_dataloader().dataset
class IterableWithLen(IterableDataset):

def __iter__(self):
return iter(original_dataset)

def __len__(self):
return len(original_dataset)

dataloader = DataLoader(IterableWithLen(), batch_size=16)
assert _has_len(dataloader)
assert _has_iterable_dataset(dataloader)
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=3,
)
with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'):
trainer.fit(model, train_dataloader=dataloader, val_dataloaders=[dataloader])
with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'):
trainer.test(model, test_dataloaders=[dataloader])


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='Test requires multiple GPUs')
def test_dataloader_reinit_for_subclass():

Expand Down