Skip to content

Commit

Permalink
New DetectMultiBackend() class (ultralytics#5549)
Browse files Browse the repository at this point in the history
* New `DetectMultiBackend()` class

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* pb to pt fix

* Cleanup

* explicit apply_classifier path

* Cleanup2

* Cleanup3

* Cleanup4

* Cleanup5

* Cleanup6

* val.py MultiBackend inference

* warmup fix

* to device fix

* pt fix

* device fix

* Val cleanup

* COCO128 URL to assets

* half fix

* detect fix

* detect fix 2

* remove half from DetectMultiBackend

* training half handling

* training half handling 2

* training half handling 3

* Cleanup

* Fix CI error

* Add torchscript _extra_files

* Add TorchScript

* Add CoreML

* CoreML cleanup

* New `DetectMultiBackend()` class

* pb to pt fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Cleanup

* explicit apply_classifier path

* Cleanup2

* Cleanup3

* Cleanup4

* Cleanup5

* Cleanup6

* val.py MultiBackend inference

* warmup fix

* to device fix

* pt fix

* device fix

* Val cleanup

* COCO128 URL to assets

* half fix

* detect fix

* detect fix 2

* remove half from DetectMultiBackend

* training half handling

* training half handling 2

* training half handling 3

* Cleanup

* Fix CI error

* Add torchscript _extra_files

* Add TorchScript

* Add CoreML

* CoreML cleanup

* revert default to pt

* Add Usage examples

* Cleanup val

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
glenn-jocher and pre-commit-ci[bot] committed Nov 9, 2021
1 parent 1d1b54f commit aa1d9f2
Show file tree
Hide file tree
Showing 7 changed files with 200 additions and 165 deletions.
2 changes: 1 addition & 1 deletion data/coco128.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@ names: ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 't


# Download script/URL (optional)
download: https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128.zip
download: https://ultralytics.com/assets/coco128.zip
133 changes: 27 additions & 106 deletions detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,10 @@

import argparse
import os
import platform
import sys
from pathlib import Path

import cv2
import numpy as np
import torch
import torch.backends.cudnn as cudnn

Expand All @@ -29,13 +27,12 @@
sys.path.append(str(ROOT)) # add ROOT to PATH
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative

from models.experimental import attempt_load
from models.common import DetectMultiBackend
from utils.datasets import IMG_FORMATS, VID_FORMATS, LoadImages, LoadStreams
from utils.general import (LOGGER, apply_classifier, check_file, check_img_size, check_imshow, check_requirements,
check_suffix, colorstr, increment_path, non_max_suppression, print_args, scale_coords,
strip_optimizer, xyxy2xywh)
from utils.general import (LOGGER, check_file, check_img_size, check_imshow, check_requirements, colorstr,
increment_path, non_max_suppression, print_args, scale_coords, strip_optimizer, xyxy2xywh)
from utils.plots import Annotator, colors, save_one_box
from utils.torch_utils import load_classifier, select_device, time_sync
from utils.torch_utils import select_device, time_sync


@torch.no_grad()
Expand Down Expand Up @@ -77,120 +74,45 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s)
save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run
(save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir

# Initialize
# Load model
device = select_device(device)
half &= device.type != 'cpu' # half precision only supported on CUDA
model = DetectMultiBackend(weights, device=device, dnn=dnn)
stride, names, pt, jit, onnx = model.stride, model.names, model.pt, model.jit, model.onnx
imgsz = check_img_size(imgsz, s=stride) # check image size

# Load model
w = str(weights[0] if isinstance(weights, list) else weights)
classify, suffix, suffixes = False, Path(w).suffix.lower(), ['.pt', '.onnx', '.tflite', '.pb', '']
check_suffix(w, suffixes) # check weights have acceptable suffix
pt, onnx, tflite, pb, saved_model = (suffix == x for x in suffixes) # backend booleans
stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults
# Half
half &= pt and device.type != 'cpu' # half precision only supported by PyTorch on CUDA
if pt:
model = torch.jit.load(w) if 'torchscript' in w else attempt_load(weights, map_location=device)
stride = int(model.stride.max()) # model stride
names = model.module.names if hasattr(model, 'module') else model.names # get class names
if half:
model.half() # to FP16
if classify: # second-stage classifier
modelc = load_classifier(name='resnet50', n=2) # initialize
modelc.load_state_dict(torch.load('resnet50.pt', map_location=device)['model']).to(device).eval()
elif onnx:
if dnn:
check_requirements(('opencv-python>=4.5.4',))
net = cv2.dnn.readNetFromONNX(w)
else:
check_requirements(('onnx', 'onnxruntime-gpu' if torch.has_cuda else 'onnxruntime'))
import onnxruntime
session = onnxruntime.InferenceSession(w, None)
else: # TensorFlow models
import tensorflow as tf
if pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
def wrap_frozen_graph(gd, inputs, outputs):
x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped import
return x.prune(tf.nest.map_structure(x.graph.as_graph_element, inputs),
tf.nest.map_structure(x.graph.as_graph_element, outputs))

graph_def = tf.Graph().as_graph_def()
graph_def.ParseFromString(open(w, 'rb').read())
frozen_func = wrap_frozen_graph(gd=graph_def, inputs="x:0", outputs="Identity:0")
elif saved_model:
model = tf.keras.models.load_model(w)
elif tflite:
if "edgetpu" in w: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
import tflite_runtime.interpreter as tflri
delegate = {'Linux': 'libedgetpu.so.1', # install libedgetpu https://coral.ai/software/#edgetpu-runtime
'Darwin': 'libedgetpu.1.dylib',
'Windows': 'edgetpu.dll'}[platform.system()]
interpreter = tflri.Interpreter(model_path=w, experimental_delegates=[tflri.load_delegate(delegate)])
else:
interpreter = tf.lite.Interpreter(model_path=w) # load TFLite model
interpreter.allocate_tensors() # allocate
input_details = interpreter.get_input_details() # inputs
output_details = interpreter.get_output_details() # outputs
int8 = input_details[0]['dtype'] == np.uint8 # is TFLite quantized uint8 model
imgsz = check_img_size(imgsz, s=stride) # check image size
model.model.half() if half else model.model.float()

# Dataloader
if webcam:
view_img = check_imshow()
cudnn.benchmark = True # set True to speed up constant image size inference
dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt)
dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt and not jit)
bs = len(dataset) # batch_size
else:
dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt)
dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt and not jit)
bs = 1 # batch_size
vid_path, vid_writer = [None] * bs, [None] * bs

# Run inference
if pt and device.type != 'cpu':
model(torch.zeros(1, 3, *imgsz).to(device).type_as(next(model.parameters()))) # run once
model(torch.zeros(1, 3, *imgsz).to(device).type_as(next(model.model.parameters()))) # warmup
dt, seen = [0.0, 0.0, 0.0], 0
for path, img, im0s, vid_cap, s in dataset:
for path, im, im0s, vid_cap, s in dataset:
t1 = time_sync()
if onnx:
img = img.astype('float32')
else:
img = torch.from_numpy(img).to(device)
img = img.half() if half else img.float() # uint8 to fp16/32
img /= 255 # 0 - 255 to 0.0 - 1.0
if len(img.shape) == 3:
img = img[None] # expand for batch dim
im = torch.from_numpy(im).to(device)
im = im.half() if half else im.float() # uint8 to fp16/32
im /= 255 # 0 - 255 to 0.0 - 1.0
if len(im.shape) == 3:
im = im[None] # expand for batch dim
t2 = time_sync()
dt[0] += t2 - t1

# Inference
if pt:
visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
pred = model(img, augment=augment, visualize=visualize)[0]
elif onnx:
if dnn:
net.setInput(img)
pred = torch.tensor(net.forward())
else:
pred = torch.tensor(session.run([session.get_outputs()[0].name], {session.get_inputs()[0].name: img}))
else: # tensorflow model (tflite, pb, saved_model)
imn = img.permute(0, 2, 3, 1).cpu().numpy() # image in numpy
if pb:
pred = frozen_func(x=tf.constant(imn)).numpy()
elif saved_model:
pred = model(imn, training=False).numpy()
elif tflite:
if int8:
scale, zero_point = input_details[0]['quantization']
imn = (imn / scale + zero_point).astype(np.uint8) # de-scale
interpreter.set_tensor(input_details[0]['index'], imn)
interpreter.invoke()
pred = interpreter.get_tensor(output_details[0]['index'])
if int8:
scale, zero_point = output_details[0]['quantization']
pred = (pred.astype(np.float32) - zero_point) * scale # re-scale
pred[..., 0] *= imgsz[1] # x
pred[..., 1] *= imgsz[0] # y
pred[..., 2] *= imgsz[1] # w
pred[..., 3] *= imgsz[0] # h
pred = torch.tensor(pred)
visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
pred = model(im, augment=augment, visualize=visualize)
t3 = time_sync()
dt[1] += t3 - t2

Expand All @@ -199,8 +121,7 @@ def wrap_frozen_graph(gd, inputs, outputs):
dt[2] += time_sync() - t3

# Second-stage classifier (optional)
if classify:
pred = apply_classifier(pred, modelc, img, im0s)
# pred = utils.general.apply_classifier(pred, classifier_model, im, im0s)

# Process predictions
for i, det in enumerate(pred): # per image
Expand All @@ -212,15 +133,15 @@ def wrap_frozen_graph(gd, inputs, outputs):
p, im0, frame = path, im0s.copy(), getattr(dataset, 'frame', 0)

p = Path(p) # to Path
save_path = str(save_dir / p.name) # img.jpg
txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # img.txt
s += '%gx%g ' % img.shape[2:] # print string
save_path = str(save_dir / p.name) # im.jpg
txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # im.txt
s += '%gx%g ' % im.shape[2:] # print string
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
imc = im0.copy() if save_crop else im0 # for save_crop
annotator = Annotator(im0, line_width=line_thickness, example=str(names))
if len(det):
# Rescale boxes from img_size to im0 size
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
det[:, :4] = scale_coords(im.shape[2:], det[:, :4], im0.shape).round()

# Print results
for c in det[:, -1].unique():
Expand Down
5 changes: 4 additions & 1 deletion export.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"""

import argparse
import json
import os
import subprocess
import sys
Expand Down Expand Up @@ -54,7 +55,9 @@ def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:'
f = file.with_suffix('.torchscript.pt')

ts = torch.jit.trace(model, im, strict=False)
(optimize_for_mobile(ts) if optimize else ts).save(f)
d = {"shape": im.shape, "stride": int(max(model.stride)), "names": model.names}
extra_files = {'config.txt': json.dumps(d)} # torch._C.ExtraFilesMap()
(optimize_for_mobile(ts) if optimize else ts).save(f, _extra_files=extra_files)

LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
except Exception as e:
Expand Down
128 changes: 127 additions & 1 deletion models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
Common modules
"""

import json
import math
import platform
import warnings
from copy import copy
from pathlib import Path

import cv2
import numpy as np
import pandas as pd
import requests
Expand All @@ -17,7 +20,8 @@
from torch.cuda import amp

from utils.datasets import exif_transpose, letterbox
from utils.general import LOGGER, colorstr, increment_path, make_divisible, non_max_suppression, scale_coords, xyxy2xywh
from utils.general import (LOGGER, check_requirements, check_suffix, colorstr, increment_path, make_divisible,
non_max_suppression, scale_coords, xywh2xyxy, xyxy2xywh)
from utils.plots import Annotator, colors, save_one_box
from utils.torch_utils import time_sync

Expand Down Expand Up @@ -269,6 +273,128 @@ def forward(self, x):
return torch.cat(x, self.d)


class DetectMultiBackend(nn.Module):
# YOLOv5 MultiBackend class for python inference on various backends
def __init__(self, weights='yolov5s.pt', device=None, dnn=True):
# Usage:
# PyTorch: weights = *.pt
# TorchScript: *.torchscript.pt
# CoreML: *.mlmodel
# TensorFlow: *_saved_model
# TensorFlow: *.pb
# TensorFlow Lite: *.tflite
# ONNX Runtime: *.onnx
# OpenCV DNN: *.onnx with dnn=True
super().__init__()
w = str(weights[0] if isinstance(weights, list) else weights)
suffix, suffixes = Path(w).suffix.lower(), ['.pt', '.onnx', '.tflite', '.pb', '', '.mlmodel']
check_suffix(w, suffixes) # check weights have acceptable suffix
pt, onnx, tflite, pb, saved_model, coreml = (suffix == x for x in suffixes) # backend booleans
jit = pt and 'torchscript' in w.lower()
stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults

if jit: # TorchScript
LOGGER.info(f'Loading {w} for TorchScript inference...')
extra_files = {'config.txt': ''} # model metadata
model = torch.jit.load(w, _extra_files=extra_files)
if extra_files['config.txt']:
d = json.loads(extra_files['config.txt']) # extra_files dict
stride, names = int(d['stride']), d['names']
elif pt: # PyTorch
from models.experimental import attempt_load # scoped to avoid circular import
model = torch.jit.load(w) if 'torchscript' in w else attempt_load(weights, map_location=device)
stride = int(model.stride.max()) # model stride
names = model.module.names if hasattr(model, 'module') else model.names # get class names
elif coreml: # CoreML *.mlmodel
import coremltools as ct
model = ct.models.MLModel(w)
elif dnn: # ONNX OpenCV DNN
LOGGER.info(f'Loading {w} for ONNX OpenCV DNN inference...')
check_requirements(('opencv-python>=4.5.4',))
net = cv2.dnn.readNetFromONNX(w)
elif onnx: # ONNX Runtime
LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
check_requirements(('onnx', 'onnxruntime-gpu' if torch.has_cuda else 'onnxruntime'))
import onnxruntime
session = onnxruntime.InferenceSession(w, None)
else: # TensorFlow model (TFLite, pb, saved_model)
import tensorflow as tf
if pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
def wrap_frozen_graph(gd, inputs, outputs):
x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped
return x.prune(tf.nest.map_structure(x.graph.as_graph_element, inputs),
tf.nest.map_structure(x.graph.as_graph_element, outputs))

LOGGER.info(f'Loading {w} for TensorFlow *.pb inference...')
graph_def = tf.Graph().as_graph_def()
graph_def.ParseFromString(open(w, 'rb').read())
frozen_func = wrap_frozen_graph(gd=graph_def, inputs="x:0", outputs="Identity:0")
elif saved_model:
LOGGER.info(f'Loading {w} for TensorFlow saved_model inference...')
model = tf.keras.models.load_model(w)
elif tflite: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
if 'edgetpu' in w.lower():
LOGGER.info(f'Loading {w} for TensorFlow Edge TPU inference...')
import tflite_runtime.interpreter as tfli
delegate = {'Linux': 'libedgetpu.so.1', # install https://coral.ai/software/#edgetpu-runtime
'Darwin': 'libedgetpu.1.dylib',
'Windows': 'edgetpu.dll'}[platform.system()]
interpreter = tfli.Interpreter(model_path=w, experimental_delegates=[tfli.load_delegate(delegate)])
else:
LOGGER.info(f'Loading {w} for TensorFlow Lite inference...')
interpreter = tf.lite.Interpreter(model_path=w) # load TFLite model
interpreter.allocate_tensors() # allocate
input_details = interpreter.get_input_details() # inputs
output_details = interpreter.get_output_details() # outputs
self.__dict__.update(locals()) # assign all variables to self

def forward(self, im, augment=False, visualize=False, val=False):
# YOLOv5 MultiBackend inference
b, ch, h, w = im.shape # batch, channel, height, width
if self.pt: # PyTorch
y = self.model(im) if self.jit else self.model(im, augment=augment, visualize=visualize)
return y if val else y[0]
elif self.coreml: # CoreML *.mlmodel
im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
im = Image.fromarray((im[0] * 255).astype('uint8'))
# im = im.resize((192, 320), Image.ANTIALIAS)
y = self.model.predict({'image': im}) # coordinates are xywh normalized
box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels
conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float)
y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1)
elif self.onnx: # ONNX
im = im.cpu().numpy() # torch to numpy
if self.dnn: # ONNX OpenCV DNN
self.net.setInput(im)
y = self.net.forward()
else: # ONNX Runtime
y = self.session.run([self.session.get_outputs()[0].name], {self.session.get_inputs()[0].name: im})[0]
else: # TensorFlow model (TFLite, pb, saved_model)
im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
if self.pb:
y = self.frozen_func(x=self.tf.constant(im)).numpy()
elif self.saved_model:
y = self.model(im, training=False).numpy()
elif self.tflite:
input, output = self.input_details[0], self.output_details[0]
int8 = input['dtype'] == np.uint8 # is TFLite quantized uint8 model
if int8:
scale, zero_point = input['quantization']
im = (im / scale + zero_point).astype(np.uint8) # de-scale
self.interpreter.set_tensor(input['index'], im)
self.interpreter.invoke()
y = self.interpreter.get_tensor(output['index'])
if int8:
scale, zero_point = output['quantization']
y = (y.astype(np.float32) - zero_point) * scale # re-scale
y[..., 0] *= w # x
y[..., 1] *= h # y
y[..., 2] *= w # w
y[..., 3] *= h # h
y = torch.tensor(y)
return (y, []) if val else y


class AutoShape(nn.Module):
# YOLOv5 input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
conf = 0.25 # NMS confidence threshold
Expand Down
3 changes: 2 additions & 1 deletion utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,7 +785,8 @@ def print_mutation(results, hyp, save_dir, bucket):


def apply_classifier(x, model, img, im0):
# Apply a second stage classifier to yolo outputs
# Apply a second stage classifier to YOLO outputs
# Example model = torchvision.models.__dict__['efficientnet_b0'](pretrained=True).to(device).eval()
im0 = [im0] if isinstance(im0, np.ndarray) else im0
for i, d in enumerate(x): # per image
if d is not None and len(d):
Expand Down
Loading

0 comments on commit aa1d9f2

Please sign in to comment.