Skip to content

Commit

Permalink
Bugfix: update dataloaders.py to fix Multi-GPU DDP RAM multiple-cache…
Browse files Browse the repository at this point in the history
… issue (#10383)

* Update dataloaders.py

This is to address (and hopefully fix) this issue: Multi-GPU DDP RAM multiple-cache bug #3818 (#3818).  This was a very serious and "blocking" issue until I could figure out what was going on.  The problem was especially bad when running Multi-GPU jobs with 8 GPUs, RAM usage was 8x higher than expected (!), causing repeated OOM failures.  Hopefully this fix will help others.
DDP causes each RANK to launch it's own process (one for each GPU) with it's own trainloader, and its own RAM image cache.  The DistributedSampler used by DDP (https://github.com/pytorch/pytorch/blob/master/torch/utils/data/distributed.py) will feed only a subset of images (1/WORLD_SIZE) to each available GPU on each epoch, but since the images are shuffled between epochs, each GPU process must still cache all images.  So I created a subclass of DistributedSampler called SmartDistributedSampler that forces each GPU process to always sample the same subset (using modulo arithmetic with RANK and WORLD_SIZE) while still allowing random shuffling between epochs.  I don't believe this disrupts the overall "randomness" of the sampling, and I haven't noticed any performance degradation.  

Signed-off-by: davidsvaughn <davidsvaughn@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update dataloaders.py

move extra parameter (rank) to end so won't mess up pre-existing positional args

* Update dataloaders.py

removing extra '#'

* Update dataloaders.py

sample from DDP index array (self.idx) in mixup mosaic

* Merging self.indices and self.idx (DDP indices) into single attribute (self.indices).
Also adding SmartDistributedSampler to segmentation dataloader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Multiply GB displayed by WORLD_SIZE

---------

Signed-off-by: davidsvaughn <davidsvaughn@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
  • Loading branch information
3 people committed Jan 3, 2024
1 parent 151c953 commit 46ae996
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 12 deletions.
51 changes: 43 additions & 8 deletions utils/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv' # include video suffixes
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
RANK = int(os.getenv('RANK', -1))
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
PIN_MEMORY = str(os.getenv('PIN_MEMORY', True)).lower() == 'true' # global pin_memory for dataloaders

# Get orientation exif tag
Expand Down Expand Up @@ -100,6 +101,34 @@ def seed_worker(worker_id):
random.seed(worker_seed)


# Inherit from DistributedSampler and override iterator
# https://github.com/pytorch/pytorch/blob/master/torch/utils/data/distributed.py
class SmartDistributedSampler(distributed.DistributedSampler):

def __iter__(self):
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)

# determine the the eventual size (n) of self.indices (DDP indices)
n = int((len(self.dataset) - self.rank - 1) / self.num_replicas) + 1 # num_replicas == WORLD_SIZE
idx = torch.randperm(n, generator=g)
if not self.shuffle:
idx = idx.sort()[0]

idx = idx.tolist()
if self.drop_last:
idx = idx[:self.num_samples]
else:
padding_size = self.num_samples - len(idx)
if padding_size <= len(idx):
idx += idx[:padding_size]
else:
idx += (idx * math.ceil(padding_size / len(idx)))[:padding_size]

return iter(idx)


def create_dataloader(path,
imgsz,
batch_size,
Expand Down Expand Up @@ -133,12 +162,13 @@ def create_dataloader(path,
stride=int(stride),
pad=pad,
image_weights=image_weights,
prefix=prefix)
prefix=prefix,
rank=rank)

batch_size = min(batch_size, len(dataset))
nd = torch.cuda.device_count() # number of CUDA devices
nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
sampler = None if rank == -1 else SmartDistributedSampler(dataset, shuffle=shuffle)
loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
generator = torch.Generator()
generator.manual_seed(6148914691236517205 + seed + RANK)
Expand Down Expand Up @@ -449,7 +479,9 @@ def __init__(self,
stride=32,
pad=0.0,
min_items=0,
prefix=''):
prefix='',
rank=-1,
seed=0):
self.img_size = img_size
self.augment = augment
self.hyp = hyp
Expand Down Expand Up @@ -527,7 +559,10 @@ def __init__(self,
nb = bi[-1] + 1 # number of batches
self.batch = bi # batch index of image
self.n = n
self.indices = range(n)
self.indices = np.arange(n)
if rank > -1: # DDP indices (see: SmartDistributedSampler)
# force each rank (i.e. GPU process) to sample the same subset of data on every epoch
self.indices = self.indices[np.random.RandomState(seed=seed).permutation(n) % WORLD_SIZE == RANK]

# Update labels
include_class = [] # filter labels to include only these classes (optional)
Expand Down Expand Up @@ -576,14 +611,14 @@ def __init__(self,
b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
self.im_hw0, self.im_hw = [None] * n, [None] * n
fcn = self.cache_images_to_disk if cache_images == 'disk' else self.load_image
results = ThreadPool(NUM_THREADS).imap(fcn, range(n))
pbar = tqdm(enumerate(results), total=n, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
results = ThreadPool(NUM_THREADS).imap(lambda i: (i, fcn(i)), self.indices)
pbar = tqdm(results, total=len(self.indices), bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
for i, x in pbar:
if cache_images == 'disk':
b += self.npy_files[i].stat().st_size
else: # 'ram'
self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
b += self.ims[i].nbytes
b += self.ims[i].nbytes * WORLD_SIZE
pbar.desc = f'{prefix}Caching images ({b / gb:.1f}GB {cache_images})'
pbar.close()

Expand Down Expand Up @@ -663,7 +698,7 @@ def __getitem__(self, index):

# MixUp augmentation
if random.random() < hyp['mixup']:
img, labels = mixup(img, labels, *self.load_mosaic(random.randint(0, self.n - 1)))
img, labels = mixup(img, labels, *self.load_mosaic(random.choice(self.indices)))

else:
# Load image
Expand Down
11 changes: 7 additions & 4 deletions utils/segment/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from torch.utils.data import DataLoader, distributed

from ..augmentations import augment_hsv, copy_paste, letterbox
from ..dataloaders import InfiniteDataLoader, LoadImagesAndLabels, seed_worker
from ..dataloaders import InfiniteDataLoader, LoadImagesAndLabels, SmartDistributedSampler, seed_worker
from ..general import LOGGER, xyn2xy, xywhn2xyxy, xyxy2xywhn
from ..torch_utils import torch_distributed_zero_first
from .augmentations import mixup, random_perspective
Expand Down Expand Up @@ -57,12 +57,13 @@ def create_dataloader(path,
image_weights=image_weights,
prefix=prefix,
downsample_ratio=mask_downsample_ratio,
overlap=overlap_mask)
overlap=overlap_mask,
rank=rank)

batch_size = min(batch_size, len(dataset))
nd = torch.cuda.device_count() # number of CUDA devices
nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
sampler = None if rank == -1 else SmartDistributedSampler(dataset, shuffle=shuffle)
loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
generator = torch.Generator()
generator.manual_seed(6148914691236517205 + seed + RANK)
Expand Down Expand Up @@ -98,9 +99,11 @@ def __init__(
prefix='',
downsample_ratio=1,
overlap=False,
rank=-1,
seed=0,
):
super().__init__(path, img_size, batch_size, augment, hyp, rect, image_weights, cache_images, single_cls,
stride, pad, min_items, prefix)
stride, pad, min_items, prefix, rank, seed)
self.downsample_ratio = downsample_ratio
self.overlap = overlap

Expand Down

0 comments on commit 46ae996

Please sign in to comment.