Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TritonLauncher #1059

Merged
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
19b79a5
Update to the latest OV API
vinnamkim Jun 12, 2023
2429d46
Make specs.json more interpretable
vinnamkim Jun 13, 2023
d4d0bb7
Add missing src/datumaro/plugins/sampler/algorithm/__init__.py
vinnamkim Jun 13, 2023
3c9783b
Merge branch 'hotfix/make-specs-more-interpretable' into bugfix/updat…
vinnamkim Jun 13, 2023
10ed762
Merge remote-tracking branch 'upstream/develop' into bugfix/update-to…
vinnamkim Jun 13, 2023
5005cfb
Change ModelInterpreter logic
vinnamkim Jun 13, 2023
25a38e5
Did some refactoring for OpenvinoLauncher and AnnotationMatchers
vinnamkim Jun 14, 2023
0bb6ee4
Add MissingLabelDetection transform
vinnamkim Jun 15, 2023
2157b70
Update CHANGELOG.md
vinnamkim Jun 15, 2023
72a013c
Merge remote-tracking branch 'upstream/develop' into feature/add-miss…
vinnamkim Jun 20, 2023
6aee4df
Add OVMS launcher
vinnamkim Jun 20, 2023
5e1f5df
Refactor Launcher
vinnamkim Jun 21, 2023
aab9693
Remove commented code pieces
vinnamkim Jun 21, 2023
07a6f1d
Fix linterror
vinnamkim Jun 21, 2023
cc49310
Remove OVMS plugin part
vinnamkim Jun 21, 2023
82d4945
Merge remote-tracking branch 'upstream/develop' into refactor/launcher
vinnamkim Jun 23, 2023
0e10f25
Fix
vinnamkim Jun 23, 2023
46cc252
Update CHANGELOG
vinnamkim Jun 23, 2023
bdbc043
Fix unittest error
vinnamkim Jun 23, 2023
429e2b1
Fix
vinnamkim Jun 23, 2023
56342d7
More loose constraint
vinnamkim Jun 23, 2023
b15b267
Implement OVMSLauncher and FaceDetectionModelInterpreter
vinnamkim Jun 21, 2023
5fefb82
Merge remote-tracking branch 'upstream/develop' into feature/add-ovms
vinnamkim Jun 26, 2023
a2aa005
Update CHANGELOG.md
vinnamkim Jun 26, 2023
a6566f4
Add missing __init__.py
vinnamkim Jun 26, 2023
87b755b
Update specs.json
vinnamkim Jun 26, 2023
ec2baf0
Add triton launcher
vinnamkim Jun 23, 2023
7dcece7
Refactoring
vinnamkim Jun 26, 2023
e000c4f
Add TritonLauncher
vinnamkim Jun 26, 2023
03fba8f
Add unit test
vinnamkim Jun 26, 2023
10262ff
Merge remote-tracking branch 'upstream/develop' into feature/add-trit…
vinnamkim Jun 27, 2023
ded9068
Remove old dir
vinnamkim Jun 27, 2023
71e86e9
Update CHANGELOG.md
vinnamkim Jun 27, 2023
676db6d
Update missing comments from the previous commit
vinnamkim Jun 30, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
(<https://github.com/openvinotoolkit/datumaro/pull/1049>)
- Add OVMSLauncher
(<https://github.com/openvinotoolkit/datumaro/pull/1056>)
- Add TritonLauncher
(<https://github.com/openvinotoolkit/datumaro/pull/1059>)

### Enhancements
- Enhance import performance for built-in plugins
Expand Down
1 change: 1 addition & 0 deletions requirements-core.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,4 @@ protobuf<4

# Model inference launcher from the dedicated inference server
ovmsclient
tritonclient[all]
8 changes: 8 additions & 0 deletions src/datumaro/plugins/inference_server_plugin/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: MIT

from datumaro.plugins.inference_server_plugin.ovms import OVMSLauncher
from datumaro.plugins.inference_server_plugin.triton import TritonLauncher

__all__ = ["OVMSLauncher", "TritonLauncher"]
114 changes: 114 additions & 0 deletions src/datumaro/plugins/inference_server_plugin/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: MIT

from dataclasses import dataclass
from enum import IntEnum
from typing import Dict, Generic, Optional, TypeVar

from grpc import ChannelCredentials, ssl_channel_credentials
from ovmsclient.tfs_compat.base.serving_client import ServingClient

from datumaro.components.errors import DatumaroError, MediaTypeError
from datumaro.components.launcher import LauncherWithModelInterpreter
from datumaro.components.media import Image


class ProtocolType(IntEnum):
"""Protocol type for communication with dedicated inference server"""

grpc = 0
http = 1


@dataclass(frozen=True)
class TLSConfig:
"""TLS configuration dataclass

Parameters:
client_key_path: Path to client key file
client_cert_path: Path to client certificate file
server_cert_path: Path to server certificate file
"""

client_key_path: str
client_cert_path: str
server_cert_path: str

def as_dict(self) -> Dict[str, str]:
return {
"client_key_path": self.client_key_path,
"client_cert_path": self.client_cert_path,
"server_cert_path": self.server_cert_path,
}

def as_grpc_creds(self) -> ChannelCredentials:
server_cert, client_cert, client_key = ServingClient._prepare_certs(
self.server_cert_path, self.client_cert_path, self.client_key_path
)
return ssl_channel_credentials(
root_certificates=server_cert, private_key=client_key, certificate_chain=client_cert
)


TClient = TypeVar("TClient")


class LauncherForDedicatedInferenceServer(Generic[TClient], LauncherWithModelInterpreter):
"""Inference launcher for dedicated inference server

Parameters:
model_name: Name of the model. It should match with the model name loaded in the server instance.
model_interpreter_path: Python source code path which implements a model interpreter.
The model interpreter implement pre-processing of the model input and post-processing of the model output.
model_version: Version of the model loaded in the server instance
host: Host address of the server instance
port: Port number of the server instance
timeout: Timeout limit during communication between the client and the server instance
tls_config: Configuration required if the server instance is in the secure mode
protocol_type: Communication protocol type with the server instance
"""

def __init__(
self,
model_name: str,
model_interpreter_path: str,
model_version: int = 0,
host: str = "localhost",
port: int = 9000,
timeout: float = 10.0,
tls_config: Optional[TLSConfig] = None,
protocol_type: ProtocolType = ProtocolType.grpc,
):
super().__init__(model_interpreter_path=model_interpreter_path)

self.model_name = model_name
self.model_version = model_version
self.url = f"{host}:{port}"
self.timeout = timeout
self.tls_config = tls_config
self.protocol_type = protocol_type

try:
self._client = self._init_client()
self._check_server_health()
self._init_metadata()
except Exception as e:
raise DatumaroError(
f"Health check failed for model_name={self.model_name}, "
f"model_version={self.model_version}, url={self.url} and tls_config={self.tls_config}"
) from e

def _init_client(self) -> TClient:
raise NotImplementedError()

def _check_server_health(self) -> None:
raise NotImplementedError()

def _init_metadata(self) -> None:
raise NotImplementedError()

def type_check(self, item):
if not isinstance(item.media, Image):
raise MediaTypeError(f"Media type should be Image, Current type={type(item.media)}")
return True
121 changes: 121 additions & 0 deletions src/datumaro/plugins/inference_server_plugin/ovms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: MIT

import logging as log
from typing import List, Union

import numpy as np
from ovmsclient import make_grpc_client, make_http_client
from ovmsclient.tfs_compat.grpc.serving_client import GrpcClient
from ovmsclient.tfs_compat.http.serving_client import HttpClient

from datumaro.components.abstracts.model_interpreter import ModelPred
from datumaro.components.errors import DatumaroError
from datumaro.plugins.inference_server_plugin.base import (
LauncherForDedicatedInferenceServer,
ProtocolType,
)

__all__ = ["OVMSLauncher"]

TClient = Union[GrpcClient, HttpClient]


class OVMSLauncher(LauncherForDedicatedInferenceServer[TClient]):
"""Inference launcher for OVMS (OpenVINO™ Model Server) (https://github.com/openvinotoolkit/model_server)

Parameters:
model_name: Name of the model. It should match with the model name loaded in the server instance.
model_interpreter_path: Python source code path which implements a model interpreter.
The model interpreter implement pre-processing of the model input and post-processing of the model output.
model_version: Version of the model loaded in the server instance
host: Host address of the server instance
port: Port number of the server instance
timeout: Timeout limit during communication between the client and the server instance
tls_config: Configuration required if the server instance is in the secure mode
protocol_type: Communication protocol type with the server instance
"""

def _init_client(self) -> TClient:
tls_config = self.tls_config.as_dict() if self.tls_config is not None else None

if self.protocol_type == ProtocolType.grpc:
return make_grpc_client(self.url, tls_config)
if self.protocol_type == ProtocolType.http:
return make_http_client(self.url, tls_config)

raise NotImplementedError(self.protocol_type)

def _check_server_health(self) -> None:
status = self._client.get_model_status(
model_name=self.model_name,
model_version=self.model_version,
timeout=self.timeout,
)
log.info(f"Health check succeeded: {status}")

def _init_metadata(self):
self._metadata = self._client.get_model_metadata(
model_name=self.model_name,
model_version=self.model_version,
timeout=self.timeout,
)
log.info(f"Received metadata: {self._metadata}")

def infer(self, inputs: np.ndarray) -> List[ModelPred]:
results = self._client.predict(
inputs={self._input_key: inputs},
model_name=self.model_name,
model_version=self.model_version,
timeout=self.timeout,
)

# If there is only one output key,
# it returns `np.ndarray`` rather than `Dict[str, np.ndarray]`.
# Please see ovmsclient.tfs_compat.grpc.responses.GrpcPredictResponse
if isinstance(results, np.ndarray):
results = {self._output_key: results}

outputs_group_by_item = [
{key: output for key, output in zip(results.keys(), outputs)}
for outputs in zip(*results.values())
]

return outputs_group_by_item

@property
def _input_key(self):
if hasattr(self, "__input_key"):
return self.__input_key

metadata_inputs = self._metadata.get("inputs")

if metadata_inputs is None:
raise DatumaroError("Cannot get metadata of the outputs.")

if len(metadata_inputs.keys()) > 1:
raise DatumaroError(
f"More than two model inputs are not allowed: {metadata_inputs.keys()}."
)

self.__input_key = next(iter(metadata_inputs.keys()))
return self.__input_key

@property
def _output_key(self):
if hasattr(self, "__output_key"):
return self.__output_key

metadata_outputs = self._metadata.get("outputs")

if metadata_outputs is None:
raise DatumaroError("Cannot get metadata of the outputs.")

if len(metadata_outputs.keys()) > 1:
raise DatumaroError(
f"More than two model outputs are not allowed: {metadata_outputs.keys()}."
)

self.__output_key = next(iter(metadata_outputs.keys()))
return self.__output_key
110 changes: 110 additions & 0 deletions src/datumaro/plugins/inference_server_plugin/triton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: MIT

import logging as log
from typing import Dict, List, Type, Union

import numpy as np
import tritonclient.grpc as grpcclient
import tritonclient.http as httpclient

from datumaro.components.abstracts.model_interpreter import ModelPred
from datumaro.components.errors import DatumaroError
from datumaro.plugins.inference_server_plugin.base import (
LauncherForDedicatedInferenceServer,
ProtocolType,
)

__all__ = ["TritonLauncher"]

TClient = Union[grpcclient.InferenceServerClient, httpclient.InferenceServerClient]
TInferInput = Union[grpcclient.InferInput, httpclient.InferInput]
TInferOutput = Union[grpcclient.InferResult, httpclient.InferResult]


class TritonLauncher(LauncherForDedicatedInferenceServer[TClient]):
"""Inference launcher for Triton Inference Server (https://github.com/triton-inference-server)

Parameters:
model_name: Name of the model. It should match with the model name loaded in the server instance.
model_interpreter_path: Python source code path which implements a model interpreter.
The model interpreter implement pre-processing of the model input and post-processing of the model output.
model_version: Version of the model loaded in the server instance
host: Host address of the server instance
port: Port number of the server instance
timeout: Timeout limit during communication between the client and the server instance
tls_config: Configuration required if the server instance is in the secure mode
protocol_type: Communication protocol type with the server instance
"""

def _init_client(self) -> TClient:
creds = self.tls_config.as_grpc_creds() if self.tls_config is not None else None

if self.protocol_type == ProtocolType.grpc:
return grpcclient.InferenceServerClient(url=self.url, creds=creds)
if self.protocol_type == ProtocolType.http:
return httpclient.InferenceServerClient(url=self.url)

raise NotImplementedError(self.protocol_type)

def _check_server_health(self) -> None:
status = self._client.is_model_ready(
model_name=self.model_name,
model_version=str(self.model_version),
)
if not status:
raise DatumaroError("Model is not ready.")
log.info(f"Health check succeeded: {status}")

def _init_metadata(self) -> None:
self._metadata = self._client.get_model_metadata(
model_name=self.model_name,
model_version=str(self.model_version),
)
log.info(f"Received metadata: {self._metadata}")

def _get_infer_input(self, inputs: np.ndarray) -> TInferInput:
def _fix_dynamic_batch_dim(shape):
if shape[0] == -1:
shape[0] = inputs.shape[0]
return shape

def _create(infer_input_cls: Type[TInferInput]) -> TInferInput:
infer_inputs = [
infer_input_cls(
name=inp.name,
shape=_fix_dynamic_batch_dim(inp.shape),
datatype=inp.datatype,
)
for inp in self._metadata.inputs
]
for infer_input in infer_inputs:
infer_input.set_data_from_numpy(inputs)
return infer_inputs

if self.protocol_type == ProtocolType.grpc:
return _create(grpcclient.InferInput)
if self.protocol_type == ProtocolType.http:
return _create(httpclient.InferInput)

raise NotImplementedError(self.protocol_type)

def infer(self, inputs: np.ndarray) -> List[ModelPred]:
infer_outputs: TInferOutput = self._client.infer(
inputs=self._get_infer_input(inputs),
model_name=self.model_name,
model_version=str(self.model_version),
)

results: Dict[str, np.ndarray] = {
output.name: infer_outputs.as_numpy(name=output.name)
for output in self._metadata.outputs
}

outputs_group_by_item = [
{key: output for key, output in zip(results.keys(), outputs)}
for outputs in zip(*results.values())
]

return outputs_group_by_item
Loading