Skip to content

Commit

Permalink
Add query timeout mechanism
Browse files Browse the repository at this point in the history
  • Loading branch information
acquamarin committed Mar 23, 2023
1 parent 956473d commit d521460
Show file tree
Hide file tree
Showing 18 changed files with 124 additions and 40 deletions.
14 changes: 12 additions & 2 deletions src/common/task_system/task_scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,16 @@ void TaskScheduler::waitAllTasksToCompleteOrError() {
}
}

void TaskScheduler::scheduleTaskAndWaitOrError(const std::shared_ptr<Task>& task) {
void TaskScheduler::scheduleTaskAndWaitOrError(
const std::shared_ptr<Task>& task, processor::ExecutionContext* context) {
for (auto& dependency : task->children) {
scheduleTaskAndWaitOrError(dependency);
scheduleTaskAndWaitOrError(dependency, context);
}
auto scheduledTask = scheduleTask(task);
while (!task->isCompleted()) {
if (context->clientContext->isTimeOutEnabled()) {
interruptTaskIfTimeOutNoLock(context);
}
std::this_thread::sleep_for(
std::chrono::microseconds(THREAD_SLEEP_TIME_WHEN_WAITING_IN_MICROS));
}
Expand Down Expand Up @@ -144,5 +148,11 @@ void TaskScheduler::runWorkerThread() {
}
}

void TaskScheduler::interruptTaskIfTimeOutNoLock(processor::ExecutionContext* context) {
if (context->clientContext->isTimeOut()) {
context->clientContext->interrupt();
}
}

} // namespace common
} // namespace kuzu
5 changes: 5 additions & 0 deletions src/include/common/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,5 +133,10 @@ struct EnumeratorKnobs {
static constexpr double FLAT_PROBE_PENALTY = 10;
};

struct ClientContextConstants {
// The default query timeout is 120 seconds.
static constexpr uint64_t TIMEOUT_IN_MS = 0;
};

} // namespace common
} // namespace kuzu
2 changes: 1 addition & 1 deletion src/include/common/exception.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class TransactionManagerException : public Exception {

class InterruptException : public Exception {
public:
explicit InterruptException() : Exception("Interrupted by the user."){};
explicit InterruptException() : Exception("Interrupted."){};
};

} // namespace common
Expand Down
6 changes: 5 additions & 1 deletion src/include/common/task_system/task_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "common/task_system/task.h"
#include "common/utils.h"
#include "processor/execution_context.h"

namespace kuzu {
namespace common {
Expand Down Expand Up @@ -57,7 +58,8 @@ class TaskScheduler {
// whether or not the given task or one of its dependencies errors, when this function
// returns, no task related to the given task will be in the task queue. Further no worker
// thread will be working on the given task.
void scheduleTaskAndWaitOrError(const std::shared_ptr<Task>& task);
void scheduleTaskAndWaitOrError(
const std::shared_ptr<Task>& task, processor::ExecutionContext* context);

// If a user, e.g., currently the copier, adds a set of tasks T1, ..., Tk, to the task scheduler
// without waiting for them to finish, the user needs to call waitAllTasksToCompleteOrError() if
Expand Down Expand Up @@ -94,6 +96,8 @@ class TaskScheduler {
void runWorkerThread();
std::shared_ptr<ScheduledTask> getTaskAndRegister();

void interruptTaskIfTimeOutNoLock(processor::ExecutionContext* context);

private:
std::shared_ptr<spdlog::logger> logger;
std::mutex mtx;
Expand Down
6 changes: 6 additions & 0 deletions src/include/common/timer.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ class Timer {
throw Exception("Timer is still running.");
}

int64_t getElapsedTimeInMS() {
auto now = std::chrono::high_resolution_clock::now();
auto duration = now - startTime;
return std::chrono::duration_cast<std::chrono::milliseconds>(duration).count();
}

private:
std::chrono::time_point<std::chrono::high_resolution_clock> startTime;
std::chrono::time_point<std::chrono::high_resolution_clock> stopTime;
Expand Down
43 changes: 29 additions & 14 deletions src/include/main/client_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,52 @@

#include <atomic>
#include <cstdint>
#include <memory>

#include "common/api.h"
#include "common/timer.h"
#include "main/kuzu_fwd.h"

namespace kuzu {
namespace main {

struct ActiveQuery {
explicit ActiveQuery();

std::atomic<bool> interrupted;
common::Timer timer;
};

/**
* @brief Contain client side configuration. We make profiler associated per query, so profiler is
* not maintained in client context.
*/
class ClientContext {
friend class Connection;
friend class testing::TinySnbDDLTest;
friend class testing::TinySnbCopyCSVTransactionTest;

public:
/**
* @brief Constructs the ClientContext object.
*/
KUZU_API explicit ClientContext();
/**
* @brief Deconstructs the ClientContext object.
*/
KUZU_API ~ClientContext() = default;

/**
* @brief Returns whether the current query is interrupted or not.
*/
KUZU_API bool isInterrupted() const { return interrupted; }
explicit ClientContext();

~ClientContext() = default;

inline void interrupt() { activeQuery->interrupted = true; }

bool isInterrupted() const { return activeQuery->interrupted; }

inline bool isTimeOut() { return activeQuery->timer.getElapsedTimeInMS() > timeoutInMS; }

inline bool isTimeOutEnabled() const { return timeoutInMS != 0; }

inline uint64_t getTimeOutMS() const { return timeoutInMS; }

void startTimingIfEnabled();

private:
uint64_t numThreadsForExecution;
std::atomic<bool> interrupted;
std::unique_ptr<ActiveQuery> activeQuery;
uint64_t timeoutInMS;
};

} // namespace main
Expand Down
6 changes: 6 additions & 0 deletions src/include/main/connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,12 @@ class Connection {
*/
KUZU_API void interrupt();

/**
* @brief sets the query timeout value of the current connection. A value of zero (the default)
* disables the timeout.
*/
KUZU_API void setQueryTimeOut(uint64_t timeoutInMS);

protected:
ConnectionTransactionMode getTransactionMode();
void setTransactionModeNoLock(ConnectionTransactionMode newTransactionMode);
Expand Down
4 changes: 2 additions & 2 deletions src/include/main/kuzu_fwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ class ApiTest;
class BaseGraphTest;
class TestHelper;
class TestHelper;
class TinySnbDDLTest;
class TinySnbCopyCSVTransactionTest;
} // namespace testing

namespace benchmark {
Expand Down Expand Up @@ -51,8 +53,6 @@ namespace transaction {
class Transaction;
enum class TransactionType : uint8_t;
class TransactionManager;
class TinySnbDDLTest;
class TinySnbCopyCSVTransactionTest;
} // namespace transaction

} // namespace kuzu
Expand Down
6 changes: 3 additions & 3 deletions src/include/main/prepared_statement.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ namespace main {
*/
class PreparedStatement {
friend class Connection;
friend class kuzu::testing::TestHelper;
friend class kuzu::transaction::TinySnbDDLTest;
friend class kuzu::transaction::TinySnbCopyCSVTransactionTest;
friend class testing::TestHelper;
friend class testing::TinySnbDDLTest;
friend class testing::TinySnbCopyCSVTransactionTest;

public:
/**
Expand Down
13 changes: 12 additions & 1 deletion src/main/client_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,22 @@

#include <thread>

#include "common/constants.h"

namespace kuzu {
namespace main {

ActiveQuery::ActiveQuery() : interrupted{false} {}

ClientContext::ClientContext()
: numThreadsForExecution{std::thread::hardware_concurrency()}, interrupted{false} {}
: numThreadsForExecution{std::thread::hardware_concurrency()},
timeoutInMS{common::ClientContextConstants::TIMEOUT_IN_MS} {}

void ClientContext::startTimingIfEnabled() {
if (isTimeOutEnabled()) {
activeQuery->timer.start();
}
}

} // namespace main
} // namespace kuzu
12 changes: 8 additions & 4 deletions src/main/connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,12 @@ std::unique_ptr<QueryResult> Connection::kuzu_query(const char* queryString) {
}

void Connection::interrupt() {
clientContext->interrupted = true;
clientContext->activeQuery->interrupted = true;
}

void Connection::setQueryTimeOut(uint64_t timeoutInMS) {
lock_t lck{mtx};
clientContext->timeoutInMS = timeoutInMS;
}

std::unique_ptr<QueryResult> Connection::executeWithParams(PreparedStatement* preparedStatement,
Expand Down Expand Up @@ -309,6 +314,8 @@ void Connection::bindParametersNoLock(PreparedStatement* preparedStatement,

std::unique_ptr<QueryResult> Connection::executeAndAutoCommitIfNecessaryNoLock(
PreparedStatement* preparedStatement, uint32_t planIdx) {
clientContext->activeQuery = std::make_unique<ActiveQuery>();
clientContext->startTimingIfEnabled();
auto mapper = PlanMapper(
*database->storageManager, database->memoryManager.get(), database->catalog.get());
std::unique_ptr<PhysicalPlan> physicalPlan;
Expand Down Expand Up @@ -345,9 +352,6 @@ std::unique_ptr<QueryResult> Connection::executeAndAutoCommitIfNecessaryNoLock(
if (ConnectionTransactionMode::AUTO_COMMIT == transactionMode) {
commitNoLock();
}
} catch (InterruptException& exception) {
clientContext->interrupted = false;
return getQueryResultWithError(exception.what());
} catch (Exception& exception) { return getQueryResultWithError(exception.what()); }
executingTimer.stop();
queryResult->querySummary->executionTime = executingTimer.getElapsedTimeMS();
Expand Down
2 changes: 1 addition & 1 deletion src/processor/processor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ std::shared_ptr<FactorizedTable> QueryProcessor::execute(
// one.
auto task = std::make_shared<ProcessorTask>(resultCollector, context);
decomposePlanIntoTasks(lastOperator, nullptr, task.get(), context);
taskScheduler->scheduleTaskAndWaitOrError(task);
taskScheduler->scheduleTaskAndWaitOrError(task, context);
return resultCollector->getResultFactorizedTable();
}
}
Expand Down
2 changes: 1 addition & 1 deletion test/include/main_test_helper/main_test_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class ApiTest : public BaseGraphTest {
static void executeLongRunningQuery(Connection* conn) {
auto result = conn->query("MATCH (a:person)-[:knows*1..28]->(b:person) RETURN COUNT(*)");
ASSERT_FALSE(result->isSuccess());
ASSERT_EQ(result->getErrorMessage(), "Interrupted by the user.");
ASSERT_EQ(result->getErrorMessage(), "Interrupted.");
}
};

Expand Down
7 changes: 7 additions & 0 deletions test/main/connection_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,10 @@ TEST_F(ApiTest, Interrupt) {
conn->interrupt();
longRunningQueryThread.join();
}

TEST_F(ApiTest, TimeOut) {
conn->setQueryTimeOut(1000 /* timeoutInMS */);
auto result = conn->query("MATCH (a:person)-[:knows*1..28]->(b:person) RETURN COUNT(*);");
ASSERT_FALSE(result->isSuccess());
ASSERT_EQ(result->getErrorMessage(), "Interrupted.");
}
14 changes: 8 additions & 6 deletions test/runner/e2e_copy_transaction_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ using namespace kuzu::storage;
using namespace kuzu::testing;

namespace kuzu {
namespace transaction {
namespace testing {

class TinySnbCopyCSVTransactionTest : public EmptyDBTest {

Expand Down Expand Up @@ -58,7 +58,7 @@ class TinySnbCopyCSVTransactionTest : public EmptyDBTest {
ASSERT_EQ(getStorageManager(*database)
->getNodesStore()
.getNodesStatisticsAndDeletedIDs()
.getMaxNodeOffset(TransactionType::READ_ONLY, tableID),
.getMaxNodeOffset(transaction::TransactionType::READ_ONLY, tableID),
UINT64_MAX);
}

Expand All @@ -76,7 +76,7 @@ class TinySnbCopyCSVTransactionTest : public EmptyDBTest {
ASSERT_EQ(getStorageManager(*database)
->getNodesStore()
.getNodesStatisticsAndDeletedIDs()
.getMaxNodeOffset(TransactionType::READ_ONLY, tableID),
.getMaxNodeOffset(transaction::TransactionType::READ_ONLY, tableID),
7);
}

Expand All @@ -89,6 +89,7 @@ class TinySnbCopyCSVTransactionTest : public EmptyDBTest {
auto physicalPlan =
mapper.mapLogicalPlanToPhysical(preparedStatement->logicalPlans[0].get(),
preparedStatement->getExpressionsToCollect(), preparedStatement->statementType);
clientContext->activeQuery = std::make_unique<ActiveQuery>();
getQueryProcessor(*database)->execute(physicalPlan.get(), executionContext.get());
auto tableID = catalog->getReadOnlyVersion()->getTableID("person");
validateDatabaseStateBeforeCheckPointCopyNode(tableID);
Expand Down Expand Up @@ -136,7 +137,7 @@ class TinySnbCopyCSVTransactionTest : public EmptyDBTest {
relTableSchema, DBFileType::WAL_VERSION, true /* existence */);
validateRelColumnAndListFilesExistence(
relTableSchema, DBFileType::ORIGINAL, true /* existence */);
auto dummyWriteTrx = Transaction::getDummyWriteTrx();
auto dummyWriteTrx = transaction::Transaction::getDummyWriteTrx();
ASSERT_EQ(getStorageManager(*database)->getRelsStore().getRelsStatistics().getNextRelOffset(
dummyWriteTrx.get(), tableID),
14);
Expand All @@ -153,7 +154,7 @@ class TinySnbCopyCSVTransactionTest : public EmptyDBTest {
relTableSchema, DBFileType::ORIGINAL, true /* existence */);
validateTinysnbKnowsDateProperty();
auto& relsStatistics = getStorageManager(*database)->getRelsStore().getRelsStatistics();
auto dummyWriteTrx = Transaction::getDummyWriteTrx();
auto dummyWriteTrx = transaction::Transaction::getDummyWriteTrx();
ASSERT_EQ(relsStatistics.getNextRelOffset(dummyWriteTrx.get(), knowsTableID), 14);
ASSERT_EQ(relsStatistics.getReadOnlyVersion()->tableStatisticPerTable.size(), 1);
auto knowsRelStatistics = (RelStatistics*)relsStatistics.getReadOnlyVersion()
Expand All @@ -173,6 +174,7 @@ class TinySnbCopyCSVTransactionTest : public EmptyDBTest {
auto physicalPlan =
mapper.mapLogicalPlanToPhysical(preparedStatement->logicalPlans[0].get(),
preparedStatement->getExpressionsToCollect(), preparedStatement->statementType);
clientContext->activeQuery = std::make_unique<ActiveQuery>();
getQueryProcessor(*database)->execute(physicalPlan.get(), executionContext.get());
auto tableID = catalog->getReadOnlyVersion()->getTableID("knows");
validateDatabaseStateBeforeCheckPointCopyRel(tableID);
Expand Down Expand Up @@ -245,5 +247,5 @@ TEST_F(TinySnbCopyCSVTransactionTest, CopyCSVStatementWithActiveTransactionError
"previous transaction and issue a ddl query without opening a transaction.");
}

} // namespace transaction
} // namespace testing
} // namespace kuzu
5 changes: 3 additions & 2 deletions test/runner/e2e_ddl_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ using namespace kuzu::storage;
using namespace kuzu::testing;

namespace kuzu {
namespace transaction {
namespace testing {

class PrimaryKeyTest : public EmptyDBTest {
public:
Expand Down Expand Up @@ -365,6 +365,7 @@ class TinySnbDDLTest : public DBTest {
auto physicalPlan =
mapper.mapLogicalPlanToPhysical(preparedStatement->logicalPlans[0].get(),
preparedStatement->getExpressionsToCollect(), preparedStatement->statementType);
executionContext->clientContext->activeQuery = std::make_unique<ActiveQuery>();
getQueryProcessor(*database)->execute(physicalPlan.get(), executionContext.get());
}

Expand Down Expand Up @@ -933,5 +934,5 @@ TEST_F(TinySnbDDLTest, RenamePropertyRecovery) {
renameProperty(TransactionTestType::RECOVERY);
}

} // namespace transaction
} // namespace testing
} // namespace kuzu
Loading

0 comments on commit d521460

Please sign in to comment.