diff --git a/segment/train.py b/segment/train.py index 7950f95df4f2..f067918e7c3c 100644 --- a/segment/train.py +++ b/segment/train.py @@ -474,7 +474,7 @@ def parse_opt(known=False): parser.add_argument('--noplots', action='store_true', help='save no plot files') parser.add_argument('--evolve', type=int, nargs='?', const=300, help='evolve hyperparameters for x generations') parser.add_argument('--bucket', type=str, default='', help='gsutil bucket') - parser.add_argument('--cache', type=str, nargs='?', const='ram', help='--cache images in "ram" (default) or "disk"') + parser.add_argument('--cache', type=str, nargs='?', const='ram', help='image --cache ram/disk') parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training') parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%') diff --git a/train.py b/train.py index e882748581bf..1fe6cf4d9ebd 100644 --- a/train.py +++ b/train.py @@ -444,7 +444,7 @@ def parse_opt(known=False): parser.add_argument('--noplots', action='store_true', help='save no plot files') parser.add_argument('--evolve', type=int, nargs='?', const=300, help='evolve hyperparameters for x generations') parser.add_argument('--bucket', type=str, default='', help='gsutil bucket') - parser.add_argument('--cache', type=str, nargs='?', const='ram', help='--cache images in "ram" (default) or "disk"') + parser.add_argument('--cache', type=str, nargs='?', const='ram', help='image --cache ram/disk') parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training') parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%') diff --git a/utils/__init__.py b/utils/__init__.py index 0afe6f475625..8354d91c4269 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -47,7 +47,6 @@ def notebook_init(verbose=True): from utils.general import check_font, check_requirements, is_colab from utils.torch_utils import select_device # imports - check_requirements(('psutil', 'IPython')) check_font() import psutil diff --git a/utils/dataloaders.py b/utils/dataloaders.py index 4e5b75edb5c2..b33a24a46f9c 100644 --- a/utils/dataloaders.py +++ b/utils/dataloaders.py @@ -19,6 +19,7 @@ from urllib.parse import urlparse import numpy as np +import psutil import torch import torch.nn.functional as F import torchvision @@ -30,8 +31,8 @@ from utils.augmentations import (Albumentations, augment_hsv, classify_albumentations, classify_transforms, copy_paste, cutout, letterbox, mixup, random_perspective) from utils.general import (DATASETS_DIR, LOGGER, NUM_THREADS, check_dataset, check_requirements, check_yaml, clean_str, - cv2, is_colab, is_kaggle, segments2boxes, unzip_file, xyn2xy, xywh2xyxy, xywhn2xyxy, - xyxy2xywhn) + colorstr, cv2, is_colab, is_kaggle, segments2boxes, unzip_file, xyn2xy, xywh2xyxy, + xywhn2xyxy, xyxy2xywhn) from utils.torch_utils import torch_distributed_zero_first # Parameters @@ -564,24 +565,43 @@ def __init__(self, self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(int) * stride - # Cache images into RAM/disk for faster training (WARNING: large datasets may exceed system resources) + # Cache images into RAM/disk for faster training + if cache_images == 'ram' and not self.check_cache_ram(prefix=prefix): + cache_images = False self.ims = [None] * n self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files] if cache_images: - gb = 0 # Gigabytes of cached images + b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes self.im_hw0, self.im_hw = [None] * n, [None] * n fcn = self.cache_images_to_disk if cache_images == 'disk' else self.load_image results = ThreadPool(NUM_THREADS).imap(fcn, range(n)) pbar = tqdm(enumerate(results), total=n, bar_format=BAR_FORMAT, disable=LOCAL_RANK > 0) for i, x in pbar: if cache_images == 'disk': - gb += self.npy_files[i].stat().st_size + b += self.npy_files[i].stat().st_size else: # 'ram' self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i) - gb += self.ims[i].nbytes - pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB {cache_images})' + b += self.ims[i].nbytes + pbar.desc = f'{prefix}Caching images ({b / gb:.1f}GB {cache_images})' pbar.close() + def check_cache_ram(self, safety_margin=0.1, prefix=''): + # Check image caching requirements vs available memory + b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes + n = min(self.n, 30) # extrapolate from 30 random images + for _ in range(n): + im = cv2.imread(random.choice(self.im_files)) # sample image + ratio = self.img_size / max(im.shape[0], im.shape[1]) # max(h, w) # ratio + b += im.nbytes * ratio ** 2 + mem_required = b * self.n / n # GB required to cache dataset into RAM + mem = psutil.virtual_memory() + cache = mem_required * (1 + safety_margin) < mem.available # to cache or not to cache, that is the question + if not cache: + LOGGER.info(f"{prefix}{mem_required / gb:.1f}GB RAM required, " + f"{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, " + f"{'caching images ✅' if cache else 'not caching images ⚠️'}") + return cache + def cache_labels(self, path=Path('./labels.cache'), prefix=''): # Cache dataset labels, check images and read shapes x = {} # dict diff --git a/utils/general.py b/utils/general.py index aae466ba5c90..0c3b44d7f9b0 100644 --- a/utils/general.py +++ b/utils/general.py @@ -374,7 +374,7 @@ def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), insta if s and install and AUTOINSTALL: # check environment variable LOGGER.info(f"{prefix} YOLOv5 requirement{'s' * (n > 1)} {s}not found, attempting AutoUpdate...") try: - assert check_online(), "AutoUpdate skipped (offline)" + # assert check_online(), "AutoUpdate skipped (offline)" LOGGER.info(check_output(f'pip install {s} {cmds}', shell=True).decode()) source = file if 'file' in locals() else requirements s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \