Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Detect.py supports running against a Triton container #9228

Merged
merged 44 commits into from
Sep 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
92a3ff0
update coco128-seg comments
glenn-jocher Aug 24, 2022
a3f9e99
Merge branch 'master' of https://github.com/gaziqbal/yolov5
Aug 30, 2022
e58e5fa
Enables detect.py to use Triton for inference
Aug 24, 2022
f27d974
added triton client to requirements
Aug 26, 2022
e764d00
fixed support for TFSavedModels in Triton
Aug 26, 2022
4ef650b
reverted change
Aug 30, 2022
a665a85
Merge branch 'master' into triton-support
glenn-jocher Sep 4, 2022
de0073f
Merge branch 'master' into triton-support
glenn-jocher Sep 4, 2022
7e8182d
Merge branch 'master' into triton-support
glenn-jocher Sep 7, 2022
304bcbd
Merge branch 'master' into triton-support
glenn-jocher Sep 18, 2022
8c48452
Merge branch 'master' into triton-support
glenn-jocher Sep 18, 2022
4f59fcb
Test CoreML update
glenn-jocher Sep 18, 2022
9c78738
Merge branch 'master' into triton-support
glenn-jocher Sep 21, 2022
9fd00c5
Update ci-testing.yml
glenn-jocher Sep 21, 2022
8407641
Use pathlib
glenn-jocher Sep 21, 2022
1227689
Refacto DetectMultiBackend to directly accept triton url as --weights…
glenn-jocher Sep 21, 2022
e5a4c0a
Deploy category
glenn-jocher Sep 21, 2022
97b15c7
Update detect.py
glenn-jocher Sep 21, 2022
43dce21
Update common.py
glenn-jocher Sep 21, 2022
2def487
Update common.py
glenn-jocher Sep 21, 2022
9dd71a6
Merge branch 'master' into triton-support
glenn-jocher Sep 21, 2022
6ed4f48
Update predict.py
glenn-jocher Sep 21, 2022
4d8fca8
Update predict.py
glenn-jocher Sep 21, 2022
ee181d5
Update predict.py
glenn-jocher Sep 21, 2022
bd0b381
Update triton.py
glenn-jocher Sep 21, 2022
9b1fb72
Update triton.py
glenn-jocher Sep 21, 2022
52739fd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 21, 2022
6be250c
Add printout and requirements check
glenn-jocher Sep 21, 2022
d47b0df
Cleanup
glenn-jocher Sep 21, 2022
141deed
Merge branch 'master' into triton-support
glenn-jocher Sep 21, 2022
1ead357
Merge branch 'master' into triton-support
glenn-jocher Sep 21, 2022
51bfe4c
Merge branch 'master' into triton-support
glenn-jocher Sep 22, 2022
c0ad66f
triton fixes
Sep 23, 2022
27d8b4d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 23, 2022
80ad40d
fixed triton model query over grpc
Sep 23, 2022
f353125
Merge branch 'triton-support' of https://github.com/gaziqbal/yolov5 i…
Sep 23, 2022
9e8583d
Merge branch 'master' into triton-support
glenn-jocher Sep 23, 2022
fe3f304
Update check_requirements('tritonclient[all]')
glenn-jocher Sep 23, 2022
ec7dfda
group imports
glenn-jocher Sep 23, 2022
7acb42c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 23, 2022
bc56b7c
Fix likely remote URL bug
glenn-jocher Sep 23, 2022
c52c4f8
update comment
glenn-jocher Sep 23, 2022
1cd9880
Update is_url()
glenn-jocher Sep 23, 2022
f0cb782
Fix 2x download attempt on http://path/to/model.pt
glenn-jocher Sep 23, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion classify/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def run(
seen, windows, dt = 0, [], (Profile(), Profile(), Profile())
for path, im, im0s, vid_cap, s in dataset:
with dt[0]:
im = torch.Tensor(im).to(device)
im = torch.Tensor(im).to(model.device)
im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
if len(im.shape) == 3:
im = im[None] # expand for batch dim
Expand Down
8 changes: 4 additions & 4 deletions detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@

@smart_inference_mode()
def run(
weights=ROOT / 'yolov5s.pt', # model.pt path(s)
weights=ROOT / 'yolov5s.pt', # model path or triton URL
source=ROOT / 'data/images', # file/dir/URL/glob/screen/0(webcam)
data=ROOT / 'data/coco128.yaml', # dataset.yaml path
imgsz=(640, 640), # inference size (height, width)
Expand Down Expand Up @@ -108,11 +108,11 @@ def run(
vid_path, vid_writer = [None] * bs, [None] * bs

# Run inference
model.warmup(imgsz=(1 if pt else bs, 3, *imgsz)) # warmup
model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz)) # warmup
seen, windows, dt = 0, [], (Profile(), Profile(), Profile())
for path, im, im0s, vid_cap, s in dataset:
with dt[0]:
im = torch.from_numpy(im).to(device)
im = torch.from_numpy(im).to(model.device)
im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
im /= 255 # 0 - 255 to 0.0 - 1.0
if len(im.shape) == 3:
Expand Down Expand Up @@ -214,7 +214,7 @@ def run(

def parse_opt():
parser = argparse.ArgumentParser()
parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s.pt', help='model path(s)')
parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s.pt', help='model path or triton URL')
parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='file/dir/URL/glob/screen/0(webcam)')
parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='(optional) dataset.yaml path')
parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
Expand Down
44 changes: 30 additions & 14 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from collections import OrderedDict, namedtuple
from copy import copy
from pathlib import Path
from urllib.parse import urlparse

import cv2
import numpy as np
Expand Down Expand Up @@ -327,11 +328,13 @@ def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False,

super().__init__()
w = str(weights[0] if isinstance(weights, list) else weights)
pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle = self._model_type(w) # type
w = attempt_download(w) # download if not local
pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, triton = self._model_type(w)
fp16 &= pt or jit or onnx or engine # FP16
nhwc = coreml or saved_model or pb or tflite or edgetpu # BHWC formats (vs torch BCWH)
stride = 32 # default stride
cuda = torch.cuda.is_available() and device.type != 'cpu' # use CUDA
if not (pt or triton):
w = attempt_download(w) # download if not local

if pt: # PyTorch
model = attempt_load(weights if isinstance(weights, list) else w, device=device, inplace=True, fuse=fuse)
Expand All @@ -342,7 +345,7 @@ def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False,
elif jit: # TorchScript
LOGGER.info(f'Loading {w} for TorchScript inference...')
extra_files = {'config.txt': ''} # model metadata
model = torch.jit.load(w, _extra_files=extra_files)
model = torch.jit.load(w, _extra_files=extra_files, map_location=device)
model.half() if fp16 else model.float()
if extra_files['config.txt']: # load metadata dict
d = json.loads(extra_files['config.txt'],
Expand Down Expand Up @@ -472,6 +475,12 @@ def gd_outputs(gd):
predictor = pdi.create_predictor(config)
input_handle = predictor.get_input_handle(predictor.get_input_names()[0])
output_names = predictor.get_output_names()
elif triton: # NVIDIA Triton Inference Server
LOGGER.info(f'Using {w} as Triton Inference Server...')
check_requirements('tritonclient[all]')
from utils.triton import TritonRemoteModel
model = TritonRemoteModel(url=w)
nhwc = model.runtime.startswith("tensorflow")
else:
raise NotImplementedError(f'ERROR: {w} is not a supported format')

Expand All @@ -488,6 +497,8 @@ def forward(self, im, augment=False, visualize=False):
b, ch, h, w = im.shape # batch, channel, height, width
if self.fp16 and im.dtype != torch.float16:
im = im.half() # to FP16
if self.nhwc:
im = im.permute(0, 2, 3, 1) # torch BCHW to numpy BHWC shape(1,320,192,3)

if self.pt: # PyTorch
y = self.model(im, augment=augment, visualize=visualize) if augment or visualize else self.model(im)
Expand Down Expand Up @@ -517,7 +528,7 @@ def forward(self, im, augment=False, visualize=False):
self.context.execute_v2(list(self.binding_addrs.values()))
y = [self.bindings[x].data for x in sorted(self.output_names)]
elif self.coreml: # CoreML
im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
im = im.cpu().numpy()
im = Image.fromarray((im[0] * 255).astype('uint8'))
# im = im.resize((192, 320), Image.ANTIALIAS)
y = self.model.predict({'image': im}) # coordinates are xywh normalized
Expand All @@ -532,8 +543,10 @@ def forward(self, im, augment=False, visualize=False):
self.input_handle.copy_from_cpu(im)
self.predictor.run()
y = [self.predictor.get_output_handle(x).copy_to_cpu() for x in self.output_names]
elif self.triton: # NVIDIA Triton Inference Server
y = self.model(im)
else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
im = im.cpu().numpy()
if self.saved_model: # SavedModel
y = self.model(im, training=False) if self.keras else self.model(im)
elif self.pb: # GraphDef
Expand Down Expand Up @@ -566,23 +579,26 @@ def from_numpy(self, x):

def warmup(self, imgsz=(1, 3, 640, 640)):
# Warmup model by running inference once
warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb
if any(warmup_types) and self.device.type != 'cpu':
warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb, self.triton
if any(warmup_types) and (self.device.type != 'cpu' or self.triton):
im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
for _ in range(2 if self.jit else 1): #
self.forward(im) # warmup

@staticmethod
def _model_type(p='path/to/model.pt'):
# Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx
# types = [pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle]
from export import export_formats
sf = list(export_formats().Suffix) + ['.xml'] # export suffixes
check_suffix(p, sf) # checks
p = Path(p).name # eliminate trailing separators
pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, xml2 = (s in p for s in sf)
xml |= xml2 # *_openvino_model or *.xml
tflite &= not edgetpu # *.tflite
return pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle
from utils.downloads import is_url
sf = list(export_formats().Suffix) # export suffixes
if not is_url(p, check=False):
check_suffix(p, sf) # checks
url = urlparse(p) # if url may be Triton inference server
types = [s in Path(p).name for s in sf]
types[8] &= not types[9] # tflite &= not edgetpu
triton = not any(types) and all([any(s in url.scheme for s in ["http", "grpc"]), url.netloc])
return types + [triton]

@staticmethod
def _load_metadata(f=Path('path/to/meta.yaml')):
Expand Down
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ seaborn>=0.11.0
# tensorflowjs>=3.9.0 # TF.js export
# openvino-dev # OpenVINO export

# Deploy --------------------------------------
# tritonclient[all]~=2.24.0

# Extras --------------------------------------
ipython # interactive notebook
psutil # system utilization
Expand Down
2 changes: 1 addition & 1 deletion segment/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def run(
seen, windows, dt = 0, [], (Profile(), Profile(), Profile())
for path, im, im0s, vid_cap, s in dataset:
with dt[0]:
im = torch.from_numpy(im).to(device)
im = torch.from_numpy(im).to(model.device)
im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
im /= 255 # 0 - 255 to 0.0 - 1.0
if len(im.shape) == 3:
Expand Down
4 changes: 2 additions & 2 deletions utils/downloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
import torch


def is_url(url, check_exists=True):
def is_url(url, check=True):
# Check if string is URL and check if URL exists
try:
url = str(url)
result = urllib.parse.urlparse(url)
assert all([result.scheme, result.netloc, result.path]) # check if is url
return (urllib.request.urlopen(url).getcode() == 200) if check_exists else True # check if exists online
return (urllib.request.urlopen(url).getcode() == 200) if check else True # check if exists online
except (AssertionError, urllib.request.HTTPError):
return False

Expand Down
85 changes: 85 additions & 0 deletions utils/triton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
""" Utils to interact with the Triton Inference Server
"""

import typing
from urllib.parse import urlparse

import torch


class TritonRemoteModel:
""" A wrapper over a model served by the Triton Inference Server. It can
be configured to communicate over GRPC or HTTP. It accepts Torch Tensors
as input and returns them as outputs.
"""

def __init__(self, url: str):
"""
Keyword arguments:
url: Fully qualified address of the Triton server - for e.g. grpc://localhost:8000
"""

parsed_url = urlparse(url)
if parsed_url.scheme == "grpc":
from tritonclient.grpc import InferenceServerClient, InferInput

self.client = InferenceServerClient(parsed_url.netloc) # Triton GRPC client
model_repository = self.client.get_model_repository_index()
self.model_name = model_repository.models[0].name
self.metadata = self.client.get_model_metadata(self.model_name, as_json=True)

def create_input_placeholders() -> typing.List[InferInput]:
return [
InferInput(i['name'], [int(s) for s in i["shape"]], i['datatype']) for i in self.metadata['inputs']]

else:
from tritonclient.http import InferenceServerClient, InferInput

self.client = InferenceServerClient(parsed_url.netloc) # Triton HTTP client
model_repository = self.client.get_model_repository_index()
self.model_name = model_repository[0]['name']
self.metadata = self.client.get_model_metadata(self.model_name)

def create_input_placeholders() -> typing.List[InferInput]:
return [
InferInput(i['name'], [int(s) for s in i["shape"]], i['datatype']) for i in self.metadata['inputs']]

self._create_input_placeholders_fn = create_input_placeholders

@property
def runtime(self):
"""Returns the model runtime"""
return self.metadata.get("backend", self.metadata.get("platform"))

def __call__(self, *args, **kwargs) -> typing.Union[torch.Tensor, typing.Tuple[torch.Tensor, ...]]:
""" Invokes the model. Parameters can be provided via args or kwargs.
args, if provided, are assumed to match the order of inputs of the model.
kwargs are matched with the model input names.
"""
inputs = self._create_inputs(*args, **kwargs)
response = self.client.infer(model_name=self.model_name, inputs=inputs)
result = []
for output in self.metadata['outputs']:
tensor = torch.as_tensor(response.as_numpy(output['name']))
result.append(tensor)
return result[0] if len(result) == 1 else result

def _create_inputs(self, *args, **kwargs):
args_len, kwargs_len = len(args), len(kwargs)
if not args_len and not kwargs_len:
raise RuntimeError("No inputs provided.")
if args_len and kwargs_len:
raise RuntimeError("Cannot specify args and kwargs at the same time")

placeholders = self._create_input_placeholders_fn()
if args_len:
if args_len != len(placeholders):
raise RuntimeError(f"Expected {len(placeholders)} inputs, got {args_len}.")
for input, value in zip(placeholders, args):
input.set_data_from_numpy(value.cpu().numpy())
else:
for input in placeholders:
value = kwargs[input.name]
input.set_data_from_numpy(value.cpu().numpy())
return placeholders