Skip to content

Commit

Permalink
Update dataloaders.py
Browse files Browse the repository at this point in the history
This is to address (and hopefully fix) this issue: Multi-GPU DDP RAM multiple-cache bug ultralytics#3818 (ultralytics#3818).  This was a very serious and "blocking" issue until I could figure out what was going on.  The problem was especially bad when running Multi-GPU jobs with 8 GPUs, RAM usage was 8x higher than expected (!), causing repeated OOM failures.  Hopefully this fix will help others.
DDP causes each RANK to launch it's own process (one for each GPU) with it's own trainloader, and its own RAM image cache.  The DistributedSampler used by DDP (https://github.com/pytorch/pytorch/blob/master/torch/utils/data/distributed.py) will feed only a subset of images (1/WORLD_SIZE) to each available GPU on each epoch, but since the images are shuffled between epochs, each GPU process must still cache all images.  So I created a subclass of DistributedSampler called SmartDistributedSampler that forces each GPU process to always sample the same subset (using modulo arithmetic with RANK and WORLD_SIZE) while still allowing random shuffling between epochs.  I don't believe this disrupts the overall "randomness" of the sampling, and I haven't noticed any performance degradation.  

Signed-off-by: davidsvaughn <davidsvaughn@gmail.com>
  • Loading branch information
davidsvaughn committed Dec 2, 2022
1 parent d7955fe commit 5488bd5
Showing 1 changed file with 37 additions and 5 deletions.
42 changes: 37 additions & 5 deletions utils/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv' # include video suffixes
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
RANK = int(os.getenv('RANK', -1))
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
PIN_MEMORY = str(os.getenv('PIN_MEMORY', True)).lower() == 'true' # global pin_memory for dataloaders

# Get orientation exif tag
Expand Down Expand Up @@ -100,6 +101,33 @@ def seed_worker(worker_id):
random.seed(worker_seed)


# Inherit from DistributedSampler and override iterator
# https://github.com/pytorch/pytorch/blob/master/torch/utils/data/distributed.py
class SmartDistributedSampler(distributed.DistributedSampler):
def __iter__(self):
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
idx = torch.randperm(len(self.dataset), generator=g)
if not self.shuffle:
idx = idx.sort()[0]

## force each rank (i.e. GPU process) to sample the same subset of data every epoch
idx = idx[idx % self.num_replicas == self.rank] # num_replicas == WORLD_SIZE

idx = idx.tolist()
if self.drop_last:
idx = idx[:self.num_samples]
else:
padding_size = self.num_samples - len(idx)
if padding_size <= len(idx):
idx += idx[:padding_size]
else:
idx += (idx * math.ceil(padding_size / len(idx)))[:padding_size]

return iter(idx)


def create_dataloader(path,
imgsz,
batch_size,
Expand Down Expand Up @@ -127,6 +155,7 @@ def create_dataloader(path,
augment=augment, # augmentation
hyp=hyp, # hyperparameters
rect=rect, # rectangular batches
rank=rank,
cache_images=cache,
single_cls=single_cls,
stride=int(stride),
Expand All @@ -137,7 +166,7 @@ def create_dataloader(path,
batch_size = min(batch_size, len(dataset))
nd = torch.cuda.device_count() # number of CUDA devices
nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
sampler = None if rank == -1 else SmartDistributedSampler(dataset, shuffle=shuffle)
loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
generator = torch.Generator()
generator.manual_seed(6148914691236517205 + RANK)
Expand Down Expand Up @@ -440,6 +469,7 @@ def __init__(self,
augment=False,
hyp=None,
rect=False,
rank=-1,
image_weights=False,
cache_images=False,
single_cls=False,
Expand Down Expand Up @@ -568,12 +598,14 @@ def __init__(self,
cache_images = False
self.ims = [None] * n
self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files]
self.idx = np.array(self.indices) # DDP indices
if cache_images:
b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
self.im_hw0, self.im_hw = [None] * n, [None] * n
self.idx = self.idx[self.idx % WORLD_SIZE == RANK] if rank>-1 else self.idx # see: SmartDistributedSampler (above)
fcn = self.cache_images_to_disk if cache_images == 'disk' else self.load_image
results = ThreadPool(NUM_THREADS).imap(fcn, range(n))
pbar = tqdm(enumerate(results), total=n, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
results = ThreadPool(NUM_THREADS).imap(lambda i: (i, fcn(i)) , self.idx)
pbar = tqdm(results, total=len(self.idx), bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
for i, x in pbar:
if cache_images == 'disk':
b += self.npy_files[i].stat().st_size
Expand Down Expand Up @@ -749,7 +781,7 @@ def load_mosaic(self, index):
labels4, segments4 = [], []
s = self.img_size
yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border) # mosaic center x, y
indices = [index] + random.choices(self.indices, k=3) # 3 additional image indices
indices = [index] + random.choices(self.idx, k=3) # 3 additional image indices
random.shuffle(indices)
for i, index in enumerate(indices):
# Load image
Expand Down Expand Up @@ -806,7 +838,7 @@ def load_mosaic9(self, index):
# YOLOv5 9-mosaic loader. Loads 1 image + 8 random images into a 9-image mosaic
labels9, segments9 = [], []
s = self.img_size
indices = [index] + random.choices(self.indices, k=8) # 8 additional image indices
indices = [index] + random.choices(self.idx, k=8) # 8 additional image indices
random.shuffle(indices)
hp, wp = -1, -1 # height, width previous
for i, index in enumerate(indices):
Expand Down

0 comments on commit 5488bd5

Please sign in to comment.