From 20fb666dec1cb4ffb4886e0a1886fcd34bf07997 Mon Sep 17 00:00:00 2001 From: Werner Duvaud <40442230+werner-duvaud@users.noreply.github.com> Date: Fri, 12 Nov 2021 05:15:33 +0100 Subject: [PATCH 1/7] Fix shuffle DataLoader argument --- utils/datasets.py | 1 + 1 file changed, 1 insertion(+) diff --git a/utils/datasets.py b/utils/datasets.py index f153db0d7104..ddf5101e6e0a 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -114,6 +114,7 @@ def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=Non # Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader() dataloader = loader(dataset, batch_size=batch_size, + shuffle=(sampler is None), num_workers=nw, sampler=sampler, pin_memory=True, From d8fd61d1fa2b1c3312a37b85e59fa23f7ea054ff Mon Sep 17 00:00:00 2001 From: Werner Duvaud <40442230+werner-duvaud@users.noreply.github.com> Date: Fri, 12 Nov 2021 06:46:12 +0100 Subject: [PATCH 2/7] Add shuffle argument --- train.py | 2 +- utils/datasets.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/train.py b/train.py index 96b3c2fdc516..91bcd1e1e2e8 100644 --- a/train.py +++ b/train.py @@ -212,7 +212,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary train_loader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls, hyp=hyp, augment=True, cache=opt.cache, rect=opt.rect, rank=LOCAL_RANK, workers=workers, image_weights=opt.image_weights, quad=opt.quad, - prefix=colorstr('train: ')) + prefix=colorstr('train: '), shuffle=True) mlc = int(np.concatenate(dataset.labels, 0)[:, 0].max()) # max label class nb = len(train_loader) # number of batches assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}' diff --git a/utils/datasets.py b/utils/datasets.py index ddf5101e6e0a..f897471bd8cd 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -93,7 +93,7 @@ def exif_transpose(image): def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0, - rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix=''): + rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix='', shuffle=False): # Make sure only the first process in DDP process the dataset first, and the following others can use the cache with torch_distributed_zero_first(rank): dataset = LoadImagesAndLabels(path, imgsz, batch_size, @@ -109,12 +109,12 @@ def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=Non batch_size = min(batch_size, len(dataset)) nw = min([os.cpu_count() // WORLD_SIZE, batch_size if batch_size > 1 else 0, workers]) # number of workers - sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None + sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=shuffle) if rank != -1 else None loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader # Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader() dataloader = loader(dataset, batch_size=batch_size, - shuffle=(sampler is None), + shuffle=(shuffle and sampler is None), num_workers=nw, sampler=sampler, pin_memory=True, From c363fb78799b74803070b0395a8efdcce0461823 Mon Sep 17 00:00:00 2001 From: Werner Duvaud <40442230+werner-duvaud@users.noreply.github.com> Date: Fri, 12 Nov 2021 14:55:25 +0100 Subject: [PATCH 3/7] Disable shuffle when rect --- utils/datasets.py | 1 + 1 file changed, 1 insertion(+) diff --git a/utils/datasets.py b/utils/datasets.py index f897471bd8cd..eb87eecd4186 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -108,6 +108,7 @@ def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=Non prefix=prefix) batch_size = min(batch_size, len(dataset)) + shuffle = shuffle and not rect # disable shuffle in rect mode to keep the dataset sorted by aspect ratio nw = min([os.cpu_count() // WORLD_SIZE, batch_size if batch_size > 1 else 0, workers]) # number of workers sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=shuffle) if rank != -1 else None loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader From 89abf9eb4fafd43fd383df4ffbe7cbea74b79795 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sat, 13 Nov 2021 12:19:14 +0100 Subject: [PATCH 4/7] Cleanup, add rect warning --- utils/datasets.py | 41 ++++++++++++++++++++--------------------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/utils/datasets.py b/utils/datasets.py index eb87eecd4186..e1ce4a6e5a22 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -22,7 +22,7 @@ import torch.nn.functional as F import yaml from PIL import ExifTags, Image, ImageOps -from torch.utils.data import Dataset +from torch.utils.data import Dataset, DataLoader, distributed, dataloader from tqdm import tqdm from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective @@ -94,12 +94,14 @@ def exif_transpose(image): def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0, rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix='', shuffle=False): - # Make sure only the first process in DDP process the dataset first, and the following others can use the cache - with torch_distributed_zero_first(rank): + if rect and shuffle: + LOGGER.warning('WARNING: --rect is incompatible with Dataloader shuffle, setting shuffle=False') + shuffle = False + with torch_distributed_zero_first(rank): # Init dataset *.cache only once if DDP dataset = LoadImagesAndLabels(path, imgsz, batch_size, - augment=augment, # augment images - hyp=hyp, # augmentation hyperparameters - rect=rect, # rectangular training + augment=augment, # augmentation + hyp=hyp, # hyperparameters + rect=rect, # rectangular batches cache_images=cache, single_cls=single_cls, stride=int(stride), @@ -108,22 +110,19 @@ def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=Non prefix=prefix) batch_size = min(batch_size, len(dataset)) - shuffle = shuffle and not rect # disable shuffle in rect mode to keep the dataset sorted by aspect ratio nw = min([os.cpu_count() // WORLD_SIZE, batch_size if batch_size > 1 else 0, workers]) # number of workers - sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=shuffle) if rank != -1 else None - loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader - # Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader() - dataloader = loader(dataset, - batch_size=batch_size, - shuffle=(shuffle and sampler is None), - num_workers=nw, - sampler=sampler, - pin_memory=True, - collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn) - return dataloader, dataset - - -class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader): + sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle) + loader = DataLoader if image_weights else InfiniteDataLoader # only Dataloader allows for attribute updates + return loader(dataset, + batch_size=batch_size, + shuffle=shuffle and sampler is None, + num_workers=nw, + sampler=sampler, + pin_memory=True, + collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn), dataset + + +class InfiniteDataLoader(dataloader.DataLoader): """ Dataloader that reuses workers Uses same syntax as vanilla DataLoader From 83fbe49bed1432d004644e08128d90f5e1e66467 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 13 Nov 2021 11:19:24 +0000 Subject: [PATCH 5/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- utils/datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/datasets.py b/utils/datasets.py index e1ce4a6e5a22..e64cca8b1b6c 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -22,7 +22,7 @@ import torch.nn.functional as F import yaml from PIL import ExifTags, Image, ImageOps -from torch.utils.data import Dataset, DataLoader, distributed, dataloader +from torch.utils.data import DataLoader, Dataset, dataloader, distributed from tqdm import tqdm from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective From 1e3f48c4b8f4bb2748412cdbfc1308a2f7e3d358 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sat, 13 Nov 2021 12:19:56 +0100 Subject: [PATCH 6/7] Cleanup2 --- utils/datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/datasets.py b/utils/datasets.py index e64cca8b1b6c..e5be469f0e03 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -97,7 +97,7 @@ def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=Non if rect and shuffle: LOGGER.warning('WARNING: --rect is incompatible with Dataloader shuffle, setting shuffle=False') shuffle = False - with torch_distributed_zero_first(rank): # Init dataset *.cache only once if DDP + with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP dataset = LoadImagesAndLabels(path, imgsz, batch_size, augment=augment, # augmentation hyp=hyp, # hyperparameters From 1873938aea969cd38998e9067243a197606918c1 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sat, 13 Nov 2021 12:26:10 +0100 Subject: [PATCH 7/7] Cleanup3 --- utils/datasets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/utils/datasets.py b/utils/datasets.py index e5be469f0e03..3504998b125d 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -95,7 +95,7 @@ def exif_transpose(image): def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0, rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix='', shuffle=False): if rect and shuffle: - LOGGER.warning('WARNING: --rect is incompatible with Dataloader shuffle, setting shuffle=False') + LOGGER.warning('WARNING: --rect is incompatible with DataLoader shuffle, setting shuffle=False') shuffle = False with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP dataset = LoadImagesAndLabels(path, imgsz, batch_size, @@ -112,7 +112,7 @@ def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=Non batch_size = min(batch_size, len(dataset)) nw = min([os.cpu_count() // WORLD_SIZE, batch_size if batch_size > 1 else 0, workers]) # number of workers sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle) - loader = DataLoader if image_weights else InfiniteDataLoader # only Dataloader allows for attribute updates + loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates return loader(dataset, batch_size=batch_size, shuffle=shuffle and sampler is None,