-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
13 changed files
with
208 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -136,4 +136,6 @@ dmypy.json | |
|
||
# Cython debug symbols | ||
cython_debug/ | ||
.vscode/ | ||
.vscode/ | ||
*.jpg | ||
test.ipynb |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
onnxruntime~=1.15.1 | ||
opencv-python-headless~=4.8.0.76 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,10 @@ | ||
torch~=2.0.1 | ||
torchvision~=0.15.1 | ||
pyyaml | ||
ultralytics~=8.0.114 | ||
av~=10.0.0 | ||
torch~=2.0.1 | ||
torchvision~=0.15.1 | ||
onnxsim~=0.4.33 | ||
onnx~=1.14.0 | ||
ipykernel | ||
--extra-index-url https://download.pytorch.org/whl/cu118 | ||
-r ./requirements.inference.txt |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import numpy as np | ||
|
||
from .image_utils import PadInfo, ScaleInfo | ||
|
||
|
||
def restore_original_coordinates_inplace(bbox_xywh: np.ndarray, pad_info: PadInfo): | ||
bbox_xywh[:, :2] -= (pad_info.pad_left, pad_info.pad_top) | ||
return bbox_xywh | ||
|
||
|
||
def yolo_bbox2xywh_inplace(yolo_xywh_bboxes: np.ndarray): | ||
yolo_xywh_bboxes[:, 0] = yolo_xywh_bboxes[:, 0] - yolo_xywh_bboxes[:, 2] / 2 | ||
yolo_xywh_bboxes[:, 1] = yolo_xywh_bboxes[:, 1] - yolo_xywh_bboxes[:, 3] / 2 | ||
return yolo_xywh_bboxes | ||
|
||
def xyhw2xyxy_inplace(xy2wh: np.ndarray): | ||
xy2wh[:, 2] += xy2wh[:, 0] | ||
xy2wh[:, 3] += xy2wh[:, 1] | ||
|
||
return xy2wh | ||
|
||
def clip_boxes_inplace(boxes: np.ndarray, width: int, height: int): | ||
""" | ||
It takes a list of bounding boxes and a shape (height, width) and clips the bounding boxes to the | ||
shape | ||
Args: | ||
boxes: the bounding boxes to clip | ||
width: image width | ||
height: image height | ||
""" | ||
boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, width) # x1, x2 | ||
boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, height) # y1, y2 | ||
|
||
|
||
def scale_bbox_inplace(bbox: np.ndarray, scale_info: ScaleInfo): | ||
bbox *= (scale_info.to_orig_scale_width, scale_info.to_orig_scale_height, | ||
scale_info.to_orig_scale_width, scale_info.to_orig_scale_height) | ||
return bbox |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
def parse_class_mapping_from_str(class_mapping: str) -> dict: | ||
class_mapping = class_mapping.strip()[1:-1] | ||
|
||
mapping = {} | ||
|
||
for pair in class_mapping.split(","): | ||
class_label, class_name = pair.split(":") | ||
mapping[int(class_label)] = class_name.strip().replace("'", "") | ||
|
||
return mapping |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .detector import ONNXYoloV8Detector |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
from typing import List, NamedTuple, Optional, Sequence | ||
|
||
import cv2 | ||
import numpy as np | ||
import onnxruntime as ort | ||
|
||
from ..bbox_utils import (clip_boxes_inplace, | ||
restore_original_coordinates_inplace, | ||
scale_bbox_inplace, xyhw2xyxy_inplace, | ||
yolo_bbox2xywh_inplace) | ||
from ..classes import parse_class_mapping_from_str | ||
from ..image_utils import (PadInfo, ScaleInfo, pad_image, | ||
resize_to_required_size_keep_aspect_ratio) | ||
|
||
|
||
class DetectionInfo(NamedTuple): | ||
scores: np.ndarray | ||
xyxy_boxes: np.ndarray | ||
classes: Sequence[str] | ||
|
||
|
||
class ONNXYoloV8Detector: | ||
def __init__(self, path_to_model: str, providers: Optional[List[str]] = None): | ||
self.session = ort.InferenceSession( | ||
path_to_model, | ||
providers=providers | ||
) | ||
inputs = self.session.get_inputs() | ||
self.session.get_modelmeta() | ||
outputs = self.session.get_outputs() | ||
|
||
assert len(inputs) == 1, f"Detection model expected only one input, but found {len(inputs)}" | ||
assert len( | ||
outputs) == 1, f"Detection model expected only one input, but found {len(outputs)}" | ||
|
||
self._class_mapping = parse_class_mapping_from_str( | ||
self.session.get_modelmeta().custom_metadata_map["names"]) | ||
input_shape = inputs[0].shape | ||
self._max_image_size = max(input_shape[2:]) | ||
self.input_name = inputs[0].name | ||
|
||
def postpocess(self, | ||
pred: np.ndarray, | ||
score_threshold: float, | ||
nms_threshold: float, | ||
pad_info: PadInfo, | ||
scale_info: ScaleInfo, | ||
original_width: int, | ||
original_height: int) -> DetectionInfo: | ||
"""pred: [1 x 84 x 8400] | ||
1 - bath size | ||
84 - 0,1,2,3 is x,y,width,height, 4,5,6,7,8,9.... probability for each class | ||
8400 - number of possible detected objects | ||
https://github.com/ultralytics/ultralytics/issues/2670#issuecomment-1551453142 | ||
""" | ||
xywh_yolo = pred[0, :4, :].T | ||
raw_scores = pred[0, 4:, :] | ||
scores = raw_scores.max(axis=0) | ||
|
||
xywh = yolo_bbox2xywh_inplace(xywh_yolo) | ||
bbox_indices = cv2.dnn.NMSBoxes(xywh, scores, score_threshold, nms_threshold) | ||
|
||
det_xywh = xywh[bbox_indices] | ||
class_indices = raw_scores[:, bbox_indices].argmax(axis=0) | ||
|
||
restore_original_coordinates_inplace(det_xywh, pad_info) | ||
scale_bbox_inplace(det_xywh, scale_info) | ||
det_xyxy = xyhw2xyxy_inplace(det_xywh).round().astype(int) | ||
clip_boxes_inplace(det_xyxy, original_width, original_height) | ||
|
||
return DetectionInfo( | ||
scores[bbox_indices], | ||
det_xyxy, | ||
[self._class_mapping[int(class_index)] for class_index in class_indices] | ||
) | ||
|
||
def preprocess_image(self, image: np.ndarray): | ||
image, scale_info = resize_to_required_size_keep_aspect_ratio(image, self._max_image_size) | ||
image, pad_info = pad_image(image, self._max_image_size, self._max_image_size) | ||
image = image.transpose((2, 0, 1)).astype(np.float32) | ||
image /= 255 | ||
return image[np.newaxis, ...], scale_info, pad_info | ||
|
||
def _raw_predict(self, image: np.ndarray): | ||
"""image is RGB image | ||
""" | ||
return self.session.run(None, {self.input_name: image})[0] | ||
|
||
def predict(self, image: np.ndarray, score_threshold: float, nms_threshold: float) -> DetectionInfo: | ||
original_height, original_width = image.shape[:2] | ||
image, scale_info, pad_info = self.preprocess_image(image) | ||
raw_pred = self._raw_predict(image) | ||
return self.postpocess(raw_pred, score_threshold, nms_threshold, pad_info, scale_info, original_width=original_width, original_height=original_height) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
from typing import NamedTuple, Tuple | ||
|
||
import cv2 | ||
import numpy as np | ||
|
||
|
||
class PadInfo(NamedTuple): | ||
pad_left: int | ||
pad_right: int | ||
pad_top: int | ||
pad_bottom: int | ||
|
||
|
||
class ScaleInfo(NamedTuple): | ||
to_orig_scale_width: float | ||
to_orig_scale_height: float | ||
|
||
|
||
def pad_image(image: np.ndarray, required_width: int, required_height: int) -> Tuple[np.ndarray, PadInfo]: | ||
image_height = image.shape[0] | ||
image_width = image.shape[1] | ||
|
||
assert image_height <= required_height | ||
assert image_width <= required_width | ||
|
||
pad_width = required_width - image_width | ||
pad_height = required_height - image_height | ||
|
||
pad_width_left = pad_width // 2 | ||
pad_width_right = required_width - pad_width_left - image_width | ||
pad_height_top = pad_height // 2 | ||
pad_height_bottom = required_height - pad_height_top - image_height | ||
return np.pad(image, ((pad_height_top, pad_height_bottom), (pad_width_left, pad_width_right), (0, 0)), constant_values=0), PadInfo(pad_width_left, pad_width_right, pad_height_top, pad_height_bottom) | ||
|
||
|
||
def resize_to_required_size_keep_aspect_ratio(image: np.ndarray, max_size: int) -> Tuple[np.ndarray, ScaleInfo]: | ||
height, width = image.shape[:2] | ||
aspect_ratio = width / height | ||
|
||
if aspect_ratio > 1: | ||
# Landscape orientation | ||
new_width = max_size | ||
new_height = min(round(new_width / aspect_ratio), max_size) | ||
else: | ||
# Portrait orientation | ||
new_height = max_size | ||
new_width = min(round(new_height * aspect_ratio), max_size) | ||
|
||
scale_info = ScaleInfo(width / new_width, height / new_height) | ||
|
||
return cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_LANCZOS4), scale_info |
File renamed without changes.
File renamed without changes.
File renamed without changes.