Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature python train.py --cache disk #4049

Merged
merged 18 commits into from
Aug 2, 2021
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ scipy>=1.4.1
torch>=1.7.0
torchvision>=0.8.1
tqdm>=4.41.0
psutil

# logging -------------------------------------
tensorboard>=2.4.1
Expand Down
6 changes: 3 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary

# Trainloader
train_loader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls,
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=RANK,
hyp=hyp, augment=True, cache_device=opt.cache, rect=opt.rect, rank=RANK,
workers=workers, image_weights=opt.image_weights, quad=opt.quad,
prefix=colorstr('train: '))
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
Expand All @@ -211,7 +211,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# Process 0
if RANK in [-1, 0]:
val_loader = create_dataloader(val_path, imgsz, batch_size // WORLD_SIZE * 2, gs, single_cls,
hyp=hyp, cache=opt.cache_images and not noval, rect=True, rank=-1,
hyp=hyp, cache_device= ('' if noval else opt.cache), rect=True, rank=-1,
workers=workers, pad=0.5,
prefix=colorstr('val: '))[0]

Expand Down Expand Up @@ -430,7 +430,7 @@ def parse_opt(known=False):
parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check')
parser.add_argument('--evolve', type=int, nargs='?', const=300, help='evolve hyperparameters for x generations')
parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
parser.add_argument('--cache-images', action='store_true', help='cache images for faster training')
parser.add_argument('--cache', type=str, nargs='?', const='ram', help='--cache images in "ram" (default) or "disk"')
parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
Expand Down
64 changes: 44 additions & 20 deletions utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from PIL import Image, ExifTags
from torch.utils.data import Dataset
from tqdm import tqdm
import psutil
import re

from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective
from utils.general import check_requirements, check_file, check_dataset, xywh2xyxy, xywhn2xyxy, xyxy2xywhn, \
Expand All @@ -46,7 +48,6 @@ def get_hash(paths):
h.update(''.join(paths).encode()) # hash paths
return h.hexdigest() # return hash


def exif_size(img):
# Returns exif-corrected PIL size
s = img.size # (width, height)
Expand Down Expand Up @@ -88,15 +89,15 @@ def exif_transpose(image):
return image


def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0,
def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache_device='', pad=0.0,
rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix=''):
# Make sure only the first process in DDP process the dataset first, and the following others can use the cache
with torch_distributed_zero_first(rank):
dataset = LoadImagesAndLabels(path, imgsz, batch_size,
augment=augment, # augment images
hyp=hyp, # augmentation hyperparameters
rect=rect, # rectangular training
cache_images=cache,
cache_device=cache_device,
single_cls=single_cls,
stride=int(stride),
pad=pad,
Expand Down Expand Up @@ -361,7 +362,7 @@ def img2label_paths(img_paths):

class LoadImagesAndLabels(Dataset): # for training/testing
def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
cache_images=False, single_cls=False, stride=32, pad=0.0, prefix=''):
cache_device='', single_cls=False, stride=32, pad=0.0, prefix=''):
self.img_size = img_size
self.augment = augment
self.hyp = hyp
Expand All @@ -372,6 +373,12 @@ def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, r
self.stride = stride
self.path = path
self.albumentations = Albumentations() if augment else None
self.cache_device = cache_device
# Use self.prefix as cache-key for on-disk-cache
self.prefix = re.sub(r'\x1B\[([0-9]{1,2}(;[0-9]{1,2})?)?[m|K]', "", prefix).replace(' ','_')

if cache_device != '' and cache_device != 'ram' and cache_device != 'disk':
raise Exception(f'{cache_device} is set in cache_device. It should be ram or disk.')

try:
f = [] # image files
Expand All @@ -397,6 +404,10 @@ def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, r
# Check cache
self.label_files = img2label_paths(self.img_files) # labels
cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache')
if cache_device == "disk":
cache_dir = Path(self.img_files[0]).parent / "images_npy"
glenn-jocher marked this conversation as resolved.
Show resolved Hide resolved
if not cache_dir.is_dir():
cache_dir.mkdir(parents=True)
try:
cache, exists = np.load(cache_path, allow_pickle=True).item(), True # load dict
assert cache['version'] == 0.4 and cache['hash'] == get_hash(self.label_files + self.img_files)
Expand Down Expand Up @@ -456,15 +467,23 @@ def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, r

# Cache images into memory for faster training (WARNING: large datasets may exceed system RAM)
self.imgs = [None] * n
if cache_images:
if cache_device != '':
gb = 0 # Gigabytes of cached images
self.img_hw0, self.img_hw = [None] * n, [None] * n
results = ThreadPool(NUM_THREADS).imap(lambda x: load_image(*x), zip(repeat(self), range(n)))
results = ThreadPool(NUM_THREADS).imap(lambda x: load_image(*x, no_cache=True), zip(repeat(self), range(n)))
pbar = tqdm(enumerate(results), total=n)
for i, x in pbar:
self.imgs[i], self.img_hw0[i], self.img_hw[i] = x # img, hw_original, hw_resized = load_image(self, i)
gb += self.imgs[i].nbytes
pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB)'
if cache_device == "disk":
parent_path = Path(self.img_files[0]).parent / "images_npy"
disk = psutil.disk_usage(parent_path)
for i, x in pbar:
img, self.img_hw0[i], self.img_hw[i] = x # img, hw_original, hw_resized = load_image(self, i)
np.save(parent_path / (self.prefix + str(i)+".npy"),img)
pbar.desc = f'{prefix}Disk usage for cache({disk.used / 1E9:.1f}GB / {disk.total / 1E9:.1f}GB = {disk.percent}%)'
glenn-jocher marked this conversation as resolved.
Show resolved Hide resolved
else:
for i, x in pbar:
self.imgs[i], self.img_hw0[i], self.img_hw[i] = x # img, hw_original, hw_resized = load_image(self, i)
gb += self.imgs[i].nbytes
pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB)'
pbar.close()

def cache_labels(self, path=Path('./labels.cache'), prefix=''):
Expand Down Expand Up @@ -618,19 +637,24 @@ def collate_fn4(batch):


# Ancillary functions --------------------------------------------------------------------------------------------------
def load_image(self, index):
def load_image(self, index, no_cache=False):
# loads 1 image from dataset, returns img, original hw, resized hw
img = self.imgs[index]
if img is None: # not cached
path = self.img_files[index]
img = cv2.imread(path) # BGR
assert img is not None, 'Image Not Found ' + path
h0, w0 = img.shape[:2] # orig hw
r = self.img_size / max(h0, w0) # ratio
if r != 1: # if sizes are not equal
img = cv2.resize(img, (int(w0 * r), int(h0 * r)),
interpolation=cv2.INTER_AREA if r < 1 and not self.augment else cv2.INTER_LINEAR)
return img, (h0, w0), img.shape[:2] # img, hw_original, hw_resized
if no_cache == False and self.cache_device == "disk":
parent_path = Path(self.img_files[index]).parent
img = np.load(parent_path / "images_npy" / (self.prefix + str(index) + ".npy"))
return img, self.img_hw0[index], self.img_hw[index] # img, hw_original, hw_resized
else:
path = self.img_files[index]
img = cv2.imread(path) # BGR
assert img is not None, 'Image Not Found ' + path
h0, w0 = img.shape[:2] # orig hw
r = self.img_size / max(h0, w0) # ratio
if r != 1: # if sizes are not equal
img = cv2.resize(img, (int(w0 * r), int(h0 * r)),
interpolation=cv2.INTER_AREA if r < 1 and not self.augment else cv2.INTER_LINEAR)
return img, (h0, w0), img.shape[:2] # img, hw_original, hw_resized
else:
return self.imgs[index], self.img_hw0[index], self.img_hw[index] # img, hw_original, hw_resized

Expand Down