Skip to content

Commit

Permalink
Make Connection::interrupt safe to use
Browse files Browse the repository at this point in the history
Interrupting before a query has executed would mean the activeQuery was null, leading to undefined behaviour
  • Loading branch information
benjaminwinger committed Jul 18, 2023
1 parent 1e9d4fd commit 883134a
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 14 deletions.
15 changes: 11 additions & 4 deletions src/include/main/client_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ struct ActiveQuery {
explicit ActiveQuery();
std::atomic<bool> interrupted;
common::Timer timer;

inline void reset() {
interrupted = false;
timer = common::Timer();
}
};

/**
Expand All @@ -33,12 +38,12 @@ class ClientContext {

~ClientContext() = default;

inline void interrupt() { activeQuery->interrupted = true; }
inline void interrupt() { activeQuery.interrupted = true; }

bool isInterrupted() const { return activeQuery->interrupted; }
bool isInterrupted() const { return activeQuery.interrupted; }

inline bool isTimeOut() {
return isTimeOutEnabled() && activeQuery->timer.getElapsedTimeInMS() > timeoutInMS;
return isTimeOutEnabled() && activeQuery.timer.getElapsedTimeInMS() > timeoutInMS;
}

inline bool isTimeOutEnabled() const { return timeoutInMS != 0; }
Expand All @@ -48,8 +53,10 @@ class ClientContext {
std::string getCurrentSetting(std::string optionName);

private:
inline void resetActiveQuery() { activeQuery.reset(); }

uint64_t numThreadsForExecution;
std::unique_ptr<ActiveQuery> activeQuery;
ActiveQuery activeQuery;
uint64_t timeoutInMS;
};

Expand Down
2 changes: 1 addition & 1 deletion src/main/client_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ ClientContext::ClientContext()

void ClientContext::startTimingIfEnabled() {
if (isTimeOutEnabled()) {
activeQuery->timer.start();
activeQuery.timer.start();
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/main/connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ std::string Connection::getRelPropertyNames(const std::string& relTableName) {
}

void Connection::interrupt() {
clientContext->activeQuery->interrupted = true;
clientContext->interrupt();
}

void Connection::setQueryTimeOut(uint64_t timeoutInMS) {
Expand Down Expand Up @@ -328,7 +328,7 @@ void Connection::bindParametersNoLock(PreparedStatement* preparedStatement,

std::unique_ptr<QueryResult> Connection::executeAndAutoCommitIfNecessaryNoLock(
PreparedStatement* preparedStatement, uint32_t planIdx) {
clientContext->activeQuery = std::make_unique<ActiveQuery>();
clientContext->resetActiveQuery();
clientContext->startTimingIfEnabled();
auto mapper = PlanMapper(
*database->storageManager, database->memoryManager.get(), database->catalog.get());
Expand Down
14 changes: 10 additions & 4 deletions test/c_api/connection_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,19 +208,25 @@ TEST_F(CApiConnectionTest, QueryTimeout) {

TEST_F(CApiConnectionTest, Interrupt) {
auto connection = getConnection();
bool finished = false;

// Interrupt the query after 100ms
std::thread t([&connection]() {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
kuzu_connection_interrupt(connection);
// This may happen too early, so try again until the query function finishes.
std::thread t([&connection, &finished]() {
do {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
kuzu_connection_interrupt(connection);
} while (!finished);
});
t.detach();
auto result = kuzu_connection_query(
connection, "MATCH (a:person)-[:knows*1..28]->(b:person) RETURN COUNT(*);");
finished = true;
ASSERT_NE(result, nullptr);
ASSERT_NE(result->_query_result, nullptr);
auto resultCpp = static_cast<QueryResult*>(result->_query_result);
ASSERT_NE(resultCpp, nullptr);
ASSERT_FALSE(resultCpp->isSuccess());
ASSERT_EQ(resultCpp->getErrorMessage(), "Interrupted.");
kuzu_query_result_destroy(result);
t.join();
}
4 changes: 2 additions & 2 deletions test/runner/e2e_copy_transaction_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class TinySnbCopyCSVTransactionTest : public EmptyDBTest {
auto physicalPlan =
mapper.mapLogicalPlanToPhysical(preparedStatement->logicalPlans[0].get(),
preparedStatement->statementResult->getColumns());
clientContext->activeQuery = std::make_unique<ActiveQuery>();
clientContext->resetActiveQuery();
getQueryProcessor(*database)->execute(physicalPlan.get(), executionContext.get());
auto tableID = catalog->getReadOnlyVersion()->getTableID("person");
validateDatabaseStateBeforeCheckPointCopyNode(tableID);
Expand Down Expand Up @@ -159,7 +159,7 @@ class TinySnbCopyCSVTransactionTest : public EmptyDBTest {
auto physicalPlan =
mapper.mapLogicalPlanToPhysical(preparedStatement->logicalPlans[0].get(),
preparedStatement->statementResult->getColumns());
clientContext->activeQuery = std::make_unique<ActiveQuery>();
clientContext->resetActiveQuery();
getQueryProcessor(*database)->execute(physicalPlan.get(), executionContext.get());
auto tableID = catalog->getReadOnlyVersion()->getTableID("knows");
validateDatabaseStateBeforeCheckPointCopyRel(tableID);
Expand Down
2 changes: 1 addition & 1 deletion test/runner/e2e_ddl_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ class TinySnbDDLTest : public DBTest {
auto physicalPlan =
mapper.mapLogicalPlanToPhysical(preparedStatement->logicalPlans[0].get(),
preparedStatement->statementResult->getColumns());
executionContext->clientContext->activeQuery = std::make_unique<ActiveQuery>();
executionContext->clientContext->resetActiveQuery();
getQueryProcessor(*database)->execute(physicalPlan.get(), executionContext.get());
}

Expand Down

0 comments on commit 883134a

Please sign in to comment.