From 9eea3ac742d11f8a5679c7b937c742450c05787c Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 3 Nov 2022 19:16:41 +0100 Subject: [PATCH 1/6] AutoCache --- segment/train.py | 2 +- train.py | 2 +- utils/__init__.py | 1 - utils/dataloaders.py | 24 ++++++++++++++++++++++-- 4 files changed, 24 insertions(+), 5 deletions(-) diff --git a/segment/train.py b/segment/train.py index 7950f95df4f2..bdfb6f86a9ed 100644 --- a/segment/train.py +++ b/segment/train.py @@ -474,7 +474,7 @@ def parse_opt(known=False): parser.add_argument('--noplots', action='store_true', help='save no plot files') parser.add_argument('--evolve', type=int, nargs='?', const=300, help='evolve hyperparameters for x generations') parser.add_argument('--bucket', type=str, default='', help='gsutil bucket') - parser.add_argument('--cache', type=str, nargs='?', const='ram', help='--cache images in "ram" (default) or "disk"') + parser.add_argument('--cache', type=str, nargs='?', const='auto', help='image --cache ram/disk/auto') parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training') parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%') diff --git a/train.py b/train.py index e882748581bf..7537c2a4c45b 100644 --- a/train.py +++ b/train.py @@ -444,7 +444,7 @@ def parse_opt(known=False): parser.add_argument('--noplots', action='store_true', help='save no plot files') parser.add_argument('--evolve', type=int, nargs='?', const=300, help='evolve hyperparameters for x generations') parser.add_argument('--bucket', type=str, default='', help='gsutil bucket') - parser.add_argument('--cache', type=str, nargs='?', const='ram', help='--cache images in "ram" (default) or "disk"') + parser.add_argument('--cache', type=str, nargs='?', const='auto', help='image --cache ram/disk/auto') parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training') parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%') diff --git a/utils/__init__.py b/utils/__init__.py index 0afe6f475625..8354d91c4269 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -47,7 +47,6 @@ def notebook_init(verbose=True): from utils.general import check_font, check_requirements, is_colab from utils.torch_utils import select_device # imports - check_requirements(('psutil', 'IPython')) check_font() import psutil diff --git a/utils/dataloaders.py b/utils/dataloaders.py index 4e5b75edb5c2..d54a344d9a2e 100644 --- a/utils/dataloaders.py +++ b/utils/dataloaders.py @@ -10,6 +10,7 @@ import math import os import random +import psutil import shutil import time from itertools import repeat @@ -31,7 +32,7 @@ cutout, letterbox, mixup, random_perspective) from utils.general import (DATASETS_DIR, LOGGER, NUM_THREADS, check_dataset, check_requirements, check_yaml, clean_str, cv2, is_colab, is_kaggle, segments2boxes, unzip_file, xyn2xy, xywh2xyxy, xywhn2xyxy, - xyxy2xywhn) + xyxy2xywhn, colorstr) from utils.torch_utils import torch_distributed_zero_first # Parameters @@ -565,6 +566,8 @@ def __init__(self, self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(int) * stride # Cache images into RAM/disk for faster training (WARNING: large datasets may exceed system resources) + if cache_images == 'auto': + cache_images = self.autocache() # AutoCache self.ims = [None] * n self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files] if cache_images: @@ -579,9 +582,26 @@ def __init__(self, else: # 'ram' self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i) gb += self.ims[i].nbytes - pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB {cache_images})' + pbar.desc = f'{prefix}Caching images ({gb << 30:.1f}GB {cache_images})' pbar.close() + def autocache(self, safety_margin=0.5, prefix=colorstr('AutoCache: ')): + # AutoCache: check image caching requirements vs available memory + bytes = 0 # gigabytes + gb = 1 << 30 # bytes in a GB + n = min(self.n, 30) # number of samples + for _ in range(n): + im = cv2.imread(random.choice(self.im_files)) # sample image + ratio = self.img_size / max(im.shape[0], im.shape[1]) # max(h, w) # ratio + bytes += im.nbytes * ratio ** 2 + mem_required = bytes * self.n / n # GB required to cache dataset into RAM + mem = psutil.virtual_memory() + cache = mem_required * (1 + safety_margin) < mem.available # to cache or not to cache, that is the question + LOGGER.info(f"{prefix}{mem_required / gb:.1f}GB RAM required, " + f"{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, " + f"{'caching images ✅' if cache else 'not caching images ⚠️'}") + return cache + def cache_labels(self, path=Path('./labels.cache'), prefix=''): # Cache dataset labels, check images and read shapes x = {} # dict From 5400ac77163a7d6f29acee2ae56ce68a7686505a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 3 Nov 2022 18:18:41 +0000 Subject: [PATCH 2/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- utils/dataloaders.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/utils/dataloaders.py b/utils/dataloaders.py index d54a344d9a2e..01527d1f5bbf 100644 --- a/utils/dataloaders.py +++ b/utils/dataloaders.py @@ -10,7 +10,6 @@ import math import os import random -import psutil import shutil import time from itertools import repeat @@ -20,6 +19,7 @@ from urllib.parse import urlparse import numpy as np +import psutil import torch import torch.nn.functional as F import torchvision @@ -31,8 +31,8 @@ from utils.augmentations import (Albumentations, augment_hsv, classify_albumentations, classify_transforms, copy_paste, cutout, letterbox, mixup, random_perspective) from utils.general import (DATASETS_DIR, LOGGER, NUM_THREADS, check_dataset, check_requirements, check_yaml, clean_str, - cv2, is_colab, is_kaggle, segments2boxes, unzip_file, xyn2xy, xywh2xyxy, xywhn2xyxy, - xyxy2xywhn, colorstr) + colorstr, cv2, is_colab, is_kaggle, segments2boxes, unzip_file, xyn2xy, xywh2xyxy, + xywhn2xyxy, xyxy2xywhn) from utils.torch_utils import torch_distributed_zero_first # Parameters From 2904e47c677736d8e5d028a64588ca393ab28913 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 3 Nov 2022 19:38:46 +0100 Subject: [PATCH 3/6] AutoCache --- utils/dataloaders.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/utils/dataloaders.py b/utils/dataloaders.py index d54a344d9a2e..527e24f0a43e 100644 --- a/utils/dataloaders.py +++ b/utils/dataloaders.py @@ -567,34 +567,33 @@ def __init__(self, # Cache images into RAM/disk for faster training (WARNING: large datasets may exceed system resources) if cache_images == 'auto': - cache_images = self.autocache() # AutoCache + cache_images = 'ram' if self.autocache() else False # AutoCache self.ims = [None] * n self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files] if cache_images: - gb = 0 # Gigabytes of cached images + b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes self.im_hw0, self.im_hw = [None] * n, [None] * n fcn = self.cache_images_to_disk if cache_images == 'disk' else self.load_image results = ThreadPool(NUM_THREADS).imap(fcn, range(n)) pbar = tqdm(enumerate(results), total=n, bar_format=BAR_FORMAT, disable=LOCAL_RANK > 0) for i, x in pbar: if cache_images == 'disk': - gb += self.npy_files[i].stat().st_size + b += self.npy_files[i].stat().st_size else: # 'ram' self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i) - gb += self.ims[i].nbytes - pbar.desc = f'{prefix}Caching images ({gb << 30:.1f}GB {cache_images})' + b += self.ims[i].nbytes + pbar.desc = f'{prefix}Caching images ({b / gb:.1f}GB {cache_images})' pbar.close() - def autocache(self, safety_margin=0.5, prefix=colorstr('AutoCache: ')): + def autocache(self, safety_margin=0.3, prefix=colorstr('AutoCache: ')): # AutoCache: check image caching requirements vs available memory - bytes = 0 # gigabytes - gb = 1 << 30 # bytes in a GB - n = min(self.n, 30) # number of samples + b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes + n = min(self.n, 30) # extrapolate from 30 random images for _ in range(n): im = cv2.imread(random.choice(self.im_files)) # sample image ratio = self.img_size / max(im.shape[0], im.shape[1]) # max(h, w) # ratio - bytes += im.nbytes * ratio ** 2 - mem_required = bytes * self.n / n # GB required to cache dataset into RAM + b += im.nbytes * ratio ** 2 + mem_required = b * self.n / n # GB required to cache dataset into RAM mem = psutil.virtual_memory() cache = mem_required * (1 + safety_margin) < mem.available # to cache or not to cache, that is the question LOGGER.info(f"{prefix}{mem_required / gb:.1f}GB RAM required, " From b40a724702a361ea9a7f10f015adee73a7a4425d Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 3 Nov 2022 20:12:27 +0100 Subject: [PATCH 4/6] AutoCache --- utils/dataloaders.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/utils/dataloaders.py b/utils/dataloaders.py index f8e7c2dbede3..6f711b5196ce 100644 --- a/utils/dataloaders.py +++ b/utils/dataloaders.py @@ -585,7 +585,7 @@ def __init__(self, pbar.desc = f'{prefix}Caching images ({b / gb:.1f}GB {cache_images})' pbar.close() - def autocache(self, safety_margin=0.3, prefix=colorstr('AutoCache: ')): + def autocache(self, safety_margin=0.3, verbose=False, prefix=colorstr('AutoCache: ')): # AutoCache: check image caching requirements vs available memory b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes n = min(self.n, 30) # extrapolate from 30 random images @@ -596,9 +596,10 @@ def autocache(self, safety_margin=0.3, prefix=colorstr('AutoCache: ')): mem_required = b * self.n / n # GB required to cache dataset into RAM mem = psutil.virtual_memory() cache = mem_required * (1 + safety_margin) < mem.available # to cache or not to cache, that is the question - LOGGER.info(f"{prefix}{mem_required / gb:.1f}GB RAM required, " - f"{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, " - f"{'caching images ✅' if cache else 'not caching images ⚠️'}") + if verbose: + LOGGER.info(f"{prefix}{mem_required / gb:.1f}GB RAM required, " + f"{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, " + f"{'caching images ✅' if cache else 'not caching images ⚠️'}") return cache def cache_labels(self, path=Path('./labels.cache'), prefix=''): From ee90a4613fec940042c249ca7e0c6150840c71d5 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 4 Nov 2022 15:13:23 +0100 Subject: [PATCH 5/6] AutoCache --- segment/train.py | 2 +- train.py | 2 +- utils/dataloaders.py | 12 ++++++------ 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/segment/train.py b/segment/train.py index bdfb6f86a9ed..f067918e7c3c 100644 --- a/segment/train.py +++ b/segment/train.py @@ -474,7 +474,7 @@ def parse_opt(known=False): parser.add_argument('--noplots', action='store_true', help='save no plot files') parser.add_argument('--evolve', type=int, nargs='?', const=300, help='evolve hyperparameters for x generations') parser.add_argument('--bucket', type=str, default='', help='gsutil bucket') - parser.add_argument('--cache', type=str, nargs='?', const='auto', help='image --cache ram/disk/auto') + parser.add_argument('--cache', type=str, nargs='?', const='ram', help='image --cache ram/disk') parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training') parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%') diff --git a/train.py b/train.py index 7537c2a4c45b..1fe6cf4d9ebd 100644 --- a/train.py +++ b/train.py @@ -444,7 +444,7 @@ def parse_opt(known=False): parser.add_argument('--noplots', action='store_true', help='save no plot files') parser.add_argument('--evolve', type=int, nargs='?', const=300, help='evolve hyperparameters for x generations') parser.add_argument('--bucket', type=str, default='', help='gsutil bucket') - parser.add_argument('--cache', type=str, nargs='?', const='auto', help='image --cache ram/disk/auto') + parser.add_argument('--cache', type=str, nargs='?', const='ram', help='image --cache ram/disk') parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training') parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%') diff --git a/utils/dataloaders.py b/utils/dataloaders.py index 6f711b5196ce..b33a24a46f9c 100644 --- a/utils/dataloaders.py +++ b/utils/dataloaders.py @@ -565,9 +565,9 @@ def __init__(self, self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(int) * stride - # Cache images into RAM/disk for faster training (WARNING: large datasets may exceed system resources) - if cache_images == 'auto': - cache_images = 'ram' if self.autocache() else False # AutoCache + # Cache images into RAM/disk for faster training + if cache_images == 'ram' and not self.check_cache_ram(prefix=prefix): + cache_images = False self.ims = [None] * n self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files] if cache_images: @@ -585,8 +585,8 @@ def __init__(self, pbar.desc = f'{prefix}Caching images ({b / gb:.1f}GB {cache_images})' pbar.close() - def autocache(self, safety_margin=0.3, verbose=False, prefix=colorstr('AutoCache: ')): - # AutoCache: check image caching requirements vs available memory + def check_cache_ram(self, safety_margin=0.1, prefix=''): + # Check image caching requirements vs available memory b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes n = min(self.n, 30) # extrapolate from 30 random images for _ in range(n): @@ -596,7 +596,7 @@ def autocache(self, safety_margin=0.3, verbose=False, prefix=colorstr('AutoCache mem_required = b * self.n / n # GB required to cache dataset into RAM mem = psutil.virtual_memory() cache = mem_required * (1 + safety_margin) < mem.available # to cache or not to cache, that is the question - if verbose: + if not cache: LOGGER.info(f"{prefix}{mem_required / gb:.1f}GB RAM required, " f"{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, " f"{'caching images ✅' if cache else 'not caching images ⚠️'}") From b071c982903eadd37679dcfbbc6853f8afa0faea Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 4 Nov 2022 15:19:06 +0100 Subject: [PATCH 6/6] AutoCache --- utils/general.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/general.py b/utils/general.py index aae466ba5c90..0c3b44d7f9b0 100644 --- a/utils/general.py +++ b/utils/general.py @@ -374,7 +374,7 @@ def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), insta if s and install and AUTOINSTALL: # check environment variable LOGGER.info(f"{prefix} YOLOv5 requirement{'s' * (n > 1)} {s}not found, attempting AutoUpdate...") try: - assert check_online(), "AutoUpdate skipped (offline)" + # assert check_online(), "AutoUpdate skipped (offline)" LOGGER.info(check_output(f'pip install {s} {cmds}', shell=True).decode()) source = file if 'file' in locals() else requirements s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \