Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add query timeout mechanism #1395

Merged
merged 1 commit into from
Mar 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions src/common/task_system/task_scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,16 @@ void TaskScheduler::waitAllTasksToCompleteOrError() {
}
}

void TaskScheduler::scheduleTaskAndWaitOrError(const std::shared_ptr<Task>& task) {
void TaskScheduler::scheduleTaskAndWaitOrError(
const std::shared_ptr<Task>& task, processor::ExecutionContext* context) {
for (auto& dependency : task->children) {
scheduleTaskAndWaitOrError(dependency);
scheduleTaskAndWaitOrError(dependency, context);
}
auto scheduledTask = scheduleTask(task);
while (!task->isCompleted()) {
if (context->clientContext->isTimeOutEnabled()) {
interruptTaskIfTimeOutNoLock(context);
}
std::this_thread::sleep_for(
std::chrono::microseconds(THREAD_SLEEP_TIME_WHEN_WAITING_IN_MICROS));
}
Expand Down Expand Up @@ -144,5 +148,11 @@ void TaskScheduler::runWorkerThread() {
}
}

void TaskScheduler::interruptTaskIfTimeOutNoLock(processor::ExecutionContext* context) {
if (context->clientContext->isTimeOut()) {
context->clientContext->interrupt();
}
}

} // namespace common
} // namespace kuzu
5 changes: 5 additions & 0 deletions src/include/common/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,5 +133,10 @@ struct EnumeratorKnobs {
static constexpr double FLAT_PROBE_PENALTY = 10;
};

struct ClientContextConstants {
// We disable query timeout by default.
static constexpr uint64_t TIMEOUT_IN_MS = 0;
};

} // namespace common
} // namespace kuzu
2 changes: 1 addition & 1 deletion src/include/common/exception.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class TransactionManagerException : public Exception {

class InterruptException : public Exception {
public:
explicit InterruptException() : Exception("Interrupted by the user."){};
explicit InterruptException() : Exception("Interrupted."){};
};

} // namespace common
Expand Down
6 changes: 5 additions & 1 deletion src/include/common/task_system/task_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "common/task_system/task.h"
#include "common/utils.h"
#include "processor/execution_context.h"

namespace kuzu {
namespace common {
Expand Down Expand Up @@ -57,7 +58,8 @@ class TaskScheduler {
// whether or not the given task or one of its dependencies errors, when this function
// returns, no task related to the given task will be in the task queue. Further no worker
// thread will be working on the given task.
void scheduleTaskAndWaitOrError(const std::shared_ptr<Task>& task);
void scheduleTaskAndWaitOrError(
const std::shared_ptr<Task>& task, processor::ExecutionContext* context);

// If a user, e.g., currently the copier, adds a set of tasks T1, ..., Tk, to the task scheduler
// without waiting for them to finish, the user needs to call waitAllTasksToCompleteOrError() if
Expand Down Expand Up @@ -94,6 +96,8 @@ class TaskScheduler {
void runWorkerThread();
std::shared_ptr<ScheduledTask> getTaskAndRegister();

void interruptTaskIfTimeOutNoLock(processor::ExecutionContext* context);

private:
std::shared_ptr<spdlog::logger> logger;
std::mutex mtx;
Expand Down
6 changes: 6 additions & 0 deletions src/include/common/timer.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ class Timer {
throw Exception("Timer is still running.");
}

int64_t getElapsedTimeInMS() {
auto now = std::chrono::high_resolution_clock::now();
auto duration = now - startTime;
return std::chrono::duration_cast<std::chrono::milliseconds>(duration).count();
}

private:
std::chrono::time_point<std::chrono::high_resolution_clock> startTime;
std::chrono::time_point<std::chrono::high_resolution_clock> stopTime;
Expand Down
43 changes: 29 additions & 14 deletions src/include/main/client_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,52 @@

#include <atomic>
#include <cstdint>
#include <memory>

#include "common/api.h"
#include "common/timer.h"
#include "main/kuzu_fwd.h"

namespace kuzu {
namespace main {

struct ActiveQuery {
explicit ActiveQuery();

std::atomic<bool> interrupted;
common::Timer timer;
};

/**
* @brief Contain client side configuration. We make profiler associated per query, so profiler is
* not maintained in client context.
*/
class ClientContext {
friend class Connection;
friend class testing::TinySnbDDLTest;
friend class testing::TinySnbCopyCSVTransactionTest;

public:
/**
* @brief Constructs the ClientContext object.
*/
KUZU_API explicit ClientContext();
/**
* @brief Deconstructs the ClientContext object.
*/
KUZU_API ~ClientContext() = default;

/**
* @brief Returns whether the current query is interrupted or not.
*/
KUZU_API bool isInterrupted() const { return interrupted; }
explicit ClientContext();
acquamarin marked this conversation as resolved.
Show resolved Hide resolved

~ClientContext() = default;

inline void interrupt() { activeQuery->interrupted = true; }

bool isInterrupted() const { return activeQuery->interrupted; }

inline bool isTimeOut() { return activeQuery->timer.getElapsedTimeInMS() > timeoutInMS; }

inline bool isTimeOutEnabled() const { return timeoutInMS != 0; }

inline uint64_t getTimeOutMS() const { return timeoutInMS; }

void startTimingIfEnabled();

private:
uint64_t numThreadsForExecution;
std::atomic<bool> interrupted;
std::unique_ptr<ActiveQuery> activeQuery;
uint64_t timeoutInMS;
};

} // namespace main
Expand Down
6 changes: 6 additions & 0 deletions src/include/main/connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,12 @@ class Connection {
*/
KUZU_API void interrupt();

/**
* @brief sets the query timeout value of the current connection. A value of zero (the default)
* disables the timeout.
*/
KUZU_API void setQueryTimeOut(uint64_t timeoutInMS);

protected:
ConnectionTransactionMode getTransactionMode();
void setTransactionModeNoLock(ConnectionTransactionMode newTransactionMode);
Expand Down
4 changes: 2 additions & 2 deletions src/include/main/kuzu_fwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ class ApiTest;
class BaseGraphTest;
class TestHelper;
class TestHelper;
class TinySnbDDLTest;
class TinySnbCopyCSVTransactionTest;
} // namespace testing

namespace benchmark {
Expand Down Expand Up @@ -51,8 +53,6 @@ namespace transaction {
class Transaction;
enum class TransactionType : uint8_t;
class TransactionManager;
class TinySnbDDLTest;
class TinySnbCopyCSVTransactionTest;
} // namespace transaction

} // namespace kuzu
Expand Down
6 changes: 3 additions & 3 deletions src/include/main/prepared_statement.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ namespace main {
*/
class PreparedStatement {
friend class Connection;
friend class kuzu::testing::TestHelper;
friend class kuzu::transaction::TinySnbDDLTest;
friend class kuzu::transaction::TinySnbCopyCSVTransactionTest;
friend class testing::TestHelper;
friend class testing::TinySnbDDLTest;
friend class testing::TinySnbCopyCSVTransactionTest;

public:
/**
Expand Down
13 changes: 12 additions & 1 deletion src/main/client_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,22 @@

#include <thread>

#include "common/constants.h"

namespace kuzu {
namespace main {

ActiveQuery::ActiveQuery() : interrupted{false} {}

ClientContext::ClientContext()
: numThreadsForExecution{std::thread::hardware_concurrency()}, interrupted{false} {}
: numThreadsForExecution{std::thread::hardware_concurrency()},
timeoutInMS{common::ClientContextConstants::TIMEOUT_IN_MS} {}

void ClientContext::startTimingIfEnabled() {
if (isTimeOutEnabled()) {
activeQuery->timer.start();
}
}

} // namespace main
} // namespace kuzu
12 changes: 8 additions & 4 deletions src/main/connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,12 @@ std::unique_ptr<QueryResult> Connection::kuzu_query(const char* queryString) {
}

void Connection::interrupt() {
clientContext->interrupted = true;
clientContext->activeQuery->interrupted = true;
}

void Connection::setQueryTimeOut(uint64_t timeoutInMS) {
lock_t lck{mtx};
clientContext->timeoutInMS = timeoutInMS;
}

std::unique_ptr<QueryResult> Connection::executeWithParams(PreparedStatement* preparedStatement,
Expand Down Expand Up @@ -309,6 +314,8 @@ void Connection::bindParametersNoLock(PreparedStatement* preparedStatement,

std::unique_ptr<QueryResult> Connection::executeAndAutoCommitIfNecessaryNoLock(
PreparedStatement* preparedStatement, uint32_t planIdx) {
clientContext->activeQuery = std::make_unique<ActiveQuery>();
clientContext->startTimingIfEnabled();
auto mapper = PlanMapper(
*database->storageManager, database->memoryManager.get(), database->catalog.get());
std::unique_ptr<PhysicalPlan> physicalPlan;
Expand Down Expand Up @@ -345,9 +352,6 @@ std::unique_ptr<QueryResult> Connection::executeAndAutoCommitIfNecessaryNoLock(
if (ConnectionTransactionMode::AUTO_COMMIT == transactionMode) {
commitNoLock();
}
} catch (InterruptException& exception) {
clientContext->interrupted = false;
return getQueryResultWithError(exception.what());
} catch (Exception& exception) { return getQueryResultWithError(exception.what()); }
executingTimer.stop();
queryResult->querySummary->executionTime = executingTimer.getElapsedTimeMS();
Expand Down
2 changes: 1 addition & 1 deletion src/processor/processor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ std::shared_ptr<FactorizedTable> QueryProcessor::execute(
// one.
auto task = std::make_shared<ProcessorTask>(resultCollector, context);
decomposePlanIntoTasks(lastOperator, nullptr, task.get(), context);
taskScheduler->scheduleTaskAndWaitOrError(task);
taskScheduler->scheduleTaskAndWaitOrError(task, context);
return resultCollector->getResultFactorizedTable();
}
}
Expand Down
2 changes: 1 addition & 1 deletion test/include/main_test_helper/main_test_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class ApiTest : public BaseGraphTest {
static void executeLongRunningQuery(Connection* conn) {
auto result = conn->query("MATCH (a:person)-[:knows*1..28]->(b:person) RETURN COUNT(*)");
ASSERT_FALSE(result->isSuccess());
ASSERT_EQ(result->getErrorMessage(), "Interrupted by the user.");
ASSERT_EQ(result->getErrorMessage(), "Interrupted.");
}
};

Expand Down
7 changes: 7 additions & 0 deletions test/main/connection_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,10 @@ TEST_F(ApiTest, Interrupt) {
conn->interrupt();
longRunningQueryThread.join();
}

TEST_F(ApiTest, TimeOut) {
conn->setQueryTimeOut(1000 /* timeoutInMS */);
auto result = conn->query("MATCH (a:person)-[:knows*1..28]->(b:person) RETURN COUNT(*);");
ASSERT_FALSE(result->isSuccess());
ASSERT_EQ(result->getErrorMessage(), "Interrupted.");
}
14 changes: 8 additions & 6 deletions test/runner/e2e_copy_transaction_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ using namespace kuzu::storage;
using namespace kuzu::testing;

namespace kuzu {
namespace transaction {
namespace testing {

class TinySnbCopyCSVTransactionTest : public EmptyDBTest {

Expand Down Expand Up @@ -58,7 +58,7 @@ class TinySnbCopyCSVTransactionTest : public EmptyDBTest {
ASSERT_EQ(getStorageManager(*database)
->getNodesStore()
.getNodesStatisticsAndDeletedIDs()
.getMaxNodeOffset(TransactionType::READ_ONLY, tableID),
.getMaxNodeOffset(transaction::TransactionType::READ_ONLY, tableID),
UINT64_MAX);
}

Expand All @@ -76,7 +76,7 @@ class TinySnbCopyCSVTransactionTest : public EmptyDBTest {
ASSERT_EQ(getStorageManager(*database)
->getNodesStore()
.getNodesStatisticsAndDeletedIDs()
.getMaxNodeOffset(TransactionType::READ_ONLY, tableID),
.getMaxNodeOffset(transaction::TransactionType::READ_ONLY, tableID),
7);
}

Expand All @@ -89,6 +89,7 @@ class TinySnbCopyCSVTransactionTest : public EmptyDBTest {
auto physicalPlan =
mapper.mapLogicalPlanToPhysical(preparedStatement->logicalPlans[0].get(),
preparedStatement->getExpressionsToCollect(), preparedStatement->statementType);
clientContext->activeQuery = std::make_unique<ActiveQuery>();
getQueryProcessor(*database)->execute(physicalPlan.get(), executionContext.get());
auto tableID = catalog->getReadOnlyVersion()->getTableID("person");
validateDatabaseStateBeforeCheckPointCopyNode(tableID);
Expand Down Expand Up @@ -136,7 +137,7 @@ class TinySnbCopyCSVTransactionTest : public EmptyDBTest {
relTableSchema, DBFileType::WAL_VERSION, true /* existence */);
validateRelColumnAndListFilesExistence(
relTableSchema, DBFileType::ORIGINAL, true /* existence */);
auto dummyWriteTrx = Transaction::getDummyWriteTrx();
auto dummyWriteTrx = transaction::Transaction::getDummyWriteTrx();
ASSERT_EQ(getStorageManager(*database)->getRelsStore().getRelsStatistics().getNextRelOffset(
dummyWriteTrx.get(), tableID),
14);
Expand All @@ -153,7 +154,7 @@ class TinySnbCopyCSVTransactionTest : public EmptyDBTest {
relTableSchema, DBFileType::ORIGINAL, true /* existence */);
validateTinysnbKnowsDateProperty();
auto& relsStatistics = getStorageManager(*database)->getRelsStore().getRelsStatistics();
auto dummyWriteTrx = Transaction::getDummyWriteTrx();
auto dummyWriteTrx = transaction::Transaction::getDummyWriteTrx();
ASSERT_EQ(relsStatistics.getNextRelOffset(dummyWriteTrx.get(), knowsTableID), 14);
ASSERT_EQ(relsStatistics.getReadOnlyVersion()->tableStatisticPerTable.size(), 1);
auto knowsRelStatistics = (RelStatistics*)relsStatistics.getReadOnlyVersion()
Expand All @@ -173,6 +174,7 @@ class TinySnbCopyCSVTransactionTest : public EmptyDBTest {
auto physicalPlan =
mapper.mapLogicalPlanToPhysical(preparedStatement->logicalPlans[0].get(),
preparedStatement->getExpressionsToCollect(), preparedStatement->statementType);
clientContext->activeQuery = std::make_unique<ActiveQuery>();
getQueryProcessor(*database)->execute(physicalPlan.get(), executionContext.get());
auto tableID = catalog->getReadOnlyVersion()->getTableID("knows");
validateDatabaseStateBeforeCheckPointCopyRel(tableID);
Expand Down Expand Up @@ -245,5 +247,5 @@ TEST_F(TinySnbCopyCSVTransactionTest, CopyCSVStatementWithActiveTransactionError
"previous transaction and issue a ddl query without opening a transaction.");
}

} // namespace transaction
} // namespace testing
} // namespace kuzu
5 changes: 3 additions & 2 deletions test/runner/e2e_ddl_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ using namespace kuzu::storage;
using namespace kuzu::testing;

namespace kuzu {
namespace transaction {
namespace testing {

class PrimaryKeyTest : public EmptyDBTest {
public:
Expand Down Expand Up @@ -365,6 +365,7 @@ class TinySnbDDLTest : public DBTest {
auto physicalPlan =
mapper.mapLogicalPlanToPhysical(preparedStatement->logicalPlans[0].get(),
preparedStatement->getExpressionsToCollect(), preparedStatement->statementType);
executionContext->clientContext->activeQuery = std::make_unique<ActiveQuery>();
getQueryProcessor(*database)->execute(physicalPlan.get(), executionContext.get());
}

Expand Down Expand Up @@ -933,5 +934,5 @@ TEST_F(TinySnbDDLTest, RenamePropertyRecovery) {
renameProperty(TransactionTestType::RECOVERY);
}

} // namespace transaction
} // namespace testing
} // namespace kuzu
Loading