Skip to content

Commit

Permalink
fix formatting
Browse files Browse the repository at this point in the history
Signed-off-by: Andrews Arokiam <andrews.arokiam@ideas2it.com>
  • Loading branch information
andyi2it committed May 22, 2024
1 parent 876dc9d commit 95f09b2
Showing 1 changed file with 61 additions and 74 deletions.
135 changes: 61 additions & 74 deletions frontend/server/src/main/java/org/pytorch/serve/wlm/ModelManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import org.pytorch.serve.archive.model.ModelVersionNotFoundException;
import org.pytorch.serve.grpc.openinference.OpenInferenceGrpc.ModelMetadataResponse;
import org.pytorch.serve.grpc.openinference.OpenInferenceGrpc.ModelMetadataResponse.TensorMetadata;
import org.pytorch.serve.http.BadRequestException;
import org.pytorch.serve.http.ConflictStatusException;
import org.pytorch.serve.http.InvalidModelVersionException;
import org.pytorch.serve.http.messages.RegisterModelRequest;
Expand Down Expand Up @@ -93,7 +92,7 @@ public ModelArchive registerModel(String url, String defaultModelName)

public void registerAndUpdateModel(String modelName, JsonObject modelInfo)
throws ModelException, IOException, InterruptedException, DownloadArchiveException,
WorkerInitializationException {
WorkerInitializationException {

boolean defaultVersion = modelInfo.get(Model.DEFAULT_VERSION).getAsBoolean();
String url = modelInfo.get(Model.MAR_NAME).getAsString();
Expand Down Expand Up @@ -133,7 +132,7 @@ public ModelArchive registerModel(
throws ModelException, IOException, InterruptedException, DownloadArchiveException {

ModelArchive archive;
if (isWorkflowModel && url == null) { // This is a workflow function
if (isWorkflowModel && url == null) { // This is a workflow function
Manifest manifest = new Manifest();
manifest.getModel().setVersion("1.0");
manifest.getModel().setModelVersion("1.0");
Expand All @@ -143,13 +142,11 @@ public ModelArchive registerModel(
File f = new File(handler.substring(0, handler.lastIndexOf(':')));
archive = new ModelArchive(manifest, url, f.getParentFile(), true);
} else {
archive =
createModelArchive(
modelName, url, handler, runtime, defaultModelName, s3SseKms);
archive = createModelArchive(
modelName, url, handler, runtime, defaultModelName, s3SseKms);
}

Model tempModel =
createModel(archive, batchSize, maxBatchDelay, responseTimeout, isWorkflowModel);
Model tempModel = createModel(archive, batchSize, maxBatchDelay, responseTimeout, isWorkflowModel);

String versionId = archive.getModelVersion();

Expand Down Expand Up @@ -179,12 +176,11 @@ private ModelArchive createModelArchive(
boolean s3SseKms)
throws ModelException, IOException, DownloadArchiveException {

ModelArchive archive =
ModelArchive.downloadModel(
configManager.getAllowedUrls(),
configManager.getModelStore(),
url,
s3SseKms);
ModelArchive archive = ModelArchive.downloadModel(
configManager.getAllowedUrls(),
configManager.getModelStore(),
url,
s3SseKms);
Manifest.Model model = archive.getManifest().getModel();
if (modelName == null || modelName.isEmpty()) {
if (archive.getModelName() == null || archive.getModelName().isEmpty()) {
Expand Down Expand Up @@ -242,11 +238,10 @@ private void setupModelVenv(Model model)
+ venvPath.toString());
}
Map<String, String> environment = processBuilder.environment();
String[] envp =
EnvironmentUtils.getEnvString(
configManager.getModelServerHome(),
model.getModelDir().getAbsolutePath(),
null);
String[] envp = EnvironmentUtils.getEnvString(
configManager.getModelServerHome(),
model.getModelDir().getAbsolutePath(),
null);
for (String envVar : envp) {
String[] parts = envVar.split("=", 2);
if (parts.length == 2) {
Expand Down Expand Up @@ -283,16 +278,14 @@ private void setupModelVenv(Model model)

private void setupModelDependencies(Model model)
throws IOException, InterruptedException, ModelException {
String requirementsFile =
model.getModelArchive().getManifest().getModel().getRequirementsFile();
String requirementsFile = model.getModelArchive().getManifest().getModel().getRequirementsFile();

if (!configManager.getInstallPyDepPerModel() || requirementsFile == null) {
return;
}

String pythonRuntime = EnvironmentUtils.getPythonRunTime(model);
Path requirementsFilePath =
Paths.get(model.getModelDir().getAbsolutePath(), requirementsFile).toAbsolutePath();
Path requirementsFilePath = Paths.get(model.getModelDir().getAbsolutePath(), requirementsFile).toAbsolutePath();
List<String> commandParts = new ArrayList<>();
ProcessBuilder processBuilder = new ProcessBuilder();

Expand Down Expand Up @@ -341,11 +334,10 @@ private void setupModelDependencies(Model model)
}

processBuilder.command(commandParts);
String[] envp =
EnvironmentUtils.getEnvString(
configManager.getModelServerHome(),
model.getModelDir().getAbsolutePath(),
null);
String[] envp = EnvironmentUtils.getEnvString(
configManager.getModelServerHome(),
model.getModelDir().getAbsolutePath(),
null);
Map<String, String> environment = processBuilder.environment();
for (String envVar : envp) {
String[] parts = envVar.split("=", 2);
Expand Down Expand Up @@ -398,64 +390,58 @@ private Model createModel(
if (batchSize == -1 * RegisterModelRequest.DEFAULT_BATCH_SIZE) {
if (archive.getModelConfig() != null) {
int marBatchSize = archive.getModelConfig().getBatchSize();
batchSize =
marBatchSize > 0
? marBatchSize
: configManager.getJsonIntValue(
archive.getModelName(),
archive.getModelVersion(),
Model.BATCH_SIZE,
RegisterModelRequest.DEFAULT_BATCH_SIZE);
} else {
batchSize =
configManager.getJsonIntValue(
batchSize = marBatchSize > 0
? marBatchSize
: configManager.getJsonIntValue(
archive.getModelName(),
archive.getModelVersion(),
Model.BATCH_SIZE,
RegisterModelRequest.DEFAULT_BATCH_SIZE);
} else {
batchSize = configManager.getJsonIntValue(
archive.getModelName(),
archive.getModelVersion(),
Model.BATCH_SIZE,
RegisterModelRequest.DEFAULT_BATCH_SIZE);
}
}
model.setBatchSize(batchSize);

if (maxBatchDelay == -1 * RegisterModelRequest.DEFAULT_MAX_BATCH_DELAY) {
if (archive.getModelConfig() != null) {
int marMaxBatchDelay = archive.getModelConfig().getMaxBatchDelay();
maxBatchDelay =
marMaxBatchDelay > 0
? marMaxBatchDelay
: configManager.getJsonIntValue(
archive.getModelName(),
archive.getModelVersion(),
Model.MAX_BATCH_DELAY,
RegisterModelRequest.DEFAULT_MAX_BATCH_DELAY);
} else {
maxBatchDelay =
configManager.getJsonIntValue(
maxBatchDelay = marMaxBatchDelay > 0
? marMaxBatchDelay
: configManager.getJsonIntValue(
archive.getModelName(),
archive.getModelVersion(),
Model.MAX_BATCH_DELAY,
RegisterModelRequest.DEFAULT_MAX_BATCH_DELAY);
} else {
maxBatchDelay = configManager.getJsonIntValue(
archive.getModelName(),
archive.getModelVersion(),
Model.MAX_BATCH_DELAY,
RegisterModelRequest.DEFAULT_MAX_BATCH_DELAY);
}
}
model.setMaxBatchDelay(maxBatchDelay);

if (archive.getModelConfig() != null) {
int marResponseTimeout = archive.getModelConfig().getResponseTimeout();
responseTimeout =
marResponseTimeout > 0
? marResponseTimeout
: configManager.getJsonIntValue(
archive.getModelName(),
archive.getModelVersion(),
Model.RESPONSE_TIMEOUT,
responseTimeout);
} else {
responseTimeout =
configManager.getJsonIntValue(
responseTimeout = marResponseTimeout > 0
? marResponseTimeout
: configManager.getJsonIntValue(
archive.getModelName(),
archive.getModelVersion(),
Model.RESPONSE_TIMEOUT,
responseTimeout);
} else {
responseTimeout = configManager.getJsonIntValue(
archive.getModelName(),
archive.getModelVersion(),
Model.RESPONSE_TIMEOUT,
responseTimeout);
}
model.setResponseTimeout(responseTimeout);
model.setWorkflowModel(isWorkflowModel);
Expand Down Expand Up @@ -586,7 +572,8 @@ public CompletableFuture<Integer> updateModel(
}
if (model.getParallelLevel() > 0 && model.getDeviceType() == ModelConfig.DeviceType.GPU) {
/**
* Current capacity check for LMI is based on single node. TODO: multiple nodes check
* Current capacity check for LMI is based on single node. TODO: multiple nodes
* check
* will be based on --proc-per-node + numCores.
*/
int capacity = model.getNumCores() / model.getParallelLevel();
Expand Down Expand Up @@ -687,25 +674,25 @@ public boolean scaleRequestStatus(String modelName, String versionId) {
}

public boolean isModelReady(String modelName, String modelVersion)
throws ModelVersionNotFoundException, ModelNotFoundException {
throws ModelVersionNotFoundException, ModelNotFoundException {

if (modelVersion == null || "".equals(modelVersion)) {
modelVersion = null;
}
if (modelVersion == null || "".equals(modelVersion)) {
modelVersion = null;
}

Model model = getModel(modelName, modelVersion);
if (model == null) {
throw new ModelNotFoundException("Model not found: " + modelName);
}
Model model = getModel(modelName, modelVersion);
if (model == null) {
throw new ModelNotFoundException("Model not found: " + modelName);
}

int numScaled = model.getMinWorkers();
int numHealthy = modelManager.getNumHealthyWorkers(model.getModelVersionName());
int numScaled = model.getMinWorkers();
int numHealthy = modelManager.getNumHealthyWorkers(model.getModelVersionName());

return numHealthy >= numScaled;
return numHealthy >= numScaled;
}

public ModelMetadataResponse.Builder modelMetadata(String modelName, String modelVersion)
throws ModelVersionNotFoundException, ModelNotFoundException {
throws ModelVersionNotFoundException, ModelNotFoundException {

ModelManager modelManager = ModelManager.getInstance();
ModelMetadataResponse.Builder response = ModelMetadataResponse.newBuilder();
Expand Down

0 comments on commit 95f09b2

Please sign in to comment.