Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #3154 #3263

Merged
merged 5 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 40 additions & 13 deletions tools/nodejs_api/src_cpp/include/node_connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ using namespace kuzu::common;
class NodeConnection : public Napi::ObjectWrap<NodeConnection> {
friend class ConnectionInitAsyncWorker;
friend class ConnectionExecuteAsyncWorker;
friend class ConnectionQueryAsyncWorker;
friend class NodePreparedStatement;

public:
Expand All @@ -27,6 +28,7 @@ class NodeConnection : public Napi::ObjectWrap<NodeConnection> {
void SetMaxNumThreadForExec(const Napi::CallbackInfo& info);
void SetQueryTimeout(const Napi::CallbackInfo& info);
Napi::Value ExecuteAsync(const Napi::CallbackInfo& info);
Napi::Value QueryAsync(const Napi::CallbackInfo& info);

private:
std::shared_ptr<Database> database;
Expand All @@ -40,29 +42,22 @@ class ConnectionInitAsyncWorker : public Napi::AsyncWorker {

~ConnectionInitAsyncWorker() override = default;

inline void Execute() override {
void Execute() override {
try {
nodeConnection->InitCppConnection();
} catch (const std::exception& exc) {
SetError(std::string(exc.what()));
}
}

inline void OnOK() override { Callback().Call({Env().Null()}); }
void OnOK() override { Callback().Call({Env().Null()}); }

inline void OnError(Napi::Error const& error) override { Callback().Call({error.Value()}); }
void OnError(Napi::Error const& error) override { Callback().Call({error.Value()}); }

private:
NodeConnection* nodeConnection;
};

enum GetTableMetadataType {
NODE_TABLE_NAME,
REL_TABLE_NAME,
NODE_PROPERTY_NAME,
REL_PROPERTY_NAME
};

class ConnectionExecuteAsyncWorker : public Napi::AsyncWorker {
public:
ConnectionExecuteAsyncWorker(Napi::Function& callback, std::shared_ptr<Connection>& connection,
Expand All @@ -73,7 +68,7 @@ class ConnectionExecuteAsyncWorker : public Napi::AsyncWorker {
params(std::move(params)) {}
~ConnectionExecuteAsyncWorker() override = default;

inline void Execute() override {
void Execute() override {
try {
std::shared_ptr<QueryResult> result =
connection->executeWithParams(preparedStatement.get(), std::move(params));
Expand All @@ -87,13 +82,45 @@ class ConnectionExecuteAsyncWorker : public Napi::AsyncWorker {
}
}

inline void OnOK() override { Callback().Call({Env().Null()}); }
void OnOK() override { Callback().Call({Env().Null()}); }

inline void OnError(Napi::Error const& error) override { Callback().Call({error.Value()}); }
void OnError(Napi::Error const& error) override { Callback().Call({error.Value()}); }

private:
std::shared_ptr<Connection> connection;
std::shared_ptr<PreparedStatement> preparedStatement;
NodeQueryResult* nodeQueryResult;
std::unordered_map<std::string, std::unique_ptr<Value>> params;
};

class ConnectionQueryAsyncWorker : public Napi::AsyncWorker {
public:
ConnectionQueryAsyncWorker(Napi::Function& callback, std::shared_ptr<Connection>& connection,
std::string statement, NodeQueryResult* nodeQueryResult)
: Napi::AsyncWorker(callback), connection(connection), statement(std::move(statement)),
nodeQueryResult(nodeQueryResult) {}

~ConnectionQueryAsyncWorker() override = default;

void Execute() override {
try {
std::shared_ptr<QueryResult> result = connection->query(statement);
nodeQueryResult->SetQueryResult(result);
if (!result->isSuccess()) {
SetError(result->getErrorMessage());
return;
}
} catch (const std::exception& exc) {
SetError(std::string(exc.what()));
}
}

void OnOK() override { Callback().Call({Env().Null()}); }

void OnError(Napi::Error const& error) override { Callback().Call({error.Value()}); }

private:
std::shared_ptr<Connection> connection;
std::string statement;
NodeQueryResult* nodeQueryResult;
};
13 changes: 13 additions & 0 deletions tools/nodejs_api/src_cpp/node_connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Napi::Object NodeConnection::Init(Napi::Env env, Napi::Object exports) {
{
InstanceMethod("initAsync", &NodeConnection::InitAsync),
InstanceMethod("executeAsync", &NodeConnection::ExecuteAsync),
InstanceMethod("queryAsync", &NodeConnection::QueryAsync),
InstanceMethod("setMaxNumThreadForExec", &NodeConnection::SetMaxNumThreadForExec),
InstanceMethod("setQueryTimeout", &NodeConnection::SetQueryTimeout),
});
Expand Down Expand Up @@ -85,3 +86,15 @@ Napi::Value NodeConnection::ExecuteAsync(const Napi::CallbackInfo& info) {
}
return info.Env().Undefined();
}

Napi::Value NodeConnection::QueryAsync(const Napi::CallbackInfo& info) {
Napi::Env env = info.Env();
Napi::HandleScope scope(env);
auto statement = info[0].As<Napi::String>().Utf8Value();
auto nodeQueryResult = Napi::ObjectWrap<NodeQueryResult>::Unwrap(info[1].As<Napi::Object>());
auto callback = info[2].As<Napi::Function>();
auto asyncWorker =
new ConnectionQueryAsyncWorker(callback, connection, statement, nodeQueryResult);
asyncWorker->Queue();
return info.Env().Undefined();
}
26 changes: 19 additions & 7 deletions tools/nodejs_api/src_js/connection.js
Original file line number Diff line number Diff line change
Expand Up @@ -179,13 +179,25 @@ class Connection {
* @param {String} statement the statement to execute.
* @returns {Promise<kuzu.QueryResult>} a promise that resolves to the query result. The promise is rejected if there is an error.
*/
async query(statement) {
if (typeof statement !== "string") {
throw new Error("statement must be a string.");
}
const preparedStatement = await this.prepare(statement);
const queryResult = await this.execute(preparedStatement);
return queryResult;
query(statement) {
return new Promise((resolve, reject) => {
if (typeof statement !== "string") {
return reject(new Error("statement must be a string."));
}
this._getConnection().then((connection) => {
const nodeQueryResult = new KuzuNative.NodeQueryResult();
try {
connection.queryAsync(statement, nodeQueryResult, (err) => {
if (err) {
return reject(err);
}
return resolve(new QueryResult(this, nodeQueryResult));
});
} catch (e) {
return reject(e);
}
});
});
}

/**
Expand Down
5 changes: 5 additions & 0 deletions tools/python_api/src_cpp/include/py_connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class PyConnection {
std::unique_ptr<PyQueryResult> execute(PyPreparedStatement* preparedStatement,
const py::dict& params);

std::unique_ptr<PyQueryResult> query(const std::string& statement);

void setMaxNumThreadForExec(uint64_t numThreads);

PyPreparedStatement prepare(const std::string& query);
Expand All @@ -36,4 +38,7 @@ class PyConnection {
private:
std::unique_ptr<StorageDriver> storageDriver;
std::unique_ptr<Connection> conn;

static std::unique_ptr<PyQueryResult> checkAndWrapQueryResult(
std::unique_ptr<QueryResult>& queryResult);
};
25 changes: 19 additions & 6 deletions tools/python_api/src_cpp/py_connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ void PyConnection::initialize(py::handle& m) {
.def(py::init<PyDatabase*, uint64_t>(), py::arg("database"), py::arg("num_threads") = 0)
.def("execute", &PyConnection::execute, py::arg("prepared_statement"),
py::arg("parameters") = py::dict())
.def("query", &PyConnection::query, py::arg("statement"))
.def("set_max_threads_for_exec", &PyConnection::setMaxNumThreadForExec,
py::arg("num_threads"))
.def("prepare", &PyConnection::prepare, py::arg("query"))
Expand Down Expand Up @@ -56,12 +57,14 @@ std::unique_ptr<PyQueryResult> PyConnection::execute(PyPreparedStatement* prepar
auto queryResult =
conn->executeWithParams(preparedStatement->preparedStatement.get(), std::move(parameters));
py::gil_scoped_acquire acquire;
if (!queryResult->isSuccess()) {
throw std::runtime_error(queryResult->getErrorMessage());
}
auto pyQueryResult = std::make_unique<PyQueryResult>();
pyQueryResult->queryResult = std::move(queryResult);
return pyQueryResult;
return checkAndWrapQueryResult(queryResult);
}

std::unique_ptr<PyQueryResult> PyConnection::query(const std::string& statement) {
py::gil_scoped_release release;
auto queryResult = conn->query(statement);
py::gil_scoped_acquire acquire;
return checkAndWrapQueryResult(queryResult);
}

void PyConnection::setMaxNumThreadForExec(uint64_t numThreads) {
Expand Down Expand Up @@ -387,3 +390,13 @@ Value transformPythonValue(py::handle val) {
auto type = pyLogicalType(val);
return transformPythonValueAs(val, type.get());
}

std::unique_ptr<PyQueryResult> PyConnection::checkAndWrapQueryResult(
std::unique_ptr<QueryResult>& queryResult) {
if (!queryResult->isSuccess()) {
throw std::runtime_error(queryResult->getErrorMessage());
}
auto pyQueryResult = std::make_unique<PyQueryResult>();
pyQueryResult->queryResult = std::move(queryResult);
return pyQueryResult;
}
19 changes: 15 additions & 4 deletions tools/python_api/src_py/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@ def __init__(self, database: Database, num_threads: int = 0):
self.init_connection()

def __getstate__(self) -> dict[str, Any]:
state = {"database": self.database, "num_threads": self.num_threads, "_connection": None}
state = {
"database": self.database,
"num_threads": self.num_threads,
"_connection": None,
}
return state

def init_connection(self) -> None:
Expand Down Expand Up @@ -88,8 +92,11 @@ def execute(
msg = f"Parameters must be a dict; found {type(parameters)}."
raise RuntimeError(msg) # noqa: TRY004

prepared_statement = self.prepare(query) if isinstance(query, str) else query
_query_result = self._connection.execute(prepared_statement._prepared_statement, parameters)
if len(parameters) == 0:
_query_result = self._connection.query(query)
else:
prepared_statement = self.prepare(query) if isinstance(query, str) else query
_query_result = self._connection.execute(prepared_statement._prepared_statement, parameters)
if not _query_result.isSuccess():
raise RuntimeError(_query_result.getErrorMessage())
return QueryResult(self, _query_result)
Expand Down Expand Up @@ -132,7 +139,11 @@ def _get_node_property_names(self, table_name: str) -> dict[str, Any]:
if s != "":
shape.append(int(s))
prop_type = splitted[0]
results[prop_name] = {"type": prop_type, "dimension": dimension, "is_primary_key": is_primary_key}
results[prop_name] = {
"type": prop_type,
"dimension": dimension,
"is_primary_key": is_primary_key,
}
if len(shape) > 0:
results[prop_name]["shape"] = tuple(shape)
return results
Expand Down
Loading