From d00b2d57d3495c438f36e62bbd77a7918999c7f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 1 Jul 2020 05:57:42 +0200 Subject: [PATCH] add warning when getting checking len --- pytorch_lightning/trainer/data_loading.py | 26 +++++++++++++++++------ 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 06ab7b316e1c2..676b5390f01bf 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -41,19 +41,33 @@ HOROVOD_AVAILABLE = True +def _has_iterable_dataset(dataloader: DataLoader): + return ITERABLE_DATASET_EXISTS and hasattr(dataloader, 'dataset') \ + and isinstance(dataloader.dataset, IterableDataset) + + def _has_len(dataloader: DataLoader) -> bool: """ Checks if a given Dataloader has __len__ method implemented i.e. if - it is a finite dataloader or infinite dataloader """ + it is a finite dataloader or infinite dataloader. """ + try: # try getting the length if len(dataloader) == 0: raise ValueError('`Dataloader` returned 0 length.' ' Please make sure that your Dataloader at least returns 1 batch') - return True + has_len = True except TypeError: - return False + has_len = False except NotImplementedError: # e.g. raised by torchtext if a batch_size_fn is used - return False + has_len = False + + if has_len and _has_iterable_dataset(dataloader): + rank_zero_warn( + 'Your `IterableDataset` has `__len__` defined.' + ' In combination with multi-processing data loading (e.g. batch size > 1),' + ' this can lead to unintended side effects since the samples will be duplicated.' + ) + return has_len class TrainerDataLoadingMixin(ABC): @@ -131,9 +145,7 @@ def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader: # don't manipulate iterable datasets is_dataloader = isinstance(dataloader, DataLoader) - is_iterable_ds = False - if ITERABLE_DATASET_EXISTS and hasattr(dataloader, 'dataset'): - is_iterable_ds = isinstance(dataloader.dataset, IterableDataset) + is_iterable_ds = _has_iterable_dataset(dataloader) if not is_dataloader or is_iterable_ds: return dataloader