diff --git a/classify/predict.py b/classify/predict.py index 011e7b83f09b..d3bec8eea7ba 100644 --- a/classify/predict.py +++ b/classify/predict.py @@ -104,7 +104,7 @@ def run( seen, windows, dt = 0, [], (Profile(), Profile(), Profile()) for path, im, im0s, vid_cap, s in dataset: with dt[0]: - im = torch.Tensor(im).to(device) + im = torch.Tensor(im).to(model.device) im = im.half() if model.fp16 else im.float() # uint8 to fp16/32 if len(im.shape) == 3: im = im[None] # expand for batch dim diff --git a/detect.py b/detect.py index 9036b26263e5..e442ed75f4c7 100644 --- a/detect.py +++ b/detect.py @@ -49,7 +49,7 @@ @smart_inference_mode() def run( - weights=ROOT / 'yolov5s.pt', # model.pt path(s) + weights=ROOT / 'yolov5s.pt', # model path or triton URL source=ROOT / 'data/images', # file/dir/URL/glob/screen/0(webcam) data=ROOT / 'data/coco128.yaml', # dataset.yaml path imgsz=(640, 640), # inference size (height, width) @@ -108,11 +108,11 @@ def run( vid_path, vid_writer = [None] * bs, [None] * bs # Run inference - model.warmup(imgsz=(1 if pt else bs, 3, *imgsz)) # warmup + model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz)) # warmup seen, windows, dt = 0, [], (Profile(), Profile(), Profile()) for path, im, im0s, vid_cap, s in dataset: with dt[0]: - im = torch.from_numpy(im).to(device) + im = torch.from_numpy(im).to(model.device) im = im.half() if model.fp16 else im.float() # uint8 to fp16/32 im /= 255 # 0 - 255 to 0.0 - 1.0 if len(im.shape) == 3: @@ -214,7 +214,7 @@ def run( def parse_opt(): parser = argparse.ArgumentParser() - parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s.pt', help='model path(s)') + parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s.pt', help='model path or triton URL') parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='file/dir/URL/glob/screen/0(webcam)') parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='(optional) dataset.yaml path') parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w') diff --git a/models/common.py b/models/common.py index fac95a82fdb9..177704849d3d 100644 --- a/models/common.py +++ b/models/common.py @@ -10,6 +10,7 @@ from collections import OrderedDict, namedtuple from copy import copy from pathlib import Path +from urllib.parse import urlparse import cv2 import numpy as np @@ -327,11 +328,13 @@ def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, super().__init__() w = str(weights[0] if isinstance(weights, list) else weights) - pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle = self._model_type(w) # type - w = attempt_download(w) # download if not local + pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, triton = self._model_type(w) fp16 &= pt or jit or onnx or engine # FP16 + nhwc = coreml or saved_model or pb or tflite or edgetpu # BHWC formats (vs torch BCWH) stride = 32 # default stride cuda = torch.cuda.is_available() and device.type != 'cpu' # use CUDA + if not (pt or triton): + w = attempt_download(w) # download if not local if pt: # PyTorch model = attempt_load(weights if isinstance(weights, list) else w, device=device, inplace=True, fuse=fuse) @@ -342,7 +345,7 @@ def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, elif jit: # TorchScript LOGGER.info(f'Loading {w} for TorchScript inference...') extra_files = {'config.txt': ''} # model metadata - model = torch.jit.load(w, _extra_files=extra_files) + model = torch.jit.load(w, _extra_files=extra_files, map_location=device) model.half() if fp16 else model.float() if extra_files['config.txt']: # load metadata dict d = json.loads(extra_files['config.txt'], @@ -472,6 +475,12 @@ def gd_outputs(gd): predictor = pdi.create_predictor(config) input_handle = predictor.get_input_handle(predictor.get_input_names()[0]) output_names = predictor.get_output_names() + elif triton: # NVIDIA Triton Inference Server + LOGGER.info(f'Using {w} as Triton Inference Server...') + check_requirements('tritonclient[all]') + from utils.triton import TritonRemoteModel + model = TritonRemoteModel(url=w) + nhwc = model.runtime.startswith("tensorflow") else: raise NotImplementedError(f'ERROR: {w} is not a supported format') @@ -488,6 +497,8 @@ def forward(self, im, augment=False, visualize=False): b, ch, h, w = im.shape # batch, channel, height, width if self.fp16 and im.dtype != torch.float16: im = im.half() # to FP16 + if self.nhwc: + im = im.permute(0, 2, 3, 1) # torch BCHW to numpy BHWC shape(1,320,192,3) if self.pt: # PyTorch y = self.model(im, augment=augment, visualize=visualize) if augment or visualize else self.model(im) @@ -517,7 +528,7 @@ def forward(self, im, augment=False, visualize=False): self.context.execute_v2(list(self.binding_addrs.values())) y = [self.bindings[x].data for x in sorted(self.output_names)] elif self.coreml: # CoreML - im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3) + im = im.cpu().numpy() im = Image.fromarray((im[0] * 255).astype('uint8')) # im = im.resize((192, 320), Image.ANTIALIAS) y = self.model.predict({'image': im}) # coordinates are xywh normalized @@ -532,8 +543,10 @@ def forward(self, im, augment=False, visualize=False): self.input_handle.copy_from_cpu(im) self.predictor.run() y = [self.predictor.get_output_handle(x).copy_to_cpu() for x in self.output_names] + elif self.triton: # NVIDIA Triton Inference Server + y = self.model(im) else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU) - im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3) + im = im.cpu().numpy() if self.saved_model: # SavedModel y = self.model(im, training=False) if self.keras else self.model(im) elif self.pb: # GraphDef @@ -566,8 +579,8 @@ def from_numpy(self, x): def warmup(self, imgsz=(1, 3, 640, 640)): # Warmup model by running inference once - warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb - if any(warmup_types) and self.device.type != 'cpu': + warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb, self.triton + if any(warmup_types) and (self.device.type != 'cpu' or self.triton): im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input for _ in range(2 if self.jit else 1): # self.forward(im) # warmup @@ -575,14 +588,17 @@ def warmup(self, imgsz=(1, 3, 640, 640)): @staticmethod def _model_type(p='path/to/model.pt'): # Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx + # types = [pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle] from export import export_formats - sf = list(export_formats().Suffix) + ['.xml'] # export suffixes - check_suffix(p, sf) # checks - p = Path(p).name # eliminate trailing separators - pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, xml2 = (s in p for s in sf) - xml |= xml2 # *_openvino_model or *.xml - tflite &= not edgetpu # *.tflite - return pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle + from utils.downloads import is_url + sf = list(export_formats().Suffix) # export suffixes + if not is_url(p, check=False): + check_suffix(p, sf) # checks + url = urlparse(p) # if url may be Triton inference server + types = [s in Path(p).name for s in sf] + types[8] &= not types[9] # tflite &= not edgetpu + triton = not any(types) and all([any(s in url.scheme for s in ["http", "grpc"]), url.netloc]) + return types + [triton] @staticmethod def _load_metadata(f=Path('path/to/meta.yaml')): diff --git a/requirements.txt b/requirements.txt index 914da54e73fc..4d6ec3509efa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -34,6 +34,9 @@ seaborn>=0.11.0 # tensorflowjs>=3.9.0 # TF.js export # openvino-dev # OpenVINO export +# Deploy -------------------------------------- +# tritonclient[all]~=2.24.0 + # Extras -------------------------------------- ipython # interactive notebook psutil # system utilization diff --git a/segment/predict.py b/segment/predict.py index 43cebc706371..2e794c342de1 100644 --- a/segment/predict.py +++ b/segment/predict.py @@ -114,7 +114,7 @@ def run( seen, windows, dt = 0, [], (Profile(), Profile(), Profile()) for path, im, im0s, vid_cap, s in dataset: with dt[0]: - im = torch.from_numpy(im).to(device) + im = torch.from_numpy(im).to(model.device) im = im.half() if model.fp16 else im.float() # uint8 to fp16/32 im /= 255 # 0 - 255 to 0.0 - 1.0 if len(im.shape) == 3: diff --git a/utils/downloads.py b/utils/downloads.py index bd495068522d..433de84b51ca 100644 --- a/utils/downloads.py +++ b/utils/downloads.py @@ -16,13 +16,13 @@ import torch -def is_url(url, check_exists=True): +def is_url(url, check=True): # Check if string is URL and check if URL exists try: url = str(url) result = urllib.parse.urlparse(url) assert all([result.scheme, result.netloc, result.path]) # check if is url - return (urllib.request.urlopen(url).getcode() == 200) if check_exists else True # check if exists online + return (urllib.request.urlopen(url).getcode() == 200) if check else True # check if exists online except (AssertionError, urllib.request.HTTPError): return False diff --git a/utils/triton.py b/utils/triton.py new file mode 100644 index 000000000000..a94ef0ad197d --- /dev/null +++ b/utils/triton.py @@ -0,0 +1,85 @@ +# YOLOv5 🚀 by Ultralytics, GPL-3.0 license +""" Utils to interact with the Triton Inference Server +""" + +import typing +from urllib.parse import urlparse + +import torch + + +class TritonRemoteModel: + """ A wrapper over a model served by the Triton Inference Server. It can + be configured to communicate over GRPC or HTTP. It accepts Torch Tensors + as input and returns them as outputs. + """ + + def __init__(self, url: str): + """ + Keyword arguments: + url: Fully qualified address of the Triton server - for e.g. grpc://localhost:8000 + """ + + parsed_url = urlparse(url) + if parsed_url.scheme == "grpc": + from tritonclient.grpc import InferenceServerClient, InferInput + + self.client = InferenceServerClient(parsed_url.netloc) # Triton GRPC client + model_repository = self.client.get_model_repository_index() + self.model_name = model_repository.models[0].name + self.metadata = self.client.get_model_metadata(self.model_name, as_json=True) + + def create_input_placeholders() -> typing.List[InferInput]: + return [ + InferInput(i['name'], [int(s) for s in i["shape"]], i['datatype']) for i in self.metadata['inputs']] + + else: + from tritonclient.http import InferenceServerClient, InferInput + + self.client = InferenceServerClient(parsed_url.netloc) # Triton HTTP client + model_repository = self.client.get_model_repository_index() + self.model_name = model_repository[0]['name'] + self.metadata = self.client.get_model_metadata(self.model_name) + + def create_input_placeholders() -> typing.List[InferInput]: + return [ + InferInput(i['name'], [int(s) for s in i["shape"]], i['datatype']) for i in self.metadata['inputs']] + + self._create_input_placeholders_fn = create_input_placeholders + + @property + def runtime(self): + """Returns the model runtime""" + return self.metadata.get("backend", self.metadata.get("platform")) + + def __call__(self, *args, **kwargs) -> typing.Union[torch.Tensor, typing.Tuple[torch.Tensor, ...]]: + """ Invokes the model. Parameters can be provided via args or kwargs. + args, if provided, are assumed to match the order of inputs of the model. + kwargs are matched with the model input names. + """ + inputs = self._create_inputs(*args, **kwargs) + response = self.client.infer(model_name=self.model_name, inputs=inputs) + result = [] + for output in self.metadata['outputs']: + tensor = torch.as_tensor(response.as_numpy(output['name'])) + result.append(tensor) + return result[0] if len(result) == 1 else result + + def _create_inputs(self, *args, **kwargs): + args_len, kwargs_len = len(args), len(kwargs) + if not args_len and not kwargs_len: + raise RuntimeError("No inputs provided.") + if args_len and kwargs_len: + raise RuntimeError("Cannot specify args and kwargs at the same time") + + placeholders = self._create_input_placeholders_fn() + if args_len: + if args_len != len(placeholders): + raise RuntimeError(f"Expected {len(placeholders)} inputs, got {args_len}.") + for input, value in zip(placeholders, args): + input.set_data_from_numpy(value.cpu().numpy()) + else: + for input in placeholders: + value = kwargs[input.name] + input.set_data_from_numpy(value.cpu().numpy()) + return placeholders