Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise an error when lightning replaces an existing sampler #2020

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 24 additions & 24 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Union, List, Tuple, Callable

import torch.distributed as torch_distrib
from torch.utils.data import DataLoader, RandomSampler
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler

from pytorch_lightning.core import LightningModule
Expand Down Expand Up @@ -113,39 +113,39 @@ def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:
need_dist_sampler = (self.use_ddp or self.use_ddp2 or self.use_horovod or self.use_tpu)

if self.replace_sampler_ddp and need_dist_sampler:
if not isinstance(dataloader.sampler, (SequentialSampler, RandomSampler)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this way you also have to include SubsetRandomSampler and WeightedRandomSampler I guess

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, shouldn't a warning really be raised if one of these is getting replaced by DistributedSampler? The original motivation of this issue was that I had a WeightedRandomSampler - and lightning replaced it leading to val acc drop. Took me a long time to debug this. Hence thought that it may be a good idea to raise an Error when an existing sampler is replaced.

Also, if someone uses one of these sampler - won't they get different results if using distributed vs non-distributed training?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you're right about that :)

raise MisconfigurationException(
'You seem to have configured a sampler in your DataLoader. This will be replaced '
' by `DistributedSampler` since `replace_sampler_ddp` is True and you are using'
' distributed training. Either remove the sampler from your DataLoader or set'
' `replace_sampler_ddp`=False if you want to use your custom sampler.')

skip_keys = ['sampler', 'batch_sampler', 'dataset_kind']

dl_args = {
k: v for k, v in dataloader.__dict__.items() if not k.startswith('_') and k not in skip_keys
}

if self.use_tpu:
sampler = DistributedSampler(
dataloader.dataset,
num_replicas=xm.xrt_world_size(),
rank=xm.get_ordinal(),
)
elif self.use_horovod:
sampler = DistributedSampler(dataloader.dataset,
num_replicas=hvd.size(),
rank=hvd.rank())
else:
world_size = {
'ddp': self.num_nodes * self.num_processes,
'ddp2': self.num_nodes,
'ddp_cpu': self.num_processes * self.num_nodes
}
sampler = DistributedSampler(
dataloader.dataset,
num_replicas=world_size[self.distributed_backend],
rank=self.proc_rank,
)

dl_args['sampler'] = sampler
dl_args['sampler'] = self._get_distributed_sampler(dataloader)
dataloader = type(dataloader)(**dl_args)

return dataloader

def _get_distributed_sampler(self, dataloader):
if self.use_tpu:
kwargs = dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
elif self.use_horovod:
kwargs = dict(num_replicas=hvd.size(), rank=hvd.rank())
else:
world_size = {
'ddp': self.num_nodes * self.num_processes,
'ddp2': self.num_nodes,
'ddp_cpu': self.num_processes * self.num_nodes
}
kwargs = dict(num_replicas=world_size[self.distributed_backend], rank=self.proc_rank)
sampler = DistributedSampler(dataloader.dataset, **kwargs)
return sampler

def reset_train_dataloader(self, model: LightningModule) -> None:
"""Resets the train dataloader and initialises required variables
(number of batches, when to validate, etc.).
Expand Down
14 changes: 14 additions & 0 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,20 @@ class CustomDummyObj:
assert isinstance(result, CustomDataLoader)
assert hasattr(result, 'dummy_kwarg')

# Shuffled DataLoader should also work
result = trainer.auto_add_sampler(CustomDataLoader(list(range(1000)), shuffle=True), train=True)
assert isinstance(result, torch.utils.data.DataLoader)
assert isinstance(result, CustomDataLoader)
assert hasattr(result, 'dummy_kwarg')

class CustomSampler(torch.utils.data.Sampler):
pass

# Should raise an error if existing sampler is being replaced
with pytest.raises(MisconfigurationException, match='DistributedSampler'):
trainer.auto_add_sampler(
CustomDataLoader(list(range(1000)), sampler=CustomSampler(list(range(1000)))), train=True)


@pytest.mark.skipif(torch.cuda.device_count() < 3, reason='Test requires multiple GPUs')
def test_batch_size_smaller_than_num_gpus():
Expand Down