From bfba802964322336d6520ed90a7ab9f9eac162c0 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 7 Apr 2021 01:50:38 +0200 Subject: [PATCH 1/7] Add Hub results.pandas() method New method converts results from torch tensors to pandas DataFrames with column names. This PR may partially resolve issue https://github.com/ultralytics/yolov5/issues/2703 ```python print(results.pandas().xyxy[0]) xmin ymin xmax ymax confidence class name 0 57.068970 391.770599 241.383545 905.797852 0.868964 0.0 person 1 667.661255 399.303589 810.000000 881.396667 0.851888 0.0 person 2 222.878387 414.774231 343.804474 857.825073 0.838376 0.0 person 3 4.205386 234.447678 803.739136 750.023376 0.658006 5.0 bus 4 0.000000 550.596008 76.681190 878.669922 0.450596 0.0 person ``` --- models/common.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/models/common.py b/models/common.py index 4fd1a8159c64..c7571c202231 100644 --- a/models/common.py +++ b/models/common.py @@ -1,9 +1,11 @@ # YOLOv5 common modules import math +from copy import copy from pathlib import Path import numpy as np +import pandas as pd import requests import torch import torch.nn as nn @@ -347,17 +349,27 @@ def render(self): self.display(render=True) # render results return self.imgs - def __len__(self): - return self.n + def pandas(self): + # return detections as pandas DataFrames + new = copy(self) # do not replace self + ca = 'xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class', 'name' # xyxy columns + cb = 'xcenter', 'ycenter', 'width', 'height', 'confidence', 'class', 'name' # xywh columns + for k, c in zip(['xyxy', 'xyxyn', 'xywh', 'xywhn'], [ca, ca, cb, cb]): + a = [[x + [self.names[int(x[5])]] for x in x.tolist()] for x in getattr(self, k)] # updated attribute + setattr(new, k, [pd.DataFrame(x, columns=c) for x in a]) + return new def tolist(self): # return a list of Detections objects, i.e. 'for result in results.tolist():' - x = [Detections([self.imgs[i]], [self.pred[i]], self.names) for i in range(self.n)] + x = [Detections([self.imgs[i]], [self.pred[i]], self.names, self.s) for i in range(self.n)] for d in x: for k in ['imgs', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']: setattr(d, k, getattr(d, k)[0]) # pop out of list return x + def __len__(self): + return self.n + class Classify(nn.Module): # Classification head, i.e. x(b,c1,20,20) to x(b,c2) From ccbe7dece9b60f15c561ce9305d3901402c21943 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 7 Apr 2021 15:15:59 +0200 Subject: [PATCH 2/7] Update comments torch example input now shown resized to size=640 and also now a multiple of P6 stride 64 (see https://github.com/ultralytics/yolov5/issues/2722#issuecomment-814785930) --- models/common.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/models/common.py b/models/common.py index c7571c202231..b75cfe495e2e 100644 --- a/models/common.py +++ b/models/common.py @@ -238,13 +238,13 @@ def autoshape(self): return self def forward(self, imgs, size=640, augment=False, profile=False): - # Inference from various sources. For height=720, width=1280, RGB images example inputs are: + # Inference from various sources. For height=640, width=1280, RGB images example inputs are: # filename: imgs = 'data/samples/zidane.jpg' # URI: = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/zidane.jpg' - # OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(720,1280,3) - # PIL: = Image.open('image.jpg') # HWC x(720,1280,3) - # numpy: = np.zeros((720,1280,3)) # HWC - # torch: = torch.zeros(16,3,720,1280) # BCHW + # OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3) + # PIL: = Image.open('image.jpg') # HWC x(640,1280,3) + # numpy: = np.zeros((640,1280,3)) # HWC + # torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640) # multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images t = [time_synchronized()] From 2ff0dfb54f522597a44fdc8956ff1374ab03c24c Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 7 Apr 2021 15:39:00 +0200 Subject: [PATCH 3/7] apply decorators --- models/common.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/models/common.py b/models/common.py index b75cfe495e2e..8dafc434ab8c 100644 --- a/models/common.py +++ b/models/common.py @@ -10,7 +10,6 @@ import torch import torch.nn as nn from PIL import Image -from torch.cuda import amp from utils.datasets import letterbox from utils.general import non_max_suppression, make_divisible, scale_coords, increment_path, xyxy2xywh @@ -237,6 +236,8 @@ def autoshape(self): print('autoShape already enabled, skipping... ') # model already converted to model.autoshape() return self + @torch.no_grad() + @torch.cuda.amp.autocast() def forward(self, imgs, size=640, augment=False, profile=False): # Inference from various sources. For height=640, width=1280, RGB images example inputs are: # filename: imgs = 'data/samples/zidane.jpg' @@ -244,7 +245,7 @@ def forward(self, imgs, size=640, augment=False, profile=False): # OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3) # PIL: = Image.open('image.jpg') # HWC x(640,1280,3) # numpy: = np.zeros((640,1280,3)) # HWC - # torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640) + # torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values) # multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images t = [time_synchronized()] @@ -277,15 +278,14 @@ def forward(self, imgs, size=640, augment=False, profile=False): x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32 t.append(time_synchronized()) - with torch.no_grad(), amp.autocast(enabled=p.device.type != 'cpu'): - # Inference - y = self.model(x, augment, profile)[0] # forward - t.append(time_synchronized()) + # Inference + y = self.model(x, augment, profile)[0] # forward + t.append(time_synchronized()) - # Post-process - y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS - for i in range(n): - scale_coords(shape1, y[i][:, :4], shape0[i]) + # Post-process + y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS + for i in range(n): + scale_coords(shape1, y[i][:, :4], shape0[i]) t.append(time_synchronized()) return Detections(imgs, y, files, t, self.names, x.shape) From dd710dd3cc06c8850ec2a5517d31878eb36f7e31 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 7 Apr 2021 15:39:24 +0200 Subject: [PATCH 4/7] PEP8 --- hubconf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hubconf.py b/hubconf.py index 1e6b9c78ac6a..0f9aa150a34e 100644 --- a/hubconf.py +++ b/hubconf.py @@ -38,7 +38,7 @@ def create(name, pretrained, channels, classes, autoshape): fname = f'{name}.pt' # checkpoint filename attempt_download(fname) # download if not found locally ckpt = torch.load(fname, map_location=torch.device('cpu')) # load - msd = model.state_dict() # model state_dict + msd = model.state_dict() # model state_dict csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32 csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter model.load_state_dict(csd, strict=False) # load From 49563c78cd8c2db3b16a04b9b6f0496743b29b08 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 7 Apr 2021 15:56:42 +0200 Subject: [PATCH 5/7] Update common.py --- models/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/common.py b/models/common.py index 8dafc434ab8c..beeba95b88f4 100644 --- a/models/common.py +++ b/models/common.py @@ -351,7 +351,7 @@ def render(self): def pandas(self): # return detections as pandas DataFrames - new = copy(self) # do not replace self + new = copy(self) # return copy ca = 'xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class', 'name' # xyxy columns cb = 'xcenter', 'ycenter', 'width', 'height', 'confidence', 'class', 'name' # xywh columns for k, c in zip(['xyxy', 'xyxyn', 'xywh', 'xywhn'], [ca, ca, cb, cb]): From af52afa44461f800be3415653d98df3f71a6b2c7 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 7 Apr 2021 16:09:13 +0200 Subject: [PATCH 6/7] pd.options.display.max_columns = 10 --- utils/general.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/utils/general.py b/utils/general.py index 9822582cdb86..a8aad16a8ab9 100755 --- a/utils/general.py +++ b/utils/general.py @@ -13,6 +13,7 @@ import cv2 import numpy as np +import pandas as pd import torch import torchvision import yaml @@ -24,6 +25,7 @@ # Settings torch.set_printoptions(linewidth=320, precision=5, profile='long') np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5 +pd.options.display.max_columns = 10 cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader) os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads From 69771d7e24feab5e7515c41ce4b893fc8040f717 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 7 Apr 2021 16:13:34 +0200 Subject: [PATCH 7/7] Update common.py --- models/common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/models/common.py b/models/common.py index beeba95b88f4..412e9bf1e411 100644 --- a/models/common.py +++ b/models/common.py @@ -350,12 +350,12 @@ def render(self): return self.imgs def pandas(self): - # return detections as pandas DataFrames + # return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0]) new = copy(self) # return copy ca = 'xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class', 'name' # xyxy columns cb = 'xcenter', 'ycenter', 'width', 'height', 'confidence', 'class', 'name' # xywh columns for k, c in zip(['xyxy', 'xyxyn', 'xywh', 'xywhn'], [ca, ca, cb, cb]): - a = [[x + [self.names[int(x[5])]] for x in x.tolist()] for x in getattr(self, k)] # updated attribute + a = [[x[:5] + [int(x[5]), self.names[int(x[5])]] for x in x.tolist()] for x in getattr(self, k)] # update setattr(new, k, [pd.DataFrame(x, columns=c) for x in a]) return new