-
Notifications
You must be signed in to change notification settings - Fork 488
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added DistributedSamplerWrapper to automatically wrap non-dist sample…
…rs in cases we use dist mode (#1856) * Added Distributed Sampler wrapper * test passes config updated * added wrapping around non dist sampler in case of a distributed mode * remove redundant file * make lint happy * removed a breaking change that did not break a thing * tidy * lint is happy, again * Added a warning in case of undesired behavior. Introduced a flag to turn-off the auto wrap. --------- Co-authored-by: Eugene Khvedchenya <ekhvedchenya@gmail.com> Co-authored-by: Ofri Masad <ofrimasad@users.noreply.github.com>
- Loading branch information
1 parent
8819667
commit 5727c28
Showing
6 changed files
with
192 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,8 @@ | ||
from super_gradients.training.datasets.samplers.infinite_sampler import InfiniteSampler | ||
from super_gradients.training.datasets.samplers.repeated_augmentation_sampler import RepeatAugSampler | ||
from super_gradients.training.datasets.samplers.distributed_sampler_wrapper import DistributedSamplerWrapper | ||
from super_gradients.common.object_names import Samplers | ||
from super_gradients.common.registry.registry import SAMPLERS | ||
|
||
|
||
__all__ = ["SAMPLERS", "Samplers", "InfiniteSampler", "RepeatAugSampler"] | ||
__all__ = ["SAMPLERS", "Samplers", "InfiniteSampler", "RepeatAugSampler", "DistributedSamplerWrapper"] |
71 changes: 71 additions & 0 deletions
71
src/super_gradients/training/datasets/samplers/distributed_sampler_wrapper.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
from operator import itemgetter | ||
from typing import Optional | ||
|
||
from torch.utils.data import Dataset, Sampler, DistributedSampler | ||
|
||
|
||
class DatasetFromSampler(Dataset): | ||
def __init__(self, sampler: Sampler): | ||
self.sampler = sampler | ||
self.sampler_list = None | ||
|
||
def __getitem__(self, index: int): | ||
if self.sampler_list is None: # we don't instantiate the list in __init__ because want to shuffle first (happens in DistributedSamplerWrapper.__iter__) | ||
self.sampler_list = list(self.sampler) | ||
return self.sampler_list[index] | ||
|
||
def __len__(self) -> int: | ||
""" | ||
Returns: | ||
int: length of the dataset | ||
""" | ||
return len(self.sampler) | ||
|
||
|
||
class DistributedSamplerWrapper(DistributedSampler): | ||
""" | ||
Wrapper over `Sampler` for distributed training. | ||
Allows you to use any sampler in distributed mode. | ||
It is especially useful in conjunction with | ||
`torch.nn.parallel.DistributedDataParallel`. In such case, each | ||
process can pass a DistributedSamplerWrapper instance as a DataLoader | ||
sampler, and load a subset of subsampled data of the original dataset | ||
that is exclusive to it. | ||
.. note:: | ||
Sampler is assumed to be of constant size. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
sampler, | ||
num_replicas: Optional[int] = None, | ||
rank: Optional[int] = None, | ||
shuffle: bool = True, | ||
): | ||
""" | ||
Args: | ||
sampler: Sampler used for subsampling | ||
num_replicas (int, optional): Number of processes participating in | ||
distributed training | ||
rank (int, optional): Rank of the current process | ||
within ``num_replicas`` | ||
shuffle (bool, optional): If true (default), | ||
sampler will shuffle the indices | ||
""" | ||
super(DistributedSamplerWrapper, self).__init__( | ||
DatasetFromSampler(sampler), | ||
num_replicas=num_replicas, | ||
rank=rank, | ||
shuffle=shuffle, | ||
) | ||
self.sampler = sampler | ||
|
||
def __iter__(self): | ||
|
||
self.dataset = DatasetFromSampler(self.sampler) | ||
indexes_of_indexes = super().__iter__() | ||
subsampler_indexes = self.dataset | ||
return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
import collections | ||
import sys | ||
from itertools import chain | ||
|
||
import torch | ||
from torch.utils.data import Dataset, DataLoader | ||
|
||
from super_gradients import setup_device | ||
from super_gradients.training.datasets.samplers.distributed_sampler_wrapper import DistributedSamplerWrapper | ||
|
||
|
||
class DummyDataset(Dataset): | ||
def __init__(self, length=42): | ||
super().__init__() | ||
self.length = length | ||
|
||
def __getitem__(self, index): | ||
return -index | ||
|
||
def __len__(self): | ||
return self.length | ||
|
||
|
||
class RepeatSampler(torch.utils.data.Sampler): | ||
def __init__(self, data_source, repeat_times): | ||
self.data_source = data_source | ||
self.repeat_times = repeat_times | ||
self.num_samples = repeat_times * len(data_source) | ||
|
||
def __iter__(self): | ||
indices = list(range(len(self.data_source))) | ||
return iter(indices * self.repeat_times) | ||
|
||
def __len__(self): | ||
return self.num_samples | ||
|
||
|
||
def aggregate_epoch(data_loader): | ||
results = list() | ||
|
||
for batch in data_loader: | ||
for element in batch: | ||
results.append(element.item()) | ||
return results | ||
|
||
|
||
def compare_counts(x, y): | ||
return collections.Counter(x) == collections.Counter(y) | ||
|
||
|
||
if __name__ == "__main__": | ||
n_gpus = 2 | ||
sampler_n_repeats = 3 | ||
bs = 4 | ||
data_size = 10 * n_gpus * bs | ||
|
||
setup_device( | ||
device="cuda", | ||
multi_gpu="DDP", | ||
num_gpus=n_gpus, | ||
) | ||
|
||
dataset = DummyDataset(length=data_size) | ||
sampler = RepeatSampler(dataset, repeat_times=sampler_n_repeats) | ||
dataloader = DataLoader(dataset, batch_size=bs, sampler=sampler) | ||
|
||
whole_epoch_data = list(chain.from_iterable([[-i] * sampler_n_repeats for i in range(data_size)])) | ||
|
||
# Test *non-distributed* sampler *in DDP mode* | ||
# THIS IS BAD EXAMPLE BECAUSE YOU EXPECT A DISTRIBUTED SAMPLER TO BE USED IN DDP MODE | ||
# The expected `len(dataloader)` when implemented correctly should ALSO be divided by `n_gpus` | ||
if len(dataloader) != (data_size * sampler_n_repeats) / bs: | ||
print(f"Wrong DataLoader length. Expected: {((data_size * sampler_n_repeats) / bs)=}, got {len(dataloader)}") | ||
torch.distributed.destroy_process_group() | ||
sys.exit(1) | ||
|
||
epoch_data_per_rank = aggregate_epoch(dataloader) | ||
if not compare_counts(epoch_data_per_rank, whole_epoch_data): # NOTE THAT EACH GPU SEES ALL DATA -- NOT WHAT WE WANT! | ||
torch.distributed.destroy_process_group() | ||
sys.exit(1) | ||
|
||
dist_sampler = DistributedSamplerWrapper(sampler) | ||
dataloader = DataLoader(dataset, batch_size=bs, sampler=dist_sampler) | ||
|
||
if len(dataloader) != (data_size * sampler_n_repeats) / (bs * n_gpus): | ||
print(f"Wrong DataLoader length. Expected: {((data_size * sampler_n_repeats) / (bs*n_gpus))=}, got {len(dataloader)}") | ||
torch.distributed.destroy_process_group() | ||
sys.exit(1) | ||
|
||
# We have dataset split across `n_gpus` processes. Let's aggregate and make sure we get the same results. | ||
per_rank_aggregated = torch.tensor(aggregate_epoch(dataloader)).cuda() | ||
all_gathered_placeholder = torch.zeros(len(per_rank_aggregated) * n_gpus, dtype=torch.int64).cuda() | ||
|
||
torch.distributed.all_gather_into_tensor(all_gathered_placeholder, per_rank_aggregated) | ||
|
||
if not compare_counts(all_gathered_placeholder.cpu().tolist(), whole_epoch_data): | ||
torch.distributed.destroy_process_group() | ||
sys.exit(1) | ||
|
||
torch.distributed.destroy_process_group() | ||
sys.exit(0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters