Skip to content

Commit

Permalink
Add Hub results.pandas() method (ultralytics#2725)
Browse files Browse the repository at this point in the history
* Add Hub results.pandas() method

New method converts results from torch tensors to pandas DataFrames with column names.

This PR may partially resolve issue ultralytics#2703

```python
results = model(imgs)

print(results.pandas().xyxy[0])
         xmin        ymin        xmax        ymax  confidence  class    name
0   57.068970  391.770599  241.383545  905.797852    0.868964      0  person
1  667.661255  399.303589  810.000000  881.396667    0.851888      0  person
2  222.878387  414.774231  343.804474  857.825073    0.838376      0  person
3    4.205386  234.447678  803.739136  750.023376    0.658006      5     bus
4    0.000000  550.596008   76.681190  878.669922    0.450596      0  person
```

* Update comments 

torch example input now shown resized to size=640 and also now a multiple of P6 stride 64 (see ultralytics#2722 (comment))

* apply decorators

* PEP8

* Update common.py

* pd.options.display.max_columns = 10

* Update common.py
  • Loading branch information
glenn-jocher committed Apr 7, 2021
1 parent b7dd733 commit 2ece733
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 18 deletions.
2 changes: 1 addition & 1 deletion hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def create(name, pretrained, channels, classes, autoshape):
fname = f'{name}.pt' # checkpoint filename
attempt_download(fname) # download if not found locally
ckpt = torch.load(fname, map_location=torch.device('cpu')) # load
msd = model.state_dict() # model state_dict
msd = model.state_dict() # model state_dict
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter
model.load_state_dict(csd, strict=False) # load
Expand Down
46 changes: 29 additions & 17 deletions models/common.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
# YOLOv5 common modules

import math
from copy import copy
from pathlib import Path

import numpy as np
import pandas as pd
import requests
import torch
import torch.nn as nn
from PIL import Image
from torch.cuda import amp

from utils.datasets import letterbox
from utils.general import non_max_suppression, make_divisible, scale_coords, increment_path, xyxy2xywh
Expand Down Expand Up @@ -235,14 +236,16 @@ def autoshape(self):
print('autoShape already enabled, skipping... ') # model already converted to model.autoshape()
return self

@torch.no_grad()
@torch.cuda.amp.autocast()
def forward(self, imgs, size=640, augment=False, profile=False):
# Inference from various sources. For height=720, width=1280, RGB images example inputs are:
# Inference from various sources. For height=640, width=1280, RGB images example inputs are:
# filename: imgs = 'data/samples/zidane.jpg'
# URI: = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/zidane.jpg'
# OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(720,1280,3)
# PIL: = Image.open('image.jpg') # HWC x(720,1280,3)
# numpy: = np.zeros((720,1280,3)) # HWC
# torch: = torch.zeros(16,3,720,1280) # BCHW
# OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
# PIL: = Image.open('image.jpg') # HWC x(640,1280,3)
# numpy: = np.zeros((640,1280,3)) # HWC
# torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
# multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images

t = [time_synchronized()]
Expand Down Expand Up @@ -275,15 +278,14 @@ def forward(self, imgs, size=640, augment=False, profile=False):
x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32
t.append(time_synchronized())

with torch.no_grad(), amp.autocast(enabled=p.device.type != 'cpu'):
# Inference
y = self.model(x, augment, profile)[0] # forward
t.append(time_synchronized())
# Inference
y = self.model(x, augment, profile)[0] # forward
t.append(time_synchronized())

# Post-process
y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS
for i in range(n):
scale_coords(shape1, y[i][:, :4], shape0[i])
# Post-process
y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS
for i in range(n):
scale_coords(shape1, y[i][:, :4], shape0[i])

t.append(time_synchronized())
return Detections(imgs, y, files, t, self.names, x.shape)
Expand Down Expand Up @@ -347,17 +349,27 @@ def render(self):
self.display(render=True) # render results
return self.imgs

def __len__(self):
return self.n
def pandas(self):
# return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0])
new = copy(self) # return copy
ca = 'xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class', 'name' # xyxy columns
cb = 'xcenter', 'ycenter', 'width', 'height', 'confidence', 'class', 'name' # xywh columns
for k, c in zip(['xyxy', 'xyxyn', 'xywh', 'xywhn'], [ca, ca, cb, cb]):
a = [[x[:5] + [int(x[5]), self.names[int(x[5])]] for x in x.tolist()] for x in getattr(self, k)] # update
setattr(new, k, [pd.DataFrame(x, columns=c) for x in a])
return new

def tolist(self):
# return a list of Detections objects, i.e. 'for result in results.tolist():'
x = [Detections([self.imgs[i]], [self.pred[i]], self.names) for i in range(self.n)]
x = [Detections([self.imgs[i]], [self.pred[i]], self.names, self.s) for i in range(self.n)]
for d in x:
for k in ['imgs', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
setattr(d, k, getattr(d, k)[0]) # pop out of list
return x

def __len__(self):
return self.n


class Classify(nn.Module):
# Classification head, i.e. x(b,c1,20,20) to x(b,c2)
Expand Down
2 changes: 2 additions & 0 deletions utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import cv2
import numpy as np
import pandas as pd
import torch
import torchvision
import yaml
Expand All @@ -24,6 +25,7 @@
# Settings
torch.set_printoptions(linewidth=320, precision=5, profile='long')
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
pd.options.display.max_columns = 10
cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads

Expand Down

0 comments on commit 2ece733

Please sign in to comment.