From 89abf9eb4fafd43fd383df4ffbe7cbea74b79795 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sat, 13 Nov 2021 12:19:14 +0100 Subject: [PATCH] Cleanup, add rect warning --- utils/datasets.py | 41 ++++++++++++++++++++--------------------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/utils/datasets.py b/utils/datasets.py index eb87eecd4186..e1ce4a6e5a22 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -22,7 +22,7 @@ import torch.nn.functional as F import yaml from PIL import ExifTags, Image, ImageOps -from torch.utils.data import Dataset +from torch.utils.data import Dataset, DataLoader, distributed, dataloader from tqdm import tqdm from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective @@ -94,12 +94,14 @@ def exif_transpose(image): def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0, rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix='', shuffle=False): - # Make sure only the first process in DDP process the dataset first, and the following others can use the cache - with torch_distributed_zero_first(rank): + if rect and shuffle: + LOGGER.warning('WARNING: --rect is incompatible with Dataloader shuffle, setting shuffle=False') + shuffle = False + with torch_distributed_zero_first(rank): # Init dataset *.cache only once if DDP dataset = LoadImagesAndLabels(path, imgsz, batch_size, - augment=augment, # augment images - hyp=hyp, # augmentation hyperparameters - rect=rect, # rectangular training + augment=augment, # augmentation + hyp=hyp, # hyperparameters + rect=rect, # rectangular batches cache_images=cache, single_cls=single_cls, stride=int(stride), @@ -108,22 +110,19 @@ def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=Non prefix=prefix) batch_size = min(batch_size, len(dataset)) - shuffle = shuffle and not rect # disable shuffle in rect mode to keep the dataset sorted by aspect ratio nw = min([os.cpu_count() // WORLD_SIZE, batch_size if batch_size > 1 else 0, workers]) # number of workers - sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=shuffle) if rank != -1 else None - loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader - # Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader() - dataloader = loader(dataset, - batch_size=batch_size, - shuffle=(shuffle and sampler is None), - num_workers=nw, - sampler=sampler, - pin_memory=True, - collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn) - return dataloader, dataset - - -class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader): + sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle) + loader = DataLoader if image_weights else InfiniteDataLoader # only Dataloader allows for attribute updates + return loader(dataset, + batch_size=batch_size, + shuffle=shuffle and sampler is None, + num_workers=nw, + sampler=sampler, + pin_memory=True, + collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn), dataset + + +class InfiniteDataLoader(dataloader.DataLoader): """ Dataloader that reuses workers Uses same syntax as vanilla DataLoader