From 46ae996cb174a3e07cee80367fddc07783fd02ff Mon Sep 17 00:00:00 2001 From: davidsvaughn Date: Wed, 3 Jan 2024 02:15:07 -0500 Subject: [PATCH] Bugfix: update dataloaders.py to fix Multi-GPU DDP RAM multiple-cache issue (#10383) * Update dataloaders.py This is to address (and hopefully fix) this issue: Multi-GPU DDP RAM multiple-cache bug #3818 (https://github.com/ultralytics/yolov5/issues/3818). This was a very serious and "blocking" issue until I could figure out what was going on. The problem was especially bad when running Multi-GPU jobs with 8 GPUs, RAM usage was 8x higher than expected (!), causing repeated OOM failures. Hopefully this fix will help others. DDP causes each RANK to launch it's own process (one for each GPU) with it's own trainloader, and its own RAM image cache. The DistributedSampler used by DDP (https://github.com/pytorch/pytorch/blob/master/torch/utils/data/distributed.py) will feed only a subset of images (1/WORLD_SIZE) to each available GPU on each epoch, but since the images are shuffled between epochs, each GPU process must still cache all images. So I created a subclass of DistributedSampler called SmartDistributedSampler that forces each GPU process to always sample the same subset (using modulo arithmetic with RANK and WORLD_SIZE) while still allowing random shuffling between epochs. I don't believe this disrupts the overall "randomness" of the sampling, and I haven't noticed any performance degradation. Signed-off-by: davidsvaughn * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update dataloaders.py move extra parameter (rank) to end so won't mess up pre-existing positional args * Update dataloaders.py removing extra '#' * Update dataloaders.py sample from DDP index array (self.idx) in mixup mosaic * Merging self.indices and self.idx (DDP indices) into single attribute (self.indices). Also adding SmartDistributedSampler to segmentation dataloader * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Multiply GB displayed by WORLD_SIZE --------- Signed-off-by: davidsvaughn Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher --- utils/dataloaders.py | 51 ++++++++++++++++++++++++++++++------ utils/segment/dataloaders.py | 11 +++++--- 2 files changed, 50 insertions(+), 12 deletions(-) diff --git a/utils/dataloaders.py b/utils/dataloaders.py index 1fbd0361ded4..d422ef0711cb 100644 --- a/utils/dataloaders.py +++ b/utils/dataloaders.py @@ -41,6 +41,7 @@ VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv' # include video suffixes LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html RANK = int(os.getenv('RANK', -1)) +WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1)) PIN_MEMORY = str(os.getenv('PIN_MEMORY', True)).lower() == 'true' # global pin_memory for dataloaders # Get orientation exif tag @@ -100,6 +101,34 @@ def seed_worker(worker_id): random.seed(worker_seed) +# Inherit from DistributedSampler and override iterator +# https://github.com/pytorch/pytorch/blob/master/torch/utils/data/distributed.py +class SmartDistributedSampler(distributed.DistributedSampler): + + def __iter__(self): + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + + # determine the the eventual size (n) of self.indices (DDP indices) + n = int((len(self.dataset) - self.rank - 1) / self.num_replicas) + 1 # num_replicas == WORLD_SIZE + idx = torch.randperm(n, generator=g) + if not self.shuffle: + idx = idx.sort()[0] + + idx = idx.tolist() + if self.drop_last: + idx = idx[:self.num_samples] + else: + padding_size = self.num_samples - len(idx) + if padding_size <= len(idx): + idx += idx[:padding_size] + else: + idx += (idx * math.ceil(padding_size / len(idx)))[:padding_size] + + return iter(idx) + + def create_dataloader(path, imgsz, batch_size, @@ -133,12 +162,13 @@ def create_dataloader(path, stride=int(stride), pad=pad, image_weights=image_weights, - prefix=prefix) + prefix=prefix, + rank=rank) batch_size = min(batch_size, len(dataset)) nd = torch.cuda.device_count() # number of CUDA devices nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers - sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle) + sampler = None if rank == -1 else SmartDistributedSampler(dataset, shuffle=shuffle) loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates generator = torch.Generator() generator.manual_seed(6148914691236517205 + seed + RANK) @@ -449,7 +479,9 @@ def __init__(self, stride=32, pad=0.0, min_items=0, - prefix=''): + prefix='', + rank=-1, + seed=0): self.img_size = img_size self.augment = augment self.hyp = hyp @@ -527,7 +559,10 @@ def __init__(self, nb = bi[-1] + 1 # number of batches self.batch = bi # batch index of image self.n = n - self.indices = range(n) + self.indices = np.arange(n) + if rank > -1: # DDP indices (see: SmartDistributedSampler) + # force each rank (i.e. GPU process) to sample the same subset of data on every epoch + self.indices = self.indices[np.random.RandomState(seed=seed).permutation(n) % WORLD_SIZE == RANK] # Update labels include_class = [] # filter labels to include only these classes (optional) @@ -576,14 +611,14 @@ def __init__(self, b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes self.im_hw0, self.im_hw = [None] * n, [None] * n fcn = self.cache_images_to_disk if cache_images == 'disk' else self.load_image - results = ThreadPool(NUM_THREADS).imap(fcn, range(n)) - pbar = tqdm(enumerate(results), total=n, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0) + results = ThreadPool(NUM_THREADS).imap(lambda i: (i, fcn(i)), self.indices) + pbar = tqdm(results, total=len(self.indices), bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0) for i, x in pbar: if cache_images == 'disk': b += self.npy_files[i].stat().st_size else: # 'ram' self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i) - b += self.ims[i].nbytes + b += self.ims[i].nbytes * WORLD_SIZE pbar.desc = f'{prefix}Caching images ({b / gb:.1f}GB {cache_images})' pbar.close() @@ -663,7 +698,7 @@ def __getitem__(self, index): # MixUp augmentation if random.random() < hyp['mixup']: - img, labels = mixup(img, labels, *self.load_mosaic(random.randint(0, self.n - 1))) + img, labels = mixup(img, labels, *self.load_mosaic(random.choice(self.indices))) else: # Load image diff --git a/utils/segment/dataloaders.py b/utils/segment/dataloaders.py index 3ee826dba69c..5398617eef68 100644 --- a/utils/segment/dataloaders.py +++ b/utils/segment/dataloaders.py @@ -12,7 +12,7 @@ from torch.utils.data import DataLoader, distributed from ..augmentations import augment_hsv, copy_paste, letterbox -from ..dataloaders import InfiniteDataLoader, LoadImagesAndLabels, seed_worker +from ..dataloaders import InfiniteDataLoader, LoadImagesAndLabels, SmartDistributedSampler, seed_worker from ..general import LOGGER, xyn2xy, xywhn2xyxy, xyxy2xywhn from ..torch_utils import torch_distributed_zero_first from .augmentations import mixup, random_perspective @@ -57,12 +57,13 @@ def create_dataloader(path, image_weights=image_weights, prefix=prefix, downsample_ratio=mask_downsample_ratio, - overlap=overlap_mask) + overlap=overlap_mask, + rank=rank) batch_size = min(batch_size, len(dataset)) nd = torch.cuda.device_count() # number of CUDA devices nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers - sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle) + sampler = None if rank == -1 else SmartDistributedSampler(dataset, shuffle=shuffle) loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates generator = torch.Generator() generator.manual_seed(6148914691236517205 + seed + RANK) @@ -98,9 +99,11 @@ def __init__( prefix='', downsample_ratio=1, overlap=False, + rank=-1, + seed=0, ): super().__init__(path, img_size, batch_size, augment, hyp, rect, image_weights, cache_images, single_cls, - stride, pad, min_items, prefix) + stride, pad, min_items, prefix, rank, seed) self.downsample_ratio = downsample_ratio self.overlap = overlap