From 2282882d383e771c0dc2b9c0f5491c31a8fc272d Mon Sep 17 00:00:00 2001 From: Naman Nandan Date: Mon, 1 Jul 2024 17:40:47 -0700 Subject: [PATCH 1/2] Restrict GPU access from worker to deviceIds --- .../pytorch/serve/wlm/WorkerLifeCycle.java | 20 +++++++------------ 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java index 74b31dfd24..9749d59a0d 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java @@ -130,16 +130,13 @@ private void startWorkerPython(int port, String deviceIds) modelPath.getAbsolutePath(), model.getModelArchive().getManifest().getModel().getHandler()))); - if (model.getParallelLevel() > 0) { - if (model.getParallelType() != ParallelType.CUSTOM) { - attachRunner(argl, envp, port, deviceIds); - } else { - if (deviceIds != null) { - envp.add("CUDA_VISIBLE_DEVICES=" + deviceIds); - } - argl.add(EnvironmentUtils.getPythonRunTime(model)); - } - } else if (model.getParallelLevel() == 0) { + if (deviceIds != null) { + envp.add("CUDA_VISIBLE_DEVICES=" + deviceIds); + } + + if (model.getParallelLevel() > 0 && model.getParallelType() != ParallelType.CUSTOM) { + attachRunner(argl, envp, port, deviceIds); + } else { argl.add(EnvironmentUtils.getPythonRunTime(model)); } @@ -291,9 +288,6 @@ private void startWorkerCPP(int port, String runtimeType, String deviceIds) private void attachRunner( ArrayList argl, List envp, int port, String deviceIds) { envp.add("LOGLEVEL=INFO"); - if (deviceIds != null) { - envp.add("CUDA_VISIBLE_DEVICES=" + deviceIds); - } ModelConfig.TorchRun torchRun = model.getModelArchive().getModelConfig().getTorchRun(); envp.add(String.format("OMP_NUM_THREADS=%d", torchRun.getOmpNumberThreads())); argl.add("torchrun"); From eba5d7b5c5e249de6ba7c3c4740ed95fd70fb2fb Mon Sep 17 00:00:00 2001 From: Naman Nandan Date: Mon, 1 Jul 2024 18:24:18 -0700 Subject: [PATCH 2/2] Set device Ids to gpu id when parallel level is not configured --- .../main/java/org/pytorch/serve/wlm/WorkerThread.java | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerThread.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerThread.java index 8a73e91412..1628f955f1 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerThread.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerThread.java @@ -544,7 +544,7 @@ public void retry() { protected String getDeviceIds() { List deviceIds; - if (gpuId == -1 || model.getParallelLevel() == 0) { + if (gpuId == -1) { return null; } else if (model.isHasCfgDeviceIds()) { return model.getDeviceIds().subList(gpuId, gpuId + model.getParallelLevel()).stream() @@ -552,8 +552,12 @@ protected String getDeviceIds() { .collect(Collectors.joining(",")); } else { deviceIds = new ArrayList<>(model.getParallelLevel()); - for (int i = gpuId; i < gpuId + model.getParallelLevel(); i++) { - deviceIds.add(i); + if (model.getParallelLevel() > 0) { + for (int i = gpuId; i < gpuId + model.getParallelLevel(); i++) { + deviceIds.add(i); + } + } else { + deviceIds.add(gpuId); } return deviceIds.stream().map(String::valueOf).collect(Collectors.joining(",")); }