Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New DetectMultiBackend() class #5549

Merged
merged 67 commits into from
Nov 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
b90e8bf
New `DetectMultiBackend()` class
glenn-jocher Nov 7, 2021
0d9bc34
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 7, 2021
795fbe8
pb to pt fix
glenn-jocher Nov 7, 2021
876e0f8
Merge remote-tracking branch 'origin/add/detect_multi_backend' into a…
glenn-jocher Nov 7, 2021
9e1d0a6
Cleanup
glenn-jocher Nov 7, 2021
4f77588
explicit apply_classifier path
glenn-jocher Nov 7, 2021
96c9462
Cleanup2
glenn-jocher Nov 7, 2021
398d377
Cleanup3
glenn-jocher Nov 7, 2021
89bf2f1
Cleanup4
glenn-jocher Nov 7, 2021
47550b0
Cleanup5
glenn-jocher Nov 7, 2021
d08b356
Cleanup6
glenn-jocher Nov 7, 2021
407d5d3
val.py MultiBackend inference
glenn-jocher Nov 7, 2021
71b320a
warmup fix
glenn-jocher Nov 7, 2021
293e98d
to device fix
glenn-jocher Nov 7, 2021
201107e
pt fix
glenn-jocher Nov 7, 2021
a7f17e9
device fix
glenn-jocher Nov 7, 2021
3cf44c3
Val cleanup
glenn-jocher Nov 7, 2021
d32ca2e
COCO128 URL to assets
glenn-jocher Nov 7, 2021
5f3a5fb
half fix
glenn-jocher Nov 7, 2021
9015662
detect fix
glenn-jocher Nov 7, 2021
dc5c370
detect fix 2
glenn-jocher Nov 7, 2021
77fbc8f
remove half from DetectMultiBackend
glenn-jocher Nov 7, 2021
a80c511
training half handling
glenn-jocher Nov 7, 2021
e165bd4
training half handling 2
glenn-jocher Nov 7, 2021
c743bb6
training half handling 3
glenn-jocher Nov 7, 2021
f312ab6
Cleanup
glenn-jocher Nov 7, 2021
bdde9ef
Fix CI error
glenn-jocher Nov 7, 2021
7003939
Add torchscript _extra_files
glenn-jocher Nov 7, 2021
ef3f161
Add TorchScript
glenn-jocher Nov 7, 2021
82bfd0f
Add CoreML
glenn-jocher Nov 7, 2021
109c5d6
CoreML cleanup
glenn-jocher Nov 7, 2021
3248a57
New `DetectMultiBackend()` class
glenn-jocher Nov 7, 2021
a9a0fed
pb to pt fix
glenn-jocher Nov 7, 2021
11bd91c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 7, 2021
07b4289
Cleanup
glenn-jocher Nov 7, 2021
b6f6c0d
explicit apply_classifier path
glenn-jocher Nov 7, 2021
4d85ec2
Cleanup2
glenn-jocher Nov 7, 2021
df737a0
Cleanup3
glenn-jocher Nov 7, 2021
70e9dfb
Cleanup4
glenn-jocher Nov 7, 2021
9729521
Cleanup5
glenn-jocher Nov 7, 2021
ab7358f
Cleanup6
glenn-jocher Nov 7, 2021
c1bf0e2
val.py MultiBackend inference
glenn-jocher Nov 7, 2021
b5bae24
warmup fix
glenn-jocher Nov 7, 2021
96f2b3c
to device fix
glenn-jocher Nov 7, 2021
32974f2
pt fix
glenn-jocher Nov 7, 2021
d955ed6
device fix
glenn-jocher Nov 7, 2021
9c25359
Val cleanup
glenn-jocher Nov 7, 2021
e9cd5eb
COCO128 URL to assets
glenn-jocher Nov 7, 2021
54d3dfa
half fix
glenn-jocher Nov 7, 2021
0b07c0c
detect fix
glenn-jocher Nov 7, 2021
55eefc0
detect fix 2
glenn-jocher Nov 7, 2021
17676ae
remove half from DetectMultiBackend
glenn-jocher Nov 7, 2021
3985a59
training half handling
glenn-jocher Nov 7, 2021
9844c81
training half handling 2
glenn-jocher Nov 7, 2021
fe94f4b
training half handling 3
glenn-jocher Nov 7, 2021
28de246
Cleanup
glenn-jocher Nov 7, 2021
709d9ce
Fix CI error
glenn-jocher Nov 7, 2021
19bdb6e
Add torchscript _extra_files
glenn-jocher Nov 7, 2021
358d9e3
Add TorchScript
glenn-jocher Nov 7, 2021
dc0b748
Add CoreML
glenn-jocher Nov 7, 2021
0bfaba5
CoreML cleanup
glenn-jocher Nov 7, 2021
47acb1e
Merge remote-tracking branch 'origin/add/detect_multi_backend' into a…
glenn-jocher Nov 8, 2021
ba46597
Merge branch 'master' into add/detect_multi_backend
glenn-jocher Nov 8, 2021
8422530
Merge remote-tracking branch 'origin/add/detect_multi_backend' into a…
glenn-jocher Nov 8, 2021
cd92e01
revert default to pt
glenn-jocher Nov 9, 2021
ffa76ee
Add Usage examples
glenn-jocher Nov 9, 2021
0f98f01
Cleanup val
glenn-jocher Nov 9, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion data/coco128.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@ names: ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 't


# Download script/URL (optional)
download: https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128.zip
download: https://ultralytics.com/assets/coco128.zip
133 changes: 27 additions & 106 deletions detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,10 @@

import argparse
import os
import platform
import sys
from pathlib import Path

import cv2
import numpy as np
import torch
import torch.backends.cudnn as cudnn

Expand All @@ -29,13 +27,12 @@
sys.path.append(str(ROOT)) # add ROOT to PATH
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative

from models.experimental import attempt_load
from models.common import DetectMultiBackend
from utils.datasets import IMG_FORMATS, VID_FORMATS, LoadImages, LoadStreams
from utils.general import (LOGGER, apply_classifier, check_file, check_img_size, check_imshow, check_requirements,
check_suffix, colorstr, increment_path, non_max_suppression, print_args, scale_coords,
strip_optimizer, xyxy2xywh)
from utils.general import (LOGGER, check_file, check_img_size, check_imshow, check_requirements, colorstr,
increment_path, non_max_suppression, print_args, scale_coords, strip_optimizer, xyxy2xywh)
from utils.plots import Annotator, colors, save_one_box
from utils.torch_utils import load_classifier, select_device, time_sync
from utils.torch_utils import select_device, time_sync


@torch.no_grad()
Expand Down Expand Up @@ -77,120 +74,45 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s)
save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run
(save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir

# Initialize
# Load model
device = select_device(device)
half &= device.type != 'cpu' # half precision only supported on CUDA
model = DetectMultiBackend(weights, device=device, dnn=dnn)
stride, names, pt, jit, onnx = model.stride, model.names, model.pt, model.jit, model.onnx
imgsz = check_img_size(imgsz, s=stride) # check image size

# Load model
w = str(weights[0] if isinstance(weights, list) else weights)
classify, suffix, suffixes = False, Path(w).suffix.lower(), ['.pt', '.onnx', '.tflite', '.pb', '']
check_suffix(w, suffixes) # check weights have acceptable suffix
pt, onnx, tflite, pb, saved_model = (suffix == x for x in suffixes) # backend booleans
stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults
# Half
half &= pt and device.type != 'cpu' # half precision only supported by PyTorch on CUDA
if pt:
model = torch.jit.load(w) if 'torchscript' in w else attempt_load(weights, map_location=device)
stride = int(model.stride.max()) # model stride
names = model.module.names if hasattr(model, 'module') else model.names # get class names
if half:
model.half() # to FP16
if classify: # second-stage classifier
modelc = load_classifier(name='resnet50', n=2) # initialize
modelc.load_state_dict(torch.load('resnet50.pt', map_location=device)['model']).to(device).eval()
elif onnx:
if dnn:
check_requirements(('opencv-python>=4.5.4',))
net = cv2.dnn.readNetFromONNX(w)
else:
check_requirements(('onnx', 'onnxruntime-gpu' if torch.has_cuda else 'onnxruntime'))
import onnxruntime
session = onnxruntime.InferenceSession(w, None)
else: # TensorFlow models
import tensorflow as tf
if pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
def wrap_frozen_graph(gd, inputs, outputs):
x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped import
return x.prune(tf.nest.map_structure(x.graph.as_graph_element, inputs),
tf.nest.map_structure(x.graph.as_graph_element, outputs))

graph_def = tf.Graph().as_graph_def()
graph_def.ParseFromString(open(w, 'rb').read())
frozen_func = wrap_frozen_graph(gd=graph_def, inputs="x:0", outputs="Identity:0")
elif saved_model:
model = tf.keras.models.load_model(w)
elif tflite:
if "edgetpu" in w: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
import tflite_runtime.interpreter as tflri
delegate = {'Linux': 'libedgetpu.so.1', # install libedgetpu https://coral.ai/software/#edgetpu-runtime
'Darwin': 'libedgetpu.1.dylib',
'Windows': 'edgetpu.dll'}[platform.system()]
interpreter = tflri.Interpreter(model_path=w, experimental_delegates=[tflri.load_delegate(delegate)])
else:
interpreter = tf.lite.Interpreter(model_path=w) # load TFLite model
interpreter.allocate_tensors() # allocate
input_details = interpreter.get_input_details() # inputs
output_details = interpreter.get_output_details() # outputs
int8 = input_details[0]['dtype'] == np.uint8 # is TFLite quantized uint8 model
imgsz = check_img_size(imgsz, s=stride) # check image size
model.model.half() if half else model.model.float()

# Dataloader
if webcam:
view_img = check_imshow()
cudnn.benchmark = True # set True to speed up constant image size inference
dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt)
dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt and not jit)
bs = len(dataset) # batch_size
else:
dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt)
dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt and not jit)
bs = 1 # batch_size
vid_path, vid_writer = [None] * bs, [None] * bs

# Run inference
if pt and device.type != 'cpu':
model(torch.zeros(1, 3, *imgsz).to(device).type_as(next(model.parameters()))) # run once
model(torch.zeros(1, 3, *imgsz).to(device).type_as(next(model.model.parameters()))) # warmup
dt, seen = [0.0, 0.0, 0.0], 0
for path, img, im0s, vid_cap, s in dataset:
for path, im, im0s, vid_cap, s in dataset:
t1 = time_sync()
if onnx:
img = img.astype('float32')
else:
img = torch.from_numpy(img).to(device)
img = img.half() if half else img.float() # uint8 to fp16/32
img /= 255 # 0 - 255 to 0.0 - 1.0
if len(img.shape) == 3:
img = img[None] # expand for batch dim
im = torch.from_numpy(im).to(device)
im = im.half() if half else im.float() # uint8 to fp16/32
im /= 255 # 0 - 255 to 0.0 - 1.0
if len(im.shape) == 3:
im = im[None] # expand for batch dim
t2 = time_sync()
dt[0] += t2 - t1

# Inference
if pt:
visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
pred = model(img, augment=augment, visualize=visualize)[0]
elif onnx:
if dnn:
net.setInput(img)
pred = torch.tensor(net.forward())
else:
pred = torch.tensor(session.run([session.get_outputs()[0].name], {session.get_inputs()[0].name: img}))
else: # tensorflow model (tflite, pb, saved_model)
imn = img.permute(0, 2, 3, 1).cpu().numpy() # image in numpy
if pb:
pred = frozen_func(x=tf.constant(imn)).numpy()
elif saved_model:
pred = model(imn, training=False).numpy()
elif tflite:
if int8:
scale, zero_point = input_details[0]['quantization']
imn = (imn / scale + zero_point).astype(np.uint8) # de-scale
interpreter.set_tensor(input_details[0]['index'], imn)
interpreter.invoke()
pred = interpreter.get_tensor(output_details[0]['index'])
if int8:
scale, zero_point = output_details[0]['quantization']
pred = (pred.astype(np.float32) - zero_point) * scale # re-scale
pred[..., 0] *= imgsz[1] # x
pred[..., 1] *= imgsz[0] # y
pred[..., 2] *= imgsz[1] # w
pred[..., 3] *= imgsz[0] # h
pred = torch.tensor(pred)
visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
pred = model(im, augment=augment, visualize=visualize)
t3 = time_sync()
dt[1] += t3 - t2

Expand All @@ -199,8 +121,7 @@ def wrap_frozen_graph(gd, inputs, outputs):
dt[2] += time_sync() - t3

# Second-stage classifier (optional)
if classify:
pred = apply_classifier(pred, modelc, img, im0s)
# pred = utils.general.apply_classifier(pred, classifier_model, im, im0s)

# Process predictions
for i, det in enumerate(pred): # per image
Expand All @@ -212,15 +133,15 @@ def wrap_frozen_graph(gd, inputs, outputs):
p, im0, frame = path, im0s.copy(), getattr(dataset, 'frame', 0)

p = Path(p) # to Path
save_path = str(save_dir / p.name) # img.jpg
txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # img.txt
s += '%gx%g ' % img.shape[2:] # print string
save_path = str(save_dir / p.name) # im.jpg
txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # im.txt
s += '%gx%g ' % im.shape[2:] # print string
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
imc = im0.copy() if save_crop else im0 # for save_crop
annotator = Annotator(im0, line_width=line_thickness, example=str(names))
if len(det):
# Rescale boxes from img_size to im0 size
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
det[:, :4] = scale_coords(im.shape[2:], det[:, :4], im0.shape).round()

# Print results
for c in det[:, -1].unique():
Expand Down
5 changes: 4 additions & 1 deletion export.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"""

import argparse
import json
import os
import subprocess
import sys
Expand Down Expand Up @@ -54,7 +55,9 @@ def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:'
f = file.with_suffix('.torchscript.pt')

ts = torch.jit.trace(model, im, strict=False)
(optimize_for_mobile(ts) if optimize else ts).save(f)
d = {"shape": im.shape, "stride": int(max(model.stride)), "names": model.names}
extra_files = {'config.txt': json.dumps(d)} # torch._C.ExtraFilesMap()
(optimize_for_mobile(ts) if optimize else ts).save(f, _extra_files=extra_files)

LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
except Exception as e:
Expand Down
128 changes: 127 additions & 1 deletion models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
Common modules
"""

import json
import math
import platform
import warnings
from copy import copy
from pathlib import Path

import cv2
import numpy as np
import pandas as pd
import requests
Expand All @@ -17,7 +20,8 @@
from torch.cuda import amp

from utils.datasets import exif_transpose, letterbox
from utils.general import LOGGER, colorstr, increment_path, make_divisible, non_max_suppression, scale_coords, xyxy2xywh
from utils.general import (LOGGER, check_requirements, check_suffix, colorstr, increment_path, make_divisible,
non_max_suppression, scale_coords, xywh2xyxy, xyxy2xywh)
from utils.plots import Annotator, colors, save_one_box
from utils.torch_utils import time_sync

Expand Down Expand Up @@ -269,6 +273,128 @@ def forward(self, x):
return torch.cat(x, self.d)


class DetectMultiBackend(nn.Module):
# YOLOv5 MultiBackend class for python inference on various backends
def __init__(self, weights='yolov5s.pt', device=None, dnn=True):
# Usage:
# PyTorch: weights = *.pt
# TorchScript: *.torchscript.pt
# CoreML: *.mlmodel
# TensorFlow: *_saved_model
# TensorFlow: *.pb
# TensorFlow Lite: *.tflite
# ONNX Runtime: *.onnx
# OpenCV DNN: *.onnx with dnn=True
super().__init__()
w = str(weights[0] if isinstance(weights, list) else weights)
suffix, suffixes = Path(w).suffix.lower(), ['.pt', '.onnx', '.tflite', '.pb', '', '.mlmodel']
check_suffix(w, suffixes) # check weights have acceptable suffix
pt, onnx, tflite, pb, saved_model, coreml = (suffix == x for x in suffixes) # backend booleans
jit = pt and 'torchscript' in w.lower()
stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults

if jit: # TorchScript
LOGGER.info(f'Loading {w} for TorchScript inference...')
extra_files = {'config.txt': ''} # model metadata
model = torch.jit.load(w, _extra_files=extra_files)
if extra_files['config.txt']:
d = json.loads(extra_files['config.txt']) # extra_files dict
stride, names = int(d['stride']), d['names']
elif pt: # PyTorch
from models.experimental import attempt_load # scoped to avoid circular import
model = torch.jit.load(w) if 'torchscript' in w else attempt_load(weights, map_location=device)
stride = int(model.stride.max()) # model stride
names = model.module.names if hasattr(model, 'module') else model.names # get class names
elif coreml: # CoreML *.mlmodel
import coremltools as ct
model = ct.models.MLModel(w)
elif dnn: # ONNX OpenCV DNN
LOGGER.info(f'Loading {w} for ONNX OpenCV DNN inference...')
check_requirements(('opencv-python>=4.5.4',))
net = cv2.dnn.readNetFromONNX(w)
elif onnx: # ONNX Runtime
LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
check_requirements(('onnx', 'onnxruntime-gpu' if torch.has_cuda else 'onnxruntime'))
import onnxruntime
session = onnxruntime.InferenceSession(w, None)
else: # TensorFlow model (TFLite, pb, saved_model)
import tensorflow as tf
if pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
def wrap_frozen_graph(gd, inputs, outputs):
x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped
return x.prune(tf.nest.map_structure(x.graph.as_graph_element, inputs),
tf.nest.map_structure(x.graph.as_graph_element, outputs))

LOGGER.info(f'Loading {w} for TensorFlow *.pb inference...')
graph_def = tf.Graph().as_graph_def()
graph_def.ParseFromString(open(w, 'rb').read())
frozen_func = wrap_frozen_graph(gd=graph_def, inputs="x:0", outputs="Identity:0")
elif saved_model:
LOGGER.info(f'Loading {w} for TensorFlow saved_model inference...')
model = tf.keras.models.load_model(w)
elif tflite: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
if 'edgetpu' in w.lower():
LOGGER.info(f'Loading {w} for TensorFlow Edge TPU inference...')
import tflite_runtime.interpreter as tfli
delegate = {'Linux': 'libedgetpu.so.1', # install https://coral.ai/software/#edgetpu-runtime
'Darwin': 'libedgetpu.1.dylib',
'Windows': 'edgetpu.dll'}[platform.system()]
interpreter = tfli.Interpreter(model_path=w, experimental_delegates=[tfli.load_delegate(delegate)])
else:
LOGGER.info(f'Loading {w} for TensorFlow Lite inference...')
interpreter = tf.lite.Interpreter(model_path=w) # load TFLite model
interpreter.allocate_tensors() # allocate
input_details = interpreter.get_input_details() # inputs
output_details = interpreter.get_output_details() # outputs
self.__dict__.update(locals()) # assign all variables to self

def forward(self, im, augment=False, visualize=False, val=False):
# YOLOv5 MultiBackend inference
b, ch, h, w = im.shape # batch, channel, height, width
if self.pt: # PyTorch
y = self.model(im) if self.jit else self.model(im, augment=augment, visualize=visualize)
return y if val else y[0]
elif self.coreml: # CoreML *.mlmodel
im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
im = Image.fromarray((im[0] * 255).astype('uint8'))
# im = im.resize((192, 320), Image.ANTIALIAS)
y = self.model.predict({'image': im}) # coordinates are xywh normalized
box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels
conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float)
y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1)
elif self.onnx: # ONNX
im = im.cpu().numpy() # torch to numpy
if self.dnn: # ONNX OpenCV DNN
self.net.setInput(im)
y = self.net.forward()
else: # ONNX Runtime
y = self.session.run([self.session.get_outputs()[0].name], {self.session.get_inputs()[0].name: im})[0]
else: # TensorFlow model (TFLite, pb, saved_model)
im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
if self.pb:
y = self.frozen_func(x=self.tf.constant(im)).numpy()
elif self.saved_model:
y = self.model(im, training=False).numpy()
elif self.tflite:
input, output = self.input_details[0], self.output_details[0]
int8 = input['dtype'] == np.uint8 # is TFLite quantized uint8 model
if int8:
scale, zero_point = input['quantization']
im = (im / scale + zero_point).astype(np.uint8) # de-scale
self.interpreter.set_tensor(input['index'], im)
self.interpreter.invoke()
y = self.interpreter.get_tensor(output['index'])
if int8:
scale, zero_point = output['quantization']
y = (y.astype(np.float32) - zero_point) * scale # re-scale
y[..., 0] *= w # x
y[..., 1] *= h # y
y[..., 2] *= w # w
y[..., 3] *= h # h
y = torch.tensor(y)
return (y, []) if val else y


class AutoShape(nn.Module):
# YOLOv5 input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
conf = 0.25 # NMS confidence threshold
Expand Down
3 changes: 2 additions & 1 deletion utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,7 +785,8 @@ def print_mutation(results, hyp, save_dir, bucket):


def apply_classifier(x, model, img, im0):
# Apply a second stage classifier to yolo outputs
# Apply a second stage classifier to YOLO outputs
# Example model = torchvision.models.__dict__['efficientnet_b0'](pretrained=True).to(device).eval()
im0 = [im0] if isinstance(im0, np.ndarray) else im0
for i, d in enumerate(x): # per image
if d is not None and len(d):
Expand Down
Loading