diff --git a/utils/datasets.py b/utils/datasets.py index b6e43b94cfe9..bda435776629 100644 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -9,7 +9,7 @@ import shutil import time from itertools import repeat -from multiprocessing.pool import ThreadPool +from multiprocessing.pool import ThreadPool, Pool from pathlib import Path from threading import Thread @@ -29,6 +29,7 @@ help_url = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data' img_formats = ['bmp', 'jpg', 'jpeg', 'png', 'tif', 'tiff', 'dng', 'webp', 'mpo'] # acceptable image suffixes vid_formats = ['mov', 'avi', 'mp4', 'mpg', 'mpeg', 'm4v', 'wmv', 'mkv'] # acceptable video suffixes +num_threads = min(8, os.cpu_count()) # number of multiprocessing threads logger = logging.getLogger(__name__) # Get orientation exif tag @@ -447,7 +448,7 @@ def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, r if cache_images: gb = 0 # Gigabytes of cached images self.img_hw0, self.img_hw = [None] * n, [None] * n - results = ThreadPool(8).imap(lambda x: load_image(*x), zip(repeat(self), range(n))) # 8 threads + results = ThreadPool(num_threads).imap(lambda x: load_image(*x), zip(repeat(self), range(n))) pbar = tqdm(enumerate(results), total=n) for i, x in pbar: self.imgs[i], self.img_hw0[i], self.img_hw[i] = x # img, hw_original, hw_resized = load_image(self, i) @@ -458,53 +459,24 @@ def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, r def cache_labels(self, path=Path('./labels.cache'), prefix=''): # Cache dataset labels, check images and read shapes x = {} # dict - nm, nf, ne, nc = 0, 0, 0, 0 # number missing, found, empty, duplicate - pbar = tqdm(zip(self.img_files, self.label_files), desc='Scanning images', total=len(self.img_files)) - for i, (im_file, lb_file) in enumerate(pbar): - try: - # verify images - im = Image.open(im_file) - im.verify() # PIL verify - shape = exif_size(im) # image size - segments = [] # instance segments - assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels' - assert im.format.lower() in img_formats, f'invalid image format {im.format}' - - # verify labels - if os.path.isfile(lb_file): - nf += 1 # label found - with open(lb_file, 'r') as f: - l = [x.split() for x in f.read().strip().splitlines() if len(x)] - if any([len(x) > 8 for x in l]): # is segment - classes = np.array([x[0] for x in l], dtype=np.float32) - segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in l] # (cls, xy1...) - l = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh) - l = np.array(l, dtype=np.float32) - if len(l): - assert l.shape[1] == 5, 'labels require 5 columns each' - assert (l >= 0).all(), 'negative labels' - assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels' - assert np.unique(l, axis=0).shape[0] == l.shape[0], 'duplicate labels' - else: - ne += 1 # label empty - l = np.zeros((0, 5), dtype=np.float32) - else: - nm += 1 # label missing - l = np.zeros((0, 5), dtype=np.float32) - x[im_file] = [l, shape, segments] - except Exception as e: - nc += 1 - logging.info(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}') - - pbar.desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels... " \ - f"{nf} found, {nm} missing, {ne} empty, {nc} corrupted" + nm, nf, ne, nc = 0, 0, 0, 0 # number missing, found, empty, corrupt + desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels..." + with Pool(num_threads) as pool: + pbar = tqdm(pool.imap_unordered(verify_image_label, + zip(self.img_files, self.label_files, repeat(prefix))), + desc=desc, total=len(self.img_files)) + for im_file, l, shape, segments, nm_f, nf_f, ne_f, nc_f in pbar: + if im_file: + x[im_file] = [l, shape, segments] + nm, nf, ne, nc = nm + nm_f, nf + nf_f, ne + ne_f, nc + nc_f + pbar.desc = f"{desc}{nf} found, {nm} missing, {ne} empty, {nc} corrupted" pbar.close() if nf == 0: logging.info(f'{prefix}WARNING: No labels found in {path}. See {help_url}') x['hash'] = get_hash(self.label_files + self.img_files) - x['results'] = nf, nm, ne, nc, i + 1 + x['results'] = nf, nm, ne, nc, len(self.img_files) x['version'] = 0.2 # cache version try: torch.save(x, path) # save cache for next time @@ -1069,3 +1041,44 @@ def autosplit(path='../coco128', weights=(0.9, 0.1, 0.0), annotated_only=False): if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label with open(path / txt[i], 'a') as f: f.write(str(img) + '\n') # add image to txt file + + +def verify_image_label(params): + # Verify one image-label pair + im_file, lb_file, prefix = params + nm, nf, ne, nc = 0, 0, 0, 0 # number missing, found, empty, corrupt + try: + # verify images + im = Image.open(im_file) + im.verify() # PIL verify + shape = exif_size(im) # image size + segments = [] # instance segments + assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels' + assert im.format.lower() in img_formats, f'invalid image format {im.format}' + + # verify labels + if os.path.isfile(lb_file): + nf = 1 # label found + with open(lb_file, 'r') as f: + l = [x.split() for x in f.read().strip().splitlines() if len(x)] + if any([len(x) > 8 for x in l]): # is segment + classes = np.array([x[0] for x in l], dtype=np.float32) + segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in l] # (cls, xy1...) + l = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh) + l = np.array(l, dtype=np.float32) + if len(l): + assert l.shape[1] == 5, 'labels require 5 columns each' + assert (l >= 0).all(), 'negative labels' + assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels' + assert np.unique(l, axis=0).shape[0] == l.shape[0], 'duplicate labels' + else: + ne = 1 # label empty + l = np.zeros((0, 5), dtype=np.float32) + else: + nm = 1 # label missing + l = np.zeros((0, 5), dtype=np.float32) + return im_file, l, shape, segments, nm, nf, ne, nc + except Exception as e: + nc = 1 + logging.info(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}') + return [None] * 4 + [nm, nf, ne, nc]