From 45dc3a8827efd80afbc994c32c34d687b8ac71ee Mon Sep 17 00:00:00 2001 From: Andrews Arokiam Date: Fri, 16 Feb 2024 13:39:35 +0530 Subject: [PATCH 1/5] model ready status endpoint Signed-off-by: Andrews Arokiam --- .../OpenInferenceProtocolRequestHandler.java | 18 ++++++++++++++++++ .../org/pytorch/serve/wlm/ModelManager.java | 18 ++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/OpenInferenceProtocolRequestHandler.java b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/OpenInferenceProtocolRequestHandler.java index 8cec95cf06..8f9a844762 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/OpenInferenceProtocolRequestHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/OpenInferenceProtocolRequestHandler.java @@ -12,6 +12,7 @@ import org.pytorch.serve.http.HttpRequestHandlerChain; import org.pytorch.serve.util.ConfigManager; import org.pytorch.serve.util.NettyUtils; +import org.pytorch.serve.wlm.ModelManager; import org.pytorch.serve.wlm.WorkerInitializationException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -29,6 +30,7 @@ public class OpenInferenceProtocolRequestHandler extends HttpRequestHandlerChain private static final String SERVER_METADATA_API = "/v2"; private static final String SERVER_LIVE_API = "/v2/health/live"; private static final String SERVER_READY_API = "/v2/health/ready"; + private static final String MODEL_READY_ENDPOINT_PATTERN = "^/v2/models/([^/]+)(?:/versions/([^/]+))?/ready$"; /** Creates a new {@code OpenInferenceProtocolRequestHandler} instance. */ public OpenInferenceProtocolRequestHandler() {} @@ -65,6 +67,22 @@ public void handleRequest( supportedExtensions.add("kubeflow"); response.add("extenstion", supportedExtensions); NettyUtils.sendJsonResponse(ctx, response); + } else if (concatenatedSegments.matches(MODEL_READY_ENDPOINT_PATTERN)) { + String modelName = segments[3]; + String modelVersion = null; + if (segments.length > 5) { + modelVersion = segments[5]; + } + + ModelManager modelManager = ModelManager.getInstance(); + + boolean isModelReady = modelManager.isModelReady(modelName, modelVersion); + + JsonObject response = new JsonObject(); + response.addProperty("name", modelName); + response.addProperty("ready", isModelReady); + NettyUtils.sendJsonResponse(ctx, response); + } else if (segments.length > 5 && concatenatedSegments.contains("/versions")) { // As of now kserve not implemented versioning, we just throws not implemented. JsonObject response = new JsonObject(); diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/ModelManager.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/ModelManager.java index e3935c9d56..120a4edf2c 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/ModelManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/ModelManager.java @@ -680,6 +680,24 @@ public boolean scaleRequestStatus(String modelName, String versionId) { return model == null || model.getMinWorkers() <= numWorkers; } + public boolean isModelReady(String modelName, String modelVersion) + throws ModelVersionNotFoundException, ModelNotFoundException { + + if (modelVersion == null || "".equals(modelVersion)) { + modelVersion = null; + } + + Model model = getModel(modelName, modelVersion); + if (model == null) { + throw new ModelNotFoundException("Model not found: " + modelName); + } + + int numScaled = model.getMinWorkers(); + int numHealthy = modelManager.getNumHealthyWorkers(model.getModelVersionName()); + + return numHealthy >= numScaled; + } + public void submitTask(Runnable runnable) { wlm.scheduleAsync(runnable); } From 3ac48f50b6d6fb03a92eb36d3b8bbde88876ed03 Mon Sep 17 00:00:00 2001 From: Andrews Arokiam Date: Tue, 27 Feb 2024 18:48:23 +0530 Subject: [PATCH 2/5] e2e test for oip model ready Signed-off-by: Andrews Arokiam --- kubernetes/kserve/tests/scripts/test_mnist.sh | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/kubernetes/kserve/tests/scripts/test_mnist.sh b/kubernetes/kserve/tests/scripts/test_mnist.sh index 5d2d7de0f0..3710b4def9 100755 --- a/kubernetes/kserve/tests/scripts/test_mnist.sh +++ b/kubernetes/kserve/tests/scripts/test_mnist.sh @@ -206,6 +206,12 @@ URL="http://${INGRESS_HOST}:${INGRESS_PORT}/v2/health/live" EXPECTED_OUTPUT='{"live":true}' make_cluster_accessible ${SERVICE_NAME} ${URL} "" ${EXPECTED_OUTPUT} +# ServerLive +echo "HTTP ServerLive method call" +URL="http://${INGRESS_HOST}:${INGRESS_PORT}/v2/models/${MODEL_NAME}/ready" +EXPECTED_OUTPUT='{"name" : "mnist", "ready": true}' +make_cluster_accessible ${SERVICE_NAME} ${URL} "" ${EXPECTED_OUTPUT} + # delete oip http isvc kubectl delete inferenceservice ${SERVICE_NAME} From 876dc9da3ba1ad8839352cd8dc150decc8a4d20c Mon Sep 17 00:00:00 2001 From: Andrews Arokiam Date: Tue, 27 Feb 2024 19:05:04 +0530 Subject: [PATCH 3/5] commit for oip model metadata endpoint and e2e tests Signed-off-by: Andrews Arokiam --- .../OpenInferenceProtocolRequestHandler.java | 15 ++++++++ .../org/pytorch/serve/wlm/ModelManager.java | 38 +++++++++++++++++++ kubernetes/kserve/tests/scripts/test_mnist.sh | 12 ++++-- 3 files changed, 62 insertions(+), 3 deletions(-) diff --git a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/OpenInferenceProtocolRequestHandler.java b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/OpenInferenceProtocolRequestHandler.java index 8f9a844762..397084ca36 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/OpenInferenceProtocolRequestHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/http/api/rest/OpenInferenceProtocolRequestHandler.java @@ -9,6 +9,7 @@ import org.pytorch.serve.archive.DownloadArchiveException; import org.pytorch.serve.archive.model.ModelException; import org.pytorch.serve.archive.workflow.WorkflowException; +import org.pytorch.serve.grpc.openinference.OpenInferenceGrpc.ModelMetadataResponse; import org.pytorch.serve.http.HttpRequestHandlerChain; import org.pytorch.serve.util.ConfigManager; import org.pytorch.serve.util.NettyUtils; @@ -31,6 +32,7 @@ public class OpenInferenceProtocolRequestHandler extends HttpRequestHandlerChain private static final String SERVER_LIVE_API = "/v2/health/live"; private static final String SERVER_READY_API = "/v2/health/ready"; private static final String MODEL_READY_ENDPOINT_PATTERN = "^/v2/models/([^/]+)(?:/versions/([^/]+))?/ready$"; + private static final String MODEL_METADATA_ENDPOINT_PATTERN = "^/v2/models/([^/]+)(?:/versions/([^/]+))?"; /** Creates a new {@code OpenInferenceProtocolRequestHandler} instance. */ public OpenInferenceProtocolRequestHandler() {} @@ -83,6 +85,19 @@ public void handleRequest( response.addProperty("ready", isModelReady); NettyUtils.sendJsonResponse(ctx, response); + } else if (concatenatedSegments.matches(MODEL_METADATA_ENDPOINT_PATTERN)) { + String modelName = segments[3]; + String modelVersion = null; + if (segments.length > 5) { + modelVersion = segments[5]; + } + + ModelManager modelManager = ModelManager.getInstance(); + + ModelMetadataResponse.Builder response = modelManager.modelMetadata(modelName, modelVersion); + + NettyUtils.sendJsonResponse(ctx, response.build()); + } else if (segments.length > 5 && concatenatedSegments.contains("/versions")) { // As of now kserve not implemented versioning, we just throws not implemented. JsonObject response = new JsonObject(); diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/ModelManager.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/ModelManager.java index 120a4edf2c..4a1bfa66c8 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/ModelManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/ModelManager.java @@ -1,6 +1,9 @@ package org.pytorch.serve.wlm; import com.google.gson.JsonObject; + +import io.grpc.Status; + import java.io.BufferedReader; import java.io.File; import java.io.IOException; @@ -28,6 +31,9 @@ import org.pytorch.serve.archive.model.ModelException; import org.pytorch.serve.archive.model.ModelNotFoundException; import org.pytorch.serve.archive.model.ModelVersionNotFoundException; +import org.pytorch.serve.grpc.openinference.OpenInferenceGrpc.ModelMetadataResponse; +import org.pytorch.serve.grpc.openinference.OpenInferenceGrpc.ModelMetadataResponse.TensorMetadata; +import org.pytorch.serve.http.BadRequestException; import org.pytorch.serve.http.ConflictStatusException; import org.pytorch.serve.http.InvalidModelVersionException; import org.pytorch.serve.http.messages.RegisterModelRequest; @@ -698,6 +704,38 @@ public boolean isModelReady(String modelName, String modelVersion) return numHealthy >= numScaled; } + public ModelMetadataResponse.Builder modelMetadata(String modelName, String modelVersion) + throws ModelVersionNotFoundException, ModelNotFoundException { + + ModelManager modelManager = ModelManager.getInstance(); + ModelMetadataResponse.Builder response = ModelMetadataResponse.newBuilder(); + List inputs = new ArrayList<>(); + List outputs = new ArrayList<>(); + List versions = new ArrayList<>(); + + if (modelVersion == null || "".equals(modelVersion)) { + modelVersion = null; + } + + Model model = modelManager.getModel(modelName, modelVersion); + if (model == null) { + throw new ModelNotFoundException("Model not found: " + modelName); + } + modelManager + .getAllModelVersions(modelName) + .forEach(entry -> versions.add(entry.getKey())); + response.setName(modelName); + response.addAllVersions(versions); + response.setPlatform(""); + response.addAllInputs(inputs); + response.addAllOutputs(outputs); + + return response; + + } + + // return numHealthy >= numScaled; + public void submitTask(Runnable runnable) { wlm.scheduleAsync(runnable); } diff --git a/kubernetes/kserve/tests/scripts/test_mnist.sh b/kubernetes/kserve/tests/scripts/test_mnist.sh index 3710b4def9..0144153b17 100755 --- a/kubernetes/kserve/tests/scripts/test_mnist.sh +++ b/kubernetes/kserve/tests/scripts/test_mnist.sh @@ -206,10 +206,16 @@ URL="http://${INGRESS_HOST}:${INGRESS_PORT}/v2/health/live" EXPECTED_OUTPUT='{"live":true}' make_cluster_accessible ${SERVICE_NAME} ${URL} "" ${EXPECTED_OUTPUT} -# ServerLive -echo "HTTP ServerLive method call" +# ModelReady +echo "HTTP ModelReady method call" URL="http://${INGRESS_HOST}:${INGRESS_PORT}/v2/models/${MODEL_NAME}/ready" -EXPECTED_OUTPUT='{"name" : "mnist", "ready": true}' +EXPECTED_OUTPUT='{"name":"mnist","ready":true}' +make_cluster_accessible ${SERVICE_NAME} ${URL} "" ${EXPECTED_OUTPUT} + +# ModelMetadata +echo "HTTP ModelMetadata method call" +URL="http://${INGRESS_HOST}:${INGRESS_PORT}/v2/models/${MODEL_NAME}" +EXPECTED_OUTPUT='{"name_":"mnist","versions_":["1.0"],"platform_":"","inputs_":[],"outputs_":[],"memoizedIsInitialized":1,"unknownFields":{"fields":{}},"memoizedSize":-1,"memoizedHashCode":0}' make_cluster_accessible ${SERVICE_NAME} ${URL} "" ${EXPECTED_OUTPUT} # delete oip http isvc From 95f09b29a5f3fd662e90030356d7a587b3b2441f Mon Sep 17 00:00:00 2001 From: Andrews Arokiam Date: Mon, 8 Apr 2024 17:34:06 +0530 Subject: [PATCH 4/5] fix formatting Signed-off-by: Andrews Arokiam --- .../org/pytorch/serve/wlm/ModelManager.java | 135 ++++++++---------- 1 file changed, 61 insertions(+), 74 deletions(-) diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/ModelManager.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/ModelManager.java index 4a1bfa66c8..7d94456aa8 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/ModelManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/ModelManager.java @@ -33,7 +33,6 @@ import org.pytorch.serve.archive.model.ModelVersionNotFoundException; import org.pytorch.serve.grpc.openinference.OpenInferenceGrpc.ModelMetadataResponse; import org.pytorch.serve.grpc.openinference.OpenInferenceGrpc.ModelMetadataResponse.TensorMetadata; -import org.pytorch.serve.http.BadRequestException; import org.pytorch.serve.http.ConflictStatusException; import org.pytorch.serve.http.InvalidModelVersionException; import org.pytorch.serve.http.messages.RegisterModelRequest; @@ -93,7 +92,7 @@ public ModelArchive registerModel(String url, String defaultModelName) public void registerAndUpdateModel(String modelName, JsonObject modelInfo) throws ModelException, IOException, InterruptedException, DownloadArchiveException, - WorkerInitializationException { + WorkerInitializationException { boolean defaultVersion = modelInfo.get(Model.DEFAULT_VERSION).getAsBoolean(); String url = modelInfo.get(Model.MAR_NAME).getAsString(); @@ -133,7 +132,7 @@ public ModelArchive registerModel( throws ModelException, IOException, InterruptedException, DownloadArchiveException { ModelArchive archive; - if (isWorkflowModel && url == null) { // This is a workflow function + if (isWorkflowModel && url == null) { // This is a workflow function Manifest manifest = new Manifest(); manifest.getModel().setVersion("1.0"); manifest.getModel().setModelVersion("1.0"); @@ -143,13 +142,11 @@ public ModelArchive registerModel( File f = new File(handler.substring(0, handler.lastIndexOf(':'))); archive = new ModelArchive(manifest, url, f.getParentFile(), true); } else { - archive = - createModelArchive( - modelName, url, handler, runtime, defaultModelName, s3SseKms); + archive = createModelArchive( + modelName, url, handler, runtime, defaultModelName, s3SseKms); } - Model tempModel = - createModel(archive, batchSize, maxBatchDelay, responseTimeout, isWorkflowModel); + Model tempModel = createModel(archive, batchSize, maxBatchDelay, responseTimeout, isWorkflowModel); String versionId = archive.getModelVersion(); @@ -179,12 +176,11 @@ private ModelArchive createModelArchive( boolean s3SseKms) throws ModelException, IOException, DownloadArchiveException { - ModelArchive archive = - ModelArchive.downloadModel( - configManager.getAllowedUrls(), - configManager.getModelStore(), - url, - s3SseKms); + ModelArchive archive = ModelArchive.downloadModel( + configManager.getAllowedUrls(), + configManager.getModelStore(), + url, + s3SseKms); Manifest.Model model = archive.getManifest().getModel(); if (modelName == null || modelName.isEmpty()) { if (archive.getModelName() == null || archive.getModelName().isEmpty()) { @@ -242,11 +238,10 @@ private void setupModelVenv(Model model) + venvPath.toString()); } Map environment = processBuilder.environment(); - String[] envp = - EnvironmentUtils.getEnvString( - configManager.getModelServerHome(), - model.getModelDir().getAbsolutePath(), - null); + String[] envp = EnvironmentUtils.getEnvString( + configManager.getModelServerHome(), + model.getModelDir().getAbsolutePath(), + null); for (String envVar : envp) { String[] parts = envVar.split("=", 2); if (parts.length == 2) { @@ -283,16 +278,14 @@ private void setupModelVenv(Model model) private void setupModelDependencies(Model model) throws IOException, InterruptedException, ModelException { - String requirementsFile = - model.getModelArchive().getManifest().getModel().getRequirementsFile(); + String requirementsFile = model.getModelArchive().getManifest().getModel().getRequirementsFile(); if (!configManager.getInstallPyDepPerModel() || requirementsFile == null) { return; } String pythonRuntime = EnvironmentUtils.getPythonRunTime(model); - Path requirementsFilePath = - Paths.get(model.getModelDir().getAbsolutePath(), requirementsFile).toAbsolutePath(); + Path requirementsFilePath = Paths.get(model.getModelDir().getAbsolutePath(), requirementsFile).toAbsolutePath(); List commandParts = new ArrayList<>(); ProcessBuilder processBuilder = new ProcessBuilder(); @@ -341,11 +334,10 @@ private void setupModelDependencies(Model model) } processBuilder.command(commandParts); - String[] envp = - EnvironmentUtils.getEnvString( - configManager.getModelServerHome(), - model.getModelDir().getAbsolutePath(), - null); + String[] envp = EnvironmentUtils.getEnvString( + configManager.getModelServerHome(), + model.getModelDir().getAbsolutePath(), + null); Map environment = processBuilder.environment(); for (String envVar : envp) { String[] parts = envVar.split("=", 2); @@ -398,21 +390,19 @@ private Model createModel( if (batchSize == -1 * RegisterModelRequest.DEFAULT_BATCH_SIZE) { if (archive.getModelConfig() != null) { int marBatchSize = archive.getModelConfig().getBatchSize(); - batchSize = - marBatchSize > 0 - ? marBatchSize - : configManager.getJsonIntValue( - archive.getModelName(), - archive.getModelVersion(), - Model.BATCH_SIZE, - RegisterModelRequest.DEFAULT_BATCH_SIZE); - } else { - batchSize = - configManager.getJsonIntValue( + batchSize = marBatchSize > 0 + ? marBatchSize + : configManager.getJsonIntValue( archive.getModelName(), archive.getModelVersion(), Model.BATCH_SIZE, RegisterModelRequest.DEFAULT_BATCH_SIZE); + } else { + batchSize = configManager.getJsonIntValue( + archive.getModelName(), + archive.getModelVersion(), + Model.BATCH_SIZE, + RegisterModelRequest.DEFAULT_BATCH_SIZE); } } model.setBatchSize(batchSize); @@ -420,42 +410,38 @@ private Model createModel( if (maxBatchDelay == -1 * RegisterModelRequest.DEFAULT_MAX_BATCH_DELAY) { if (archive.getModelConfig() != null) { int marMaxBatchDelay = archive.getModelConfig().getMaxBatchDelay(); - maxBatchDelay = - marMaxBatchDelay > 0 - ? marMaxBatchDelay - : configManager.getJsonIntValue( - archive.getModelName(), - archive.getModelVersion(), - Model.MAX_BATCH_DELAY, - RegisterModelRequest.DEFAULT_MAX_BATCH_DELAY); - } else { - maxBatchDelay = - configManager.getJsonIntValue( + maxBatchDelay = marMaxBatchDelay > 0 + ? marMaxBatchDelay + : configManager.getJsonIntValue( archive.getModelName(), archive.getModelVersion(), Model.MAX_BATCH_DELAY, RegisterModelRequest.DEFAULT_MAX_BATCH_DELAY); + } else { + maxBatchDelay = configManager.getJsonIntValue( + archive.getModelName(), + archive.getModelVersion(), + Model.MAX_BATCH_DELAY, + RegisterModelRequest.DEFAULT_MAX_BATCH_DELAY); } } model.setMaxBatchDelay(maxBatchDelay); if (archive.getModelConfig() != null) { int marResponseTimeout = archive.getModelConfig().getResponseTimeout(); - responseTimeout = - marResponseTimeout > 0 - ? marResponseTimeout - : configManager.getJsonIntValue( - archive.getModelName(), - archive.getModelVersion(), - Model.RESPONSE_TIMEOUT, - responseTimeout); - } else { - responseTimeout = - configManager.getJsonIntValue( + responseTimeout = marResponseTimeout > 0 + ? marResponseTimeout + : configManager.getJsonIntValue( archive.getModelName(), archive.getModelVersion(), Model.RESPONSE_TIMEOUT, responseTimeout); + } else { + responseTimeout = configManager.getJsonIntValue( + archive.getModelName(), + archive.getModelVersion(), + Model.RESPONSE_TIMEOUT, + responseTimeout); } model.setResponseTimeout(responseTimeout); model.setWorkflowModel(isWorkflowModel); @@ -586,7 +572,8 @@ public CompletableFuture updateModel( } if (model.getParallelLevel() > 0 && model.getDeviceType() == ModelConfig.DeviceType.GPU) { /** - * Current capacity check for LMI is based on single node. TODO: multiple nodes check + * Current capacity check for LMI is based on single node. TODO: multiple nodes + * check * will be based on --proc-per-node + numCores. */ int capacity = model.getNumCores() / model.getParallelLevel(); @@ -687,25 +674,25 @@ public boolean scaleRequestStatus(String modelName, String versionId) { } public boolean isModelReady(String modelName, String modelVersion) - throws ModelVersionNotFoundException, ModelNotFoundException { + throws ModelVersionNotFoundException, ModelNotFoundException { - if (modelVersion == null || "".equals(modelVersion)) { - modelVersion = null; - } + if (modelVersion == null || "".equals(modelVersion)) { + modelVersion = null; + } - Model model = getModel(modelName, modelVersion); - if (model == null) { - throw new ModelNotFoundException("Model not found: " + modelName); - } + Model model = getModel(modelName, modelVersion); + if (model == null) { + throw new ModelNotFoundException("Model not found: " + modelName); + } - int numScaled = model.getMinWorkers(); - int numHealthy = modelManager.getNumHealthyWorkers(model.getModelVersionName()); + int numScaled = model.getMinWorkers(); + int numHealthy = modelManager.getNumHealthyWorkers(model.getModelVersionName()); - return numHealthy >= numScaled; + return numHealthy >= numScaled; } public ModelMetadataResponse.Builder modelMetadata(String modelName, String modelVersion) - throws ModelVersionNotFoundException, ModelNotFoundException { + throws ModelVersionNotFoundException, ModelNotFoundException { ModelManager modelManager = ModelManager.getInstance(); ModelMetadataResponse.Builder response = ModelMetadataResponse.newBuilder(); From a5ba78d9dc38ff1a668eb50aa9d4521817dafbdc Mon Sep 17 00:00:00 2001 From: Andrews Arokiam Date: Mon, 8 Apr 2024 19:33:34 +0530 Subject: [PATCH 5/5] fix formatter issue Signed-off-by: Andrews Arokiam --- .../org/pytorch/serve/wlm/ModelManager.java | 52 +++++++++++-------- 1 file changed, 29 insertions(+), 23 deletions(-) diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/ModelManager.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/ModelManager.java index 7d94456aa8..f15cd65a6e 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/ModelManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/ModelManager.java @@ -1,9 +1,6 @@ package org.pytorch.serve.wlm; import com.google.gson.JsonObject; - -import io.grpc.Status; - import java.io.BufferedReader; import java.io.File; import java.io.IOException; @@ -142,11 +139,13 @@ public ModelArchive registerModel( File f = new File(handler.substring(0, handler.lastIndexOf(':'))); archive = new ModelArchive(manifest, url, f.getParentFile(), true); } else { - archive = createModelArchive( + archive = + createModelArchive( modelName, url, handler, runtime, defaultModelName, s3SseKms); } - Model tempModel = createModel(archive, batchSize, maxBatchDelay, responseTimeout, isWorkflowModel); + Model tempModel = + createModel(archive, batchSize, maxBatchDelay, responseTimeout, isWorkflowModel); String versionId = archive.getModelVersion(); @@ -176,7 +175,8 @@ private ModelArchive createModelArchive( boolean s3SseKms) throws ModelException, IOException, DownloadArchiveException { - ModelArchive archive = ModelArchive.downloadModel( + ModelArchive archive = + ModelArchive.downloadModel( configManager.getAllowedUrls(), configManager.getModelStore(), url, @@ -238,7 +238,8 @@ private void setupModelVenv(Model model) + venvPath.toString()); } Map environment = processBuilder.environment(); - String[] envp = EnvironmentUtils.getEnvString( + String[] envp = + EnvironmentUtils.getEnvString( configManager.getModelServerHome(), model.getModelDir().getAbsolutePath(), null); @@ -278,14 +279,16 @@ private void setupModelVenv(Model model) private void setupModelDependencies(Model model) throws IOException, InterruptedException, ModelException { - String requirementsFile = model.getModelArchive().getManifest().getModel().getRequirementsFile(); + String requirementsFile = + model.getModelArchive().getManifest().getModel().getRequirementsFile(); if (!configManager.getInstallPyDepPerModel() || requirementsFile == null) { return; } String pythonRuntime = EnvironmentUtils.getPythonRunTime(model); - Path requirementsFilePath = Paths.get(model.getModelDir().getAbsolutePath(), requirementsFile).toAbsolutePath(); + Path requirementsFilePath = + Paths.get(model.getModelDir().getAbsolutePath(), requirementsFile).toAbsolutePath(); List commandParts = new ArrayList<>(); ProcessBuilder processBuilder = new ProcessBuilder(); @@ -334,7 +337,8 @@ private void setupModelDependencies(Model model) } processBuilder.command(commandParts); - String[] envp = EnvironmentUtils.getEnvString( + String[] envp = + EnvironmentUtils.getEnvString( configManager.getModelServerHome(), model.getModelDir().getAbsolutePath(), null); @@ -390,7 +394,8 @@ private Model createModel( if (batchSize == -1 * RegisterModelRequest.DEFAULT_BATCH_SIZE) { if (archive.getModelConfig() != null) { int marBatchSize = archive.getModelConfig().getBatchSize(); - batchSize = marBatchSize > 0 + batchSize = + marBatchSize > 0 ? marBatchSize : configManager.getJsonIntValue( archive.getModelName(), @@ -398,7 +403,8 @@ private Model createModel( Model.BATCH_SIZE, RegisterModelRequest.DEFAULT_BATCH_SIZE); } else { - batchSize = configManager.getJsonIntValue( + batchSize = + configManager.getJsonIntValue( archive.getModelName(), archive.getModelVersion(), Model.BATCH_SIZE, @@ -410,7 +416,8 @@ private Model createModel( if (maxBatchDelay == -1 * RegisterModelRequest.DEFAULT_MAX_BATCH_DELAY) { if (archive.getModelConfig() != null) { int marMaxBatchDelay = archive.getModelConfig().getMaxBatchDelay(); - maxBatchDelay = marMaxBatchDelay > 0 + maxBatchDelay = + marMaxBatchDelay > 0 ? marMaxBatchDelay : configManager.getJsonIntValue( archive.getModelName(), @@ -418,7 +425,8 @@ private Model createModel( Model.MAX_BATCH_DELAY, RegisterModelRequest.DEFAULT_MAX_BATCH_DELAY); } else { - maxBatchDelay = configManager.getJsonIntValue( + maxBatchDelay = + configManager.getJsonIntValue( archive.getModelName(), archive.getModelVersion(), Model.MAX_BATCH_DELAY, @@ -429,7 +437,8 @@ private Model createModel( if (archive.getModelConfig() != null) { int marResponseTimeout = archive.getModelConfig().getResponseTimeout(); - responseTimeout = marResponseTimeout > 0 + responseTimeout = + marResponseTimeout > 0 ? marResponseTimeout : configManager.getJsonIntValue( archive.getModelName(), @@ -437,7 +446,8 @@ private Model createModel( Model.RESPONSE_TIMEOUT, responseTimeout); } else { - responseTimeout = configManager.getJsonIntValue( + responseTimeout = + configManager.getJsonIntValue( archive.getModelName(), archive.getModelVersion(), Model.RESPONSE_TIMEOUT, @@ -572,8 +582,7 @@ public CompletableFuture updateModel( } if (model.getParallelLevel() > 0 && model.getDeviceType() == ModelConfig.DeviceType.GPU) { /** - * Current capacity check for LMI is based on single node. TODO: multiple nodes - * check + * Current capacity check for LMI is based on single node. TODO: multiple nodes check * will be based on --proc-per-node + numCores. */ int capacity = model.getNumCores() / model.getParallelLevel(); @@ -708,9 +717,7 @@ public ModelMetadataResponse.Builder modelMetadata(String modelName, String mode if (model == null) { throw new ModelNotFoundException("Model not found: " + modelName); } - modelManager - .getAllModelVersions(modelName) - .forEach(entry -> versions.add(entry.getKey())); + modelManager.getAllModelVersions(modelName).forEach(entry -> versions.add(entry.getKey())); response.setName(modelName); response.addAllVersions(versions); response.setPlatform(""); @@ -718,7 +725,6 @@ public ModelMetadataResponse.Builder modelMetadata(String modelName, String mode response.addAllOutputs(outputs); return response; - } // return numHealthy >= numScaled; @@ -765,4 +771,4 @@ public int getNumRunningWorkers(ModelVersionName modelVersionName) { public int getNumHealthyWorkers(ModelVersionName modelVersionName) { return wlm.getNumHealthyWorkers(modelVersionName); } -} +} \ No newline at end of file