diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/ModelManager.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/ModelManager.java index 7d94456aa82..f15cd65a6e8 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/ModelManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/ModelManager.java @@ -1,9 +1,6 @@ package org.pytorch.serve.wlm; import com.google.gson.JsonObject; - -import io.grpc.Status; - import java.io.BufferedReader; import java.io.File; import java.io.IOException; @@ -142,11 +139,13 @@ public ModelArchive registerModel( File f = new File(handler.substring(0, handler.lastIndexOf(':'))); archive = new ModelArchive(manifest, url, f.getParentFile(), true); } else { - archive = createModelArchive( + archive = + createModelArchive( modelName, url, handler, runtime, defaultModelName, s3SseKms); } - Model tempModel = createModel(archive, batchSize, maxBatchDelay, responseTimeout, isWorkflowModel); + Model tempModel = + createModel(archive, batchSize, maxBatchDelay, responseTimeout, isWorkflowModel); String versionId = archive.getModelVersion(); @@ -176,7 +175,8 @@ private ModelArchive createModelArchive( boolean s3SseKms) throws ModelException, IOException, DownloadArchiveException { - ModelArchive archive = ModelArchive.downloadModel( + ModelArchive archive = + ModelArchive.downloadModel( configManager.getAllowedUrls(), configManager.getModelStore(), url, @@ -238,7 +238,8 @@ private void setupModelVenv(Model model) + venvPath.toString()); } Map environment = processBuilder.environment(); - String[] envp = EnvironmentUtils.getEnvString( + String[] envp = + EnvironmentUtils.getEnvString( configManager.getModelServerHome(), model.getModelDir().getAbsolutePath(), null); @@ -278,14 +279,16 @@ private void setupModelVenv(Model model) private void setupModelDependencies(Model model) throws IOException, InterruptedException, ModelException { - String requirementsFile = model.getModelArchive().getManifest().getModel().getRequirementsFile(); + String requirementsFile = + model.getModelArchive().getManifest().getModel().getRequirementsFile(); if (!configManager.getInstallPyDepPerModel() || requirementsFile == null) { return; } String pythonRuntime = EnvironmentUtils.getPythonRunTime(model); - Path requirementsFilePath = Paths.get(model.getModelDir().getAbsolutePath(), requirementsFile).toAbsolutePath(); + Path requirementsFilePath = + Paths.get(model.getModelDir().getAbsolutePath(), requirementsFile).toAbsolutePath(); List commandParts = new ArrayList<>(); ProcessBuilder processBuilder = new ProcessBuilder(); @@ -334,7 +337,8 @@ private void setupModelDependencies(Model model) } processBuilder.command(commandParts); - String[] envp = EnvironmentUtils.getEnvString( + String[] envp = + EnvironmentUtils.getEnvString( configManager.getModelServerHome(), model.getModelDir().getAbsolutePath(), null); @@ -390,7 +394,8 @@ private Model createModel( if (batchSize == -1 * RegisterModelRequest.DEFAULT_BATCH_SIZE) { if (archive.getModelConfig() != null) { int marBatchSize = archive.getModelConfig().getBatchSize(); - batchSize = marBatchSize > 0 + batchSize = + marBatchSize > 0 ? marBatchSize : configManager.getJsonIntValue( archive.getModelName(), @@ -398,7 +403,8 @@ private Model createModel( Model.BATCH_SIZE, RegisterModelRequest.DEFAULT_BATCH_SIZE); } else { - batchSize = configManager.getJsonIntValue( + batchSize = + configManager.getJsonIntValue( archive.getModelName(), archive.getModelVersion(), Model.BATCH_SIZE, @@ -410,7 +416,8 @@ private Model createModel( if (maxBatchDelay == -1 * RegisterModelRequest.DEFAULT_MAX_BATCH_DELAY) { if (archive.getModelConfig() != null) { int marMaxBatchDelay = archive.getModelConfig().getMaxBatchDelay(); - maxBatchDelay = marMaxBatchDelay > 0 + maxBatchDelay = + marMaxBatchDelay > 0 ? marMaxBatchDelay : configManager.getJsonIntValue( archive.getModelName(), @@ -418,7 +425,8 @@ private Model createModel( Model.MAX_BATCH_DELAY, RegisterModelRequest.DEFAULT_MAX_BATCH_DELAY); } else { - maxBatchDelay = configManager.getJsonIntValue( + maxBatchDelay = + configManager.getJsonIntValue( archive.getModelName(), archive.getModelVersion(), Model.MAX_BATCH_DELAY, @@ -429,7 +437,8 @@ private Model createModel( if (archive.getModelConfig() != null) { int marResponseTimeout = archive.getModelConfig().getResponseTimeout(); - responseTimeout = marResponseTimeout > 0 + responseTimeout = + marResponseTimeout > 0 ? marResponseTimeout : configManager.getJsonIntValue( archive.getModelName(), @@ -437,7 +446,8 @@ private Model createModel( Model.RESPONSE_TIMEOUT, responseTimeout); } else { - responseTimeout = configManager.getJsonIntValue( + responseTimeout = + configManager.getJsonIntValue( archive.getModelName(), archive.getModelVersion(), Model.RESPONSE_TIMEOUT, @@ -572,8 +582,7 @@ public CompletableFuture updateModel( } if (model.getParallelLevel() > 0 && model.getDeviceType() == ModelConfig.DeviceType.GPU) { /** - * Current capacity check for LMI is based on single node. TODO: multiple nodes - * check + * Current capacity check for LMI is based on single node. TODO: multiple nodes check * will be based on --proc-per-node + numCores. */ int capacity = model.getNumCores() / model.getParallelLevel(); @@ -708,9 +717,7 @@ public ModelMetadataResponse.Builder modelMetadata(String modelName, String mode if (model == null) { throw new ModelNotFoundException("Model not found: " + modelName); } - modelManager - .getAllModelVersions(modelName) - .forEach(entry -> versions.add(entry.getKey())); + modelManager.getAllModelVersions(modelName).forEach(entry -> versions.add(entry.getKey())); response.setName(modelName); response.addAllVersions(versions); response.setPlatform(""); @@ -718,7 +725,6 @@ public ModelMetadataResponse.Builder modelMetadata(String modelName, String mode response.addAllOutputs(outputs); return response; - } // return numHealthy >= numScaled; @@ -765,4 +771,4 @@ public int getNumRunningWorkers(ModelVersionName modelVersionName) { public int getNumHealthyWorkers(ModelVersionName modelVersionName) { return wlm.getNumHealthyWorkers(modelVersionName); } -} +} \ No newline at end of file