Skip to content

Commit

Permalink
Fix overlapping samples in DDP when no global seed is set (#17713)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

(cherry picked from commit 53815e6)
  • Loading branch information
awaelchli authored and lantiga committed Jun 2, 2023
1 parent 64d84cc commit b810098
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 3 deletions.
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed inconsistent settings for FSDP Precision ([#17670](https://github.com/Lightning-AI/lightning/issues/17670))


- Fixed an edge case causing overlapping samples in DDP when no global seed is set ([#17713](https://github.com/Lightning-AI/lightning/pull/17713))


## [2.0.2] - 2023-04-24

### Fixed
Expand Down
9 changes: 6 additions & 3 deletions src/lightning/pytorch/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from dataclasses import dataclass, field
from typing import Any, Iterable, Optional, Tuple, Union

from torch.utils.data import BatchSampler, DataLoader, Sampler, SequentialSampler
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, Sampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler

import lightning.pytorch as pl
Expand Down Expand Up @@ -245,8 +245,11 @@ def _get_distributed_sampler(
"""This function is used to created the distributed sampler injected within the user DataLoader."""
kwargs["shuffle"] = shuffle and not overfit_batches
kwargs.setdefault("seed", int(os.getenv("PL_GLOBAL_SEED", 0)))
cls = UnrepeatedDistributedSamplerWrapper if mode == RunningStage.PREDICTING else DistributedSamplerWrapper
return cls(dataloader.sampler, **kwargs)
if mode == RunningStage.PREDICTING:
return UnrepeatedDistributedSamplerWrapper(dataloader.sampler, **kwargs)
if isinstance(dataloader.sampler, (RandomSampler, SequentialSampler)):
return DistributedSampler(dataloader.dataset, **kwargs)
return DistributedSamplerWrapper(dataloader.sampler, **kwargs)


def _resolve_overfit_batches(combined_loader: CombinedLoader, mode: RunningStage) -> None:
Expand Down
34 changes: 34 additions & 0 deletions tests/tests_pytorch/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,40 @@ def test_dataloader_distributed_sampler(tmpdir):
trainer.test(model)


class TestModelUniqueDDPSampling(BoringModel):
def __init__(self):
super().__init__()
self.seen_samples = []

def training_step(self, batch):
self.seen_samples.extend(batch.tolist())

def on_train_end(self):
seen_samples = self.all_gather(self.seen_samples)
# The samples should be unique across all processes
assert set(torch.cat(seen_samples).view(-1).tolist()) == set(range(32))


@RunIf(standalone=True)
def test_distributed_sampler_without_global_seed(tmpdir):
"""Test that the samples are non-overlapping in DDP when shuffling is enabled and no global seed is set."""
# This test must run without a global seed set (e.g. through `seed_everything`), to ensure that each process
# starts with a different initial state.
assert "PL_GLOBAL_SEED" not in os.environ
train_dataloader = DataLoader(range(32), shuffle=True, batch_size=4)
trainer = Trainer(
default_root_dir=tmpdir,
num_sanity_val_steps=False,
logger=False,
enable_progress_bar=False,
accelerator="cpu",
devices=2,
strategy="ddp",
max_epochs=1,
)
trainer.fit(TestModelUniqueDDPSampling(), train_dataloader)


class ModelWithDataLoaderDistributedSampler(BoringModel):
def train_dataloader(self):
dataloader = super().train_dataloader()
Expand Down

0 comments on commit b810098

Please sign in to comment.