Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ONNX inference providers #5918

Merged
merged 2 commits into from
Dec 8, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,9 +320,11 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=False):
net = cv2.dnn.readNetFromONNX(w)
elif onnx: # ONNX Runtime
LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
check_requirements(('onnx', 'onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime'))
cuda = torch.cuda.is_available()
check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime'))
import onnxruntime
session = onnxruntime.InferenceSession(w, None)
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider']
session = onnxruntime.InferenceSession(w, providers=providers)
elif engine: # TensorRT
LOGGER.info(f'Loading {w} for TensorRT inference...')
import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
Expand Down