diff --git a/CHANGELOG.md b/CHANGELOG.md index b0951e1914..101500a284 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 () - Add OVMSLauncher () +- Add TritonLauncher + () ### Enhancements - Enhance import performance for built-in plugins diff --git a/requirements-core.txt b/requirements-core.txt index 87a9212300..f3ea34d933 100644 --- a/requirements-core.txt +++ b/requirements-core.txt @@ -50,3 +50,4 @@ protobuf<4 # Model inference launcher from the dedicated inference server ovmsclient +tritonclient[all] diff --git a/src/datumaro/plugins/inference_server_plugin/__init__.py b/src/datumaro/plugins/inference_server_plugin/__init__.py new file mode 100644 index 0000000000..1ac4d76935 --- /dev/null +++ b/src/datumaro/plugins/inference_server_plugin/__init__.py @@ -0,0 +1,8 @@ +# Copyright (C) 2023 Intel Corporation +# +# SPDX-License-Identifier: MIT + +from datumaro.plugins.inference_server_plugin.ovms import OVMSLauncher +from datumaro.plugins.inference_server_plugin.triton import TritonLauncher + +__all__ = ["OVMSLauncher", "TritonLauncher"] diff --git a/src/datumaro/plugins/inference_server_plugin/base.py b/src/datumaro/plugins/inference_server_plugin/base.py new file mode 100644 index 0000000000..71063e7455 --- /dev/null +++ b/src/datumaro/plugins/inference_server_plugin/base.py @@ -0,0 +1,114 @@ +# Copyright (C) 2023 Intel Corporation +# +# SPDX-License-Identifier: MIT + +from dataclasses import dataclass +from enum import IntEnum +from typing import Dict, Generic, Optional, TypeVar + +from grpc import ChannelCredentials, ssl_channel_credentials +from ovmsclient.tfs_compat.base.serving_client import ServingClient + +from datumaro.components.errors import DatumaroError, MediaTypeError +from datumaro.components.launcher import LauncherWithModelInterpreter +from datumaro.components.media import Image + + +class ProtocolType(IntEnum): + """Protocol type for communication with dedicated inference server""" + + grpc = 0 + http = 1 + + +@dataclass(frozen=True) +class TLSConfig: + """TLS configuration dataclass + + Parameters: + client_key_path: Path to client key file + client_cert_path: Path to client certificate file + server_cert_path: Path to server certificate file + """ + + client_key_path: str + client_cert_path: str + server_cert_path: str + + def as_dict(self) -> Dict[str, str]: + return { + "client_key_path": self.client_key_path, + "client_cert_path": self.client_cert_path, + "server_cert_path": self.server_cert_path, + } + + def as_grpc_creds(self) -> ChannelCredentials: + server_cert, client_cert, client_key = ServingClient._prepare_certs( + self.server_cert_path, self.client_cert_path, self.client_key_path + ) + return ssl_channel_credentials( + root_certificates=server_cert, private_key=client_key, certificate_chain=client_cert + ) + + +TClient = TypeVar("TClient") + + +class LauncherForDedicatedInferenceServer(Generic[TClient], LauncherWithModelInterpreter): + """Inference launcher for dedicated inference server + + Parameters: + model_name: Name of the model. It should match with the model name loaded in the server instance. + model_interpreter_path: Python source code path which implements a model interpreter. + The model interpreter implement pre-processing of the model input and post-processing of the model output. + model_version: Version of the model loaded in the server instance + host: Host address of the server instance + port: Port number of the server instance + timeout: Timeout limit during communication between the client and the server instance + tls_config: Configuration required if the server instance is in the secure mode + protocol_type: Communication protocol type with the server instance + """ + + def __init__( + self, + model_name: str, + model_interpreter_path: str, + model_version: int = 0, + host: str = "localhost", + port: int = 9000, + timeout: float = 10.0, + tls_config: Optional[TLSConfig] = None, + protocol_type: ProtocolType = ProtocolType.grpc, + ): + super().__init__(model_interpreter_path=model_interpreter_path) + + self.model_name = model_name + self.model_version = model_version + self.url = f"{host}:{port}" + self.timeout = timeout + self.tls_config = tls_config + self.protocol_type = protocol_type + + try: + self._client = self._init_client() + self._check_server_health() + self._init_metadata() + except Exception as e: + raise DatumaroError( + f"Health check failed for model_name={self.model_name}, " + f"model_version={self.model_version}, url={self.url} and tls_config={self.tls_config}" + ) from e + + def _init_client(self) -> TClient: + raise NotImplementedError() + + def _check_server_health(self) -> None: + raise NotImplementedError() + + def _init_metadata(self) -> None: + raise NotImplementedError() + + def type_check(self, item): + if not isinstance(item.media, Image): + raise MediaTypeError(f"Media type should be Image, Current type={type(item.media)}") + return True diff --git a/src/datumaro/plugins/inference_server_plugin/ovms.py b/src/datumaro/plugins/inference_server_plugin/ovms.py new file mode 100644 index 0000000000..049604c70d --- /dev/null +++ b/src/datumaro/plugins/inference_server_plugin/ovms.py @@ -0,0 +1,126 @@ +# Copyright (C) 2023 Intel Corporation +# +# SPDX-License-Identifier: MIT + +import logging as log +from typing import List, Union + +import numpy as np +from ovmsclient import make_grpc_client, make_http_client +from ovmsclient.tfs_compat.grpc.serving_client import GrpcClient +from ovmsclient.tfs_compat.http.serving_client import HttpClient + +from datumaro.components.abstracts.model_interpreter import ModelPred +from datumaro.components.errors import DatumaroError +from datumaro.plugins.inference_server_plugin.base import ( + LauncherForDedicatedInferenceServer, + ProtocolType, +) + +__all__ = ["OVMSLauncher"] + +TClient = Union[GrpcClient, HttpClient] + + +class OVMSLauncher(LauncherForDedicatedInferenceServer[TClient]): + """Inference launcher for OVMS (OpenVINO™ Model Server) (https://github.com/openvinotoolkit/model_server) + + Parameters: + model_name: Name of the model. It should match with the model name loaded in the server instance. + model_interpreter_path: Python source code path which implements a model interpreter. + The model interpreter implement pre-processing of the model input and post-processing of the model output. + model_version: Version of the model loaded in the server instance + host: Host address of the server instance + port: Port number of the server instance + timeout: Timeout limit during communication between the client and the server instance + tls_config: Configuration required if the server instance is in the secure mode + protocol_type: Communication protocol type with the server instance + """ + + def _init_client(self) -> TClient: + tls_config = self.tls_config.as_dict() if self.tls_config is not None else None + + if self.protocol_type == ProtocolType.grpc: + return make_grpc_client(self.url, tls_config) + if self.protocol_type == ProtocolType.http: + return make_http_client(self.url, tls_config) + + raise NotImplementedError(self.protocol_type) + + def _check_server_health(self) -> None: + status = self._client.get_model_status( + model_name=self.model_name, + model_version=self.model_version, + timeout=self.timeout, + ) + log.info(f"Health check succeeded: {status}") + + def _init_metadata(self): + self._metadata = self._client.get_model_metadata( + model_name=self.model_name, + model_version=self.model_version, + timeout=self.timeout, + ) + log.info(f"Received metadata: {self._metadata}") + + def infer(self, inputs: np.ndarray) -> List[ModelPred]: + # Please see the following link for the input and output type of self._client.predict() + # https://github.com/openvinotoolkit/model_server/blob/releases/2022/3/client/python/ovmsclient/lib/docs/grpc_client.md#method-predict + # The input is Dict[str, np.ndarray]. + # The output is Dict[str, np.ndarray] (If the model has multiple outputs), + # or np.ndarray (If the model has one single output). + results = self._client.predict( + inputs={self._input_key: inputs}, + model_name=self.model_name, + model_version=self.model_version, + timeout=self.timeout, + ) + + # If there is only one output key, + # it returns `np.ndarray`` rather than `Dict[str, np.ndarray]`. + # Please see ovmsclient.tfs_compat.grpc.responses.GrpcPredictResponse + if isinstance(results, np.ndarray): + results = {self._output_key: results} + + outputs_group_by_item = [ + {key: output for key, output in zip(results.keys(), outputs)} + for outputs in zip(*results.values()) + ] + + return outputs_group_by_item + + @property + def _input_key(self): + if hasattr(self, "__input_key"): + return self.__input_key + + metadata_inputs = self._metadata.get("inputs") + + if metadata_inputs is None: + raise DatumaroError("Cannot get metadata of the outputs.") + + if len(metadata_inputs.keys()) > 1: + raise DatumaroError( + f"More than two model inputs are not allowed: {metadata_inputs.keys()}." + ) + + self.__input_key = next(iter(metadata_inputs.keys())) + return self.__input_key + + @property + def _output_key(self): + if hasattr(self, "__output_key"): + return self.__output_key + + metadata_outputs = self._metadata.get("outputs") + + if metadata_outputs is None: + raise DatumaroError("Cannot get metadata of the outputs.") + + if len(metadata_outputs.keys()) > 1: + raise DatumaroError( + f"More than two model outputs are not allowed: {metadata_outputs.keys()}." + ) + + self.__output_key = next(iter(metadata_outputs.keys())) + return self.__output_key diff --git a/src/datumaro/plugins/ovms_plugin/__init__.py b/src/datumaro/plugins/inference_server_plugin/samples/__init__.py similarity index 100% rename from src/datumaro/plugins/ovms_plugin/__init__.py rename to src/datumaro/plugins/inference_server_plugin/samples/__init__.py diff --git a/src/datumaro/plugins/ovms_plugin/samples/face_detection.py b/src/datumaro/plugins/inference_server_plugin/samples/face_detection.py similarity index 100% rename from src/datumaro/plugins/ovms_plugin/samples/face_detection.py rename to src/datumaro/plugins/inference_server_plugin/samples/face_detection.py diff --git a/src/datumaro/plugins/inference_server_plugin/triton.py b/src/datumaro/plugins/inference_server_plugin/triton.py new file mode 100644 index 0000000000..9d0f21115f --- /dev/null +++ b/src/datumaro/plugins/inference_server_plugin/triton.py @@ -0,0 +1,110 @@ +# Copyright (C) 2023 Intel Corporation +# +# SPDX-License-Identifier: MIT + +import logging as log +from typing import Dict, List, Type, Union + +import numpy as np +import tritonclient.grpc as grpcclient +import tritonclient.http as httpclient + +from datumaro.components.abstracts.model_interpreter import ModelPred +from datumaro.components.errors import DatumaroError +from datumaro.plugins.inference_server_plugin.base import ( + LauncherForDedicatedInferenceServer, + ProtocolType, +) + +__all__ = ["TritonLauncher"] + +TClient = Union[grpcclient.InferenceServerClient, httpclient.InferenceServerClient] +TInferInput = Union[grpcclient.InferInput, httpclient.InferInput] +TInferOutput = Union[grpcclient.InferResult, httpclient.InferResult] + + +class TritonLauncher(LauncherForDedicatedInferenceServer[TClient]): + """Inference launcher for Triton Inference Server (https://github.com/triton-inference-server) + + Parameters: + model_name: Name of the model. It should match with the model name loaded in the server instance. + model_interpreter_path: Python source code path which implements a model interpreter. + The model interpreter implement pre-processing of the model input and post-processing of the model output. + model_version: Version of the model loaded in the server instance + host: Host address of the server instance + port: Port number of the server instance + timeout: Timeout limit during communication between the client and the server instance + tls_config: Configuration required if the server instance is in the secure mode + protocol_type: Communication protocol type with the server instance + """ + + def _init_client(self) -> TClient: + creds = self.tls_config.as_grpc_creds() if self.tls_config is not None else None + + if self.protocol_type == ProtocolType.grpc: + return grpcclient.InferenceServerClient(url=self.url, creds=creds) + if self.protocol_type == ProtocolType.http: + return httpclient.InferenceServerClient(url=self.url) + + raise NotImplementedError(self.protocol_type) + + def _check_server_health(self) -> None: + status = self._client.is_model_ready( + model_name=self.model_name, + model_version=str(self.model_version), + ) + if not status: + raise DatumaroError("Model is not ready.") + log.info(f"Health check succeeded: {status}") + + def _init_metadata(self) -> None: + self._metadata = self._client.get_model_metadata( + model_name=self.model_name, + model_version=str(self.model_version), + ) + log.info(f"Received metadata: {self._metadata}") + + def _get_infer_input(self, inputs: np.ndarray) -> TInferInput: + def _fix_dynamic_batch_dim(shape): + if shape[0] == -1: + shape[0] = inputs.shape[0] + return shape + + def _create(infer_input_cls: Type[TInferInput]) -> TInferInput: + infer_inputs = [ + infer_input_cls( + name=inp.name, + shape=_fix_dynamic_batch_dim(inp.shape), + datatype=inp.datatype, + ) + for inp in self._metadata.inputs + ] + for infer_input in infer_inputs: + infer_input.set_data_from_numpy(inputs) + return infer_inputs + + if self.protocol_type == ProtocolType.grpc: + return _create(grpcclient.InferInput) + if self.protocol_type == ProtocolType.http: + return _create(httpclient.InferInput) + + raise NotImplementedError(self.protocol_type) + + def infer(self, inputs: np.ndarray) -> List[ModelPred]: + infer_outputs: TInferOutput = self._client.infer( + inputs=self._get_infer_input(inputs), + model_name=self.model_name, + model_version=str(self.model_version), + ) + + results: Dict[str, np.ndarray] = { + output.name: infer_outputs.as_numpy(name=output.name) + for output in self._metadata.outputs + } + + outputs_group_by_item = [ + {key: output for key, output in zip(results.keys(), outputs)} + for outputs in zip(*results.values()) + ] + + return outputs_group_by_item diff --git a/src/datumaro/plugins/ovms_plugin/launcher.py b/src/datumaro/plugins/ovms_plugin/launcher.py deleted file mode 100644 index 3836d79f61..0000000000 --- a/src/datumaro/plugins/ovms_plugin/launcher.py +++ /dev/null @@ -1,196 +0,0 @@ -# Copyright (C) 2023 Intel Corporation -# -# SPDX-License-Identifier: MIT - -import logging as log -from dataclasses import dataclass -from enum import IntEnum -from typing import Dict, List, Optional, Union - -import numpy as np -from ovmsclient import make_grpc_client, make_http_client -from ovmsclient.tfs_compat.grpc.serving_client import GrpcClient -from ovmsclient.tfs_compat.http.serving_client import HttpClient - -from datumaro.components.abstracts.model_interpreter import ModelPred -from datumaro.components.errors import DatumaroError, MediaTypeError -from datumaro.components.launcher import LauncherWithModelInterpreter -from datumaro.components.media import Image - - -class OVMSClientType(IntEnum): - """API types of OVMS client - - OVMS client can accept gRPC or HTTP REST API. - """ - - grpc = 0 - http = 1 - - -@dataclass(frozen=True) -class TLSConfig: - """TLS configuration dataclass - - Parameters: - client_key_path: Path to client key file - client_cert_path: Path to client certificate file - server_cert_path: Path to server certificate file - """ - - client_key_path: str - client_cert_path: str - server_cert_path: str - - def as_dict(self) -> Dict[str, str]: - return { - "client_key_path": self.client_key_path, - "client_cert_path": self.client_cert_path, - "server_cert_path": self.server_cert_path, - } - - -class OVMSLauncher(LauncherWithModelInterpreter): - """Inference launcher for OVMS (OpenVINO™ Model Server) - - Parameters: - model_name: Name of the model. It should match with the model name loaded in the OVMS instance. - model_interpreter_path: Python source code path which implements a model interpreter. - The model interpreter implement pre-processing of the model input and post-processing of the model output. - host: Host address of the OVMS instance - port: Port number of the OVMS instance - model_version: Version of the model loaded in the OVMS instance - timeout: Timeout limit during communication between the client and the OVMS instance - tls_config: Configuration required if the OVMS instance is in the secure mode - ovms_client_type: OVMS client API type - """ - - def __init__( - self, - model_name: str, - model_interpreter_path: str, - host: str = "localhost", - port: int = 9000, - model_version: int = 0, - timeout: float = 10.0, - tls_config: Optional[TLSConfig] = None, - ovms_client_type: OVMSClientType = OVMSClientType.grpc, - ): - super().__init__(model_interpreter_path=model_interpreter_path) - - self._client = self._init_client( - model_name, - host, - port, - tls_config, - ovms_client_type, - ) - self._check_server_health(model_version, timeout) - self._init_input_name(model_version, timeout) - - self.model_version = model_version - self.timeout = timeout - - def _init_client( - self, - model_name, - host, - port, - tls_config, - ovms_client_type, - ) -> Union[GrpcClient, HttpClient]: - self.model_name = model_name - self.url = f"{host}:{port}" - self.tls_config = tls_config - - if ovms_client_type == OVMSClientType.grpc: - return make_grpc_client(self.url, self.tls_config) - elif ovms_client_type == OVMSClientType.http: - return make_http_client(self.url, self.tls_config) - else: - raise NotImplementedError(ovms_client_type) - - def _check_server_health(self, model_version, timeout): - try: - status = self._client.get_model_status( - model_name=self.model_name, - model_version=model_version, - timeout=timeout, - ) - log.info(f"Health check succeeded: {status}") - except Exception as e: - raise DatumaroError( - f"Health check failed for model_name={self.model_name}, " - f"model_version={model_version}, url={self.url} and tls_config={self.tls_config}" - ) from e - - def _init_input_name(self, model_version, timeout): - metadata = self._client.get_model_metadata( - model_name=self.model_name, - model_version=model_version, - timeout=timeout, - ) - metadata_inputs = metadata.get("inputs") - if metadata_inputs is None: - raise DatumaroError("Cannot get metadata of the inputs.") - - if len(metadata_inputs.keys()) > 1: - raise DatumaroError( - f"More than two model inputs are not allowed: {metadata_inputs.keys()}." - ) - - self._input_key = next(iter(metadata_inputs.keys())) - log.info(f"Model input key is {self._input_key}") - - def infer(self, inputs: np.ndarray) -> List[ModelPred]: - # Please see the following link for the input and output type of self._client.predict() - # https://github.com/openvinotoolkit/model_server/blob/releases/2022/3/client/python/ovmsclient/lib/docs/grpc_client.md#method-predict - # The input is Dict[str, np.ndarray]. - # The output is Dict[str, np.ndarray] (If the model has multiple outputs), - # or np.ndarray (If the model has one single output). - results = self._client.predict( - inputs={self._input_key: inputs}, - model_name=self.model_name, - model_version=self.model_version, - timeout=self.timeout, - ) - - # If there is only one output key, - # it returns `np.ndarray`` rather than `Dict[str, np.ndarray]`. - # Please see ovmsclient.tfs_compat.grpc.responses.GrpcPredictResponse - if isinstance(results, np.ndarray): - results = {self._output_key: results} - - outputs_group_by_item = [ - {key: output for key, output in zip(results.keys(), outputs)} - for outputs in zip(*results.values()) - ] - - return outputs_group_by_item - - @property - def _output_key(self): - if not hasattr(self, "__output_key"): - metadata = self._client.get_model_metadata( - model_name=self.model_name, - model_version=self.model_version, - timeout=self.timeout, - ) - metadata_outputs = metadata.get("outputs") - - if metadata_outputs is None: - raise DatumaroError("Cannot get metadata of the outputs.") - - if len(metadata_outputs.keys()) > 1: - raise DatumaroError( - f"More than two model outputs are not allowed: {metadata_outputs.keys()}." - ) - - self.__output_key = next(iter(metadata_outputs.keys())) - - return self.__output_key - - def type_check(self, item): - if not isinstance(item.media, Image): - raise MediaTypeError(f"Media type should be Image, Current type={type(item.media)}") - return True diff --git a/src/datumaro/plugins/ovms_plugin/samples/__init__.py b/src/datumaro/plugins/ovms_plugin/samples/__init__.py deleted file mode 100644 index ff847f0120..0000000000 --- a/src/datumaro/plugins/ovms_plugin/samples/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# Copyright (C) 2023 Intel Corporation -# -# SPDX-License-Identifier: MIT diff --git a/src/datumaro/plugins/specs.json b/src/datumaro/plugins/specs.json index 1eb989928a..da2a53226c 100644 --- a/src/datumaro/plugins/specs.json +++ b/src/datumaro/plugins/specs.json @@ -1004,25 +1004,37 @@ "extra_deps": [] }, { - "import_path": "datumaro.components.launcher.LauncherWithModelInterpreter", - "plugin_name": "launcher_with_model_interpreter", + "import_path": "datumaro.plugins.accuracy_checker_plugin.ac_launcher.AcLauncher", + "plugin_name": "ac", + "plugin_type": "Launcher", + "extra_deps": [ + "openvino.tools", + "tensorflow" + ] + }, + { + "import_path": "datumaro.plugins.inference_server_plugin.base.LauncherForDedicatedInferenceServer", + "plugin_name": "launcher_for_dedicated_inference_server", + "plugin_type": "Launcher", + "extra_deps": [] + }, + { + "import_path": "datumaro.plugins.inference_server_plugin.triton.TritonLauncher", + "plugin_name": "triton", "plugin_type": "Launcher", "extra_deps": [] }, { - "import_path": "datumaro.plugins.ovms_plugin.launcher.OVMSLauncher", + "import_path": "datumaro.plugins.inference_server_plugin.ovms.OVMSLauncher", "plugin_name": "ovmslauncher", "plugin_type": "Launcher", "extra_deps": [] }, { - "import_path": "datumaro.plugins.accuracy_checker_plugin.ac_launcher.AcLauncher", - "plugin_name": "ac", + "import_path": "datumaro.components.launcher.LauncherWithModelInterpreter", + "plugin_name": "launcher_with_model_interpreter", "plugin_type": "Launcher", - "extra_deps": [ - "openvino.tools", - "tensorflow" - ] + "extra_deps": [] }, { "import_path": "datumaro.plugins.openvino_plugin.shift_launcher.ShiftLauncher", diff --git a/tests/unit/launchers/test_ovms_launcher.py b/tests/unit/launchers/test_ovms_launcher.py index adff996777..ca18c5623c 100644 --- a/tests/unit/launchers/test_ovms_launcher.py +++ b/tests/unit/launchers/test_ovms_launcher.py @@ -9,10 +9,10 @@ import numpy as np import pytest -import datumaro.plugins.ovms_plugin.samples.face_detection as face_det_model_interp +import datumaro.plugins.inference_server_plugin.samples.face_detection as face_det_model_interp from datumaro.components.dataset_base import DatasetItem from datumaro.components.media import Image -from datumaro.plugins.ovms_plugin.launcher import OVMSLauncher +from datumaro.plugins.inference_server_plugin.ovms import OVMSLauncher from ...requirements import Requirements, mark_requirement @@ -50,7 +50,7 @@ def test_launchers(self, fxt_input, fxt_output, fxt_metadata): mock_client.predict.return_value = fxt_output with patch( - "datumaro.plugins.ovms_plugin.launcher.make_grpc_client", + "datumaro.plugins.inference_server_plugin.ovms.make_grpc_client", return_value=mock_client, ): launcher = OVMSLauncher( diff --git a/tests/unit/launchers/test_triton_launcher.py b/tests/unit/launchers/test_triton_launcher.py new file mode 100644 index 0000000000..9e4a9702dc --- /dev/null +++ b/tests/unit/launchers/test_triton_launcher.py @@ -0,0 +1,83 @@ +# Copyright (C) 2023 Intel Corporation +# +# SPDX-License-Identifier: MIT + +import os +from typing import Dict, List +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest + +import datumaro.plugins.inference_server_plugin.samples.face_detection as face_det_model_interp +from datumaro.components.dataset_base import DatasetItem +from datumaro.components.media import Image +from datumaro.plugins.inference_server_plugin.triton import TritonLauncher + +from ...requirements import Requirements, mark_requirement + + +class TritonLauncherTest: + @pytest.fixture + def fxt_input(self) -> List[DatasetItem]: + return [ + DatasetItem( + id="test", + media=Image.from_numpy(np.zeros(shape=[10, 10, 3], dtype=np.uint8)), + annotations=[], + ) + ] + + @pytest.fixture + def fxt_output(self) -> np.ndarray: + # Output of face-detection model + np.random.seed(3003) + return np.random.rand(1, 1, 200, 7) + + @pytest.fixture + def fxt_metadata(self) -> MagicMock: + # Metadata of face-detection model + metadata = MagicMock() + + inp = MagicMock() + inp.name = "data" + inp.shape = [-1, 3, 400, 600] + inp.datatype = "FP32" + metadata.inputs = [inp] + + out = MagicMock() + out.name = "detection_out" + out.shape = [-1, 1, 200, 7] + out.datatype = "FP32" + metadata.outputs = [out] + + return metadata + + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_launchers(self, fxt_input, fxt_output, fxt_metadata): + mock_client = MagicMock() + mock_client.is_model_ready.return_value = True + mock_client.get_model_metadata.return_value = fxt_metadata + + outputs = MagicMock() + outputs.as_numpy.return_value = fxt_output + mock_client.infer.return_value = outputs + + with patch( + "datumaro.plugins.inference_server_plugin.triton.grpcclient.InferenceServerClient", + return_value=mock_client, + ): + launcher = TritonLauncher( + model_name="face-detection", + model_interpreter_path=os.path.abspath(face_det_model_interp.__file__), + ) + + mock_client.get_model_metadata.assert_called_once() + mock_client.is_model_ready.assert_called_once() + + outputs = launcher.launch(fxt_input) + mock_client.infer.assert_called_once() + + assert len(outputs) > 0 + for anns in outputs: + assert len(anns) > 0