Skip to content

Commit

Permalink
Add CoreML
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher committed Nov 8, 2021
1 parent 358d9e3 commit dc0b748
Showing 1 changed file with 17 additions and 8 deletions.
25 changes: 17 additions & 8 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from utils.datasets import exif_transpose, letterbox
from utils.general import (check_requirements, check_suffix, colorstr, increment_path, make_divisible,
non_max_suppression, scale_coords, xyxy2xywh)
non_max_suppression, scale_coords, xywh2xyxy, xyxy2xywh)
from utils.plots import Annotator, colors, save_one_box
from utils.torch_utils import time_sync

Expand Down Expand Up @@ -281,9 +281,9 @@ class DetectMultiBackend(nn.Module):
def __init__(self, weights='yolov5s.pt', device=None, dnn=True):
super().__init__()
w = str(weights[0] if isinstance(weights, list) else weights)
suffix, suffixes = Path(w).suffix.lower(), ['.pt', '.onnx', '.tflite', '.pb', '']
suffix, suffixes = Path(w).suffix.lower(), ['.pt', '.onnx', '.tflite', '.pb', '', '.mlmodel']
check_suffix(w, suffixes) # check weights have acceptable suffix
pt, onnx, tflite, pb, saved_model = (suffix == x for x in suffixes) # backend booleans
pt, onnx, tflite, pb, saved_model, coreml = (suffix == x for x in suffixes) # backend booleans
jit = pt and 'torchscript' in w.lower()
stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults

Expand All @@ -299,6 +299,9 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=True):
model = torch.jit.load(w) if 'torchscript' in w else attempt_load(weights, map_location=device)
stride = int(model.stride.max()) # model stride
names = model.module.names if hasattr(model, 'module') else model.names # get class names
elif coreml: # CoreML *.mlmodel
import coremltools as ct
model = ct.models.MLModel(w)
elif dnn: # ONNX OpenCV DNN
LOGGER.info(f'Loading {w} for ONNX OpenCV DNN inference...')
check_requirements(('opencv-python>=4.5.4',))
Expand Down Expand Up @@ -341,9 +344,15 @@ def wrap_frozen_graph(gd, inputs, outputs):

def forward(self, im, augment=False, visualize=False, val=False):
# YOLOv5 MultiBackend inference
b, ch, h, w = im.shape # batch, channel, height, width
if self.pt: # PyTorch
y = self.model(im) if self.jit else self.model(im, augment=augment, visualize=visualize)
return y if val else y[0]
elif self.coreml: # CoreML *.mlmodel
y = self.model.predict({'image': im}) # coordinates are xywh normalized
box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels
conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float)
y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1)
elif self.onnx: # ONNX
im = im.cpu().numpy() # torch to numpy
if self.dnn: # ONNX OpenCV DNN
Expand All @@ -352,7 +361,7 @@ def forward(self, im, augment=False, visualize=False, val=False):
else: # ONNX Runtime
y = self.session.run([self.session.get_outputs()[0].name], {self.session.get_inputs()[0].name: im})[0]
else: # TensorFlow model (TFLite, pb, saved_model)
im = im.permute(0, 2, 3, 1).cpu().numpy() # TF format (1,640,640,3)
im = im.permute(0, 2, 3, 1).cpu().numpy() # TF format (1,h=640,w=640,3)
if self.pb:
y = self.frozen_func(x=self.tf.constant(im)).numpy()
elif self.saved_model:
Expand All @@ -369,10 +378,10 @@ def forward(self, im, augment=False, visualize=False, val=False):
if int8:
scale, zero_point = output['quantization']
y = (y.astype(np.float32) - zero_point) * scale # re-scale
y[..., 0] *= im.shape[2] # x
y[..., 1] *= im.shape[1] # y
y[..., 2] *= im.shape[2] # w
y[..., 3] *= im.shape[1] # h
y[..., 0] *= w # x
y[..., 1] *= h # y
y[..., 2] *= w # w
y[..., 3] *= h # h
y = torch.tensor(y)
return (y, []) if val else y

Expand Down

0 comments on commit dc0b748

Please sign in to comment.