Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added DistributedSamplerWrapper to automatically wrap non-dist samplers in cases we use dist mode #1856

Merged
merged 11 commits into from
Feb 26, 2024
1 change: 1 addition & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,7 @@ jobs:
python3.8 -m pip install pytorch-quantization==2.1.2 --extra-index-url https://pypi.ngc.nvidia.com

python3.8 tests/verify_min_samples_ddp.py
python3.8 tests/verify_distributed_sampler_wrapper.py
python3.8 src/super_gradients/train_from_recipe.py --config-name=coco2017_pose_dekr_w32_no_dc experiment_name=shortened_coco2017_pose_dekr_w32_ap_test batch_size=4 val_batch_size=8 epochs=1 training_hyperparams.lr_warmup_steps=0 training_hyperparams.average_best_models=False training_hyperparams.max_train_batches=1000 training_hyperparams.max_valid_batches=100 multi_gpu=DDP num_gpus=4
python3.8 src/super_gradients/train_from_recipe.py --config-name=cifar10_resnet experiment_name=shortened_cifar10_resnet_accuracy_test epochs=100 training_hyperparams.average_best_models=False multi_gpu=DDP num_gpus=4
python3.8 src/super_gradients/examples/convert_recipe_example/convert_recipe_example.py --config-name=cifar10_conversion_params experiment_name=shortened_cifar10_resnet_accuracy_test
Expand Down
18 changes: 16 additions & 2 deletions src/super_gradients/training/dataloaders/dataloaders.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from typing import Dict, Mapping

import hydra
Expand Down Expand Up @@ -25,7 +26,7 @@
)
from super_gradients.training.datasets.pose_estimation_datasets import COCOKeypointsDataset
from super_gradients.training.datasets.pose_estimation_datasets.rescoring_dataset import TrainRescoringDataset, ValTrainRescoringDataset
from super_gradients.training.datasets.samplers import RepeatAugSampler
from super_gradients.training.datasets.samplers import RepeatAugSampler, DistributedSamplerWrapper
from super_gradients.training.datasets.segmentation_datasets import (
CityscapesDataset,
CoCoSegmentationDataSet,
Expand Down Expand Up @@ -201,7 +202,20 @@ def _instantiate_sampler(dataset, dataloader_params):
# SHUFFLE IS MUTUALLY EXCLUSIVE WITH SAMPLER ARG IN DATALOADER INIT
dataloader_params["sampler"][sampler_name]["shuffle"] = dataloader_params.pop("shuffle")
dataloader_params["sampler"][sampler_name]["dataset"] = dataset
dataloader_params["sampler"] = SamplersFactory().get(dataloader_params["sampler"])
dataloader_params["sampler"] = SamplersFactory().get(dataloader_params["sampler"]) # a living object

if (
super_gradients.is_distributed()
and dataloader_params.get("auto_wrap_sampler_when_ddp", True)
and not isinstance(dataloader_params["sampler"], torch.utils.data.distributed.DistributedSampler)
):
warnings.warn(
f"You are running in a distributed setting, with {dataloader_params['sampler'].__class__.__name__} that appears not to fit into this setting.\n"
f"We automatically wrapped it so that it will fit into this setting, however, the behavior also depends on your implementation.\n"
f"In case of undesired behavior, please set the `auto_wrap_sampler_when_ddp` argument to `False` in your dataloader config.\n"
)
dataloader_params["sampler"] = DistributedSamplerWrapper(dataloader_params["sampler"])

return dataloader_params


Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from super_gradients.training.datasets.samplers.infinite_sampler import InfiniteSampler
from super_gradients.training.datasets.samplers.repeated_augmentation_sampler import RepeatAugSampler
from super_gradients.training.datasets.samplers.distributed_sampler_wrapper import DistributedSamplerWrapper
from super_gradients.common.object_names import Samplers
from super_gradients.common.registry.registry import SAMPLERS


__all__ = ["SAMPLERS", "Samplers", "InfiniteSampler", "RepeatAugSampler"]
__all__ = ["SAMPLERS", "Samplers", "InfiniteSampler", "RepeatAugSampler", "DistributedSamplerWrapper"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from operator import itemgetter
from typing import Optional

from torch.utils.data import Dataset, Sampler, DistributedSampler


class DatasetFromSampler(Dataset):
def __init__(self, sampler: Sampler):
self.sampler = sampler
self.sampler_list = None

def __getitem__(self, index: int):
if self.sampler_list is None: # we don't instantiate the list in __init__ because want to shuffle first (happens in DistributedSamplerWrapper.__iter__)
self.sampler_list = list(self.sampler)
return self.sampler_list[index]

def __len__(self) -> int:
"""
Returns:
int: length of the dataset
"""
return len(self.sampler)


class DistributedSamplerWrapper(DistributedSampler):
"""
Wrapper over `Sampler` for distributed training.
Allows you to use any sampler in distributed mode.

It is especially useful in conjunction with
`torch.nn.parallel.DistributedDataParallel`. In such case, each
process can pass a DistributedSamplerWrapper instance as a DataLoader
sampler, and load a subset of subsampled data of the original dataset
that is exclusive to it.

.. note::
Sampler is assumed to be of constant size.
"""

def __init__(
self,
sampler,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True,
):
"""

Args:
sampler: Sampler used for subsampling
num_replicas (int, optional): Number of processes participating in
distributed training
rank (int, optional): Rank of the current process
within ``num_replicas``
shuffle (bool, optional): If true (default),
sampler will shuffle the indices
"""
super(DistributedSamplerWrapper, self).__init__(
DatasetFromSampler(sampler),
num_replicas=num_replicas,
rank=rank,
shuffle=shuffle,
)
self.sampler = sampler

def __iter__(self):

self.dataset = DatasetFromSampler(self.sampler)
indexes_of_indexes = super().__iter__()
subsampler_indexes = self.dataset
return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes))
101 changes: 101 additions & 0 deletions tests/verify_distributed_sampler_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import collections
import sys
from itertools import chain

import torch
from torch.utils.data import Dataset, DataLoader

from super_gradients import setup_device
from super_gradients.training.datasets.samplers.distributed_sampler_wrapper import DistributedSamplerWrapper


class DummyDataset(Dataset):
def __init__(self, length=42):
super().__init__()
self.length = length

def __getitem__(self, index):
return -index

def __len__(self):
return self.length


class RepeatSampler(torch.utils.data.Sampler):
def __init__(self, data_source, repeat_times):
self.data_source = data_source
self.repeat_times = repeat_times
self.num_samples = repeat_times * len(data_source)

def __iter__(self):
indices = list(range(len(self.data_source)))
return iter(indices * self.repeat_times)

def __len__(self):
return self.num_samples


def aggregate_epoch(data_loader):
results = list()

for batch in data_loader:
for element in batch:
results.append(element.item())
return results


def compare_counts(x, y):
return collections.Counter(x) == collections.Counter(y)


if __name__ == "__main__":
n_gpus = 2
sampler_n_repeats = 3
bs = 4
data_size = 10 * n_gpus * bs

setup_device(
device="cuda",
multi_gpu="DDP",
num_gpus=n_gpus,
)

dataset = DummyDataset(length=data_size)
sampler = RepeatSampler(dataset, repeat_times=sampler_n_repeats)
dataloader = DataLoader(dataset, batch_size=bs, sampler=sampler)

whole_epoch_data = list(chain.from_iterable([[-i] * sampler_n_repeats for i in range(data_size)]))

# Test *non-distributed* sampler *in DDP mode*
# THIS IS BAD EXAMPLE BECAUSE YOU EXPECT A DISTRIBUTED SAMPLER TO BE USED IN DDP MODE
# The expected `len(dataloader)` when implemented correctly should ALSO be divided by `n_gpus`
if len(dataloader) != (data_size * sampler_n_repeats) / bs:
print(f"Wrong DataLoader length. Expected: {((data_size * sampler_n_repeats) / bs)=}, got {len(dataloader)}")
torch.distributed.destroy_process_group()
sys.exit(1)

epoch_data_per_rank = aggregate_epoch(dataloader)
if not compare_counts(epoch_data_per_rank, whole_epoch_data): # NOTE THAT EACH GPU SEES ALL DATA -- NOT WHAT WE WANT!
torch.distributed.destroy_process_group()
sys.exit(1)

dist_sampler = DistributedSamplerWrapper(sampler)
dataloader = DataLoader(dataset, batch_size=bs, sampler=dist_sampler)

if len(dataloader) != (data_size * sampler_n_repeats) / (bs * n_gpus):
print(f"Wrong DataLoader length. Expected: {((data_size * sampler_n_repeats) / (bs*n_gpus))=}, got {len(dataloader)}")
torch.distributed.destroy_process_group()
sys.exit(1)

# We have dataset split across `n_gpus` processes. Let's aggregate and make sure we get the same results.
per_rank_aggregated = torch.tensor(aggregate_epoch(dataloader)).cuda()
all_gathered_placeholder = torch.zeros(len(per_rank_aggregated) * n_gpus, dtype=torch.int64).cuda()

torch.distributed.all_gather_into_tensor(all_gathered_placeholder, per_rank_aggregated)

if not compare_counts(all_gathered_placeholder.cpu().tolist(), whole_epoch_data):
torch.distributed.destroy_process_group()
sys.exit(1)

torch.distributed.destroy_process_group()
sys.exit(0)
2 changes: 1 addition & 1 deletion tests/verify_min_samples_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,6 @@ def get_dataset(dataset_size, image_size):
torch.distributed.destroy_process_group()
sys.exit(0)
else:
print(f"wrong datalaoder length, expected min_samples/(world_size*batch_size)=80/(4*4=5), got {len(dataloader)}")
print(f"wrong DataLoader length, expected min_samples/(world_size*batch_size)=80/(4*4=5), got {len(dataloader)}")
torch.distributed.destroy_process_group()
sys.exit(1)
Loading