diff --git a/.github/workflows/kserve_cpu_tests.yml b/.github/workflows/kserve_cpu_tests.yml index 6d92b8c1ad2..a4079f1bb85 100644 --- a/.github/workflows/kserve_cpu_tests.yml +++ b/.github/workflows/kserve_cpu_tests.yml @@ -39,7 +39,7 @@ jobs: uses: actions/checkout@v4 with: repository: kserve/kserve - ref: v0.11.1 + ref: v0.11.2 path: kserve - name: Validate torchserve-kfs and Open Inference Protocol run: ./kubernetes/kserve/tests/scripts/test_mnist.sh diff --git a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/OpenInferenceProtocolRequestHandler.java b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/OpenInferenceProtocolRequestHandler.java index 8f9a8447622..397084ca369 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/OpenInferenceProtocolRequestHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/OpenInferenceProtocolRequestHandler.java @@ -9,6 +9,7 @@ import org.pytorch.serve.archive.DownloadArchiveException; import org.pytorch.serve.archive.model.ModelException; import org.pytorch.serve.archive.workflow.WorkflowException; +import org.pytorch.serve.grpc.openinference.OpenInferenceGrpc.ModelMetadataResponse; import org.pytorch.serve.http.HttpRequestHandlerChain; import org.pytorch.serve.util.ConfigManager; import org.pytorch.serve.util.NettyUtils; @@ -31,6 +32,7 @@ public class OpenInferenceProtocolRequestHandler extends HttpRequestHandlerChain private static final String SERVER_LIVE_API = "/v2/health/live"; private static final String SERVER_READY_API = "/v2/health/ready"; private static final String MODEL_READY_ENDPOINT_PATTERN = "^/v2/models/([^/]+)(?:/versions/([^/]+))?/ready$"; + private static final String MODEL_METADATA_ENDPOINT_PATTERN = "^/v2/models/([^/]+)(?:/versions/([^/]+))?"; /** Creates a new {@code OpenInferenceProtocolRequestHandler} instance. */ public OpenInferenceProtocolRequestHandler() {} @@ -83,6 +85,19 @@ public void handleRequest( response.addProperty("ready", isModelReady); NettyUtils.sendJsonResponse(ctx, response); + } else if (concatenatedSegments.matches(MODEL_METADATA_ENDPOINT_PATTERN)) { + String modelName = segments[3]; + String modelVersion = null; + if (segments.length > 5) { + modelVersion = segments[5]; + } + + ModelManager modelManager = ModelManager.getInstance(); + + ModelMetadataResponse.Builder response = modelManager.modelMetadata(modelName, modelVersion); + + NettyUtils.sendJsonResponse(ctx, response.build()); + } else if (segments.length > 5 && concatenatedSegments.contains("/versions")) { // As of now kserve not implemented versioning, we just throws not implemented. JsonObject response = new JsonObject(); diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/ModelManager.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/ModelManager.java index 120a4edf2ce..4a1bfa66c84 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/ModelManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/ModelManager.java @@ -1,6 +1,9 @@ package org.pytorch.serve.wlm; import com.google.gson.JsonObject; + +import io.grpc.Status; + import java.io.BufferedReader; import java.io.File; import java.io.IOException; @@ -28,6 +31,9 @@ import org.pytorch.serve.archive.model.ModelException; import org.pytorch.serve.archive.model.ModelNotFoundException; import org.pytorch.serve.archive.model.ModelVersionNotFoundException; +import org.pytorch.serve.grpc.openinference.OpenInferenceGrpc.ModelMetadataResponse; +import org.pytorch.serve.grpc.openinference.OpenInferenceGrpc.ModelMetadataResponse.TensorMetadata; +import org.pytorch.serve.http.BadRequestException; import org.pytorch.serve.http.ConflictStatusException; import org.pytorch.serve.http.InvalidModelVersionException; import org.pytorch.serve.http.messages.RegisterModelRequest; @@ -698,6 +704,38 @@ public boolean isModelReady(String modelName, String modelVersion) return numHealthy >= numScaled; } + public ModelMetadataResponse.Builder modelMetadata(String modelName, String modelVersion) + throws ModelVersionNotFoundException, ModelNotFoundException { + + ModelManager modelManager = ModelManager.getInstance(); + ModelMetadataResponse.Builder response = ModelMetadataResponse.newBuilder(); + List inputs = new ArrayList<>(); + List outputs = new ArrayList<>(); + List versions = new ArrayList<>(); + + if (modelVersion == null || "".equals(modelVersion)) { + modelVersion = null; + } + + Model model = modelManager.getModel(modelName, modelVersion); + if (model == null) { + throw new ModelNotFoundException("Model not found: " + modelName); + } + modelManager + .getAllModelVersions(modelName) + .forEach(entry -> versions.add(entry.getKey())); + response.setName(modelName); + response.addAllVersions(versions); + response.setPlatform(""); + response.addAllInputs(inputs); + response.addAllOutputs(outputs); + + return response; + + } + + // return numHealthy >= numScaled; + public void submitTask(Runnable runnable) { wlm.scheduleAsync(runnable); } diff --git a/kubernetes/kserve/requirements.txt b/kubernetes/kserve/requirements.txt index 9d1898d4699..2e2316aff42 100644 --- a/kubernetes/kserve/requirements.txt +++ b/kubernetes/kserve/requirements.txt @@ -1,4 +1,4 @@ -kserve[storage]>=0.11.0 +kserve[storage]==0.11.2 transformers captum grpcio diff --git a/kubernetes/kserve/tests/scripts/test_mnist.sh b/kubernetes/kserve/tests/scripts/test_mnist.sh index 3710b4def9b..0144153b177 100755 --- a/kubernetes/kserve/tests/scripts/test_mnist.sh +++ b/kubernetes/kserve/tests/scripts/test_mnist.sh @@ -206,10 +206,16 @@ URL="http://${INGRESS_HOST}:${INGRESS_PORT}/v2/health/live" EXPECTED_OUTPUT='{"live":true}' make_cluster_accessible ${SERVICE_NAME} ${URL} "" ${EXPECTED_OUTPUT} -# ServerLive -echo "HTTP ServerLive method call" +# ModelReady +echo "HTTP ModelReady method call" URL="http://${INGRESS_HOST}:${INGRESS_PORT}/v2/models/${MODEL_NAME}/ready" -EXPECTED_OUTPUT='{"name" : "mnist", "ready": true}' +EXPECTED_OUTPUT='{"name":"mnist","ready":true}' +make_cluster_accessible ${SERVICE_NAME} ${URL} "" ${EXPECTED_OUTPUT} + +# ModelMetadata +echo "HTTP ModelMetadata method call" +URL="http://${INGRESS_HOST}:${INGRESS_PORT}/v2/models/${MODEL_NAME}" +EXPECTED_OUTPUT='{"name_":"mnist","versions_":["1.0"],"platform_":"","inputs_":[],"outputs_":[],"memoizedIsInitialized":1,"unknownFields":{"fields":{}},"memoizedSize":-1,"memoizedHashCode":0}' make_cluster_accessible ${SERVICE_NAME} ${URL} "" ${EXPECTED_OUTPUT} # delete oip http isvc