Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature python train.py --cache disk #4049

Merged
merged 18 commits into from
Aug 2, 2021
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary

# Trainloader
train_loader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls,
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=RANK,
hyp=hyp, augment=True, cache=opt.cache_images, cache_on_disk=opt.cache_on_disk, cache_directory=opt.cache_directory, rect=opt.rect, rank=RANK,
workers=workers, image_weights=opt.image_weights, quad=opt.quad,
prefix=colorstr('train: '))
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
Expand All @@ -209,7 +209,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# Process 0
if RANK in [-1, 0]:
val_loader = create_dataloader(val_path, imgsz, batch_size // WORLD_SIZE * 2, gs, single_cls,
hyp=hyp, cache=opt.cache_images and not noval, rect=True, rank=-1,
hyp=hyp, cache=opt.cache_images and not noval, cache_on_disk=opt.cache_on_disk, cache_directory=opt.cache_directory, rect=True, rank=-1,
workers=workers, pad=0.5,
prefix=colorstr('val: '))[0]

Expand Down Expand Up @@ -440,6 +440,8 @@ def parse_opt(known=False):
parser.add_argument('--evolve', type=int, nargs='?', const=300, help='evolve hyperparameters for x generations')
parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
parser.add_argument('--cache-images', action='store_true', help='cache images for faster training')
parser.add_argument('--cache-on-disk', action='store_true', help='cache images on disk, and use the --cache-directory option together')
parser.add_argument('--cache-directory', type=str, default='', help='A directory for cache, and it is not available by default')
parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
Expand Down
65 changes: 49 additions & 16 deletions utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def get_hash(paths):
h.update(''.join(paths).encode()) # hash paths
return h.hexdigest() # return hash

def str2md5(str_arg):
return hashlib.md5(str_arg.encode()).hexdigest()

def exif_size(img):
# Returns exif-corrected PIL size
Expand Down Expand Up @@ -88,7 +90,7 @@ def exif_transpose(image):
return image


def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0,
def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, cache_on_disk=False, cache_directory="", pad=0.0,
rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix=''):
# Make sure only the first process in DDP process the dataset first, and the following others can use the cache
with torch_distributed_zero_first(rank):
Expand All @@ -97,6 +99,8 @@ def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=Non
hyp=hyp, # augmentation hyperparameters
rect=rect, # rectangular training
cache_images=cache,
cache_on_disk=cache_on_disk,
cache_directory=cache_directory,
single_cls=single_cls,
stride=int(stride),
pad=pad,
Expand Down Expand Up @@ -361,7 +365,7 @@ def img2label_paths(img_paths):

class LoadImagesAndLabels(Dataset): # for training/testing
def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
cache_images=False, single_cls=False, stride=32, pad=0.0, prefix=''):
cache_images=False, cache_on_disk=False, cache_directory="", single_cls=False, stride=32, pad=0.0, prefix=''):
self.img_size = img_size
self.augment = augment
self.hyp = hyp
Expand All @@ -372,6 +376,10 @@ def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, r
self.stride = stride
self.path = path
self.albumentations = Albumentations() if augment else None
self.cache_on_disk = cache_on_disk
self.cache_directory = cache_directory
# Use self.prefix as cache-key for on-disk-cache
self.prefix = prefix

try:
f = [] # image files
Expand All @@ -397,6 +405,11 @@ def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, r
# Check cache
self.label_files = img2label_paths(self.img_files) # labels
cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache')
if cache_directory:
cache_dir = Path(cache_directory)
if not cache_dir.is_dir():
cache_dir.mkdir(parents=True)
cache_path = Path(cache_directory + "/label_files_"+hashlib.md5(self.label_files[0].encode()).hexdigest()).with_suffix('.cache')
try:
cache, exists = np.load(cache_path, allow_pickle=True).item(), True # load dict
assert cache['version'] == 0.4 and cache['hash'] == get_hash(self.label_files + self.img_files)
Expand Down Expand Up @@ -459,11 +472,16 @@ def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, r
if cache_images:
gb = 0 # Gigabytes of cached images
self.img_hw0, self.img_hw = [None] * n, [None] * n
results = ThreadPool(NUM_THREADS).imap(lambda x: load_image(*x), zip(repeat(self), range(n)))
results = ThreadPool(NUM_THREADS).imap(lambda x: load_image(*x, no_cache=True), zip(repeat(self), range(n)))
pbar = tqdm(enumerate(results), total=n)
for i, x in pbar:
self.imgs[i], self.img_hw0[i], self.img_hw[i] = x # img, hw_original, hw_resized = load_image(self, i)
gb += self.imgs[i].nbytes
if cache_on_disk:
img, self.img_hw0[i], self.img_hw[i] = x # img, hw_original, hw_resized = load_image(self, i)
parent_path = Path(self.img_files[i]).parent
np.save(self.cache_directory+"/"+str2md5(prefix+str(parent_path))+"_"+str(i)+".npy",img)
else:
self.imgs[i], self.img_hw0[i], self.img_hw[i] = x # img, hw_original, hw_resized = load_image(self, i)
gb += self.imgs[i].nbytes
pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB)'
pbar.close()

Expand Down Expand Up @@ -618,19 +636,24 @@ def collate_fn4(batch):


# Ancillary functions --------------------------------------------------------------------------------------------------
def load_image(self, index):
def load_image(self, index, no_cache=False):
# loads 1 image from dataset, returns img, original hw, resized hw
img = self.imgs[index]
if img is None: # not cached
path = self.img_files[index]
img = cv2.imread(path) # BGR
assert img is not None, 'Image Not Found ' + path
h0, w0 = img.shape[:2] # orig hw
r = self.img_size / max(h0, w0) # ratio
if r != 1: # if sizes are not equal
img = cv2.resize(img, (int(w0 * r), int(h0 * r)),
interpolation=cv2.INTER_AREA if r < 1 and not self.augment else cv2.INTER_LINEAR)
return img, (h0, w0), img.shape[:2] # img, hw_original, hw_resized
if no_cache == False and self.cache_on_disk:
parent_path = Path(self.img_files[index]).parent
img = np.load(self.cache_directory+"/"+str2md5(self.prefix+str(parent_path)) + "_" + str(index)+".npy")
return img, self.img_hw0[index], self.img_hw[index] # img, hw_original, hw_resized
else:
path = self.img_files[index]
img = cv2.imread(path) # BGR
assert img is not None, 'Image Not Found ' + path
h0, w0 = img.shape[:2] # orig hw
r = self.img_size / max(h0, w0) # ratio
if r != 1: # if sizes are not equal
img = cv2.resize(img, (int(w0 * r), int(h0 * r)),
interpolation=cv2.INTER_AREA if r < 1 and not self.augment else cv2.INTER_LINEAR)
return img, (h0, w0), img.shape[:2] # img, hw_original, hw_resized
else:
return self.imgs[index], self.img_hw0[index], self.img_hw[index] # img, hw_original, hw_resized

Expand Down Expand Up @@ -661,6 +684,13 @@ def load_mosaic(self, index):
x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)

x2b_t = min(x2b, w)
y2b_t = min(y2b, h)
x2a = x2a - (x2b - x2b_t)
y2a = y2a - (y2b - y2b_t)
x2b = x2b_t
y2b = y2b_t

img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
padw = x1a - x1b
padh = y1a - y1b
Expand Down Expand Up @@ -923,7 +953,10 @@ def unzip(path):
x = []
dataset = LoadImagesAndLabels(data[split], augment=False, rect=True) # load dataset
if split == 'train':
cache_path = Path(dataset.label_files[0]).parent.with_suffix('.cache') # *.cache path
if cache_directory:
cache_path = Path(cache_directory + "/label_files_"+hashlib.md5(self.dataset.label_files[0].encode()).hexdigest()).with_suffix('.cache')
else:
cache_path = Path(dataset.label_files[0]).parent.with_suffix('.cache') # *.cache path
for label in tqdm(dataset.labels, total=dataset.n, desc='Statistics'):
x.append(np.bincount(label[:, 0].astype(int), minlength=nc))
x = np.array(x) # shape(128x80)
Expand Down