Skip to content

Commit

Permalink
YOLOv5 AutoCache Update (#10027)
Browse files Browse the repository at this point in the history
* AutoCache

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* AutoCache

* AutoCache

* AutoCache

* AutoCache

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
glenn-jocher and pre-commit-ci[bot] committed Nov 4, 2022
1 parent 02b8a4c commit fde7758
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 11 deletions.
2 changes: 1 addition & 1 deletion segment/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ def parse_opt(known=False):
parser.add_argument('--noplots', action='store_true', help='save no plot files')
parser.add_argument('--evolve', type=int, nargs='?', const=300, help='evolve hyperparameters for x generations')
parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
parser.add_argument('--cache', type=str, nargs='?', const='ram', help='--cache images in "ram" (default) or "disk"')
parser.add_argument('--cache', type=str, nargs='?', const='ram', help='image --cache ram/disk')
parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ def parse_opt(known=False):
parser.add_argument('--noplots', action='store_true', help='save no plot files')
parser.add_argument('--evolve', type=int, nargs='?', const=300, help='evolve hyperparameters for x generations')
parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
parser.add_argument('--cache', type=str, nargs='?', const='ram', help='--cache images in "ram" (default) or "disk"')
parser.add_argument('--cache', type=str, nargs='?', const='ram', help='image --cache ram/disk')
parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
Expand Down
1 change: 0 additions & 1 deletion utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def notebook_init(verbose=True):
from utils.general import check_font, check_requirements, is_colab
from utils.torch_utils import select_device # imports

check_requirements(('psutil', 'IPython'))
check_font()

import psutil
Expand Down
34 changes: 27 additions & 7 deletions utils/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from urllib.parse import urlparse

import numpy as np
import psutil
import torch
import torch.nn.functional as F
import torchvision
Expand All @@ -30,8 +31,8 @@
from utils.augmentations import (Albumentations, augment_hsv, classify_albumentations, classify_transforms, copy_paste,
cutout, letterbox, mixup, random_perspective)
from utils.general import (DATASETS_DIR, LOGGER, NUM_THREADS, check_dataset, check_requirements, check_yaml, clean_str,
cv2, is_colab, is_kaggle, segments2boxes, unzip_file, xyn2xy, xywh2xyxy, xywhn2xyxy,
xyxy2xywhn)
colorstr, cv2, is_colab, is_kaggle, segments2boxes, unzip_file, xyn2xy, xywh2xyxy,
xywhn2xyxy, xyxy2xywhn)
from utils.torch_utils import torch_distributed_zero_first

# Parameters
Expand Down Expand Up @@ -564,24 +565,43 @@ def __init__(self,

self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(int) * stride

# Cache images into RAM/disk for faster training (WARNING: large datasets may exceed system resources)
# Cache images into RAM/disk for faster training
if cache_images == 'ram' and not self.check_cache_ram(prefix=prefix):
cache_images = False
self.ims = [None] * n
self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files]
if cache_images:
gb = 0 # Gigabytes of cached images
b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
self.im_hw0, self.im_hw = [None] * n, [None] * n
fcn = self.cache_images_to_disk if cache_images == 'disk' else self.load_image
results = ThreadPool(NUM_THREADS).imap(fcn, range(n))
pbar = tqdm(enumerate(results), total=n, bar_format=BAR_FORMAT, disable=LOCAL_RANK > 0)
for i, x in pbar:
if cache_images == 'disk':
gb += self.npy_files[i].stat().st_size
b += self.npy_files[i].stat().st_size
else: # 'ram'
self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
gb += self.ims[i].nbytes
pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB {cache_images})'
b += self.ims[i].nbytes
pbar.desc = f'{prefix}Caching images ({b / gb:.1f}GB {cache_images})'
pbar.close()

def check_cache_ram(self, safety_margin=0.1, prefix=''):
# Check image caching requirements vs available memory
b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
n = min(self.n, 30) # extrapolate from 30 random images
for _ in range(n):
im = cv2.imread(random.choice(self.im_files)) # sample image
ratio = self.img_size / max(im.shape[0], im.shape[1]) # max(h, w) # ratio
b += im.nbytes * ratio ** 2
mem_required = b * self.n / n # GB required to cache dataset into RAM
mem = psutil.virtual_memory()
cache = mem_required * (1 + safety_margin) < mem.available # to cache or not to cache, that is the question
if not cache:
LOGGER.info(f"{prefix}{mem_required / gb:.1f}GB RAM required, "
f"{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, "
f"{'caching images ✅' if cache else 'not caching images ⚠️'}")
return cache

def cache_labels(self, path=Path('./labels.cache'), prefix=''):
# Cache dataset labels, check images and read shapes
x = {} # dict
Expand Down
2 changes: 1 addition & 1 deletion utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), insta
if s and install and AUTOINSTALL: # check environment variable
LOGGER.info(f"{prefix} YOLOv5 requirement{'s' * (n > 1)} {s}not found, attempting AutoUpdate...")
try:
assert check_online(), "AutoUpdate skipped (offline)"
# assert check_online(), "AutoUpdate skipped (offline)"
LOGGER.info(check_output(f'pip install {s} {cmds}', shell=True).decode())
source = file if 'file' in locals() else requirements
s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \
Expand Down

0 comments on commit fde7758

Please sign in to comment.