Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix shuffle for distributed sampler #2789

Merged
merged 10 commits into from
Aug 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed data transfer to device when using `torchtext.data.Field` and `include_lengths is True` ([#2689](https://github.com/PyTorchLightning/pytorch-lightning/pull/2689))

- Fixed shuffle argument for distributed sampler ([#2789](https://github.com/PyTorchLightning/pytorch-lightning/pull/2789))

## [0.8.5] - 2020-07-09

### Added
Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:
' `replace_sampler_ddp`=False if you want to use your custom sampler.')

# replace with distributed sampler
sampler = self._get_distributed_sampler(dataloader)
sampler = self._get_distributed_sampler(dataloader, train)
dataloader = self.replace_sampler(dataloader, sampler)

return dataloader
Expand All @@ -179,7 +179,7 @@ def replace_sampler(self, dataloader, sampler):
dataloader = type(dataloader)(**dl_args)
return dataloader

def _get_distributed_sampler(self, dataloader):
def _get_distributed_sampler(self, dataloader, train):
if self.use_tpu:
kwargs = dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
elif self.use_horovod:
Expand All @@ -193,6 +193,8 @@ def _get_distributed_sampler(self, dataloader):
}
assert self.distributed_backend is not None
kwargs = dict(num_replicas=world_size[self.distributed_backend], rank=self.global_rank)

kwargs['shuffle'] = train
sampler = DistributedSampler(dataloader.dataset, **kwargs)
return sampler

Expand Down
39 changes: 38 additions & 1 deletion tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
from packaging.version import parse
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import IterableDataset, Subset
from torch.utils.data.distributed import DistributedSampler

import tests.base.develop_pipelines as tpipes
from pytorch_lightning import Trainer
from pytorch_lightning import Trainer, Callback
from pytorch_lightning.trainer.data_loading import _has_iterable_dataset, _has_len
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
Expand Down Expand Up @@ -640,6 +641,42 @@ class CustomSampler(torch.utils.data.Sampler):
CustomDataLoader(list(range(1000)), sampler=CustomSampler(list(range(1000)))), train=True)


class DistribSamplerCallback(Callback):

def on_train_start(self, trainer, pl_module):
train_sampler = trainer.train_dataloader.sampler
assert isinstance(train_sampler, DistributedSampler)
assert train_sampler.shuffle

def on_validation_start(self, trainer, pl_module):
val_sampler = trainer.val_dataloaders[0].sampler
assert isinstance(val_sampler, DistributedSampler)
assert not val_sampler.shuffle

def on_test_start(self, trainer, pl_module):
test_sampler = trainer.test_dataloaders[0].sampler
assert isinstance(test_sampler, DistributedSampler)
assert not test_sampler.shuffle


@pytest.mark.skipif(platform.system() == 'Windows', reason='Does not apply to Windows platform.')
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='Test requires multiple GPUs')
def test_dataloader_distributed_sampler(tmpdir):
""" Test DistributedSampler and it's arguments for DDP backend """

model = EvalModelTemplate()
trainer = Trainer(
gpus=[0, 1],
num_nodes=1,
distributed_backend='ddp_spawn',
default_root_dir=tmpdir,
max_steps=1,
callbacks=[DistribSamplerCallback()]
)
trainer.fit(model)
trainer.test(ckpt_path=None)


@pytest.mark.skipif(torch.cuda.device_count() < 3, reason='Test requires multiple GPUs')
def test_batch_size_smaller_than_num_gpus(tmpdir):
# we need at least 3 gpus for this test
Expand Down