Skip to content

Commit

Permalink
Cleanup, add rect warning
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher committed Nov 13, 2021
1 parent c363fb7 commit 89abf9e
Showing 1 changed file with 20 additions and 21 deletions.
41 changes: 20 additions & 21 deletions utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import torch.nn.functional as F
import yaml
from PIL import ExifTags, Image, ImageOps
from torch.utils.data import Dataset
from torch.utils.data import Dataset, DataLoader, distributed, dataloader
from tqdm import tqdm

from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective
Expand Down Expand Up @@ -94,12 +94,14 @@ def exif_transpose(image):

def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0,
rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix='', shuffle=False):
# Make sure only the first process in DDP process the dataset first, and the following others can use the cache
with torch_distributed_zero_first(rank):
if rect and shuffle:
LOGGER.warning('WARNING: --rect is incompatible with Dataloader shuffle, setting shuffle=False')
shuffle = False
with torch_distributed_zero_first(rank): # Init dataset *.cache only once if DDP
dataset = LoadImagesAndLabels(path, imgsz, batch_size,
augment=augment, # augment images
hyp=hyp, # augmentation hyperparameters
rect=rect, # rectangular training
augment=augment, # augmentation
hyp=hyp, # hyperparameters
rect=rect, # rectangular batches
cache_images=cache,
single_cls=single_cls,
stride=int(stride),
Expand All @@ -108,22 +110,19 @@ def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=Non
prefix=prefix)

batch_size = min(batch_size, len(dataset))
shuffle = shuffle and not rect # disable shuffle in rect mode to keep the dataset sorted by aspect ratio
nw = min([os.cpu_count() // WORLD_SIZE, batch_size if batch_size > 1 else 0, workers]) # number of workers
sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=shuffle) if rank != -1 else None
loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader
# Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader()
dataloader = loader(dataset,
batch_size=batch_size,
shuffle=(shuffle and sampler is None),
num_workers=nw,
sampler=sampler,
pin_memory=True,
collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn)
return dataloader, dataset


class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader):
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
loader = DataLoader if image_weights else InfiniteDataLoader # only Dataloader allows for attribute updates
return loader(dataset,
batch_size=batch_size,
shuffle=shuffle and sampler is None,
num_workers=nw,
sampler=sampler,
pin_memory=True,
collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn), dataset


class InfiniteDataLoader(dataloader.DataLoader):
""" Dataloader that reuses workers
Uses same syntax as vanilla DataLoader
Expand Down

0 comments on commit 89abf9e

Please sign in to comment.