Skip to content

Commit

Permalink
Add query timeout mechanism
Browse files Browse the repository at this point in the history
  • Loading branch information
acquamarin committed Mar 21, 2023
1 parent b5805d6 commit 24cf468
Show file tree
Hide file tree
Showing 25 changed files with 222 additions and 119 deletions.
44 changes: 40 additions & 4 deletions src/common/task_system/task_scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,25 +49,28 @@ void TaskScheduler::errorIfThereIsAnExceptionNoLock() {
}
}

void TaskScheduler::waitAllTasksToCompleteOrError() {
void TaskScheduler::waitAllTasksToCompleteOrError(processor::ExecutionContext* context) {
while (true) {
lock_t lck{mtx};
if (taskQueue.empty()) {
return;
}
errorIfThereIsAnExceptionNoLock();
interruptAllTasksIfTimeOutNoLock(context);
lck.unlock();
std::this_thread::sleep_for(
std::chrono::microseconds(THREAD_SLEEP_TIME_WHEN_WAITING_IN_MICROS));
}
}

void TaskScheduler::scheduleTaskAndWaitOrError(const std::shared_ptr<Task>& task) {
void TaskScheduler::scheduleTaskAndWaitOrError(
const std::shared_ptr<Task>& task, processor::ExecutionContext* context) {
for (auto& dependency : task->children) {
scheduleTaskAndWaitOrError(dependency);
scheduleTaskAndWaitOrError(dependency, context);
}
auto scheduledTask = scheduleTask(task);
while (!task->isCompleted()) {
interruptTaskIfTimeOutNoLock(context, scheduledTask->ID);
std::this_thread::sleep_for(
std::chrono::microseconds(THREAD_SLEEP_TIME_WHEN_WAITING_IN_MICROS));
}
Expand All @@ -77,9 +80,11 @@ void TaskScheduler::scheduleTaskAndWaitOrError(const std::shared_ptr<Task>& task
}
}

void TaskScheduler::waitUntilEnoughTasksFinish(int64_t minimumNumTasksToScheduleMore) {
void TaskScheduler::waitUntilEnoughTasksFinish(
int64_t minimumNumTasksToScheduleMore, processor::ExecutionContext* context) {
while (getNumTasks() > minimumNumTasksToScheduleMore) {
errorIfThereIsAnException();
interruptAllTasksIfTimeOutNoLock(context);
std::this_thread::sleep_for(
std::chrono::microseconds(THREAD_SLEEP_TIME_WHEN_WAITING_IN_MICROS));
}
Expand Down Expand Up @@ -116,6 +121,11 @@ void TaskScheduler::removeErroringTask(uint64_t scheduledTaskID) {
lock_t lck{mtx};
for (auto it = taskQueue.begin(); it != taskQueue.end(); ++it) {
if (scheduledTaskID == (*it)->ID) {
// We need to wait for all worker threads to abort execution of the task.
while (!(*it)->task->isCompleted()) {
std::this_thread::sleep_for(
std::chrono::microseconds(THREAD_SLEEP_TIME_WHEN_WAITING_IN_MICROS));
}
taskQueue.erase(it);
return;
}
Expand Down Expand Up @@ -144,5 +154,31 @@ void TaskScheduler::runWorkerThread() {
}
}

void TaskScheduler::interruptTaskIfTimeOutNoLock(
processor::ExecutionContext* context, uint64_t scheduledTaskID) {
if (context->clientContext->isQueryTimeOut()) {
context->clientContext->interrupt();
removeErroringTask(scheduledTaskID);
throw common::InterruptException{common::StringUtils::string_format(
"query timeout after {} ms.", context->clientContext->getQueryTimeOut())};
}
}

void TaskScheduler::interruptAllTasksIfTimeOutNoLock(processor::ExecutionContext* context) {
if (context->clientContext->isQueryTimeOut()) {
context->clientContext->interrupt();
for (auto it = taskQueue.begin(); it != taskQueue.end(); ++it) {
// We need to wait for all worker threads to abort execution of the task.
while (!(*it)->task->isCompleted()) {
std::this_thread::sleep_for(
std::chrono::microseconds(THREAD_SLEEP_TIME_WHEN_WAITING_IN_MICROS));
}
taskQueue.erase(it);
}
throw common::InterruptException{common::StringUtils::string_format(
"query timeout after {} ms.", context->clientContext->getQueryTimeOut())};
}
}

} // namespace common
} // namespace kuzu
2 changes: 1 addition & 1 deletion src/include/common/exception.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class TransactionManagerException : public Exception {

class InterruptException : public Exception {
public:
explicit InterruptException() : Exception("Interrupted by the user."){};
explicit InterruptException(std::string msg) : Exception("Query interrupted by: " + msg){};
};

} // namespace common
Expand Down
14 changes: 11 additions & 3 deletions src/include/common/task_system/task_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "common/task_system/task.h"
#include "common/utils.h"
#include "processor/execution_context.h"

namespace kuzu {
namespace common {
Expand Down Expand Up @@ -57,7 +58,8 @@ class TaskScheduler {
// whether or not the given task or one of its dependencies errors, when this function
// returns, no task related to the given task will be in the task queue. Further no worker
// thread will be working on the given task.
void scheduleTaskAndWaitOrError(const std::shared_ptr<Task>& task);
void scheduleTaskAndWaitOrError(
const std::shared_ptr<Task>& task, processor::ExecutionContext* context);

// If a user, e.g., currently the copier, adds a set of tasks T1, ..., Tk, to the task scheduler
// without waiting for them to finish, the user needs to call waitAllTasksToCompleteOrError() if
Expand All @@ -75,9 +77,10 @@ class TaskScheduler {
// it will remove only that one. Other tasks that may have failed many not be removed
// from the task queue and remain in the queue. So for now, use this function if you
// want the system to crash if any of the tasks fails.
void waitAllTasksToCompleteOrError();
void waitAllTasksToCompleteOrError(processor::ExecutionContext* context);

void waitUntilEnoughTasksFinish(int64_t minimumNumTasksToScheduleMore);
void waitUntilEnoughTasksFinish(
int64_t minimumNumTasksToScheduleMore, processor::ExecutionContext* context);

// Checks if there is an erroring task in the queue and if so, errors.
void errorIfThereIsAnException();
Expand All @@ -94,6 +97,11 @@ class TaskScheduler {
void runWorkerThread();
std::shared_ptr<ScheduledTask> getTaskAndRegister();

void interruptTaskIfTimeOutNoLock(
processor::ExecutionContext* context, uint64_t scheduledTaskID);

void interruptAllTasksIfTimeOutNoLock(processor::ExecutionContext* context);

private:
std::shared_ptr<spdlog::logger> logger;
std::mutex mtx;
Expand Down
45 changes: 31 additions & 14 deletions src/include/main/client_context.h
Original file line number Diff line number Diff line change
@@ -1,38 +1,55 @@
#pragma once

#include <atomic>
#include <chrono>
#include <cstdint>
#include <memory>

#include "common/api.h"
#include "main/kuzu_fwd.h"

namespace kuzu {
namespace main {

struct ActiveQuery {
explicit ActiveQuery();

std::chrono::time_point<std::chrono::high_resolution_clock> startTime;
std::atomic<bool> interrupted;
};

/**
* @brief Contain client side configuration. We make profiler associated per query, so profiler is
* not maintained in client context.
*/
class ClientContext {
friend class Connection;
friend class testing::TinySnbDDLTest;
friend class testing::TinySnbCopyCSVTransactionTest;

public:
/**
* @brief Constructs the ClientContext object.
*/
KUZU_API explicit ClientContext();
/**
* @brief Deconstructs the ClientContext object.
*/
KUZU_API ~ClientContext() = default;

/**
* @brief Returns whether the current query is interrupted or not.
*/
KUZU_API bool isInterrupted() const { return interrupted; }
explicit ClientContext();

~ClientContext() = default;

inline void interrupt() { activeQuery->interrupted = true; }

bool isInterrupted() const { return activeQuery->interrupted; }

inline bool isQueryTimeOut() {
auto currentTime = std::chrono::high_resolution_clock::now();
return std::chrono::duration_cast<std::chrono::milliseconds>(
currentTime - activeQuery->startTime)
.count() > timeoutInMS;
}

inline uint64_t getQueryTimeOut() const { return timeoutInMS; }

private:
uint64_t numThreadsForExecution;
std::atomic<bool> interrupted;
std::unique_ptr<ActiveQuery> activeQuery;
// The default query timeout is 120 seconds.
uint64_t timeoutInMS = 120000;
};

} // namespace main
Expand Down
5 changes: 5 additions & 0 deletions src/include/main/connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,11 @@ class Connection {
*/
KUZU_API void interrupt();

/**
* @brief sets the query timeout value of the current connection.
*/
KUZU_API void setQueryTimeOut(uint64_t timeoutInMS);

protected:
ConnectionTransactionMode getTransactionMode();
void setTransactionModeNoLock(ConnectionTransactionMode newTransactionMode);
Expand Down
4 changes: 2 additions & 2 deletions src/include/main/kuzu_fwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ class ApiTest;
class BaseGraphTest;
class TestHelper;
class TestHelper;
class TinySnbDDLTest;
class TinySnbCopyCSVTransactionTest;
} // namespace testing

namespace benchmark {
Expand Down Expand Up @@ -51,8 +53,6 @@ namespace transaction {
class Transaction;
enum class TransactionType : uint8_t;
class TransactionManager;
class TinySnbDDLTest;
class TinySnbCopyCSVTransactionTest;
} // namespace transaction

} // namespace kuzu
Expand Down
6 changes: 3 additions & 3 deletions src/include/main/prepared_statement.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ namespace main {
*/
class PreparedStatement {
friend class Connection;
friend class kuzu::testing::TestHelper;
friend class kuzu::transaction::TinySnbDDLTest;
friend class kuzu::transaction::TinySnbCopyCSVTransactionTest;
friend class testing::TestHelper;
friend class testing::TinySnbDDLTest;
friend class testing::TinySnbCopyCSVTransactionTest;

public:
/**
Expand Down
2 changes: 1 addition & 1 deletion src/include/processor/operator/physical_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class PhysicalOperator {

inline bool getNextTuple(ExecutionContext* context) {
if (context->clientContext->isInterrupted()) {
throw common::InterruptException{};
throw common::InterruptException{"user interrupt."};
}
metrics->executionTime.start();
auto result = getNextTuplesInternal(context);
Expand Down
16 changes: 9 additions & 7 deletions src/include/storage/copy_arrow/copy_node_arrow.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,20 @@ class CopyNodeArrow : public CopyStructuresArrow {

void initializeColumnsAndLists() override;

void populateColumnsAndLists() override;
void populateColumnsAndLists(processor::ExecutionContext* context) override;

void saveToFile() override;
void saveToFile(processor::ExecutionContext* context) override;

template<typename T>
arrow::Status populateColumns();
arrow::Status populateColumns(processor::ExecutionContext* context);

template<typename T>
arrow::Status populateColumnsFromCSV(std::unique_ptr<HashIndexBuilder<T>>& pkIndex);
arrow::Status populateColumnsFromCSV(
std::unique_ptr<HashIndexBuilder<T>>& pkIndex, processor::ExecutionContext* context);

template<typename T>
arrow::Status populateColumnsFromParquet(std::unique_ptr<HashIndexBuilder<T>>& pkIndex);
arrow::Status populateColumnsFromParquet(
std::unique_ptr<HashIndexBuilder<T>>& pkIndex, processor::ExecutionContext* context);

template<typename T>
static void putPropsOfLineIntoColumns(std::vector<std::unique_ptr<InMemColumn>>& columns,
Expand All @@ -57,12 +59,12 @@ class CopyNodeArrow : public CopyStructuresArrow {
template<typename T>
arrow::Status assignCopyCSVTasks(arrow::csv::StreamingReader* csvStreamingReader,
common::offset_t startOffset, std::string filePath,
std::unique_ptr<HashIndexBuilder<T>>& pkIndex);
std::unique_ptr<HashIndexBuilder<T>>& pkIndex, processor::ExecutionContext* context);

template<typename T>
arrow::Status assignCopyParquetTasks(parquet::arrow::FileReader* parquetReader,
common::offset_t startOffset, std::string filePath,
std::unique_ptr<HashIndexBuilder<T>>& pkIndex);
std::unique_ptr<HashIndexBuilder<T>>& pkIndex, processor::ExecutionContext* context);

private:
std::vector<std::unique_ptr<InMemColumn>> columns;
Expand Down
23 changes: 13 additions & 10 deletions src/include/storage/copy_arrow/copy_rel_arrow.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,29 +28,32 @@ class CopyRelArrow : public CopyStructuresArrow {

void initializeColumnsAndLists() override;

void populateColumnsAndLists() override;
void populateColumnsAndLists(processor::ExecutionContext* context) override;

void saveToFile() override;
void saveToFile(processor::ExecutionContext* context) override;

void initializeColumns(common::RelDirection relDirection);

void initializeLists(common::RelDirection relDirection);

void initAdjListsHeaders();
void initAdjListsHeaders(processor::ExecutionContext* context);

void initListsMetadata();
void initListsMetadata(processor::ExecutionContext* context);

void initializePkIndexes(common::table_id_t nodeTableID, BufferManager& bufferManager);

arrow::Status executePopulateTask(PopulateTaskType populateTaskType);
arrow::Status executePopulateTask(
PopulateTaskType populateTaskType, processor::ExecutionContext* context);

arrow::Status populateFromCSV(PopulateTaskType populateTaskType);
arrow::Status populateFromCSV(
PopulateTaskType populateTaskType, processor::ExecutionContext* context);

arrow::Status populateFromParquet(PopulateTaskType populateTaskType);
arrow::Status populateFromParquet(
PopulateTaskType populateTaskType, processor::ExecutionContext* context);

void populateAdjColumnsAndCountRelsInAdjLists();
void populateAdjColumnsAndCountRelsInAdjLists(processor::ExecutionContext* context);

void populateLists();
void populateLists(processor::ExecutionContext* context);

// We store rel properties with overflows, e.g., strings or lists, in
// InMemColumn/ListsWithOverflowFile (e.g., InMemStringLists). When loading these properties
Expand All @@ -66,7 +69,7 @@ class CopyRelArrow : public CopyStructuresArrow {
// of scanning these overflow files, we also sort the overflow pointers based on nodeOffsets, so
// when scanning rels of consecutive nodes, the overflows of these rels appear consecutively on
// disk.
void sortAndCopyOverflowValues();
void sortAndCopyOverflowValues(processor::ExecutionContext* context);

template<typename T>
static void inferTableIDsAndOffsets(const std::vector<std::shared_ptr<T>>& batchColumns,
Expand Down
8 changes: 4 additions & 4 deletions src/include/storage/copy_arrow/copy_structures_arrow.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class CopyStructuresArrow {
common::TaskScheduler& taskScheduler, catalog::Catalog& catalog,
common::table_id_t tableID);

uint64_t copy();
uint64_t copy(processor::ExecutionContext* context);

virtual ~CopyStructuresArrow() = default;

Expand All @@ -47,11 +47,11 @@ class CopyStructuresArrow {

virtual void initializeColumnsAndLists() = 0;

virtual void populateColumnsAndLists() = 0;
virtual void populateColumnsAndLists(processor::ExecutionContext* context) = 0;

virtual void saveToFile() = 0;
virtual void saveToFile(processor::ExecutionContext* context) = 0;

void populateInMemoryStructures();
void populateInMemoryStructures(processor::ExecutionContext* context);

void countNumLines(const std::vector<std::string>& filePath);

Expand Down
6 changes: 4 additions & 2 deletions src/main/client_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
namespace kuzu {
namespace main {

ClientContext::ClientContext()
: numThreadsForExecution{std::thread::hardware_concurrency()}, interrupted{false} {}
ActiveQuery::ActiveQuery()
: startTime{std::chrono::high_resolution_clock::now()}, interrupted{false} {}

ClientContext::ClientContext() : numThreadsForExecution{std::thread::hardware_concurrency()} {}

} // namespace main
} // namespace kuzu
Loading

0 comments on commit 24cf468

Please sign in to comment.