Skip to content

Commit

Permalink
Warn user when IterableDataset has __len__ defined (#2437)
Browse files Browse the repository at this point in the history
* add warning when getting checking len

* added test

* changelog

* pep

* do not show warning below 1.4

* try version parse

* comments

* xfail

* Update requirements/base.txt

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update pytorch_lightning/trainer/data_loading.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* version

Co-authored-by: William Falcon <waf2107@columbia.edu>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Jirka <jirka@pytorchlightning.ai>
  • Loading branch information
4 people committed Jul 1, 2020
1 parent 325852c commit 927f305
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 17 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added reduce ddp results on eval ([#2434](https://github.com/PyTorchLightning/pytorch-lightning/pull/2434))

- Added a warning when an `IterableDataset` has `__len__` defined ([#2437](https://github.com/PyTorchLightning/pytorch-lightning/pull/2437))

### Changed


Expand Down
39 changes: 24 additions & 15 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import multiprocessing
import platform
from abc import ABC, abstractmethod
from distutils.version import LooseVersion
from typing import Union, List, Tuple, Callable, Optional
import multiprocessing

import torch
import torch.distributed as torch_distrib
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
Expand Down Expand Up @@ -41,19 +43,33 @@
HOROVOD_AVAILABLE = True


def _has_iterable_dataset(dataloader: DataLoader):
return ITERABLE_DATASET_EXISTS and hasattr(dataloader, 'dataset') \
and isinstance(dataloader.dataset, IterableDataset)


def _has_len(dataloader: DataLoader) -> bool:
""" Checks if a given Dataloader has __len__ method implemented i.e. if
it is a finite dataloader or infinite dataloader """
it is a finite dataloader or infinite dataloader. """

try:
# try getting the length
if len(dataloader) == 0:
raise ValueError('`Dataloader` returned 0 length.'
' Please make sure that your Dataloader at least returns 1 batch')
return True
has_len = True
except TypeError:
return False
has_len = False
except NotImplementedError: # e.g. raised by torchtext if a batch_size_fn is used
return False
has_len = False

if has_len and _has_iterable_dataset(dataloader) and LooseVersion(torch.__version__) >= LooseVersion("1.4.0"):
rank_zero_warn(
'Your `IterableDataset` has `__len__` defined.'
' In combination with multi-processing data loading (e.g. batch size > 1),'
' this can lead to unintended side effects since the samples will be duplicated.'
)
return has_len


class TrainerDataLoadingMixin(ABC):
Expand Down Expand Up @@ -128,12 +144,9 @@ def _worker_check(self, dataloader: DataLoader, name: str) -> None:
def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:

# don't do anything if it's not a dataloader
# don't manipulate iterable datasets
is_dataloader = isinstance(dataloader, DataLoader)

is_iterable_ds = False
if ITERABLE_DATASET_EXISTS and hasattr(dataloader, 'dataset'):
is_iterable_ds = isinstance(dataloader.dataset, IterableDataset)
# don't manipulate iterable datasets
is_iterable_ds = _has_iterable_dataset(dataloader)

if not is_dataloader or is_iterable_ds:
return dataloader
Expand Down Expand Up @@ -285,11 +298,7 @@ def _reset_eval_dataloader(
# datasets could be none, 1 or 2+
if len(dataloaders) != 0:
for i, dataloader in enumerate(dataloaders):
try:
num_batches = len(dataloader)
except (TypeError, NotImplementedError):
num_batches = float('inf')

num_batches = len(dataloader) if _has_len(dataloader) else float('inf')
self._worker_check(dataloader, f'{mode} dataloader {i}')

# percent or num_steps
Expand Down
2 changes: 1 addition & 1 deletion requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ tensorboard>=1.14
future>=0.17.1 # required for builtins in setup.py
# pyyaml>=3.13
PyYAML>=5.1 # OmegaConf requirement
tqdm>=4.41.0
tqdm>=4.41.0
34 changes: 33 additions & 1 deletion tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

import pytest
import torch
from packaging.version import parse
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Subset
from torch.utils.data.dataset import Subset, IterableDataset

import tests.base.develop_pipelines as tpipes
from pytorch_lightning import Trainer
from pytorch_lightning.trainer.data_loading import _has_len, _has_iterable_dataset
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate

Expand Down Expand Up @@ -487,6 +489,36 @@ def test_warning_with_few_workers(tmpdir, ckpt_path):
trainer.test(**test_options)


@pytest.mark.xfail(
parse(torch.__version__) < parse("1.4.0"),
reason="IterableDataset with __len__ before 1.4 raises",
)
def test_warning_with_iterable_dataset_and_len(tmpdir):
""" Tests that a warning messages is shown when an IterableDataset defines `__len__`. """
model = EvalModelTemplate()
original_dataset = model.train_dataloader().dataset

class IterableWithLen(IterableDataset):

def __iter__(self):
return iter(original_dataset)

def __len__(self):
return len(original_dataset)

dataloader = DataLoader(IterableWithLen(), batch_size=16)
assert _has_len(dataloader)
assert _has_iterable_dataset(dataloader)
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=3,
)
with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'):
trainer.fit(model, train_dataloader=dataloader, val_dataloaders=[dataloader])
with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'):
trainer.test(model, test_dataloaders=[dataloader])


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='Test requires multiple GPUs')
def test_dataloader_reinit_for_subclass():

Expand Down

0 comments on commit 927f305

Please sign in to comment.