From 332718ee55767991ba2c80a2c5b51d62b2bb1bee Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 30 Aug 2022 00:56:31 +0200 Subject: [PATCH 01/10] New LetterBox transform YOLOv5 LetterBox class for image preprocessing, i.e. T.Compose([T.ToTensor(), LetterBox(size)]) Signed-off-by: Glenn Jocher --- utils/augmentations.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/utils/augmentations.py b/utils/augmentations.py index c8499b3fc8ae..d92d5f0d65cb 100644 --- a/utils/augmentations.py +++ b/utils/augmentations.py @@ -346,3 +346,26 @@ def classify_transforms(size=224): # Transforms to apply if albumentations not installed assert isinstance(size, int), f'ERROR: classify_transforms size {size} must be integer, not (list, tuple)' return T.Compose([T.ToTensor(), T.Resize(size), T.CenterCrop(size), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)]) + + +class LetterBox: + # YOLOv5 LetterBox class for image preprocessing, i.e. T.Compose([T.ToTensor(), LetterBox(size)]) + + def __init__(self, size=(640, 640), auto=False, stride=32): + super().__init__() + self.h, self.w = (size, size) if isinstance(size, int) else size + self.auto = auto # pass max size integer, automatically solve for short side using stride + self.stride = stride # used with auto + + def __call__(self, im): + imh, imw = im.shape[1:] + + r = min(self.h / imh, self.w / imw) # ratio of new/old + h, w = round(imh * r), round(imw * r) # resized image + hs, ws = (math.ceil(x / self.stride) * self.stride for x in (h, w)) if self.auto else self.h, self.w + top = round((hs - h) / 2 - 0.1) + left = round((ws - w) / 2 - 0.1) + + im_out = im.new_full((3, self.h, self.w), 0.44706) + im_out[:, top:top + h, left:left + w] = TF.resize(im, [h, w]) + return im_out From ef6e1a47111f794ea4eed78e8120fb8325f1ffa0 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 30 Aug 2022 01:17:33 +0200 Subject: [PATCH 02/10] Update augmentations.py Signed-off-by: Glenn Jocher --- utils/augmentations.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/utils/augmentations.py b/utils/augmentations.py index d92d5f0d65cb..43875aae27fd 100644 --- a/utils/augmentations.py +++ b/utils/augmentations.py @@ -359,12 +359,10 @@ def __init__(self, size=(640, 640), auto=False, stride=32): def __call__(self, im): imh, imw = im.shape[1:] - r = min(self.h / imh, self.w / imw) # ratio of new/old h, w = round(imh * r), round(imw * r) # resized image hs, ws = (math.ceil(x / self.stride) * self.stride for x in (h, w)) if self.auto else self.h, self.w - top = round((hs - h) / 2 - 0.1) - left = round((ws - w) / 2 - 0.1) + top, left = round((hs - h) / 2 - 0.1), round((ws - w) / 2 - 0.1) im_out = im.new_full((3, self.h, self.w), 0.44706) im_out[:, top:top + h, left:left + w] = TF.resize(im, [h, w]) From 46e24d803ecd26bd91b02e9be0e860d65c312ea0 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 30 Aug 2022 10:52:33 +0200 Subject: [PATCH 03/10] Update augmentations.py Signed-off-by: Glenn Jocher --- utils/augmentations.py | 40 +++++++++++++++++++++++++++++++++------- 1 file changed, 33 insertions(+), 7 deletions(-) diff --git a/utils/augmentations.py b/utils/augmentations.py index 43875aae27fd..f95ddbe657eb 100644 --- a/utils/augmentations.py +++ b/utils/augmentations.py @@ -349,21 +349,47 @@ def classify_transforms(size=224): class LetterBox: - # YOLOv5 LetterBox class for image preprocessing, i.e. T.Compose([T.ToTensor(), LetterBox(size)]) - + # YOLOv5 LetterBox class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()]) def __init__(self, size=(640, 640), auto=False, stride=32): super().__init__() self.h, self.w = (size, size) if isinstance(size, int) else size self.auto = auto # pass max size integer, automatically solve for short side using stride self.stride = stride # used with auto - def __call__(self, im): - imh, imw = im.shape[1:] + def __call__(self, im): # im = np.array HWC + imh, imw = im.shape[:2] r = min(self.h / imh, self.w / imw) # ratio of new/old h, w = round(imh * r), round(imw * r) # resized image hs, ws = (math.ceil(x / self.stride) * self.stride for x in (h, w)) if self.auto else self.h, self.w top, left = round((hs - h) / 2 - 0.1), round((ws - w) / 2 - 0.1) - - im_out = im.new_full((3, self.h, self.w), 0.44706) - im_out[:, top:top + h, left:left + w] = TF.resize(im, [h, w]) + im_out = np.full((self.h, self.w, 3), 114, dtype=im.dtype) + im_out[top:top + h, left:left + w] = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR) return im_out + + +class CenterCrop: + # YOLOv5 CenterCrop class for image preprocessing, i.e. T.Compose([CenterCrop(size), ToTensor()]) + def __init__(self, size=640): + super().__init__() + self.h, self.w = (size, size) if isinstance(size, int) else size + + def __call__(self, im): # im = np.array HWC + imh, imw = im.shape[:2] + m = min(imh, imw) # min dimension + top, left = round((imh - m) / 2), round((imw - m) / 2 - 0.1) + return cv2.resize(im[top:top + m, left:left + m], (self.w, self.h), interpolation=cv2.INTER_LINEAR) + + +class ToTensor: + # YOLOv5 ToTensor class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()]) + def __init__(self, half=False): + super().__init__() + self.half = half + + def __call__(self, im): # im = np.array HWC in BGR order + im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB + im = np.ascontiguousarray(im) # contiguous + im = torch.from_numpy(im) + im = im.half() if self.half else im.float() # uint8 to fp16/32 + im /= 255.0 # 0 - 255 to 0.0 - 1.0 + return im From 91d2a40ea36882c5904f28d5e79f2fedf4bcf29e Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 30 Aug 2022 11:31:44 +0200 Subject: [PATCH 04/10] Update augmentations.py Signed-off-by: Glenn Jocher --- utils/augmentations.py | 1 + 1 file changed, 1 insertion(+) diff --git a/utils/augmentations.py b/utils/augmentations.py index f95ddbe657eb..6e4d50d15657 100644 --- a/utils/augmentations.py +++ b/utils/augmentations.py @@ -7,6 +7,7 @@ import random import cv2 +import torch import numpy as np import torchvision.transforms as T import torchvision.transforms.functional as TF From 0d1a009fc1f2c48721cbe556a32b18bdcbda30e7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 30 Aug 2022 09:32:05 +0000 Subject: [PATCH 05/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- utils/augmentations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/augmentations.py b/utils/augmentations.py index 6e4d50d15657..2e89ac53bfa2 100644 --- a/utils/augmentations.py +++ b/utils/augmentations.py @@ -7,8 +7,8 @@ import random import cv2 -import torch import numpy as np +import torch import torchvision.transforms as T import torchvision.transforms.functional as TF From 3150861dcf6ce1116f4e0eee3907a09f9dc94b57 Mon Sep 17 00:00:00 2001 From: glennjocher Date: Tue, 30 Aug 2022 12:10:17 +0200 Subject: [PATCH 06/10] cleanup --- utils/augmentations.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/utils/augmentations.py b/utils/augmentations.py index 2e89ac53bfa2..e6d6777bfb33 100644 --- a/utils/augmentations.py +++ b/utils/augmentations.py @@ -346,7 +346,8 @@ def classify_albumentations(augment=True, def classify_transforms(size=224): # Transforms to apply if albumentations not installed assert isinstance(size, int), f'ERROR: classify_transforms size {size} must be integer, not (list, tuple)' - return T.Compose([T.ToTensor(), T.Resize(size), T.CenterCrop(size), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)]) + # T.Compose([T.ToTensor(), T.Resize(size), T.CenterCrop(size), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)]) + return T.Compose([CenterCrop(size), ToTensor(), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)]) class LetterBox: From 51199378c206953d89d89b5a1809b314186f207c Mon Sep 17 00:00:00 2001 From: glennjocher Date: Tue, 30 Aug 2022 12:27:20 +0200 Subject: [PATCH 07/10] cleanup --- utils/augmentations.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/utils/augmentations.py b/utils/augmentations.py index e6d6777bfb33..1b83331662e1 100644 --- a/utils/augmentations.py +++ b/utils/augmentations.py @@ -389,9 +389,8 @@ def __init__(self, half=False): self.half = half def __call__(self, im): # im = np.array HWC in BGR order - im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB - im = np.ascontiguousarray(im) # contiguous - im = torch.from_numpy(im) + im = np.ascontiguousarray(im.transpose((2, 0, 1))[::-1]) # HWC to CHW -> BGR to RGB -> contiguous + im = torch.from_numpy(im) # to torch im = im.half() if self.half else im.float() # uint8 to fp16/32 - im /= 255.0 # 0 - 255 to 0.0 - 1.0 + im /= 255.0 # 0-255 to 0.0-1.0 return im From ed9e8de61d8c43019536e1a2c975a95174e2750b Mon Sep 17 00:00:00 2001 From: glennjocher Date: Tue, 30 Aug 2022 12:36:00 +0200 Subject: [PATCH 08/10] cleanup --- utils/dataloaders.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/utils/dataloaders.py b/utils/dataloaders.py index 675c2898e7d7..fd4b69cdfa38 100755 --- a/utils/dataloaders.py +++ b/utils/dataloaders.py @@ -1113,18 +1113,18 @@ def __init__(self, root, augment, imgsz, cache=False): def __getitem__(self, i): f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image + if self.cache_ram and im is None: + im = self.samples[i][3] = cv2.imread(f) + elif self.cache_disk: + if not fn.exists(): # load npy + np.save(fn.as_posix(), cv2.imread(f)) + im = np.load(fn) + else: # read image + im = cv2.imread(f) # BGR if self.album_transforms: - if self.cache_ram and im is None: - im = self.samples[i][3] = cv2.imread(f) - elif self.cache_disk: - if not fn.exists(): # load npy - np.save(fn.as_posix(), cv2.imread(f)) - im = np.load(fn) - else: # read image - im = cv2.imread(f) # BGR sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))["image"] else: - sample = self.torch_transforms(self.loader(f)) + sample = self.torch_transforms(im) return sample, j From 85f1a9ffa1d54896b9955de39bc25e548e1f3e09 Mon Sep 17 00:00:00 2001 From: glennjocher Date: Tue, 30 Aug 2022 13:12:02 +0200 Subject: [PATCH 09/10] cleanup --- utils/augmentations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/augmentations.py b/utils/augmentations.py index 1b83331662e1..a5587351f75b 100644 --- a/utils/augmentations.py +++ b/utils/augmentations.py @@ -378,7 +378,7 @@ def __init__(self, size=640): def __call__(self, im): # im = np.array HWC imh, imw = im.shape[:2] m = min(imh, imw) # min dimension - top, left = round((imh - m) / 2), round((imw - m) / 2 - 0.1) + top, left = (imh - m) // 2, (imw - m) // 2 return cv2.resize(im[top:top + m, left:left + m], (self.w, self.h), interpolation=cv2.INTER_LINEAR) From 79862a02b7a1c20198681cca8f47384e5f2754f1 Mon Sep 17 00:00:00 2001 From: glennjocher Date: Tue, 30 Aug 2022 13:19:58 +0200 Subject: [PATCH 10/10] cleanup --- utils/dataloaders.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/utils/dataloaders.py b/utils/dataloaders.py index fd4b69cdfa38..79d5dbf58d07 100755 --- a/utils/dataloaders.py +++ b/utils/dataloaders.py @@ -251,7 +251,7 @@ def __next__(self): s = f'image {self.count}/{self.nf} {path}: ' if self.transforms: - im = self.transforms(cv2.cvtColor(im0, cv2.COLOR_BGR2RGB)) # transforms + im = self.transforms(im0) # transforms else: im = letterbox(im0, self.img_size, stride=self.stride, auto=self.auto)[0] # padded resize im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB @@ -386,7 +386,7 @@ def __next__(self): im0 = self.imgs.copy() if self.transforms: - im = np.stack([self.transforms(cv2.cvtColor(x, cv2.COLOR_BGR2RGB)) for x in im0]) # transforms + im = np.stack([self.transforms(x) for x in im0]) # transforms else: im = np.stack([letterbox(x, self.img_size, stride=self.stride, auto=self.auto)[0] for x in im0]) # resize im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW