From 8baec1a191e5db855efc0aeabf64ef414b93d12a Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Sun, 2 Aug 2020 08:52:57 +0530 Subject: [PATCH] Fix shuffle for distributed sampler (#2789) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix shuffle for distributed sampler * add test * test * chlog * update test * update test * update test * assertions via callback * define callback outside for pickling * skip ddp test on windows Co-authored-by: Adrian Wälchli --- CHANGELOG.md | 2 ++ pytorch_lightning/trainer/data_loading.py | 6 ++-- tests/trainer/test_dataloaders.py | 39 ++++++++++++++++++++++- 3 files changed, 44 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c44db54cac5a1..ac509ae5b5a88 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -49,6 +49,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed data transfer to device when using `torchtext.data.Field` and `include_lengths is True` ([#2689](https://github.com/PyTorchLightning/pytorch-lightning/pull/2689)) +- Fixed shuffle argument for distributed sampler ([#2789](https://github.com/PyTorchLightning/pytorch-lightning/pull/2789)) + ## [0.8.5] - 2020-07-09 ### Added diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 525956d257521..09186765c6eee 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -163,7 +163,7 @@ def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader: ' `replace_sampler_ddp`=False if you want to use your custom sampler.') # replace with distributed sampler - sampler = self._get_distributed_sampler(dataloader) + sampler = self._get_distributed_sampler(dataloader, train) dataloader = self.replace_sampler(dataloader, sampler) return dataloader @@ -179,7 +179,7 @@ def replace_sampler(self, dataloader, sampler): dataloader = type(dataloader)(**dl_args) return dataloader - def _get_distributed_sampler(self, dataloader): + def _get_distributed_sampler(self, dataloader, train): if self.use_tpu: kwargs = dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) elif self.use_horovod: @@ -193,6 +193,8 @@ def _get_distributed_sampler(self, dataloader): } assert self.distributed_backend is not None kwargs = dict(num_replicas=world_size[self.distributed_backend], rank=self.global_rank) + + kwargs['shuffle'] = train sampler = DistributedSampler(dataloader.dataset, **kwargs) return sampler diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 23aa7547a689d..1c7e21b7a72bb 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -7,9 +7,10 @@ from packaging.version import parse from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataset import IterableDataset, Subset +from torch.utils.data.distributed import DistributedSampler import tests.base.develop_pipelines as tpipes -from pytorch_lightning import Trainer +from pytorch_lightning import Trainer, Callback from pytorch_lightning.trainer.data_loading import _has_iterable_dataset, _has_len from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import EvalModelTemplate @@ -640,6 +641,42 @@ class CustomSampler(torch.utils.data.Sampler): CustomDataLoader(list(range(1000)), sampler=CustomSampler(list(range(1000)))), train=True) +class DistribSamplerCallback(Callback): + + def on_train_start(self, trainer, pl_module): + train_sampler = trainer.train_dataloader.sampler + assert isinstance(train_sampler, DistributedSampler) + assert train_sampler.shuffle + + def on_validation_start(self, trainer, pl_module): + val_sampler = trainer.val_dataloaders[0].sampler + assert isinstance(val_sampler, DistributedSampler) + assert not val_sampler.shuffle + + def on_test_start(self, trainer, pl_module): + test_sampler = trainer.test_dataloaders[0].sampler + assert isinstance(test_sampler, DistributedSampler) + assert not test_sampler.shuffle + + +@pytest.mark.skipif(platform.system() == 'Windows', reason='Does not apply to Windows platform.') +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='Test requires multiple GPUs') +def test_dataloader_distributed_sampler(tmpdir): + """ Test DistributedSampler and it's arguments for DDP backend """ + + model = EvalModelTemplate() + trainer = Trainer( + gpus=[0, 1], + num_nodes=1, + distributed_backend='ddp_spawn', + default_root_dir=tmpdir, + max_steps=1, + callbacks=[DistribSamplerCallback()] + ) + trainer.fit(model) + trainer.test(ckpt_path=None) + + @pytest.mark.skipif(torch.cuda.device_count() < 3, reason='Test requires multiple GPUs') def test_batch_size_smaller_than_num_gpus(tmpdir): # we need at least 3 gpus for this test