From abdac079dce0c08648f67344f92c19c4dbf4f582 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 16 Nov 2020 22:40:28 +0100 Subject: [PATCH 1/3] PyTorch Hub and autoShape update --- detect.py | 2 +- hubconf.py | 14 ++++---- models/common.py | 83 ++++++++++++++++++++++++++++++++++++++---------- test.py | 2 +- utils/general.py | 6 ++-- 5 files changed, 78 insertions(+), 29 deletions(-) diff --git a/detect.py b/detect.py index 1b70dbb7ef89..90d01ea37446 100644 --- a/detect.py +++ b/detect.py @@ -89,7 +89,7 @@ def detect(save_img=False): txt_path = str(save_dir / 'labels' / p.stem) + ('_%g' % dataset.frame if dataset.mode == 'video' else '') s += '%gx%g ' % img.shape[2:] # print string gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh - if det is not None and len(det): + if len(det): # Rescale boxes from img_size to im0 size det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round() diff --git a/hubconf.py b/hubconf.py index 790868d9ea4d..26cb3f4f7774 100644 --- a/hubconf.py +++ b/hubconf.py @@ -5,15 +5,16 @@ model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True, channels=3, classes=80) """ -dependencies = ['torch', 'yaml'] from pathlib import Path import torch +from PIL import Image from models.yolo import Model from utils.general import set_logging from utils.google_utils import attempt_download +dependencies = ['torch', 'yaml', 'pillow'] set_logging() @@ -108,11 +109,10 @@ def yolov5x(pretrained=False, channels=3, classes=80): if __name__ == '__main__': model = create(name='yolov5s', pretrained=True, channels=3, classes=80) # example - model = model.fuse().eval().autoshape() # for autoshaping of PIL/cv2/np inputs and NMS + model = model.fuse().autoshape() # for PIL/cv2/np inputs and NMS # Verify inference - from PIL import Image - - img = Image.open('data/images/zidane.jpg') - y = model(img) - print(y[0].shape) + imgs = [Image.open(x) for x in Path('data/images').glob('*.jpg')] + results = model(imgs) + results.show() + results.print() diff --git a/models/common.py b/models/common.py index b48ad48b57be..894a7757f636 100644 --- a/models/common.py +++ b/models/common.py @@ -5,9 +5,11 @@ import numpy as np import torch import torch.nn as nn +from PIL import Image, ImageDraw from utils.datasets import letterbox -from utils.general import non_max_suppression, make_divisible, scale_coords +from utils.general import non_max_suppression, make_divisible, scale_coords, xyxy2xywh +from utils.plots import color_list def autopad(k, p=None): # kernel, padding @@ -125,9 +127,9 @@ class autoShape(nn.Module): def __init__(self, model): super(autoShape, self).__init__() - self.model = model + self.model = model.eval() - def forward(self, x, size=640, augment=False, profile=False): + def forward(self, imgs, size=640, augment=False, profile=False): # supports inference from various sources. For height=720, width=1280, RGB images example inputs are: # opencv: x = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(720,1280,3) # PIL: x = Image.open('image.jpg') # HWC x(720,1280,3) @@ -136,36 +138,83 @@ def forward(self, x, size=640, augment=False, profile=False): # multiple: x = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images p = next(self.model.parameters()) # for device and type - if isinstance(x, torch.Tensor): # torch - return self.model(x.to(p.device).type_as(p), augment, profile) # inference + if isinstance(imgs, torch.Tensor): # torch + return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference # Pre-process - if not isinstance(x, list): - x = [x] + if not isinstance(imgs, list): + imgs = [imgs] shape0, shape1 = [], [] # image and inference shapes - batch = range(len(x)) # batch size + batch = range(len(imgs)) # batch size for i in batch: - x[i] = np.array(x[i]) # to numpy - x[i] = x[i][:, :, :3] if x[i].ndim == 3 else np.tile(x[i][:, :, None], 3) # enforce 3ch input - s = x[i].shape[:2] # HWC + imgs[i] = np.array(imgs[i]) # to numpy + imgs[i] = imgs[i][:, :, :3] if imgs[i].ndim == 3 else np.tile(imgs[i][:, :, None], 3) # enforce 3ch input + s = imgs[i].shape[:2] # HWC shape0.append(s) # image shape g = (size / max(s)) # gain shape1.append([y * g for y in s]) shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape - x = [letterbox(x[i], new_shape=shape1, auto=False)[0] for i in batch] # pad + x = [letterbox(imgs[i], new_shape=shape1, auto=False)[0] for i in batch] # pad x = np.stack(x, 0) if batch[-1] else x[0][None] # stack x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32 # Inference - x = self.model(x, augment, profile) # forward - x = non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS + with torch.no_grad(): + y = self.model(x, augment, profile)[0] # forward + y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS # Post-process for i in batch: - if x[i] is not None: - x[i][:, :4] = scale_coords(shape1, x[i][:, :4], shape0[i]) - return x + if y[i] is not None: + y[i][:, :4] = scale_coords(shape1, y[i][:, :4], shape0[i]) + + return Detections(imgs, y, self.names) + + +class Detections: + # detections class for YOLOv5 inference results + def __init__(self, imgs, pred, names=None): + super(Detections, self).__init__() + self.imgs = imgs # list of images as numpy arrays + self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls) + self.names = names # class names + self.xyxy = pred # xyxy pixels + self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels + gn = [torch.Tensor([*[im.shape[i] for i in [1, 0, 1, 0]], 1., 1.]) for im in imgs] # normalization gains + self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized + self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized + + def display(self, pprint=False, show=False, save=False): + colors = color_list() + for i, (img, pred) in enumerate(zip(self.imgs, self.pred)): + str = f'Image {i + 1}/{len(self.pred)}: {img.shape[0]}x{img.shape[1]} ' + if pred is not None: + for c in pred[:, -1].unique(): + n = (pred[:, -1] == c).sum() # detections per class + str += f'{n} {self.names[int(c)]}s, ' # add to string + if show or save: + img = Image.fromarray(img.astype(np.uint8)) if isinstance(img, np.ndarray) else img # from np + for *box, conf, cls in pred: # xyxy, confidence, class + # str += '%s %.2f, ' % (names[int(cls)], conf) # label + ImageDraw.Draw(img).rectangle(box, width=4, outline=colors[int(cls) % 10]) # plot + if save: + f = f'results{i}.jpg' + str += f"saved to '{f}'" + img.save(f) # save + if show: + img.show(f'Image {i}') # show + if pprint: + print(str) + + def print(self): + self.display(pprint=True) # print results + + def show(self): + self.display(show=True) # show results + + def save(self): + self.display(save=True) # save results class Flatten(nn.Module): diff --git a/test.py b/test.py index e0802983befe..15dd243622f0 100644 --- a/test.py +++ b/test.py @@ -126,7 +126,7 @@ def test(data, tcls = labels[:, 0].tolist() if nl else [] # target class seen += 1 - if pred is None: + if len(pred) == 0: if nl: stats.append((torch.zeros(0, niou, dtype=torch.bool), torch.Tensor(), torch.Tensor(), tcls)) continue diff --git a/utils/general.py b/utils/general.py index fd55217d702d..2134cf132776 100755 --- a/utils/general.py +++ b/utils/general.py @@ -142,7 +142,7 @@ def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper) def xyxy2xywh(x): # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right - y = torch.zeros_like(x) if isinstance(x, torch.Tensor) else np.zeros_like(x) + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center y[:, 2] = x[:, 2] - x[:, 0] # width @@ -152,7 +152,7 @@ def xyxy2xywh(x): def xywh2xyxy(x): # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right - y = torch.zeros_like(x) if isinstance(x, torch.Tensor) else np.zeros_like(x) + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x @@ -280,7 +280,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, merge=False, multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img) t = time.time() - output = [None] * prediction.shape[0] + output = [torch.zeros(0, 6)] * prediction.shape[0] for xi, x in enumerate(prediction): # image index, image inference # Apply constraints # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height From b384e213e80b1f676928bd0ef657db58e65fe75c Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 16 Nov 2020 22:41:09 +0100 Subject: [PATCH 2/3] comment x for imgs --- models/common.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/models/common.py b/models/common.py index 894a7757f636..b6c37af666be 100644 --- a/models/common.py +++ b/models/common.py @@ -131,11 +131,11 @@ def __init__(self, model): def forward(self, imgs, size=640, augment=False, profile=False): # supports inference from various sources. For height=720, width=1280, RGB images example inputs are: - # opencv: x = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(720,1280,3) - # PIL: x = Image.open('image.jpg') # HWC x(720,1280,3) - # numpy: x = np.zeros((720,1280,3)) # HWC - # torch: x = torch.zeros(16,3,720,1280) # BCHW - # multiple: x = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images + # opencv: imgs = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(720,1280,3) + # PIL: imgs = Image.open('image.jpg') # HWC x(720,1280,3) + # numpy: imgs = np.zeros((720,1280,3)) # HWC + # torch: imgs = torch.zeros(16,3,720,1280) # BCHW + # multiple: imgs = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images p = next(self.model.parameters()) # for device and type if isinstance(imgs, torch.Tensor): # torch From db35c2c64e29318b613a6b81b20adfde6ba2d443 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 16 Nov 2020 22:49:42 +0100 Subject: [PATCH 3/3] reduce comment --- hubconf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hubconf.py b/hubconf.py index 26cb3f4f7774..61ff4cc4f239 100644 --- a/hubconf.py +++ b/hubconf.py @@ -42,7 +42,7 @@ def create(name, pretrained, channels, classes): model.load_state_dict(state_dict, strict=False) # load if len(ckpt['model'].names) == classes: model.names = ckpt['model'].names # set class names attribute - # model = model.autoshape() # for autoshaping of PIL/cv2/np inputs and NMS + # model = model.autoshape() # for PIL/cv2/np inputs and NMS return model except Exception as e: