From 55bcf9d36d63c4a47178b2a7d8959d4c22c68a8c Mon Sep 17 00:00:00 2001 From: lxning <23464292+lxning@users.noreply.github.com> Date: Wed, 3 Jul 2024 12:01:37 -0700 Subject: [PATCH] clean a jobGroup immediately when it finished (#3222) * refactor * fix typo * add log * clean jobgroup if it is end * clean up log * clean up log * refactor job polling * update comments --- .../java/org/pytorch/serve/job/JobGroup.java | 6 ++ .../java/org/pytorch/serve/wlm/Model.java | 3 +- .../pytorch/serve/wlm/SequenceBatching.java | 75 +++++++------------ .../serve/wlm/SequenceContinuousBatching.java | 4 + 4 files changed, 39 insertions(+), 49 deletions(-) diff --git a/frontend/server/src/main/java/org/pytorch/serve/job/JobGroup.java b/frontend/server/src/main/java/org/pytorch/serve/job/JobGroup.java index 75c2b1b5da..78babb5077 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/job/JobGroup.java +++ b/frontend/server/src/main/java/org/pytorch/serve/job/JobGroup.java @@ -12,12 +12,14 @@ public class JobGroup { LinkedBlockingDeque jobs; int maxJobQueueSize; AtomicBoolean finished; + AtomicBoolean polling; public JobGroup(String groupId, int maxJobQueueSize) { this.groupId = groupId; this.maxJobQueueSize = maxJobQueueSize; this.jobs = new LinkedBlockingDeque<>(maxJobQueueSize); this.finished = new AtomicBoolean(false); + this.polling = new AtomicBoolean(false); } public boolean appendJob(Job job) { @@ -47,4 +49,8 @@ public void setFinished(boolean sequenceEnd) { public boolean isFinished() { return this.finished.get(); } + + public AtomicBoolean getPolling() { + return this.polling; + } } diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java index cfe7ae4efb..459cc87d38 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java @@ -311,9 +311,10 @@ private boolean addJobInGroup(Job job) { logger.info("added jobGroup for sequenceId:{}", job.getGroupId()); } else { logger.warn( - "Skip the requestId: {} for sequence: {} due to exceeding maxNumSequence: {}", + "Skip the requestId: {} for sequence: {} due to jobGroups size: {} exceeding maxNumSequence: {}", job.getJobId(), job.getGroupId(), + jobGroups.size(), maxNumSequence); return false; } diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/SequenceBatching.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/SequenceBatching.java index 4705956ec3..d14a5e7047 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/SequenceBatching.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/SequenceBatching.java @@ -3,14 +3,12 @@ import java.util.LinkedHashSet; import java.util.LinkedList; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.LinkedBlockingDeque; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; import org.pytorch.serve.job.Job; import org.pytorch.serve.job.JobGroup; import org.pytorch.serve.util.messages.BaseModelRequest; @@ -36,18 +34,14 @@ public class SequenceBatching extends BatchAggregator { // A list of jobGroupIds which are added into current batch. These jobGroupIds need to be added // back to eventJobGroupIds once their jobs are processed by a batch. protected LinkedList currentJobGroupIds; - private AtomicInteger localCapacity; + private int localCapacity; private AtomicBoolean running = new AtomicBoolean(true); - // HashMap to track poll queue tasks in the executor queue - private ConcurrentHashMap> pollQueueTasks = - new ConcurrentHashMap>(); public SequenceBatching(Model model) { super(model); - this.localCapacity = - new AtomicInteger(Math.max(1, model.getMaxNumSequence() / model.getMinWorkers())); + this.localCapacity = Math.max(1, model.getMaxNumSequence() / model.getMinWorkers()); this.currentJobGroupIds = new LinkedList<>(); - this.pollExecutors = Executors.newFixedThreadPool(localCapacity.get() + 1); + this.pollExecutors = Executors.newFixedThreadPool(model.getBatchSize() + 1); this.jobsQueue = new LinkedBlockingDeque<>(); this.isPollJobGroup = new AtomicBoolean(false); this.eventJobGroupIds = new LinkedBlockingDeque<>(); @@ -76,7 +70,7 @@ private void pollJobGroup() throws InterruptedException { int quota = Math.min( - this.localCapacity.get(), + this.localCapacity - jobsQueue.size(), Math.max( 1, model.getPendingJobGroups().size() / model.getMaxWorkers())); if (quota > 0 && model.getPendingJobGroups().size() > 0) { @@ -123,12 +117,10 @@ public void pollBatch(String threadName, WorkerState state) } } - private void cleanJobGroup(String jobGroupId) { + protected void cleanJobGroup(String jobGroupId) { logger.debug("Clean jobGroup: {}", jobGroupId); if (jobGroupId != null) { model.removeJobGroup(jobGroupId); - pollQueueTasks.remove(jobGroupId); - localCapacity.incrementAndGet(); } } @@ -185,7 +177,6 @@ public void shutdownExecutors() { private void addJobGroup(String jobGroupId) { if (jobGroupId != null) { - localCapacity.decrementAndGet(); eventJobGroupIds.add(jobGroupId); } } @@ -202,39 +193,21 @@ public void run() { String jobGroupId = eventJobGroupIds.poll(model.getMaxBatchDelay(), TimeUnit.MILLISECONDS); if (jobGroupId == null || jobGroupId.isEmpty()) { - // Skip fetching new job groups when no capacity is available - if (localCapacity.get() <= 0) { - continue; - } - // Avoid duplicate poll tasks in the executor queue - if (pollQueueTasks.containsKey("pollJobGroup") - && !pollQueueTasks.get("pollJobGroup").isDone()) { - continue; - } - CompletableFuture pollTask = - CompletableFuture.runAsync( - () -> { - try { - pollJobGroup(); - } catch (InterruptedException e) { - logger.error("Failed to poll a job group", e); - } - }, - pollExecutors); - pollQueueTasks.put("pollJobGroup", pollTask); + CompletableFuture.runAsync( + () -> { + try { + pollJobGroup(); + } catch (InterruptedException e) { + logger.error("Failed to poll a job group", e); + } + }, + pollExecutors); } else { - // Avoid duplicate poll tasks in the executor queue - if (pollQueueTasks.containsKey(jobGroupId) - && !pollQueueTasks.get(jobGroupId).isDone()) { - continue; - } - CompletableFuture pollTask = - CompletableFuture.runAsync( - () -> { - pollJobFromJobGroup(jobGroupId); - }, - pollExecutors); - pollQueueTasks.put(jobGroupId, pollTask); + CompletableFuture.runAsync( + () -> { + pollJobFromJobGroup(jobGroupId); + }, + pollExecutors); } } catch (InterruptedException e) { if (running.get()) { @@ -248,10 +221,16 @@ private void pollJobFromJobGroup(String jobGroupId) { // Poll a job from a jobGroup JobGroup jobGroup = model.getJobGroup(jobGroupId); Job job = null; + AtomicBoolean isPolling = jobGroup.getPolling(); if (!jobGroup.isFinished()) { - job = jobGroup.pollJob(model.getSequenceMaxIdleMSec()); + if (!isPolling.getAndSet(true)) { + job = jobGroup.pollJob(model.getSequenceMaxIdleMSec()); + isPolling.set(false); + } else { + return; + } } - if (job == null || jobGroup.isFinished()) { + if (job == null) { // JobGroup expired, clean it. cleanJobGroup(jobGroupId); // intent to add new job groups. diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/SequenceContinuousBatching.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/SequenceContinuousBatching.java index 4ec5e4747c..3284e8606a 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/SequenceContinuousBatching.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/SequenceContinuousBatching.java @@ -152,6 +152,10 @@ private void setJobGroupFinished(Predictions prediction) { JobGroup jobGroup = model.getJobGroup(jobGroupId); if (jobGroup != null) { jobGroup.setFinished(true); + // JobGroup finished, clean it. + cleanJobGroup(jobGroupId); + // intent to add new job groups. + eventJobGroupIds.add(""); } } }