Skip to content

Commit

Permalink
clean a jobGroup immediately when it finished (#3222)
Browse files Browse the repository at this point in the history
* refactor

* fix typo

* add log

* clean jobgroup if it is end

* clean up log

* clean up log

* refactor job polling

* update comments
  • Loading branch information
lxning committed Jul 3, 2024
1 parent dc08344 commit 55bcf9d
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@ public class JobGroup {
LinkedBlockingDeque<Job> jobs;
int maxJobQueueSize;
AtomicBoolean finished;
AtomicBoolean polling;

public JobGroup(String groupId, int maxJobQueueSize) {
this.groupId = groupId;
this.maxJobQueueSize = maxJobQueueSize;
this.jobs = new LinkedBlockingDeque<>(maxJobQueueSize);
this.finished = new AtomicBoolean(false);
this.polling = new AtomicBoolean(false);
}

public boolean appendJob(Job job) {
Expand Down Expand Up @@ -47,4 +49,8 @@ public void setFinished(boolean sequenceEnd) {
public boolean isFinished() {
return this.finished.get();
}

public AtomicBoolean getPolling() {
return this.polling;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -311,9 +311,10 @@ private boolean addJobInGroup(Job job) {
logger.info("added jobGroup for sequenceId:{}", job.getGroupId());
} else {
logger.warn(
"Skip the requestId: {} for sequence: {} due to exceeding maxNumSequence: {}",
"Skip the requestId: {} for sequence: {} due to jobGroups size: {} exceeding maxNumSequence: {}",
job.getJobId(),
job.getGroupId(),
jobGroups.size(),
maxNumSequence);
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import org.pytorch.serve.job.Job;
import org.pytorch.serve.job.JobGroup;
import org.pytorch.serve.util.messages.BaseModelRequest;
Expand All @@ -36,18 +34,14 @@ public class SequenceBatching extends BatchAggregator {
// A list of jobGroupIds which are added into current batch. These jobGroupIds need to be added
// back to eventJobGroupIds once their jobs are processed by a batch.
protected LinkedList<String> currentJobGroupIds;
private AtomicInteger localCapacity;
private int localCapacity;
private AtomicBoolean running = new AtomicBoolean(true);
// HashMap to track poll queue tasks in the executor queue
private ConcurrentHashMap<String, CompletableFuture<Void>> pollQueueTasks =
new ConcurrentHashMap<String, CompletableFuture<Void>>();

public SequenceBatching(Model model) {
super(model);
this.localCapacity =
new AtomicInteger(Math.max(1, model.getMaxNumSequence() / model.getMinWorkers()));
this.localCapacity = Math.max(1, model.getMaxNumSequence() / model.getMinWorkers());
this.currentJobGroupIds = new LinkedList<>();
this.pollExecutors = Executors.newFixedThreadPool(localCapacity.get() + 1);
this.pollExecutors = Executors.newFixedThreadPool(model.getBatchSize() + 1);
this.jobsQueue = new LinkedBlockingDeque<>();
this.isPollJobGroup = new AtomicBoolean(false);
this.eventJobGroupIds = new LinkedBlockingDeque<>();
Expand Down Expand Up @@ -76,7 +70,7 @@ private void pollJobGroup() throws InterruptedException {

int quota =
Math.min(
this.localCapacity.get(),
this.localCapacity - jobsQueue.size(),
Math.max(
1, model.getPendingJobGroups().size() / model.getMaxWorkers()));
if (quota > 0 && model.getPendingJobGroups().size() > 0) {
Expand Down Expand Up @@ -123,12 +117,10 @@ public void pollBatch(String threadName, WorkerState state)
}
}

private void cleanJobGroup(String jobGroupId) {
protected void cleanJobGroup(String jobGroupId) {
logger.debug("Clean jobGroup: {}", jobGroupId);
if (jobGroupId != null) {
model.removeJobGroup(jobGroupId);
pollQueueTasks.remove(jobGroupId);
localCapacity.incrementAndGet();
}
}

Expand Down Expand Up @@ -185,7 +177,6 @@ public void shutdownExecutors() {

private void addJobGroup(String jobGroupId) {
if (jobGroupId != null) {
localCapacity.decrementAndGet();
eventJobGroupIds.add(jobGroupId);
}
}
Expand All @@ -202,39 +193,21 @@ public void run() {
String jobGroupId =
eventJobGroupIds.poll(model.getMaxBatchDelay(), TimeUnit.MILLISECONDS);
if (jobGroupId == null || jobGroupId.isEmpty()) {
// Skip fetching new job groups when no capacity is available
if (localCapacity.get() <= 0) {
continue;
}
// Avoid duplicate poll tasks in the executor queue
if (pollQueueTasks.containsKey("pollJobGroup")
&& !pollQueueTasks.get("pollJobGroup").isDone()) {
continue;
}
CompletableFuture<Void> pollTask =
CompletableFuture.runAsync(
() -> {
try {
pollJobGroup();
} catch (InterruptedException e) {
logger.error("Failed to poll a job group", e);
}
},
pollExecutors);
pollQueueTasks.put("pollJobGroup", pollTask);
CompletableFuture.runAsync(
() -> {
try {
pollJobGroup();
} catch (InterruptedException e) {
logger.error("Failed to poll a job group", e);
}
},
pollExecutors);
} else {
// Avoid duplicate poll tasks in the executor queue
if (pollQueueTasks.containsKey(jobGroupId)
&& !pollQueueTasks.get(jobGroupId).isDone()) {
continue;
}
CompletableFuture<Void> pollTask =
CompletableFuture.runAsync(
() -> {
pollJobFromJobGroup(jobGroupId);
},
pollExecutors);
pollQueueTasks.put(jobGroupId, pollTask);
CompletableFuture.runAsync(
() -> {
pollJobFromJobGroup(jobGroupId);
},
pollExecutors);
}
} catch (InterruptedException e) {
if (running.get()) {
Expand All @@ -248,10 +221,16 @@ private void pollJobFromJobGroup(String jobGroupId) {
// Poll a job from a jobGroup
JobGroup jobGroup = model.getJobGroup(jobGroupId);
Job job = null;
AtomicBoolean isPolling = jobGroup.getPolling();
if (!jobGroup.isFinished()) {
job = jobGroup.pollJob(model.getSequenceMaxIdleMSec());
if (!isPolling.getAndSet(true)) {
job = jobGroup.pollJob(model.getSequenceMaxIdleMSec());
isPolling.set(false);
} else {
return;
}
}
if (job == null || jobGroup.isFinished()) {
if (job == null) {
// JobGroup expired, clean it.
cleanJobGroup(jobGroupId);
// intent to add new job groups.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,10 @@ private void setJobGroupFinished(Predictions prediction) {
JobGroup jobGroup = model.getJobGroup(jobGroupId);
if (jobGroup != null) {
jobGroup.setFinished(true);
// JobGroup finished, clean it.
cleanJobGroup(jobGroupId);
// intent to add new job groups.
eventJobGroupIds.add("");
}
}
}
Expand Down

0 comments on commit 55bcf9d

Please sign in to comment.