From 569757ecc09d115e275a6ec3662514d72dfe18c2 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sat, 13 Mar 2021 19:50:34 -0800 Subject: [PATCH] Add autoShape() speed profiling (#2459) * Add autoShape() speed profiling * Update common.py * Create README.md * Update hubconf.py * cleanuip --- README.md | 4 ++-- hubconf.py | 8 ++++---- models/common.py | 14 +++++++++++--- 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index b7129e80adfe..097b2750bf49 100755 --- a/README.md +++ b/README.md @@ -108,11 +108,11 @@ To run **batched inference** with YOLOv5 and [PyTorch Hub](https://github.com/ul import torch # Model -model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True) +model = torch.hub.load('ultralytics/yolov5', 'yolov5s') # Images dir = 'https://github.com/ultralytics/yolov5/raw/master/data/images/' -imgs = [dir + f for f in ('zidane.jpg', 'bus.jpg')] # batched list of images +imgs = [dir + f for f in ('zidane.jpg', 'bus.jpg')] # batch of images # Inference results = model(imgs) diff --git a/hubconf.py b/hubconf.py index a8eb51681794..e51ac90da36c 100644 --- a/hubconf.py +++ b/hubconf.py @@ -51,7 +51,7 @@ def create(name, pretrained, channels, classes, autoshape): raise Exception(s) from e -def yolov5s(pretrained=False, channels=3, classes=80, autoshape=True): +def yolov5s(pretrained=True, channels=3, classes=80, autoshape=True): """YOLOv5-small model from https://github.com/ultralytics/yolov5 Arguments: @@ -65,7 +65,7 @@ def yolov5s(pretrained=False, channels=3, classes=80, autoshape=True): return create('yolov5s', pretrained, channels, classes, autoshape) -def yolov5m(pretrained=False, channels=3, classes=80, autoshape=True): +def yolov5m(pretrained=True, channels=3, classes=80, autoshape=True): """YOLOv5-medium model from https://github.com/ultralytics/yolov5 Arguments: @@ -79,7 +79,7 @@ def yolov5m(pretrained=False, channels=3, classes=80, autoshape=True): return create('yolov5m', pretrained, channels, classes, autoshape) -def yolov5l(pretrained=False, channels=3, classes=80, autoshape=True): +def yolov5l(pretrained=True, channels=3, classes=80, autoshape=True): """YOLOv5-large model from https://github.com/ultralytics/yolov5 Arguments: @@ -93,7 +93,7 @@ def yolov5l(pretrained=False, channels=3, classes=80, autoshape=True): return create('yolov5l', pretrained, channels, classes, autoshape) -def yolov5x(pretrained=False, channels=3, classes=80, autoshape=True): +def yolov5x(pretrained=True, channels=3, classes=80, autoshape=True): """YOLOv5-xlarge model from https://github.com/ultralytics/yolov5 Arguments: diff --git a/models/common.py b/models/common.py index ad35f908d865..7ef5762efbf3 100644 --- a/models/common.py +++ b/models/common.py @@ -12,6 +12,7 @@ from utils.datasets import letterbox from utils.general import non_max_suppression, make_divisible, scale_coords, xyxy2xywh from utils.plots import color_list, plot_one_box +from utils.torch_utils import time_synchronized def autopad(k, p=None): # kernel, padding @@ -190,6 +191,7 @@ def forward(self, imgs, size=640, augment=False, profile=False): # torch: = torch.zeros(16,3,720,1280) # BCHW # multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images + t = [time_synchronized()] p = next(self.model.parameters()) # for device and type if isinstance(imgs, torch.Tensor): # torch return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference @@ -216,22 +218,25 @@ def forward(self, imgs, size=640, augment=False, profile=False): x = np.stack(x, 0) if n > 1 else x[0][None] # stack x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32 + t.append(time_synchronized()) # Inference with torch.no_grad(): y = self.model(x, augment, profile)[0] # forward - y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS + t.append(time_synchronized()) # Post-process + y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS for i in range(n): scale_coords(shape1, y[i][:, :4], shape0[i]) + t.append(time_synchronized()) - return Detections(imgs, y, files, self.names) + return Detections(imgs, y, files, t, self.names, x.shape) class Detections: # detections class for YOLOv5 inference results - def __init__(self, imgs, pred, files, names=None): + def __init__(self, imgs, pred, files, times, names=None, shape=None): super(Detections, self).__init__() d = pred[0].device # device gn = [torch.tensor([*[im.shape[i] for i in [1, 0, 1, 0]], 1., 1.], device=d) for im in imgs] # normalizations @@ -244,6 +249,8 @@ def __init__(self, imgs, pred, files, names=None): self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized self.n = len(self.pred) + self.t = ((times[i + 1] - times[i]) * 1000 / self.n for i in range(3)) # timestamps (ms) + self.s = shape # inference BCHW shape def display(self, pprint=False, show=False, save=False, render=False, save_dir=''): colors = color_list() @@ -271,6 +278,7 @@ def display(self, pprint=False, show=False, save=False, render=False, save_dir=' def print(self): self.display(pprint=True) # print results + print(f'Speed: %.1f/%.1f/%.1f ms pre-process/inference/NMS per image at shape {tuple(self.s)}' % tuple(self.t)) def show(self): self.display(show=True) # show results