diff --git a/export.py b/export.py index 83e293b72e73..cec85958b4a9 100644 --- a/export.py +++ b/export.py @@ -156,8 +156,8 @@ def run(weights='./yolov5s.pt', # weights path # Finish print(f'\nExport complete ({time.time() - t:.2f}s)' - f"Results saved to {colorstr('bold', file.parent.resolve())}\n" - f'Visualize with https://netron.app') + f"\nResults saved to {colorstr('bold', file.parent.resolve())}" + f'\nVisualize with https://netron.app') def parse_opt(): diff --git a/train.py b/train.py index d4a5495d3b3b..34bd8e73c290 100644 --- a/train.py +++ b/train.py @@ -201,7 +201,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary # Trainloader train_loader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls, - hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=RANK, + hyp=hyp, augment=True, cache=opt.cache, rect=opt.rect, rank=RANK, workers=workers, image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: ')) mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class @@ -211,7 +211,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary # Process 0 if RANK in [-1, 0]: val_loader = create_dataloader(val_path, imgsz, batch_size // WORLD_SIZE * 2, gs, single_cls, - hyp=hyp, cache=opt.cache_images and not noval, rect=True, rank=-1, + hyp=hyp, cache=None if noval else opt.cache, rect=True, rank=-1, workers=workers, pad=0.5, prefix=colorstr('val: '))[0] @@ -389,7 +389,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary # end epoch ---------------------------------------------------------------------------------------------------- # end training ----------------------------------------------------------------------------------------------------- if RANK in [-1, 0]: - LOGGER.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n') + LOGGER.info(f'\n{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.') if not evolve: if is_coco: # COCO dataset for m in [last, best] if best.exists() else [last]: # speed, mAP tests @@ -430,7 +430,7 @@ def parse_opt(known=False): parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check') parser.add_argument('--evolve', type=int, nargs='?', const=300, help='evolve hyperparameters for x generations') parser.add_argument('--bucket', type=str, default='', help='gsutil bucket') - parser.add_argument('--cache-images', action='store_true', help='cache images for faster training') + parser.add_argument('--cache', type=str, nargs='?', const='ram', help='--cache images in "ram" (default) or "disk"') parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training') parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%') diff --git a/utils/datasets.py b/utils/datasets.py index fffe39a61459..1c780cdbac4b 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -455,16 +455,25 @@ def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, r self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride # Cache images into memory for faster training (WARNING: large datasets may exceed system RAM) - self.imgs = [None] * n + self.imgs, self.img_npy = [None] * n, [None] * n if cache_images: + if cache_images == 'disk': + self.im_cache_dir = Path(Path(self.img_files[0]).parent.as_posix() + '_npy') + self.img_npy = [self.im_cache_dir / Path(f).with_suffix('.npy').name for f in self.img_files] + self.im_cache_dir.mkdir(parents=True, exist_ok=True) gb = 0 # Gigabytes of cached images self.img_hw0, self.img_hw = [None] * n, [None] * n results = ThreadPool(NUM_THREADS).imap(lambda x: load_image(*x), zip(repeat(self), range(n))) pbar = tqdm(enumerate(results), total=n) for i, x in pbar: - self.imgs[i], self.img_hw0[i], self.img_hw[i] = x # img, hw_original, hw_resized = load_image(self, i) - gb += self.imgs[i].nbytes - pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB)' + if cache_images == 'disk': + if not self.img_npy[i].exists(): + np.save(self.img_npy[i].as_posix(), x[0]) + gb += self.img_npy[i].stat().st_size + else: + self.imgs[i], self.img_hw0[i], self.img_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i) + gb += self.imgs[i].nbytes + pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB {cache_images})' pbar.close() def cache_labels(self, path=Path('./labels.cache'), prefix=''): @@ -618,21 +627,25 @@ def collate_fn4(batch): # Ancillary functions -------------------------------------------------------------------------------------------------- -def load_image(self, index): - # loads 1 image from dataset, returns img, original hw, resized hw - img = self.imgs[index] - if img is None: # not cached - path = self.img_files[index] - img = cv2.imread(path) # BGR - assert img is not None, 'Image Not Found ' + path - h0, w0 = img.shape[:2] # orig hw +def load_image(self, i): + # loads 1 image from dataset index 'i', returns im, original hw, resized hw + im = self.imgs[i] + if im is None: # not cached in ram + npy = self.img_npy[i] + if npy and npy.exists(): # load npy + im = np.load(npy) + else: # read image + path = self.img_files[i] + im = cv2.imread(path) # BGR + assert im is not None, 'Image Not Found ' + path + h0, w0 = im.shape[:2] # orig hw r = self.img_size / max(h0, w0) # ratio if r != 1: # if sizes are not equal - img = cv2.resize(img, (int(w0 * r), int(h0 * r)), - interpolation=cv2.INTER_AREA if r < 1 and not self.augment else cv2.INTER_LINEAR) - return img, (h0, w0), img.shape[:2] # img, hw_original, hw_resized + im = cv2.resize(im, (int(w0 * r), int(h0 * r)), + interpolation=cv2.INTER_AREA if r < 1 and not self.augment else cv2.INTER_LINEAR) + return im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized else: - return self.imgs[index], self.img_hw0[index], self.img_hw[index] # img, hw_original, hw_resized + return self.imgs[i], self.img_hw0[i], self.img_hw[i] # im, hw_original, hw_resized def load_mosaic(self, index):