Skip to content

Commit

Permalink
Fix #3154 (#3263)
Browse files Browse the repository at this point in the history
* Resolve #3154 for Python

* Resolve #3154 for Node.js

* Address PR comments

* Run clang-format

* Move PyConnection::checkAndWrapQueryResult method

---------

Co-authored-by: CI Bot <mewim@users.noreply.github.com>
  • Loading branch information
mewim and mewim committed Apr 12, 2024
1 parent 65a080f commit 35bc707
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 30 deletions.
53 changes: 40 additions & 13 deletions tools/nodejs_api/src_cpp/include/node_connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ using namespace kuzu::common;
class NodeConnection : public Napi::ObjectWrap<NodeConnection> {
friend class ConnectionInitAsyncWorker;
friend class ConnectionExecuteAsyncWorker;
friend class ConnectionQueryAsyncWorker;
friend class NodePreparedStatement;

public:
Expand All @@ -27,6 +28,7 @@ class NodeConnection : public Napi::ObjectWrap<NodeConnection> {
void SetMaxNumThreadForExec(const Napi::CallbackInfo& info);
void SetQueryTimeout(const Napi::CallbackInfo& info);
Napi::Value ExecuteAsync(const Napi::CallbackInfo& info);
Napi::Value QueryAsync(const Napi::CallbackInfo& info);

private:
std::shared_ptr<Database> database;
Expand All @@ -40,29 +42,22 @@ class ConnectionInitAsyncWorker : public Napi::AsyncWorker {

~ConnectionInitAsyncWorker() override = default;

inline void Execute() override {
void Execute() override {
try {
nodeConnection->InitCppConnection();
} catch (const std::exception& exc) {
SetError(std::string(exc.what()));
}
}

inline void OnOK() override { Callback().Call({Env().Null()}); }
void OnOK() override { Callback().Call({Env().Null()}); }

inline void OnError(Napi::Error const& error) override { Callback().Call({error.Value()}); }
void OnError(Napi::Error const& error) override { Callback().Call({error.Value()}); }

private:
NodeConnection* nodeConnection;
};

enum GetTableMetadataType {
NODE_TABLE_NAME,
REL_TABLE_NAME,
NODE_PROPERTY_NAME,
REL_PROPERTY_NAME
};

class ConnectionExecuteAsyncWorker : public Napi::AsyncWorker {
public:
ConnectionExecuteAsyncWorker(Napi::Function& callback, std::shared_ptr<Connection>& connection,
Expand All @@ -73,7 +68,7 @@ class ConnectionExecuteAsyncWorker : public Napi::AsyncWorker {
params(std::move(params)) {}
~ConnectionExecuteAsyncWorker() override = default;

inline void Execute() override {
void Execute() override {
try {
std::shared_ptr<QueryResult> result =
connection->executeWithParams(preparedStatement.get(), std::move(params));
Expand All @@ -87,13 +82,45 @@ class ConnectionExecuteAsyncWorker : public Napi::AsyncWorker {
}
}

inline void OnOK() override { Callback().Call({Env().Null()}); }
void OnOK() override { Callback().Call({Env().Null()}); }

inline void OnError(Napi::Error const& error) override { Callback().Call({error.Value()}); }
void OnError(Napi::Error const& error) override { Callback().Call({error.Value()}); }

private:
std::shared_ptr<Connection> connection;
std::shared_ptr<PreparedStatement> preparedStatement;
NodeQueryResult* nodeQueryResult;
std::unordered_map<std::string, std::unique_ptr<Value>> params;
};

class ConnectionQueryAsyncWorker : public Napi::AsyncWorker {
public:
ConnectionQueryAsyncWorker(Napi::Function& callback, std::shared_ptr<Connection>& connection,
std::string statement, NodeQueryResult* nodeQueryResult)
: Napi::AsyncWorker(callback), connection(connection), statement(std::move(statement)),
nodeQueryResult(nodeQueryResult) {}

~ConnectionQueryAsyncWorker() override = default;

void Execute() override {
try {
std::shared_ptr<QueryResult> result = connection->query(statement);
nodeQueryResult->SetQueryResult(result);
if (!result->isSuccess()) {
SetError(result->getErrorMessage());
return;
}
} catch (const std::exception& exc) {
SetError(std::string(exc.what()));
}
}

void OnOK() override { Callback().Call({Env().Null()}); }

void OnError(Napi::Error const& error) override { Callback().Call({error.Value()}); }

private:
std::shared_ptr<Connection> connection;
std::string statement;
NodeQueryResult* nodeQueryResult;
};
13 changes: 13 additions & 0 deletions tools/nodejs_api/src_cpp/node_connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Napi::Object NodeConnection::Init(Napi::Env env, Napi::Object exports) {
{
InstanceMethod("initAsync", &NodeConnection::InitAsync),
InstanceMethod("executeAsync", &NodeConnection::ExecuteAsync),
InstanceMethod("queryAsync", &NodeConnection::QueryAsync),
InstanceMethod("setMaxNumThreadForExec", &NodeConnection::SetMaxNumThreadForExec),
InstanceMethod("setQueryTimeout", &NodeConnection::SetQueryTimeout),
});
Expand Down Expand Up @@ -85,3 +86,15 @@ Napi::Value NodeConnection::ExecuteAsync(const Napi::CallbackInfo& info) {
}
return info.Env().Undefined();
}

Napi::Value NodeConnection::QueryAsync(const Napi::CallbackInfo& info) {
Napi::Env env = info.Env();
Napi::HandleScope scope(env);
auto statement = info[0].As<Napi::String>().Utf8Value();
auto nodeQueryResult = Napi::ObjectWrap<NodeQueryResult>::Unwrap(info[1].As<Napi::Object>());
auto callback = info[2].As<Napi::Function>();
auto asyncWorker =
new ConnectionQueryAsyncWorker(callback, connection, statement, nodeQueryResult);
asyncWorker->Queue();
return info.Env().Undefined();
}
26 changes: 19 additions & 7 deletions tools/nodejs_api/src_js/connection.js
Original file line number Diff line number Diff line change
Expand Up @@ -179,13 +179,25 @@ class Connection {
* @param {String} statement the statement to execute.
* @returns {Promise<kuzu.QueryResult>} a promise that resolves to the query result. The promise is rejected if there is an error.
*/
async query(statement) {
if (typeof statement !== "string") {
throw new Error("statement must be a string.");
}
const preparedStatement = await this.prepare(statement);
const queryResult = await this.execute(preparedStatement);
return queryResult;
query(statement) {
return new Promise((resolve, reject) => {
if (typeof statement !== "string") {
return reject(new Error("statement must be a string."));
}
this._getConnection().then((connection) => {
const nodeQueryResult = new KuzuNative.NodeQueryResult();
try {
connection.queryAsync(statement, nodeQueryResult, (err) => {
if (err) {
return reject(err);
}
return resolve(new QueryResult(this, nodeQueryResult));
});
} catch (e) {
return reject(e);
}
});
});
}

/**
Expand Down
5 changes: 5 additions & 0 deletions tools/python_api/src_cpp/include/py_connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class PyConnection {
std::unique_ptr<PyQueryResult> execute(PyPreparedStatement* preparedStatement,
const py::dict& params);

std::unique_ptr<PyQueryResult> query(const std::string& statement);

void setMaxNumThreadForExec(uint64_t numThreads);

PyPreparedStatement prepare(const std::string& query);
Expand All @@ -36,4 +38,7 @@ class PyConnection {
private:
std::unique_ptr<StorageDriver> storageDriver;
std::unique_ptr<Connection> conn;

static std::unique_ptr<PyQueryResult> checkAndWrapQueryResult(
std::unique_ptr<QueryResult>& queryResult);
};
25 changes: 19 additions & 6 deletions tools/python_api/src_cpp/py_connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ void PyConnection::initialize(py::handle& m) {
.def(py::init<PyDatabase*, uint64_t>(), py::arg("database"), py::arg("num_threads") = 0)
.def("execute", &PyConnection::execute, py::arg("prepared_statement"),
py::arg("parameters") = py::dict())
.def("query", &PyConnection::query, py::arg("statement"))
.def("set_max_threads_for_exec", &PyConnection::setMaxNumThreadForExec,
py::arg("num_threads"))
.def("prepare", &PyConnection::prepare, py::arg("query"))
Expand Down Expand Up @@ -56,12 +57,14 @@ std::unique_ptr<PyQueryResult> PyConnection::execute(PyPreparedStatement* prepar
auto queryResult =
conn->executeWithParams(preparedStatement->preparedStatement.get(), std::move(parameters));
py::gil_scoped_acquire acquire;
if (!queryResult->isSuccess()) {
throw std::runtime_error(queryResult->getErrorMessage());
}
auto pyQueryResult = std::make_unique<PyQueryResult>();
pyQueryResult->queryResult = std::move(queryResult);
return pyQueryResult;
return checkAndWrapQueryResult(queryResult);
}

std::unique_ptr<PyQueryResult> PyConnection::query(const std::string& statement) {
py::gil_scoped_release release;
auto queryResult = conn->query(statement);
py::gil_scoped_acquire acquire;
return checkAndWrapQueryResult(queryResult);
}

void PyConnection::setMaxNumThreadForExec(uint64_t numThreads) {
Expand Down Expand Up @@ -387,3 +390,13 @@ Value transformPythonValue(py::handle val) {
auto type = pyLogicalType(val);
return transformPythonValueAs(val, type.get());
}

std::unique_ptr<PyQueryResult> PyConnection::checkAndWrapQueryResult(
std::unique_ptr<QueryResult>& queryResult) {
if (!queryResult->isSuccess()) {
throw std::runtime_error(queryResult->getErrorMessage());
}
auto pyQueryResult = std::make_unique<PyQueryResult>();
pyQueryResult->queryResult = std::move(queryResult);
return pyQueryResult;
}
19 changes: 15 additions & 4 deletions tools/python_api/src_py/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@ def __init__(self, database: Database, num_threads: int = 0):
self.init_connection()

def __getstate__(self) -> dict[str, Any]:
state = {"database": self.database, "num_threads": self.num_threads, "_connection": None}
state = {
"database": self.database,
"num_threads": self.num_threads,
"_connection": None,
}
return state

def init_connection(self) -> None:
Expand Down Expand Up @@ -88,8 +92,11 @@ def execute(
msg = f"Parameters must be a dict; found {type(parameters)}."
raise RuntimeError(msg) # noqa: TRY004

prepared_statement = self.prepare(query) if isinstance(query, str) else query
_query_result = self._connection.execute(prepared_statement._prepared_statement, parameters)
if len(parameters) == 0:
_query_result = self._connection.query(query)
else:
prepared_statement = self.prepare(query) if isinstance(query, str) else query
_query_result = self._connection.execute(prepared_statement._prepared_statement, parameters)
if not _query_result.isSuccess():
raise RuntimeError(_query_result.getErrorMessage())
return QueryResult(self, _query_result)
Expand Down Expand Up @@ -132,7 +139,11 @@ def _get_node_property_names(self, table_name: str) -> dict[str, Any]:
if s != "":
shape.append(int(s))
prop_type = splitted[0]
results[prop_name] = {"type": prop_type, "dimension": dimension, "is_primary_key": is_primary_key}
results[prop_name] = {
"type": prop_type,
"dimension": dimension,
"is_primary_key": is_primary_key,
}
if len(shape) > 0:
results[prop_name]["shape"] = tuple(shape)
return results
Expand Down

0 comments on commit 35bc707

Please sign in to comment.