Skip to content

Commit

Permalink
Cleanup2
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher committed Aug 2, 2021
1 parent 0419df1 commit fa7e4ce
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 43 deletions.
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ scipy>=1.4.1
torch>=1.7.0
torchvision>=0.8.1
tqdm>=4.41.0
psutil

# logging -------------------------------------
tensorboard>=2.4.1
Expand Down
4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary

# Trainloader
train_loader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls,
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=RANK,
hyp=hyp, augment=True, cache=opt.cache, rect=opt.rect, rank=RANK,
workers=workers, image_weights=opt.image_weights, quad=opt.quad,
prefix=colorstr('train: '))
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
Expand All @@ -211,7 +211,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# Process 0
if RANK in [-1, 0]:
val_loader = create_dataloader(val_path, imgsz, batch_size // WORLD_SIZE * 2, gs, single_cls,
hyp=hyp, cache=opt.cache_images and not noval, rect=True, rank=-1,
hyp=hyp, cache=None if noval else opt.cache, rect=True, rank=-1,
workers=workers, pad=0.5,
prefix=colorstr('val: '))[0]

Expand Down
73 changes: 33 additions & 40 deletions utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
from PIL import Image, ExifTags
from torch.utils.data import Dataset
from tqdm import tqdm
import psutil
import re

from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective
from utils.general import check_requirements, check_file, check_dataset, xywh2xyxy, xywhn2xyxy, xyxy2xywhn, \
Expand All @@ -48,6 +46,7 @@ def get_hash(paths):
h.update(''.join(paths).encode()) # hash paths
return h.hexdigest() # return hash


def exif_size(img):
# Returns exif-corrected PIL size
s = img.size # (width, height)
Expand Down Expand Up @@ -373,8 +372,6 @@ def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, r
self.stride = stride
self.path = path
self.albumentations = Albumentations() if augment else None
self.cache_images = cache_images
self.prefix = prefix

try:
f = [] # image files
Expand All @@ -400,10 +397,6 @@ def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, r
# Check cache
self.label_files = img2label_paths(self.img_files) # labels
cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache')
if cache_images == "disk":
cache_dir = Path(self.img_files[0]).parent / "images_npy"
if not cache_dir.is_dir():
cache_dir.mkdir(parents=True)
try:
cache, exists = np.load(cache_path, allow_pickle=True).item(), True # load dict
assert cache['version'] == 0.4 and cache['hash'] == get_hash(self.label_files + self.img_files)
Expand Down Expand Up @@ -462,24 +455,25 @@ def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, r
self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride

# Cache images into memory for faster training (WARNING: large datasets may exceed system RAM)
self.imgs = [None] * n
self.imgs, self.img_npy = [None] * n, [None] * n
if cache_images:
if cache_images == 'disk':
self.im_cache_dir = Path(Path(self.img_files[0]).parent.as_posix() + '_npy')
self.img_npy = [self.im_cache_dir / Path(f).with_suffix('.npy').name for f in self.img_files]
self.im_cache_dir.mkdir(parents=True, exist_ok=True)
gb = 0 # Gigabytes of cached images
self.img_hw0, self.img_hw = [None] * n, [None] * n
results = ThreadPool(NUM_THREADS).imap(lambda x: load_image(*x, no_cache=True), zip(repeat(self), range(n)))
results = ThreadPool(NUM_THREADS).imap(lambda x: load_image(*x), zip(repeat(self), range(n)))
pbar = tqdm(enumerate(results), total=n)
if cache_images == "disk":
parent_path = Path(self.img_files[0]).parent / "images_npy"
disk = psutil.disk_usage(parent_path)
for i, x in pbar:
img, self.img_hw0[i], self.img_hw[i] = x # img, hw_original, hw_resized = load_image(self, i)
np.save(parent_path / (self.prefix + str(i)+".npy"),img)
pbar.desc = f'{prefix}Disk usage for cache({disk.used / 1E9:.1f}GB / {disk.total / 1E9:.1f}GB = {disk.percent}%)'
else:
for i, x in pbar:
self.imgs[i], self.img_hw0[i], self.img_hw[i] = x # img, hw_original, hw_resized = load_image(self, i)
for i, x in pbar:
im, self.img_hw0[i], self.img_hw[i] = x # img, hw_original, hw_resized = load_image(self, i)
if cache_images == 'disk':
np.save(self.img_npy[i].as_posix(), im)
gb += self.img_npy[i].stat().st_size
else:
self.imgs[i] = im
gb += self.imgs[i].nbytes
pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB)'
pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB {cache_images})'
pbar.close()

def cache_labels(self, path=Path('./labels.cache'), prefix=''):
Expand Down Expand Up @@ -633,26 +627,25 @@ def collate_fn4(batch):


# Ancillary functions --------------------------------------------------------------------------------------------------
def load_image(self, index, no_cache=False):
# loads 1 image from dataset, returns img, original hw, resized hw
img = self.imgs[index]
if img is None: # not cached
if no_cache == False and self.cache_images == "disk":
parent_path = Path(self.img_files[index]).parent
img = np.load(parent_path / "images_npy" / (self.prefix + str(index) + ".npy"))
return img, self.img_hw0[index], self.img_hw[index] # img, hw_original, hw_resized
else:
path = self.img_files[index]
img = cv2.imread(path) # BGR
assert img is not None, 'Image Not Found ' + path
h0, w0 = img.shape[:2] # orig hw
r = self.img_size / max(h0, w0) # ratio
if r != 1: # if sizes are not equal
img = cv2.resize(img, (int(w0 * r), int(h0 * r)),
interpolation=cv2.INTER_AREA if r < 1 and not self.augment else cv2.INTER_LINEAR)
return img, (h0, w0), img.shape[:2] # img, hw_original, hw_resized
def load_image(self, i):
# loads 1 image from dataset index 'i', returns im, original hw, resized hw
im = self.imgs[i]
if im is None: # not cached in ram
npy = self.img_npy[i]
if npy and npy.exists(): # load npy
im = np.load(npy)
else: # read image
path = self.img_files[i]
im = cv2.imread(path) # BGR
assert im is not None, 'Image Not Found ' + path
h0, w0 = im.shape[:2] # orig hw
r = self.img_size / max(h0, w0) # ratio
if r != 1: # if sizes are not equal
im = cv2.resize(im, (int(w0 * r), int(h0 * r)),
interpolation=cv2.INTER_AREA if r < 1 and not self.augment else cv2.INTER_LINEAR)
return im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
else:
return self.imgs[index], self.img_hw0[index], self.img_hw[index] # img, hw_original, hw_resized
return self.imgs[i], self.img_hw0[i], self.img_hw[i] # im, hw_original, hw_resized


def load_mosaic(self, index):
Expand Down

0 comments on commit fa7e4ce

Please sign in to comment.