Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch Hub and autoShape update #1415

Merged
merged 3 commits into from
Nov 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def detect(save_img=False):
txt_path = str(save_dir / 'labels' / p.stem) + ('_%g' % dataset.frame if dataset.mode == 'video' else '')
s += '%gx%g ' % img.shape[2:] # print string
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
if det is not None and len(det):
if len(det):
# Rescale boxes from img_size to im0 size
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()

Expand Down
16 changes: 8 additions & 8 deletions hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True, channels=3, classes=80)
"""

dependencies = ['torch', 'yaml']
from pathlib import Path

import torch
from PIL import Image

from models.yolo import Model
from utils.general import set_logging
from utils.google_utils import attempt_download

dependencies = ['torch', 'yaml', 'pillow']
set_logging()


Expand Down Expand Up @@ -41,7 +42,7 @@ def create(name, pretrained, channels, classes):
model.load_state_dict(state_dict, strict=False) # load
if len(ckpt['model'].names) == classes:
model.names = ckpt['model'].names # set class names attribute
# model = model.autoshape() # for autoshaping of PIL/cv2/np inputs and NMS
# model = model.autoshape() # for PIL/cv2/np inputs and NMS
return model

except Exception as e:
Expand Down Expand Up @@ -108,11 +109,10 @@ def yolov5x(pretrained=False, channels=3, classes=80):

if __name__ == '__main__':
model = create(name='yolov5s', pretrained=True, channels=3, classes=80) # example
model = model.fuse().eval().autoshape() # for autoshaping of PIL/cv2/np inputs and NMS
model = model.fuse().autoshape() # for PIL/cv2/np inputs and NMS

# Verify inference
from PIL import Image

img = Image.open('data/images/zidane.jpg')
y = model(img)
print(y[0].shape)
imgs = [Image.open(x) for x in Path('data/images').glob('*.jpg')]
results = model(imgs)
results.show()
results.print()
93 changes: 71 additions & 22 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
import numpy as np
import torch
import torch.nn as nn
from PIL import Image, ImageDraw

from utils.datasets import letterbox
from utils.general import non_max_suppression, make_divisible, scale_coords
from utils.general import non_max_suppression, make_divisible, scale_coords, xyxy2xywh
from utils.plots import color_list


def autopad(k, p=None): # kernel, padding
Expand Down Expand Up @@ -125,47 +127,94 @@ class autoShape(nn.Module):

def __init__(self, model):
super(autoShape, self).__init__()
self.model = model
self.model = model.eval()

def forward(self, x, size=640, augment=False, profile=False):
def forward(self, imgs, size=640, augment=False, profile=False):
# supports inference from various sources. For height=720, width=1280, RGB images example inputs are:
# opencv: x = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(720,1280,3)
# PIL: x = Image.open('image.jpg') # HWC x(720,1280,3)
# numpy: x = np.zeros((720,1280,3)) # HWC
# torch: x = torch.zeros(16,3,720,1280) # BCHW
# multiple: x = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
# opencv: imgs = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(720,1280,3)
# PIL: imgs = Image.open('image.jpg') # HWC x(720,1280,3)
# numpy: imgs = np.zeros((720,1280,3)) # HWC
# torch: imgs = torch.zeros(16,3,720,1280) # BCHW
# multiple: imgs = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images

p = next(self.model.parameters()) # for device and type
if isinstance(x, torch.Tensor): # torch
return self.model(x.to(p.device).type_as(p), augment, profile) # inference
if isinstance(imgs, torch.Tensor): # torch
return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference

# Pre-process
if not isinstance(x, list):
x = [x]
if not isinstance(imgs, list):
imgs = [imgs]
shape0, shape1 = [], [] # image and inference shapes
batch = range(len(x)) # batch size
batch = range(len(imgs)) # batch size
for i in batch:
x[i] = np.array(x[i]) # to numpy
x[i] = x[i][:, :, :3] if x[i].ndim == 3 else np.tile(x[i][:, :, None], 3) # enforce 3ch input
s = x[i].shape[:2] # HWC
imgs[i] = np.array(imgs[i]) # to numpy
imgs[i] = imgs[i][:, :, :3] if imgs[i].ndim == 3 else np.tile(imgs[i][:, :, None], 3) # enforce 3ch input
s = imgs[i].shape[:2] # HWC
shape0.append(s) # image shape
g = (size / max(s)) # gain
shape1.append([y * g for y in s])
shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape
x = [letterbox(x[i], new_shape=shape1, auto=False)[0] for i in batch] # pad
x = [letterbox(imgs[i], new_shape=shape1, auto=False)[0] for i in batch] # pad
x = np.stack(x, 0) if batch[-1] else x[0][None] # stack
x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW
x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32

# Inference
x = self.model(x, augment, profile) # forward
x = non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS
with torch.no_grad():
y = self.model(x, augment, profile)[0] # forward
y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS

# Post-process
for i in batch:
if x[i] is not None:
x[i][:, :4] = scale_coords(shape1, x[i][:, :4], shape0[i])
return x
if y[i] is not None:
y[i][:, :4] = scale_coords(shape1, y[i][:, :4], shape0[i])

return Detections(imgs, y, self.names)


class Detections:
# detections class for YOLOv5 inference results
def __init__(self, imgs, pred, names=None):
super(Detections, self).__init__()
self.imgs = imgs # list of images as numpy arrays
self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
self.names = names # class names
self.xyxy = pred # xyxy pixels
self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels
gn = [torch.Tensor([*[im.shape[i] for i in [1, 0, 1, 0]], 1., 1.]) for im in imgs] # normalization gains
self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized

def display(self, pprint=False, show=False, save=False):
colors = color_list()
for i, (img, pred) in enumerate(zip(self.imgs, self.pred)):
str = f'Image {i + 1}/{len(self.pred)}: {img.shape[0]}x{img.shape[1]} '
if pred is not None:
for c in pred[:, -1].unique():
n = (pred[:, -1] == c).sum() # detections per class
str += f'{n} {self.names[int(c)]}s, ' # add to string
if show or save:
img = Image.fromarray(img.astype(np.uint8)) if isinstance(img, np.ndarray) else img # from np
for *box, conf, cls in pred: # xyxy, confidence, class
# str += '%s %.2f, ' % (names[int(cls)], conf) # label
ImageDraw.Draw(img).rectangle(box, width=4, outline=colors[int(cls) % 10]) # plot
if save:
f = f'results{i}.jpg'
str += f"saved to '{f}'"
img.save(f) # save
if show:
img.show(f'Image {i}') # show
if pprint:
print(str)

def print(self):
self.display(pprint=True) # print results

def show(self):
self.display(show=True) # show results

def save(self):
self.display(save=True) # save results


class Flatten(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion test.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def test(data,
tcls = labels[:, 0].tolist() if nl else [] # target class
seen += 1

if pred is None:
if len(pred) == 0:
if nl:
stats.append((torch.zeros(0, niou, dtype=torch.bool), torch.Tensor(), torch.Tensor(), tcls))
continue
Expand Down
6 changes: 3 additions & 3 deletions utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)

def xyxy2xywh(x):
# Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
y = torch.zeros_like(x) if isinstance(x, torch.Tensor) else np.zeros_like(x)
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
y[:, 2] = x[:, 2] - x[:, 0] # width
Expand All @@ -152,7 +152,7 @@ def xyxy2xywh(x):

def xywh2xyxy(x):
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
y = torch.zeros_like(x) if isinstance(x, torch.Tensor) else np.zeros_like(x)
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
Expand Down Expand Up @@ -280,7 +280,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, merge=False,
multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img)

t = time.time()
output = [None] * prediction.shape[0]
output = [torch.zeros(0, 6)] * prediction.shape[0]
for xi, x in enumerate(prediction): # image index, image inference
# Apply constraints
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
Expand Down