diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 06ab7b316e1c2..676b5390f01bf 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -41,19 +41,33 @@ HOROVOD_AVAILABLE = True +def _has_iterable_dataset(dataloader: DataLoader): + return ITERABLE_DATASET_EXISTS and hasattr(dataloader, 'dataset') \ + and isinstance(dataloader.dataset, IterableDataset) + + def _has_len(dataloader: DataLoader) -> bool: """ Checks if a given Dataloader has __len__ method implemented i.e. if - it is a finite dataloader or infinite dataloader """ + it is a finite dataloader or infinite dataloader. """ + try: # try getting the length if len(dataloader) == 0: raise ValueError('`Dataloader` returned 0 length.' ' Please make sure that your Dataloader at least returns 1 batch') - return True + has_len = True except TypeError: - return False + has_len = False except NotImplementedError: # e.g. raised by torchtext if a batch_size_fn is used - return False + has_len = False + + if has_len and _has_iterable_dataset(dataloader): + rank_zero_warn( + 'Your `IterableDataset` has `__len__` defined.' + ' In combination with multi-processing data loading (e.g. batch size > 1),' + ' this can lead to unintended side effects since the samples will be duplicated.' + ) + return has_len class TrainerDataLoadingMixin(ABC): @@ -131,9 +145,7 @@ def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader: # don't manipulate iterable datasets is_dataloader = isinstance(dataloader, DataLoader) - is_iterable_ds = False - if ITERABLE_DATASET_EXISTS and hasattr(dataloader, 'dataset'): - is_iterable_ds = isinstance(dataloader.dataset, IterableDataset) + is_iterable_ds = _has_iterable_dataset(dataloader) if not is_dataloader or is_iterable_ds: return dataloader