Skip to content

Commit

Permalink
commit for oip model metadata endpoint and e2e tests
Browse files Browse the repository at this point in the history
Signed-off-by: Andrews Arokiam <andrews.arokiam@ideas2it.com>
  • Loading branch information
andyi2it committed Apr 8, 2024
1 parent 3def3fa commit 7e8f76c
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/kserve_cpu_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
uses: actions/checkout@v4
with:
repository: kserve/kserve
ref: v0.11.1
ref: v0.11.2
path: kserve
- name: Validate torchserve-kfs and Open Inference Protocol
run: ./kubernetes/kserve/tests/scripts/test_mnist.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.pytorch.serve.archive.DownloadArchiveException;
import org.pytorch.serve.archive.model.ModelException;
import org.pytorch.serve.archive.workflow.WorkflowException;
import org.pytorch.serve.grpc.openinference.OpenInferenceGrpc.ModelMetadataResponse;
import org.pytorch.serve.http.HttpRequestHandlerChain;
import org.pytorch.serve.util.ConfigManager;
import org.pytorch.serve.util.NettyUtils;
Expand All @@ -31,6 +32,7 @@ public class OpenInferenceProtocolRequestHandler extends HttpRequestHandlerChain
private static final String SERVER_LIVE_API = "/v2/health/live";
private static final String SERVER_READY_API = "/v2/health/ready";
private static final String MODEL_READY_ENDPOINT_PATTERN = "^/v2/models/([^/]+)(?:/versions/([^/]+))?/ready$";
private static final String MODEL_METADATA_ENDPOINT_PATTERN = "^/v2/models/([^/]+)(?:/versions/([^/]+))?";

/** Creates a new {@code OpenInferenceProtocolRequestHandler} instance. */
public OpenInferenceProtocolRequestHandler() {}
Expand Down Expand Up @@ -83,6 +85,19 @@ public void handleRequest(
response.addProperty("ready", isModelReady);
NettyUtils.sendJsonResponse(ctx, response);

} else if (concatenatedSegments.matches(MODEL_METADATA_ENDPOINT_PATTERN)) {
String modelName = segments[3];
String modelVersion = null;
if (segments.length > 5) {
modelVersion = segments[5];
}

ModelManager modelManager = ModelManager.getInstance();

ModelMetadataResponse.Builder response = modelManager.modelMetadata(modelName, modelVersion);

NettyUtils.sendJsonResponse(ctx, response.build());

} else if (segments.length > 5 && concatenatedSegments.contains("/versions")) {
// As of now kserve not implemented versioning, we just throws not implemented.
JsonObject response = new JsonObject();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package org.pytorch.serve.wlm;

import com.google.gson.JsonObject;

import io.grpc.Status;

import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
Expand Down Expand Up @@ -28,6 +31,9 @@
import org.pytorch.serve.archive.model.ModelException;
import org.pytorch.serve.archive.model.ModelNotFoundException;
import org.pytorch.serve.archive.model.ModelVersionNotFoundException;
import org.pytorch.serve.grpc.openinference.OpenInferenceGrpc.ModelMetadataResponse;
import org.pytorch.serve.grpc.openinference.OpenInferenceGrpc.ModelMetadataResponse.TensorMetadata;
import org.pytorch.serve.http.BadRequestException;
import org.pytorch.serve.http.ConflictStatusException;
import org.pytorch.serve.http.InvalidModelVersionException;
import org.pytorch.serve.http.messages.RegisterModelRequest;
Expand Down Expand Up @@ -698,6 +704,38 @@ public boolean isModelReady(String modelName, String modelVersion)
return numHealthy >= numScaled;
}

public ModelMetadataResponse.Builder modelMetadata(String modelName, String modelVersion)
throws ModelVersionNotFoundException, ModelNotFoundException {

ModelManager modelManager = ModelManager.getInstance();
ModelMetadataResponse.Builder response = ModelMetadataResponse.newBuilder();
List<TensorMetadata> inputs = new ArrayList<>();
List<TensorMetadata> outputs = new ArrayList<>();
List<String> versions = new ArrayList<>();

if (modelVersion == null || "".equals(modelVersion)) {
modelVersion = null;
}

Model model = modelManager.getModel(modelName, modelVersion);
if (model == null) {
throw new ModelNotFoundException("Model not found: " + modelName);
}
modelManager
.getAllModelVersions(modelName)
.forEach(entry -> versions.add(entry.getKey()));
response.setName(modelName);
response.addAllVersions(versions);
response.setPlatform("");
response.addAllInputs(inputs);
response.addAllOutputs(outputs);

return response;

}

// return numHealthy >= numScaled;

public void submitTask(Runnable runnable) {
wlm.scheduleAsync(runnable);
}
Expand Down
2 changes: 1 addition & 1 deletion kubernetes/kserve/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
kserve[storage]>=0.11.0
kserve[storage]==0.11.2
transformers
captum
grpcio
Expand Down
12 changes: 9 additions & 3 deletions kubernetes/kserve/tests/scripts/test_mnist.sh
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,16 @@ URL="http://${INGRESS_HOST}:${INGRESS_PORT}/v2/health/live"
EXPECTED_OUTPUT='{"live":true}'
make_cluster_accessible ${SERVICE_NAME} ${URL} "" ${EXPECTED_OUTPUT}

# ServerLive
echo "HTTP ServerLive method call"
# ModelReady
echo "HTTP ModelReady method call"
URL="http://${INGRESS_HOST}:${INGRESS_PORT}/v2/models/${MODEL_NAME}/ready"
EXPECTED_OUTPUT='{"name" : "mnist", "ready": true}'
EXPECTED_OUTPUT='{"name":"mnist","ready":true}'
make_cluster_accessible ${SERVICE_NAME} ${URL} "" ${EXPECTED_OUTPUT}

# ModelMetadata
echo "HTTP ModelMetadata method call"
URL="http://${INGRESS_HOST}:${INGRESS_PORT}/v2/models/${MODEL_NAME}"
EXPECTED_OUTPUT='{"name_":"mnist","versions_":["1.0"],"platform_":"","inputs_":[],"outputs_":[],"memoizedIsInitialized":1,"unknownFields":{"fields":{}},"memoizedSize":-1,"memoizedHashCode":0}'
make_cluster_accessible ${SERVICE_NAME} ${URL} "" ${EXPECTED_OUTPUT}

# delete oip http isvc
Expand Down

0 comments on commit 7e8f76c

Please sign in to comment.