Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New LetterBox(size) CenterCrop(size), ToTensor() transforms (#9213) #9213

Merged
merged 13 commits into from
Aug 30, 2022
50 changes: 49 additions & 1 deletion utils/augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import cv2
import numpy as np
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as TF

Expand Down Expand Up @@ -345,4 +346,51 @@ def classify_albumentations(augment=True,
def classify_transforms(size=224):
# Transforms to apply if albumentations not installed
assert isinstance(size, int), f'ERROR: classify_transforms size {size} must be integer, not (list, tuple)'
return T.Compose([T.ToTensor(), T.Resize(size), T.CenterCrop(size), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
# T.Compose([T.ToTensor(), T.Resize(size), T.CenterCrop(size), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
return T.Compose([CenterCrop(size), ToTensor(), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])


class LetterBox:
# YOLOv5 LetterBox class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])
def __init__(self, size=(640, 640), auto=False, stride=32):
super().__init__()
self.h, self.w = (size, size) if isinstance(size, int) else size
self.auto = auto # pass max size integer, automatically solve for short side using stride
self.stride = stride # used with auto

def __call__(self, im): # im = np.array HWC
imh, imw = im.shape[:2]
r = min(self.h / imh, self.w / imw) # ratio of new/old
h, w = round(imh * r), round(imw * r) # resized image
hs, ws = (math.ceil(x / self.stride) * self.stride for x in (h, w)) if self.auto else self.h, self.w
top, left = round((hs - h) / 2 - 0.1), round((ws - w) / 2 - 0.1)
im_out = np.full((self.h, self.w, 3), 114, dtype=im.dtype)
im_out[top:top + h, left:left + w] = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR)
return im_out


class CenterCrop:
# YOLOv5 CenterCrop class for image preprocessing, i.e. T.Compose([CenterCrop(size), ToTensor()])
def __init__(self, size=640):
super().__init__()
self.h, self.w = (size, size) if isinstance(size, int) else size

def __call__(self, im): # im = np.array HWC
imh, imw = im.shape[:2]
m = min(imh, imw) # min dimension
top, left = (imh - m) // 2, (imw - m) // 2
return cv2.resize(im[top:top + m, left:left + m], (self.w, self.h), interpolation=cv2.INTER_LINEAR)


class ToTensor:
# YOLOv5 ToTensor class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])
def __init__(self, half=False):
super().__init__()
self.half = half

def __call__(self, im): # im = np.array HWC in BGR order
im = np.ascontiguousarray(im.transpose((2, 0, 1))[::-1]) # HWC to CHW -> BGR to RGB -> contiguous
im = torch.from_numpy(im) # to torch
im = im.half() if self.half else im.float() # uint8 to fp16/32
im /= 255.0 # 0-255 to 0.0-1.0
return im
22 changes: 11 additions & 11 deletions utils/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def __next__(self):
s = f'image {self.count}/{self.nf} {path}: '

if self.transforms:
im = self.transforms(cv2.cvtColor(im0, cv2.COLOR_BGR2RGB)) # transforms
im = self.transforms(im0) # transforms
else:
im = letterbox(im0, self.img_size, stride=self.stride, auto=self.auto)[0] # padded resize
im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
Expand Down Expand Up @@ -386,7 +386,7 @@ def __next__(self):

im0 = self.imgs.copy()
if self.transforms:
im = np.stack([self.transforms(cv2.cvtColor(x, cv2.COLOR_BGR2RGB)) for x in im0]) # transforms
im = np.stack([self.transforms(x) for x in im0]) # transforms
else:
im = np.stack([letterbox(x, self.img_size, stride=self.stride, auto=self.auto)[0] for x in im0]) # resize
im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW
Expand Down Expand Up @@ -1113,18 +1113,18 @@ def __init__(self, root, augment, imgsz, cache=False):

def __getitem__(self, i):
f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
if self.cache_ram and im is None:
im = self.samples[i][3] = cv2.imread(f)
elif self.cache_disk:
if not fn.exists(): # load npy
np.save(fn.as_posix(), cv2.imread(f))
im = np.load(fn)
else: # read image
im = cv2.imread(f) # BGR
if self.album_transforms:
if self.cache_ram and im is None:
im = self.samples[i][3] = cv2.imread(f)
elif self.cache_disk:
if not fn.exists(): # load npy
np.save(fn.as_posix(), cv2.imread(f))
im = np.load(fn)
else: # read image
im = cv2.imread(f) # BGR
sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))["image"]
else:
sample = self.torch_transforms(self.loader(f))
sample = self.torch_transforms(im)
return sample, j


Expand Down