Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Warn user when IterableDataset has __len__ defined #2437

Merged
merged 11 commits into from
Jul 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added reduce ddp results on eval ([#2434](https://github.com/PyTorchLightning/pytorch-lightning/pull/2434))

- Added a warning when an `IterableDataset` has `__len__` defined ([#2437](https://github.com/PyTorchLightning/pytorch-lightning/pull/2437))

### Changed


Expand Down
39 changes: 24 additions & 15 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import multiprocessing
import platform
from abc import ABC, abstractmethod
from distutils.version import LooseVersion
from typing import Union, List, Tuple, Callable, Optional
import multiprocessing

import torch
import torch.distributed as torch_distrib
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
Expand Down Expand Up @@ -41,19 +43,33 @@
HOROVOD_AVAILABLE = True


def _has_iterable_dataset(dataloader: DataLoader):
return ITERABLE_DATASET_EXISTS and hasattr(dataloader, 'dataset') \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you encapsulate the if statement in () you do not need to use \

and isinstance(dataloader.dataset, IterableDataset)


def _has_len(dataloader: DataLoader) -> bool:
""" Checks if a given Dataloader has __len__ method implemented i.e. if
it is a finite dataloader or infinite dataloader """
it is a finite dataloader or infinite dataloader. """

try:
# try getting the length
if len(dataloader) == 0:
raise ValueError('`Dataloader` returned 0 length.'
' Please make sure that your Dataloader at least returns 1 batch')
return True
has_len = True
except TypeError:
return False
has_len = False
except NotImplementedError: # e.g. raised by torchtext if a batch_size_fn is used
return False
has_len = False
Comment on lines 61 to +64
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the resulting action is the same, lets write it in one

    except (TypeError, NotImplementedError):  # e.g. raised by torchtext if a batch_size_fn is used
        has_len = False


if has_len and _has_iterable_dataset(dataloader) and LooseVersion(torch.__version__) >= LooseVersion("1.4.0"):
rank_zero_warn(
'Your `IterableDataset` has `__len__` defined.'
' In combination with multi-processing data loading (e.g. batch size > 1),'
' this can lead to unintended side effects since the samples will be duplicated.'
)
return has_len


class TrainerDataLoadingMixin(ABC):
Expand Down Expand Up @@ -128,12 +144,9 @@ def _worker_check(self, dataloader: DataLoader, name: str) -> None:
def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:

# don't do anything if it's not a dataloader
# don't manipulate iterable datasets
is_dataloader = isinstance(dataloader, DataLoader)

is_iterable_ds = False
if ITERABLE_DATASET_EXISTS and hasattr(dataloader, 'dataset'):
is_iterable_ds = isinstance(dataloader.dataset, IterableDataset)
# don't manipulate iterable datasets
is_iterable_ds = _has_iterable_dataset(dataloader)

if not is_dataloader or is_iterable_ds:
return dataloader
Expand Down Expand Up @@ -285,11 +298,7 @@ def _reset_eval_dataloader(
# datasets could be none, 1 or 2+
if len(dataloaders) != 0:
for i, dataloader in enumerate(dataloaders):
try:
num_batches = len(dataloader)
except (TypeError, NotImplementedError):
num_batches = float('inf')

num_batches = len(dataloader) if _has_len(dataloader) else float('inf')
self._worker_check(dataloader, f'{mode} dataloader {i}')

# percent or num_steps
Expand Down
2 changes: 1 addition & 1 deletion requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ tensorboard>=1.14
future>=0.17.1 # required for builtins in setup.py
# pyyaml>=3.13
PyYAML>=5.1 # OmegaConf requirement
tqdm>=4.41.0
tqdm>=4.41.0
34 changes: 33 additions & 1 deletion tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

import pytest
import torch
from packaging.version import parse
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Subset
from torch.utils.data.dataset import Subset, IterableDataset

import tests.base.develop_pipelines as tpipes
from pytorch_lightning import Trainer
from pytorch_lightning.trainer.data_loading import _has_len, _has_iterable_dataset
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate

Expand Down Expand Up @@ -487,6 +489,36 @@ def test_warning_with_few_workers(tmpdir, ckpt_path):
trainer.test(**test_options)


@pytest.mark.xfail(
parse(torch.__version__) < parse("1.4.0"),
reason="IterableDataset with __len__ before 1.4 raises",
)
def test_warning_with_iterable_dataset_and_len(tmpdir):
""" Tests that a warning messages is shown when an IterableDataset defines `__len__`. """
model = EvalModelTemplate()
original_dataset = model.train_dataloader().dataset

class IterableWithLen(IterableDataset):

def __iter__(self):
return iter(original_dataset)

def __len__(self):
return len(original_dataset)

dataloader = DataLoader(IterableWithLen(), batch_size=16)
assert _has_len(dataloader)
assert _has_iterable_dataset(dataloader)
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=3,
)
with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'):
trainer.fit(model, train_dataloader=dataloader, val_dataloaders=[dataloader])
with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'):
trainer.test(model, test_dataloaders=[dataloader])


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='Test requires multiple GPUs')
def test_dataloader_reinit_for_subclass():

Expand Down