Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix: update dataloaders.py to fix Multi-GPU DDP RAM multiple-cache issue #10383

Merged
merged 11 commits into from
Jan 3, 2024
51 changes: 43 additions & 8 deletions utils/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv' # include video suffixes
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
RANK = int(os.getenv('RANK', -1))
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
PIN_MEMORY = str(os.getenv('PIN_MEMORY', True)).lower() == 'true' # global pin_memory for dataloaders

# Get orientation exif tag
Expand Down Expand Up @@ -100,6 +101,34 @@ def seed_worker(worker_id):
random.seed(worker_seed)


# Inherit from DistributedSampler and override iterator
# https://github.com/pytorch/pytorch/blob/master/torch/utils/data/distributed.py
class SmartDistributedSampler(distributed.DistributedSampler):

def __iter__(self):
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)

# determine the the eventual size (n) of self.indices (DDP indices)
n = int((len(self.dataset) - self.rank - 1) / self.num_replicas) + 1 # num_replicas == WORLD_SIZE
idx = torch.randperm(n, generator=g)
if not self.shuffle:
idx = idx.sort()[0]

idx = idx.tolist()
if self.drop_last:
idx = idx[:self.num_samples]
else:
padding_size = self.num_samples - len(idx)
if padding_size <= len(idx):
idx += idx[:padding_size]
else:
idx += (idx * math.ceil(padding_size / len(idx)))[:padding_size]

return iter(idx)


def create_dataloader(path,
imgsz,
batch_size,
Expand Down Expand Up @@ -132,12 +161,13 @@ def create_dataloader(path,
stride=int(stride),
pad=pad,
image_weights=image_weights,
prefix=prefix)
prefix=prefix,
rank=rank)

batch_size = min(batch_size, len(dataset))
nd = torch.cuda.device_count() # number of CUDA devices
nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
sampler = None if rank == -1 else SmartDistributedSampler(dataset, shuffle=shuffle)
loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
generator = torch.Generator()
generator.manual_seed(6148914691236517205 + RANK)
Expand Down Expand Up @@ -448,7 +478,9 @@ def __init__(self,
stride=32,
pad=0.0,
min_items=0,
prefix=''):
prefix='',
rank=-1,
seed=0):
self.img_size = img_size
self.augment = augment
self.hyp = hyp
Expand Down Expand Up @@ -526,7 +558,10 @@ def __init__(self,
nb = bi[-1] + 1 # number of batches
self.batch = bi # batch index of image
self.n = n
self.indices = range(n)
self.indices = np.arange(n)
if rank > -1: # DDP indices (see: SmartDistributedSampler)
# force each rank (i.e. GPU process) to sample the same subset of data on every epoch
self.indices = self.indices[np.random.RandomState(seed=seed).permutation(n) % WORLD_SIZE == RANK]

# Update labels
include_class = [] # filter labels to include only these classes (optional)
Expand Down Expand Up @@ -574,14 +609,14 @@ def __init__(self,
b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
self.im_hw0, self.im_hw = [None] * n, [None] * n
fcn = self.cache_images_to_disk if cache_images == 'disk' else self.load_image
results = ThreadPool(NUM_THREADS).imap(fcn, range(n))
pbar = tqdm(enumerate(results), total=n, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
results = ThreadPool(NUM_THREADS).imap(lambda i: (i, fcn(i)), self.indices)
pbar = tqdm(results, total=len(self.indices), bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
for i, x in pbar:
if cache_images == 'disk':
b += self.npy_files[i].stat().st_size
else: # 'ram'
self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
b += self.ims[i].nbytes
b += self.ims[i].nbytes * WORLD_SIZE
pbar.desc = f'{prefix}Caching images ({b / gb:.1f}GB {cache_images})'
pbar.close()

Expand Down Expand Up @@ -661,7 +696,7 @@ def __getitem__(self, index):

# MixUp augmentation
if random.random() < hyp['mixup']:
img, labels = mixup(img, labels, *self.load_mosaic(random.randint(0, self.n - 1)))
img, labels = mixup(img, labels, *self.load_mosaic(random.choice(self.indices)))

else:
# Load image
Expand Down
11 changes: 7 additions & 4 deletions utils/segment/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from torch.utils.data import DataLoader, distributed

from ..augmentations import augment_hsv, copy_paste, letterbox
from ..dataloaders import InfiniteDataLoader, LoadImagesAndLabels, seed_worker
from ..dataloaders import InfiniteDataLoader, LoadImagesAndLabels, SmartDistributedSampler, seed_worker
from ..general import LOGGER, xyn2xy, xywhn2xyxy, xyxy2xywhn
from ..torch_utils import torch_distributed_zero_first
from .augmentations import mixup, random_perspective
Expand Down Expand Up @@ -56,12 +56,13 @@ def create_dataloader(path,
image_weights=image_weights,
prefix=prefix,
downsample_ratio=mask_downsample_ratio,
overlap=overlap_mask)
overlap=overlap_mask,
rank=rank)

batch_size = min(batch_size, len(dataset))
nd = torch.cuda.device_count() # number of CUDA devices
nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
sampler = None if rank == -1 else SmartDistributedSampler(dataset, shuffle=shuffle)
loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
generator = torch.Generator()
generator.manual_seed(6148914691236517205 + RANK)
Expand Down Expand Up @@ -97,9 +98,11 @@ def __init__(
prefix="",
downsample_ratio=1,
overlap=False,
rank=-1,
seed=0,
):
super().__init__(path, img_size, batch_size, augment, hyp, rect, image_weights, cache_images, single_cls,
stride, pad, min_items, prefix)
stride, pad, min_items, prefix, rank, seed)
self.downsample_ratio = downsample_ratio
self.overlap = overlap

Expand Down