Skip to content

Commit

Permalink
Fix shuffle for distributed sampler (#2789)
Browse files Browse the repository at this point in the history
* Fix shuffle for distributed sampler

* add test

* test

* chlog

* update test

* update test

* update test

* assertions via callback

* define callback outside for pickling

* skip ddp test on windows

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
  • Loading branch information
rohitgr7 and awaelchli committed Aug 2, 2020
1 parent 38fce2e commit 8baec1a
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 3 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed data transfer to device when using `torchtext.data.Field` and `include_lengths is True` ([#2689](https://github.com/PyTorchLightning/pytorch-lightning/pull/2689))

- Fixed shuffle argument for distributed sampler ([#2789](https://github.com/PyTorchLightning/pytorch-lightning/pull/2789))

## [0.8.5] - 2020-07-09

### Added
Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:
' `replace_sampler_ddp`=False if you want to use your custom sampler.')

# replace with distributed sampler
sampler = self._get_distributed_sampler(dataloader)
sampler = self._get_distributed_sampler(dataloader, train)
dataloader = self.replace_sampler(dataloader, sampler)

return dataloader
Expand All @@ -179,7 +179,7 @@ def replace_sampler(self, dataloader, sampler):
dataloader = type(dataloader)(**dl_args)
return dataloader

def _get_distributed_sampler(self, dataloader):
def _get_distributed_sampler(self, dataloader, train):
if self.use_tpu:
kwargs = dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
elif self.use_horovod:
Expand All @@ -193,6 +193,8 @@ def _get_distributed_sampler(self, dataloader):
}
assert self.distributed_backend is not None
kwargs = dict(num_replicas=world_size[self.distributed_backend], rank=self.global_rank)

kwargs['shuffle'] = train
sampler = DistributedSampler(dataloader.dataset, **kwargs)
return sampler

Expand Down
39 changes: 38 additions & 1 deletion tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
from packaging.version import parse
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import IterableDataset, Subset
from torch.utils.data.distributed import DistributedSampler

import tests.base.develop_pipelines as tpipes
from pytorch_lightning import Trainer
from pytorch_lightning import Trainer, Callback
from pytorch_lightning.trainer.data_loading import _has_iterable_dataset, _has_len
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
Expand Down Expand Up @@ -640,6 +641,42 @@ class CustomSampler(torch.utils.data.Sampler):
CustomDataLoader(list(range(1000)), sampler=CustomSampler(list(range(1000)))), train=True)


class DistribSamplerCallback(Callback):

def on_train_start(self, trainer, pl_module):
train_sampler = trainer.train_dataloader.sampler
assert isinstance(train_sampler, DistributedSampler)
assert train_sampler.shuffle

def on_validation_start(self, trainer, pl_module):
val_sampler = trainer.val_dataloaders[0].sampler
assert isinstance(val_sampler, DistributedSampler)
assert not val_sampler.shuffle

def on_test_start(self, trainer, pl_module):
test_sampler = trainer.test_dataloaders[0].sampler
assert isinstance(test_sampler, DistributedSampler)
assert not test_sampler.shuffle


@pytest.mark.skipif(platform.system() == 'Windows', reason='Does not apply to Windows platform.')
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='Test requires multiple GPUs')
def test_dataloader_distributed_sampler(tmpdir):
""" Test DistributedSampler and it's arguments for DDP backend """

model = EvalModelTemplate()
trainer = Trainer(
gpus=[0, 1],
num_nodes=1,
distributed_backend='ddp_spawn',
default_root_dir=tmpdir,
max_steps=1,
callbacks=[DistribSamplerCallback()]
)
trainer.fit(model)
trainer.test(ckpt_path=None)


@pytest.mark.skipif(torch.cuda.device_count() < 3, reason='Test requires multiple GPUs')
def test_batch_size_smaller_than_num_gpus(tmpdir):
# we need at least 3 gpus for this test
Expand Down

0 comments on commit 8baec1a

Please sign in to comment.