diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index fff67d1042334..5361f47b5f3d0 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -3,7 +3,7 @@ from typing import Union, List, Tuple, Callable import torch.distributed as torch_distrib -from torch.utils.data import DataLoader, RandomSampler +from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler from pytorch_lightning.core import LightningModule @@ -113,39 +113,39 @@ def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader: need_dist_sampler = (self.use_ddp or self.use_ddp2 or self.use_horovod or self.use_tpu) if self.replace_sampler_ddp and need_dist_sampler: + if not isinstance(dataloader.sampler, (SequentialSampler, RandomSampler)): + raise MisconfigurationException( + 'You seem to have configured a sampler in your DataLoader. This will be replaced ' + ' by `DistributedSampler` since `replace_sampler_ddp` is True and you are using' + ' distributed training. Either remove the sampler from your DataLoader or set' + ' `replace_sampler_ddp`=False if you want to use your custom sampler.') + skip_keys = ['sampler', 'batch_sampler', 'dataset_kind'] dl_args = { k: v for k, v in dataloader.__dict__.items() if not k.startswith('_') and k not in skip_keys } - if self.use_tpu: - sampler = DistributedSampler( - dataloader.dataset, - num_replicas=xm.xrt_world_size(), - rank=xm.get_ordinal(), - ) - elif self.use_horovod: - sampler = DistributedSampler(dataloader.dataset, - num_replicas=hvd.size(), - rank=hvd.rank()) - else: - world_size = { - 'ddp': self.num_nodes * self.num_processes, - 'ddp2': self.num_nodes, - 'ddp_cpu': self.num_processes * self.num_nodes - } - sampler = DistributedSampler( - dataloader.dataset, - num_replicas=world_size[self.distributed_backend], - rank=self.proc_rank, - ) - - dl_args['sampler'] = sampler + dl_args['sampler'] = self._get_distributed_sampler(dataloader) dataloader = type(dataloader)(**dl_args) return dataloader + def _get_distributed_sampler(self, dataloader): + if self.use_tpu: + kwargs = dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) + elif self.use_horovod: + kwargs = dict(num_replicas=hvd.size(), rank=hvd.rank()) + else: + world_size = { + 'ddp': self.num_nodes * self.num_processes, + 'ddp2': self.num_nodes, + 'ddp_cpu': self.num_processes * self.num_nodes + } + kwargs = dict(num_replicas=world_size[self.distributed_backend], rank=self.proc_rank) + sampler = DistributedSampler(dataloader.dataset, **kwargs) + return sampler + def reset_train_dataloader(self, model: LightningModule) -> None: """Resets the train dataloader and initialises required variables (number of batches, when to validate, etc.). diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 7c0cf0bf95b79..13b614ee5c8ca 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -393,6 +393,20 @@ class CustomDummyObj: assert isinstance(result, CustomDataLoader) assert hasattr(result, 'dummy_kwarg') + # Shuffled DataLoader should also work + result = trainer.auto_add_sampler(CustomDataLoader(list(range(1000)), shuffle=True), train=True) + assert isinstance(result, torch.utils.data.DataLoader) + assert isinstance(result, CustomDataLoader) + assert hasattr(result, 'dummy_kwarg') + + class CustomSampler(torch.utils.data.Sampler): + pass + + # Should raise an error if existing sampler is being replaced + with pytest.raises(MisconfigurationException, match='DistributedSampler'): + trainer.auto_add_sampler( + CustomDataLoader(list(range(1000)), sampler=CustomSampler(list(range(1000)))), train=True) + @pytest.mark.skipif(torch.cuda.device_count() < 3, reason='Test requires multiple GPUs') def test_batch_size_smaller_than_num_gpus():