diff --git a/CHANGELOG.md b/CHANGELOG.md index 0968488111d24..6d0bf1b750900 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added reduce ddp results on eval ([#2434](https://github.com/PyTorchLightning/pytorch-lightning/pull/2434)) +- Added a warning when an `IterableDataset` has `__len__` defined ([#2437](https://github.com/PyTorchLightning/pytorch-lightning/pull/2437)) + ### Changed diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 06ab7b316e1c2..e283166234968 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -1,8 +1,10 @@ +import multiprocessing import platform from abc import ABC, abstractmethod +from distutils.version import LooseVersion from typing import Union, List, Tuple, Callable, Optional -import multiprocessing +import torch import torch.distributed as torch_distrib from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler @@ -41,19 +43,33 @@ HOROVOD_AVAILABLE = True +def _has_iterable_dataset(dataloader: DataLoader): + return ITERABLE_DATASET_EXISTS and hasattr(dataloader, 'dataset') \ + and isinstance(dataloader.dataset, IterableDataset) + + def _has_len(dataloader: DataLoader) -> bool: """ Checks if a given Dataloader has __len__ method implemented i.e. if - it is a finite dataloader or infinite dataloader """ + it is a finite dataloader or infinite dataloader. """ + try: # try getting the length if len(dataloader) == 0: raise ValueError('`Dataloader` returned 0 length.' ' Please make sure that your Dataloader at least returns 1 batch') - return True + has_len = True except TypeError: - return False + has_len = False except NotImplementedError: # e.g. raised by torchtext if a batch_size_fn is used - return False + has_len = False + + if has_len and _has_iterable_dataset(dataloader) and LooseVersion(torch.__version__) >= LooseVersion("1.4.0"): + rank_zero_warn( + 'Your `IterableDataset` has `__len__` defined.' + ' In combination with multi-processing data loading (e.g. batch size > 1),' + ' this can lead to unintended side effects since the samples will be duplicated.' + ) + return has_len class TrainerDataLoadingMixin(ABC): @@ -128,12 +144,9 @@ def _worker_check(self, dataloader: DataLoader, name: str) -> None: def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader: # don't do anything if it's not a dataloader - # don't manipulate iterable datasets is_dataloader = isinstance(dataloader, DataLoader) - - is_iterable_ds = False - if ITERABLE_DATASET_EXISTS and hasattr(dataloader, 'dataset'): - is_iterable_ds = isinstance(dataloader.dataset, IterableDataset) + # don't manipulate iterable datasets + is_iterable_ds = _has_iterable_dataset(dataloader) if not is_dataloader or is_iterable_ds: return dataloader @@ -285,11 +298,7 @@ def _reset_eval_dataloader( # datasets could be none, 1 or 2+ if len(dataloaders) != 0: for i, dataloader in enumerate(dataloaders): - try: - num_batches = len(dataloader) - except (TypeError, NotImplementedError): - num_batches = float('inf') - + num_batches = len(dataloader) if _has_len(dataloader) else float('inf') self._worker_check(dataloader, f'{mode} dataloader {i}') # percent or num_steps diff --git a/requirements/base.txt b/requirements/base.txt index 2b26d79033f6f..bacc868dada85 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -6,4 +6,4 @@ tensorboard>=1.14 future>=0.17.1 # required for builtins in setup.py # pyyaml>=3.13 PyYAML>=5.1 # OmegaConf requirement -tqdm>=4.41.0 \ No newline at end of file +tqdm>=4.41.0 diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index b36eca8a2e429..99fe02979013f 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -2,11 +2,13 @@ import pytest import torch +from packaging.version import parse from torch.utils.data.dataloader import DataLoader -from torch.utils.data.dataset import Subset +from torch.utils.data.dataset import Subset, IterableDataset import tests.base.develop_pipelines as tpipes from pytorch_lightning import Trainer +from pytorch_lightning.trainer.data_loading import _has_len, _has_iterable_dataset from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import EvalModelTemplate @@ -487,6 +489,36 @@ def test_warning_with_few_workers(tmpdir, ckpt_path): trainer.test(**test_options) +@pytest.mark.xfail( + parse(torch.__version__) < parse("1.4.0"), + reason="IterableDataset with __len__ before 1.4 raises", +) +def test_warning_with_iterable_dataset_and_len(tmpdir): + """ Tests that a warning messages is shown when an IterableDataset defines `__len__`. """ + model = EvalModelTemplate() + original_dataset = model.train_dataloader().dataset + + class IterableWithLen(IterableDataset): + + def __iter__(self): + return iter(original_dataset) + + def __len__(self): + return len(original_dataset) + + dataloader = DataLoader(IterableWithLen(), batch_size=16) + assert _has_len(dataloader) + assert _has_iterable_dataset(dataloader) + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=3, + ) + with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'): + trainer.fit(model, train_dataloader=dataloader, val_dataloaders=[dataloader]) + with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'): + trainer.test(model, test_dataloaders=[dataloader]) + + @pytest.mark.skipif(torch.cuda.device_count() < 2, reason='Test requires multiple GPUs') def test_dataloader_reinit_for_subclass():