Skip to content

Commit

Permalink
Remove busy wait completely in TaskScheduler (#2573)
Browse files Browse the repository at this point in the history
* Remove busy wait completely in TaskScheduler

* Remove const include

* Fix timer issue

* Temp: limit test time on CI

* Remove additional condition check on interrupt

* Revert timeout

* small fix for windows, maybe (#2574)

* Address PR comments

---------

Co-authored-by: Keenan G <41458184+Riolku@users.noreply.github.com>
  • Loading branch information
mewim and Riolku committed Dec 12, 2023
1 parent 00d8dda commit 60ebb05
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 17 deletions.
4 changes: 4 additions & 0 deletions src/common/task_system/task.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ void Task::deRegisterThreadAndFinalizeTaskIfNecessary() {
finalizeIfNecessary();
} catch (std::exception& e) { setExceptionNoLock(std::current_exception()); }
}
if (isCompletedNoLock()) {
lck.unlock();
cv.notify_all();
}
}

} // namespace common
Expand Down
39 changes: 25 additions & 14 deletions src/common/task_system/task_scheduler.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
#include "common/task_system/task_scheduler.h"

#include "common/constants.h"

using namespace kuzu::common;

namespace kuzu {
Expand Down Expand Up @@ -30,15 +28,35 @@ void TaskScheduler::scheduleTaskAndWaitOrError(
}
auto scheduledTask = pushTaskIntoQueue(task);
cv.notify_all();
while (!task->isCompleted()) {
std::unique_lock<std::mutex> taskLck{task->mtx, std::defer_lock};
while (true) {
taskLck.lock();
bool timedWait = false;
auto timeout = 0u;
if (task->isCompletedNoLock()) {
// Note: we do not remove completed tasks from the queue in this function. They will be
// removed by the worker threads when they traverse down the queue for a task to work on
// (see getTaskAndRegister()).
taskLck.unlock();
break;
}
if (context->clientContext->isTimeOutEnabled()) {
interruptTaskIfTimeOutNoLock(context);
} else if (task->hasException()) {
timeout = context->clientContext->getTimeoutRemainingInMS();
if (timeout == 0) {
context->clientContext->interrupt();
} else {
timedWait = true;
}
} else if (task->hasExceptionNoLock()) {
// Interrupt tasks that errored, so other threads can stop working on them early.
context->clientContext->interrupt();
}
std::this_thread::sleep_for(
std::chrono::microseconds(THREAD_SLEEP_TIME_WHEN_WAITING_IN_MICROS));
if (timedWait) {
task->cv.wait_for(taskLck, std::chrono::milliseconds(timeout));
} else {
task->cv.wait(taskLck);
}
taskLck.unlock();
}
if (task->hasException()) {
removeErroringTask(scheduledTask->ID);
Expand Down Expand Up @@ -112,12 +130,5 @@ void TaskScheduler::runWorkerThread() {
}
}
}

void TaskScheduler::interruptTaskIfTimeOutNoLock(processor::ExecutionContext* context) {
if (context->clientContext->isTimeOut()) {
context->clientContext->interrupt();
}
}

} // namespace common
} // namespace kuzu
3 changes: 3 additions & 0 deletions src/include/common/task_system/task.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include <condition_variable>
#include <memory>
#include <mutex>
#include <utility>
Expand All @@ -23,6 +24,7 @@ using lock_t = std::unique_lock<std::mutex>;
* finalizeIfNecessary(). See ProcessorTask for an example of this.
*/
class Task {
friend class TaskScheduler;

public:
explicit Task(uint64_t maxNumThreads);
Expand Down Expand Up @@ -100,6 +102,7 @@ class Task {

protected:
std::mutex mtx;
std::condition_variable cv;
uint64_t maxNumThreads, numThreadsFinished{0}, numThreadsRegistered{0};
std::exception_ptr exceptionsPtr = nullptr;
uint64_t ID;
Expand Down
4 changes: 1 addition & 3 deletions src/include/common/task_system/task_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,11 @@ class TaskScheduler {
void runWorkerThread();
std::shared_ptr<ScheduledTask> getTaskAndRegister();

void interruptTaskIfTimeOutNoLock(processor::ExecutionContext* context);

private:
std::mutex mtx;
std::deque<std::shared_ptr<ScheduledTask>> taskQueue;
bool stopThreads;
std::vector<std::thread> threads;
std::mutex mtx;
std::condition_variable cv;
uint64_t nextScheduledTaskID;
};
Expand Down
6 changes: 6 additions & 0 deletions src/include/main/client_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ class ClientContext {

inline bool isTimeOutEnabled() const { return timeoutInMS != 0; }

inline uint64_t getTimeoutRemainingInMS() {
KU_ASSERT(isTimeOutEnabled());
auto elapsed = activeQuery.timer.getElapsedTimeInMS();
return elapsed >= timeoutInMS ? 0 : timeoutInMS - elapsed;
}

inline bool isEnableSemiMask() const { return enableSemiMask; }

void startTimingIfEnabled();
Expand Down

0 comments on commit 60ebb05

Please sign in to comment.