Skip to content

Commit

Permalink
Make Connection::interrupt safe to use
Browse files Browse the repository at this point in the history
Interrupting before a query has executed would mean the activeQuery was null, leading to undefined behaviour
  • Loading branch information
benjaminwinger committed Jul 17, 2023
1 parent 1e9d4fd commit d5f954d
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 19 deletions.
20 changes: 10 additions & 10 deletions src/include/main/client_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,6 @@
namespace kuzu {
namespace main {

struct ActiveQuery {
explicit ActiveQuery();
std::atomic<bool> interrupted;
common::Timer timer;
};

/**
* @brief Contain client side configuration. We make profiler associated per query, so profiler is
* not maintained in client context.
Expand All @@ -33,12 +27,12 @@ class ClientContext {

~ClientContext() = default;

inline void interrupt() { activeQuery->interrupted = true; }
inline void interrupt() { interrupted = true; }

bool isInterrupted() const { return activeQuery->interrupted; }
bool isInterrupted() const { return interrupted; }

inline bool isTimeOut() {
return isTimeOutEnabled() && activeQuery->timer.getElapsedTimeInMS() > timeoutInMS;
return isTimeOutEnabled() && timer.getElapsedTimeInMS() > timeoutInMS;
}

inline bool isTimeOutEnabled() const { return timeoutInMS != 0; }
Expand All @@ -48,8 +42,14 @@ class ClientContext {
std::string getCurrentSetting(std::string optionName);

private:
inline void reset() {
interrupted = false;
timer = common::Timer();
}

uint64_t numThreadsForExecution;
std::unique_ptr<ActiveQuery> activeQuery;
std::atomic<bool> interrupted;
common::Timer timer;
uint64_t timeoutInMS;
};

Expand Down
4 changes: 1 addition & 3 deletions src/main/client_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,13 @@
namespace kuzu {
namespace main {

ActiveQuery::ActiveQuery() : interrupted{false} {}

ClientContext::ClientContext()
: numThreadsForExecution{std::thread::hardware_concurrency()},
timeoutInMS{common::ClientContextConstants::TIMEOUT_IN_MS} {}

void ClientContext::startTimingIfEnabled() {
if (isTimeOutEnabled()) {
activeQuery->timer.start();
timer.start();
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/main/connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ std::string Connection::getRelPropertyNames(const std::string& relTableName) {
}

void Connection::interrupt() {
clientContext->activeQuery->interrupted = true;
clientContext->interrupt();
}

void Connection::setQueryTimeOut(uint64_t timeoutInMS) {
Expand Down Expand Up @@ -328,7 +328,7 @@ void Connection::bindParametersNoLock(PreparedStatement* preparedStatement,

std::unique_ptr<QueryResult> Connection::executeAndAutoCommitIfNecessaryNoLock(
PreparedStatement* preparedStatement, uint32_t planIdx) {
clientContext->activeQuery = std::make_unique<ActiveQuery>();
clientContext->reset();
clientContext->startTimingIfEnabled();
auto mapper = PlanMapper(
*database->storageManager, database->memoryManager.get(), database->catalog.get());
Expand Down
2 changes: 1 addition & 1 deletion test/c_api/connection_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,6 @@ TEST_F(CApiConnectionTest, Interrupt) {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
kuzu_connection_interrupt(connection);
});
t.detach();
auto result = kuzu_connection_query(
connection, "MATCH (a:person)-[:knows*1..28]->(b:person) RETURN COUNT(*);");
ASSERT_NE(result, nullptr);
Expand All @@ -223,4 +222,5 @@ TEST_F(CApiConnectionTest, Interrupt) {
ASSERT_FALSE(resultCpp->isSuccess());
ASSERT_EQ(resultCpp->getErrorMessage(), "Interrupted.");
kuzu_query_result_destroy(result);
t.join();
}
4 changes: 2 additions & 2 deletions test/runner/e2e_copy_transaction_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class TinySnbCopyCSVTransactionTest : public EmptyDBTest {
auto physicalPlan =
mapper.mapLogicalPlanToPhysical(preparedStatement->logicalPlans[0].get(),
preparedStatement->statementResult->getColumns());
clientContext->activeQuery = std::make_unique<ActiveQuery>();
clientContext->reset();
getQueryProcessor(*database)->execute(physicalPlan.get(), executionContext.get());
auto tableID = catalog->getReadOnlyVersion()->getTableID("person");
validateDatabaseStateBeforeCheckPointCopyNode(tableID);
Expand Down Expand Up @@ -159,7 +159,7 @@ class TinySnbCopyCSVTransactionTest : public EmptyDBTest {
auto physicalPlan =
mapper.mapLogicalPlanToPhysical(preparedStatement->logicalPlans[0].get(),
preparedStatement->statementResult->getColumns());
clientContext->activeQuery = std::make_unique<ActiveQuery>();
clientContext->reset();
getQueryProcessor(*database)->execute(physicalPlan.get(), executionContext.get());
auto tableID = catalog->getReadOnlyVersion()->getTableID("knows");
validateDatabaseStateBeforeCheckPointCopyRel(tableID);
Expand Down
2 changes: 1 addition & 1 deletion test/runner/e2e_ddl_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ class TinySnbDDLTest : public DBTest {
auto physicalPlan =
mapper.mapLogicalPlanToPhysical(preparedStatement->logicalPlans[0].get(),
preparedStatement->statementResult->getColumns());
executionContext->clientContext->activeQuery = std::make_unique<ActiveQuery>();
executionContext->clientContext->reset();
getQueryProcessor(*database)->execute(physicalPlan.get(), executionContext.get());
}

Expand Down

0 comments on commit d5f954d

Please sign in to comment.