Skip to content

Commit

Permalink
Raise an error when lightning replaces an existing sampler (#2020)
Browse files Browse the repository at this point in the history
* Raise an error when lightning replaces an existing sampler

Currently, Trainer replaces the existing sampler with DistributedSampler
if running distributing training and `replace_sampler_ddp=True` (default
behaviour). If a user has configured an existing sampler, this would
lead to widely different results if running a distributed vs
non-distributed training.

This PR fixes this by raising an Error if user has configured a sampler
and uses `replace_sampler_ddp=True`. The recommended behavior from now
on is to either remove the sampler or set `replace_sampler_ddp=False`

* Fix tests

* Simpler fix

* Fix tests

* Make inner method protected

* Apply suggestions from code review

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
2 people authored and justusschock committed Jun 29, 2020
1 parent a18067b commit 74260a7
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 24 deletions.
48 changes: 24 additions & 24 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Union, List, Tuple, Callable

import torch.distributed as torch_distrib
from torch.utils.data import DataLoader, RandomSampler
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler

from pytorch_lightning.core import LightningModule
Expand Down Expand Up @@ -114,39 +114,39 @@ def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:
need_dist_sampler = (self.use_ddp or self.use_ddp2 or self.use_horovod or self.use_tpu)

if self.replace_sampler_ddp and need_dist_sampler:
if not isinstance(dataloader.sampler, (SequentialSampler, RandomSampler)):
raise MisconfigurationException(
'You seem to have configured a sampler in your DataLoader. This will be replaced '
' by `DistributedSampler` since `replace_sampler_ddp` is True and you are using'
' distributed training. Either remove the sampler from your DataLoader or set'
' `replace_sampler_ddp`=False if you want to use your custom sampler.')

skip_keys = ['sampler', 'batch_sampler', 'dataset_kind']

dl_args = {
k: v for k, v in dataloader.__dict__.items() if not k.startswith('_') and k not in skip_keys
}

if self.use_tpu:
sampler = DistributedSampler(
dataloader.dataset,
num_replicas=xm.xrt_world_size(),
rank=xm.get_ordinal(),
)
elif self.use_horovod:
sampler = DistributedSampler(dataloader.dataset,
num_replicas=hvd.size(),
rank=hvd.rank())
else:
world_size = {
'ddp': self.num_nodes * self.num_processes,
'ddp2': self.num_nodes,
'ddp_cpu': self.num_processes * self.num_nodes
}
sampler = DistributedSampler(
dataloader.dataset,
num_replicas=world_size[self.distributed_backend],
rank=self.proc_rank,
)

dl_args['sampler'] = sampler
dl_args['sampler'] = self._get_distributed_sampler(dataloader)
dataloader = type(dataloader)(**dl_args)

return dataloader

def _get_distributed_sampler(self, dataloader):
if self.use_tpu:
kwargs = dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
elif self.use_horovod:
kwargs = dict(num_replicas=hvd.size(), rank=hvd.rank())
else:
world_size = {
'ddp': self.num_nodes * self.num_processes,
'ddp2': self.num_nodes,
'ddp_cpu': self.num_processes * self.num_nodes
}
kwargs = dict(num_replicas=world_size[self.distributed_backend], rank=self.proc_rank)
sampler = DistributedSampler(dataloader.dataset, **kwargs)
return sampler

def reset_train_dataloader(self, model: LightningModule) -> None:
"""Resets the train dataloader and initialises required variables
(number of batches, when to validate, etc.).
Expand Down
14 changes: 14 additions & 0 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,20 @@ class CustomDummyObj:
assert isinstance(result, CustomDataLoader)
assert hasattr(result, 'dummy_kwarg')

# Shuffled DataLoader should also work
result = trainer.auto_add_sampler(CustomDataLoader(list(range(1000)), shuffle=True), train=True)
assert isinstance(result, torch.utils.data.DataLoader)
assert isinstance(result, CustomDataLoader)
assert hasattr(result, 'dummy_kwarg')

class CustomSampler(torch.utils.data.Sampler):
pass

# Should raise an error if existing sampler is being replaced
with pytest.raises(MisconfigurationException, match='DistributedSampler'):
trainer.auto_add_sampler(
CustomDataLoader(list(range(1000)), sampler=CustomSampler(list(range(1000)))), train=True)


@pytest.mark.skipif(torch.cuda.device_count() < 3, reason='Test requires multiple GPUs')
def test_batch_size_smaller_than_num_gpus():
Expand Down

0 comments on commit 74260a7

Please sign in to comment.