Skip to content

Commit

Permalink
Add query timeout mechanism
Browse files Browse the repository at this point in the history
  • Loading branch information
acquamarin committed Mar 22, 2023
1 parent b5805d6 commit a793c97
Show file tree
Hide file tree
Showing 17 changed files with 111 additions and 40 deletions.
21 changes: 19 additions & 2 deletions src/common/task_system/task_scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,14 @@ void TaskScheduler::waitAllTasksToCompleteOrError() {
}
}

void TaskScheduler::scheduleTaskAndWaitOrError(const std::shared_ptr<Task>& task) {
void TaskScheduler::scheduleTaskAndWaitOrError(
const std::shared_ptr<Task>& task, processor::ExecutionContext* context) {
for (auto& dependency : task->children) {
scheduleTaskAndWaitOrError(dependency);
scheduleTaskAndWaitOrError(dependency, context);
}
auto scheduledTask = scheduleTask(task);
while (!task->isCompleted()) {
interruptTaskIfTimeOutNoLock(context, scheduledTask->ID);
std::this_thread::sleep_for(
std::chrono::microseconds(THREAD_SLEEP_TIME_WHEN_WAITING_IN_MICROS));
}
Expand Down Expand Up @@ -116,6 +118,11 @@ void TaskScheduler::removeErroringTask(uint64_t scheduledTaskID) {
lock_t lck{mtx};
for (auto it = taskQueue.begin(); it != taskQueue.end(); ++it) {
if (scheduledTaskID == (*it)->ID) {
// We need to wait for all worker threads to abort execution of the task.
while (!(*it)->task->isCompleted()) {
std::this_thread::sleep_for(
std::chrono::microseconds(THREAD_SLEEP_TIME_WHEN_WAITING_IN_MICROS));
}
taskQueue.erase(it);
return;
}
Expand Down Expand Up @@ -144,5 +151,15 @@ void TaskScheduler::runWorkerThread() {
}
}

void TaskScheduler::interruptTaskIfTimeOutNoLock(
processor::ExecutionContext* context, uint64_t scheduledTaskID) {
if (context->clientContext->isQueryTimeOut()) {
context->clientContext->interrupt();
removeErroringTask(scheduledTaskID);
throw common::InterruptException{common::StringUtils::string_format(
"query timeout after {} ms.", context->clientContext->getQueryTimeOut())};
}
}

} // namespace common
} // namespace kuzu
2 changes: 1 addition & 1 deletion src/include/common/exception.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class TransactionManagerException : public Exception {

class InterruptException : public Exception {
public:
explicit InterruptException() : Exception("Interrupted by the user."){};
explicit InterruptException(std::string msg) : Exception("Query interrupted by: " + msg){};
};

} // namespace common
Expand Down
7 changes: 6 additions & 1 deletion src/include/common/task_system/task_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "common/task_system/task.h"
#include "common/utils.h"
#include "processor/execution_context.h"

namespace kuzu {
namespace common {
Expand Down Expand Up @@ -57,7 +58,8 @@ class TaskScheduler {
// whether or not the given task or one of its dependencies errors, when this function
// returns, no task related to the given task will be in the task queue. Further no worker
// thread will be working on the given task.
void scheduleTaskAndWaitOrError(const std::shared_ptr<Task>& task);
void scheduleTaskAndWaitOrError(
const std::shared_ptr<Task>& task, processor::ExecutionContext* context);

// If a user, e.g., currently the copier, adds a set of tasks T1, ..., Tk, to the task scheduler
// without waiting for them to finish, the user needs to call waitAllTasksToCompleteOrError() if
Expand Down Expand Up @@ -94,6 +96,9 @@ class TaskScheduler {
void runWorkerThread();
std::shared_ptr<ScheduledTask> getTaskAndRegister();

void interruptTaskIfTimeOutNoLock(
processor::ExecutionContext* context, uint64_t scheduledTaskID);

private:
std::shared_ptr<spdlog::logger> logger;
std::mutex mtx;
Expand Down
45 changes: 31 additions & 14 deletions src/include/main/client_context.h
Original file line number Diff line number Diff line change
@@ -1,38 +1,55 @@
#pragma once

#include <atomic>
#include <chrono>
#include <cstdint>
#include <memory>

#include "common/api.h"
#include "main/kuzu_fwd.h"

namespace kuzu {
namespace main {

struct ActiveQuery {
explicit ActiveQuery();

std::chrono::time_point<std::chrono::high_resolution_clock> startTime;
std::atomic<bool> interrupted;
};

/**
* @brief Contain client side configuration. We make profiler associated per query, so profiler is
* not maintained in client context.
*/
class ClientContext {
friend class Connection;
friend class testing::TinySnbDDLTest;
friend class testing::TinySnbCopyCSVTransactionTest;

public:
/**
* @brief Constructs the ClientContext object.
*/
KUZU_API explicit ClientContext();
/**
* @brief Deconstructs the ClientContext object.
*/
KUZU_API ~ClientContext() = default;

/**
* @brief Returns whether the current query is interrupted or not.
*/
KUZU_API bool isInterrupted() const { return interrupted; }
explicit ClientContext();

~ClientContext() = default;

inline void interrupt() { activeQuery->interrupted = true; }

bool isInterrupted() const { return activeQuery->interrupted; }

inline bool isQueryTimeOut() {
auto currentTime = std::chrono::high_resolution_clock::now();
return std::chrono::duration_cast<std::chrono::milliseconds>(
currentTime - activeQuery->startTime)
.count() > timeoutInMS;
}

inline uint64_t getQueryTimeOut() const { return timeoutInMS; }

private:
uint64_t numThreadsForExecution;
std::atomic<bool> interrupted;
std::unique_ptr<ActiveQuery> activeQuery;
// The default query timeout is 120 seconds.
uint64_t timeoutInMS = 120000;
};

} // namespace main
Expand Down
5 changes: 5 additions & 0 deletions src/include/main/connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,11 @@ class Connection {
*/
KUZU_API void interrupt();

/**
* @brief sets the query timeout value of the current connection.
*/
KUZU_API void setQueryTimeOut(uint64_t timeoutInMS);

protected:
ConnectionTransactionMode getTransactionMode();
void setTransactionModeNoLock(ConnectionTransactionMode newTransactionMode);
Expand Down
4 changes: 2 additions & 2 deletions src/include/main/kuzu_fwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ class ApiTest;
class BaseGraphTest;
class TestHelper;
class TestHelper;
class TinySnbDDLTest;
class TinySnbCopyCSVTransactionTest;
} // namespace testing

namespace benchmark {
Expand Down Expand Up @@ -51,8 +53,6 @@ namespace transaction {
class Transaction;
enum class TransactionType : uint8_t;
class TransactionManager;
class TinySnbDDLTest;
class TinySnbCopyCSVTransactionTest;
} // namespace transaction

} // namespace kuzu
Expand Down
6 changes: 3 additions & 3 deletions src/include/main/prepared_statement.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ namespace main {
*/
class PreparedStatement {
friend class Connection;
friend class kuzu::testing::TestHelper;
friend class kuzu::transaction::TinySnbDDLTest;
friend class kuzu::transaction::TinySnbCopyCSVTransactionTest;
friend class testing::TestHelper;
friend class testing::TinySnbDDLTest;
friend class testing::TinySnbCopyCSVTransactionTest;

public:
/**
Expand Down
2 changes: 1 addition & 1 deletion src/include/processor/operator/physical_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class PhysicalOperator {

inline bool getNextTuple(ExecutionContext* context) {
if (context->clientContext->isInterrupted()) {
throw common::InterruptException{};
throw common::InterruptException{"user interrupt."};
}
metrics->executionTime.start();
auto result = getNextTuplesInternal(context);
Expand Down
6 changes: 4 additions & 2 deletions src/main/client_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
namespace kuzu {
namespace main {

ClientContext::ClientContext()
: numThreadsForExecution{std::thread::hardware_concurrency()}, interrupted{false} {}
ActiveQuery::ActiveQuery()
: startTime{std::chrono::high_resolution_clock::now()}, interrupted{false} {}

ClientContext::ClientContext() : numThreadsForExecution{std::thread::hardware_concurrency()} {}

} // namespace main
} // namespace kuzu
10 changes: 6 additions & 4 deletions src/main/connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,11 @@ std::unique_ptr<QueryResult> Connection::kuzu_query(const char* queryString) {
}

void Connection::interrupt() {
clientContext->interrupted = true;
clientContext->activeQuery->interrupted = true;
}

void Connection::setQueryTimeOut(uint64_t timeoutInMS) {
clientContext->timeoutInMS = timeoutInMS;
}

std::unique_ptr<QueryResult> Connection::executeWithParams(PreparedStatement* preparedStatement,
Expand Down Expand Up @@ -309,6 +313,7 @@ void Connection::bindParametersNoLock(PreparedStatement* preparedStatement,

std::unique_ptr<QueryResult> Connection::executeAndAutoCommitIfNecessaryNoLock(
PreparedStatement* preparedStatement, uint32_t planIdx) {
clientContext->activeQuery = std::make_unique<ActiveQuery>();
auto mapper = PlanMapper(
*database->storageManager, database->memoryManager.get(), database->catalog.get());
std::unique_ptr<PhysicalPlan> physicalPlan;
Expand Down Expand Up @@ -345,9 +350,6 @@ std::unique_ptr<QueryResult> Connection::executeAndAutoCommitIfNecessaryNoLock(
if (ConnectionTransactionMode::AUTO_COMMIT == transactionMode) {
commitNoLock();
}
} catch (InterruptException& exception) {
clientContext->interrupted = false;
return getQueryResultWithError(exception.what());
} catch (Exception& exception) { return getQueryResultWithError(exception.what()); }
executingTimer.stop();
queryResult->querySummary->executionTime = executingTimer.getElapsedTimeMS();
Expand Down
2 changes: 1 addition & 1 deletion src/processor/processor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ std::shared_ptr<FactorizedTable> QueryProcessor::execute(
// one.
auto task = std::make_shared<ProcessorTask>(resultCollector, context);
decomposePlanIntoTasks(lastOperator, nullptr, task.get(), context);
taskScheduler->scheduleTaskAndWaitOrError(task);
taskScheduler->scheduleTaskAndWaitOrError(task, context);
return resultCollector->getResultFactorizedTable();
}
}
Expand Down
2 changes: 1 addition & 1 deletion test/include/main_test_helper/main_test_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class ApiTest : public BaseGraphTest {
static void executeLongRunningQuery(Connection* conn) {
auto result = conn->query("MATCH (a:person)-[:knows*1..28]->(b:person) RETURN COUNT(*)");
ASSERT_FALSE(result->isSuccess());
ASSERT_EQ(result->getErrorMessage(), "Interrupted by the user.");
ASSERT_EQ(result->getErrorMessage(), "Query interrupted by: user interrupt.");
}
};

Expand Down
7 changes: 7 additions & 0 deletions test/main/connection_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,10 @@ TEST_F(ApiTest, Interrupt) {
conn->interrupt();
longRunningQueryThread.join();
}

TEST_F(ApiTest, TimeOut) {
conn->setQueryTimeOut(3000 /* timeoutInMS */);
auto result = conn->query("MATCH (a:person)-[:knows*1..28]->(b:person) RETURN COUNT(*);");
ASSERT_FALSE(result->isSuccess());
ASSERT_EQ(result->getErrorMessage(), "Query interrupted by: query timeout after 3000 ms.");
}
14 changes: 8 additions & 6 deletions test/runner/e2e_copy_transaction_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ using namespace kuzu::storage;
using namespace kuzu::testing;

namespace kuzu {
namespace transaction {
namespace testing {

class TinySnbCopyCSVTransactionTest : public EmptyDBTest {

Expand Down Expand Up @@ -58,7 +58,7 @@ class TinySnbCopyCSVTransactionTest : public EmptyDBTest {
ASSERT_EQ(getStorageManager(*database)
->getNodesStore()
.getNodesStatisticsAndDeletedIDs()
.getMaxNodeOffset(TransactionType::READ_ONLY, tableID),
.getMaxNodeOffset(transaction::TransactionType::READ_ONLY, tableID),
UINT64_MAX);
}

Expand All @@ -76,7 +76,7 @@ class TinySnbCopyCSVTransactionTest : public EmptyDBTest {
ASSERT_EQ(getStorageManager(*database)
->getNodesStore()
.getNodesStatisticsAndDeletedIDs()
.getMaxNodeOffset(TransactionType::READ_ONLY, tableID),
.getMaxNodeOffset(transaction::TransactionType::READ_ONLY, tableID),
7);
}

Expand All @@ -89,6 +89,7 @@ class TinySnbCopyCSVTransactionTest : public EmptyDBTest {
auto physicalPlan =
mapper.mapLogicalPlanToPhysical(preparedStatement->logicalPlans[0].get(),
preparedStatement->getExpressionsToCollect(), preparedStatement->statementType);
clientContext->activeQuery = std::make_unique<ActiveQuery>();
getQueryProcessor(*database)->execute(physicalPlan.get(), executionContext.get());
auto tableID = catalog->getReadOnlyVersion()->getTableID("person");
validateDatabaseStateBeforeCheckPointCopyNode(tableID);
Expand Down Expand Up @@ -136,7 +137,7 @@ class TinySnbCopyCSVTransactionTest : public EmptyDBTest {
relTableSchema, DBFileType::WAL_VERSION, true /* existence */);
validateRelColumnAndListFilesExistence(
relTableSchema, DBFileType::ORIGINAL, true /* existence */);
auto dummyWriteTrx = Transaction::getDummyWriteTrx();
auto dummyWriteTrx = transaction::Transaction::getDummyWriteTrx();
ASSERT_EQ(getStorageManager(*database)->getRelsStore().getRelsStatistics().getNextRelOffset(
dummyWriteTrx.get(), tableID),
14);
Expand All @@ -153,7 +154,7 @@ class TinySnbCopyCSVTransactionTest : public EmptyDBTest {
relTableSchema, DBFileType::ORIGINAL, true /* existence */);
validateTinysnbKnowsDateProperty();
auto& relsStatistics = getStorageManager(*database)->getRelsStore().getRelsStatistics();
auto dummyWriteTrx = Transaction::getDummyWriteTrx();
auto dummyWriteTrx = transaction::Transaction::getDummyWriteTrx();
ASSERT_EQ(relsStatistics.getNextRelOffset(dummyWriteTrx.get(), knowsTableID), 14);
ASSERT_EQ(relsStatistics.getReadOnlyVersion()->tableStatisticPerTable.size(), 1);
auto knowsRelStatistics = (RelStatistics*)relsStatistics.getReadOnlyVersion()
Expand All @@ -173,6 +174,7 @@ class TinySnbCopyCSVTransactionTest : public EmptyDBTest {
auto physicalPlan =
mapper.mapLogicalPlanToPhysical(preparedStatement->logicalPlans[0].get(),
preparedStatement->getExpressionsToCollect(), preparedStatement->statementType);
clientContext->activeQuery = std::make_unique<ActiveQuery>();
getQueryProcessor(*database)->execute(physicalPlan.get(), executionContext.get());
auto tableID = catalog->getReadOnlyVersion()->getTableID("knows");
validateDatabaseStateBeforeCheckPointCopyRel(tableID);
Expand Down Expand Up @@ -245,5 +247,5 @@ TEST_F(TinySnbCopyCSVTransactionTest, CopyCSVStatementWithActiveTransactionError
"previous transaction and issue a ddl query without opening a transaction.");
}

} // namespace transaction
} // namespace testing
} // namespace kuzu
5 changes: 3 additions & 2 deletions test/runner/e2e_ddl_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ using namespace kuzu::storage;
using namespace kuzu::testing;

namespace kuzu {
namespace transaction {
namespace testing {

class PrimaryKeyTest : public EmptyDBTest {
public:
Expand Down Expand Up @@ -365,6 +365,7 @@ class TinySnbDDLTest : public DBTest {
auto physicalPlan =
mapper.mapLogicalPlanToPhysical(preparedStatement->logicalPlans[0].get(),
preparedStatement->getExpressionsToCollect(), preparedStatement->statementType);
executionContext->clientContext->activeQuery = std::make_unique<ActiveQuery>();
getQueryProcessor(*database)->execute(physicalPlan.get(), executionContext.get());
}

Expand Down Expand Up @@ -933,5 +934,5 @@ TEST_F(TinySnbDDLTest, RenamePropertyRecovery) {
renameProperty(TransactionTestType::RECOVERY);
}

} // namespace transaction
} // namespace testing
} // namespace kuzu
Loading

0 comments on commit a793c97

Please sign in to comment.