From 35bc7077e2d3fefe888e56d83d774e8013185b21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=9B=A7=E5=9B=A7?= Date: Fri, 12 Apr 2024 23:38:01 +0800 Subject: [PATCH] Fix #3154 (#3263) * Resolve #3154 for Python * Resolve #3154 for Node.js * Address PR comments * Run clang-format * Move PyConnection::checkAndWrapQueryResult method --------- Co-authored-by: CI Bot --- .../src_cpp/include/node_connection.h | 53 ++++++++++++++----- tools/nodejs_api/src_cpp/node_connection.cpp | 13 +++++ tools/nodejs_api/src_js/connection.js | 26 ++++++--- .../src_cpp/include/py_connection.h | 5 ++ tools/python_api/src_cpp/py_connection.cpp | 25 ++++++--- tools/python_api/src_py/connection.py | 19 +++++-- 6 files changed, 111 insertions(+), 30 deletions(-) diff --git a/tools/nodejs_api/src_cpp/include/node_connection.h b/tools/nodejs_api/src_cpp/include/node_connection.h index 7afeb5703e..2575c597f0 100644 --- a/tools/nodejs_api/src_cpp/include/node_connection.h +++ b/tools/nodejs_api/src_cpp/include/node_connection.h @@ -14,6 +14,7 @@ using namespace kuzu::common; class NodeConnection : public Napi::ObjectWrap { friend class ConnectionInitAsyncWorker; friend class ConnectionExecuteAsyncWorker; + friend class ConnectionQueryAsyncWorker; friend class NodePreparedStatement; public: @@ -27,6 +28,7 @@ class NodeConnection : public Napi::ObjectWrap { void SetMaxNumThreadForExec(const Napi::CallbackInfo& info); void SetQueryTimeout(const Napi::CallbackInfo& info); Napi::Value ExecuteAsync(const Napi::CallbackInfo& info); + Napi::Value QueryAsync(const Napi::CallbackInfo& info); private: std::shared_ptr database; @@ -40,7 +42,7 @@ class ConnectionInitAsyncWorker : public Napi::AsyncWorker { ~ConnectionInitAsyncWorker() override = default; - inline void Execute() override { + void Execute() override { try { nodeConnection->InitCppConnection(); } catch (const std::exception& exc) { @@ -48,21 +50,14 @@ class ConnectionInitAsyncWorker : public Napi::AsyncWorker { } } - inline void OnOK() override { Callback().Call({Env().Null()}); } + void OnOK() override { Callback().Call({Env().Null()}); } - inline void OnError(Napi::Error const& error) override { Callback().Call({error.Value()}); } + void OnError(Napi::Error const& error) override { Callback().Call({error.Value()}); } private: NodeConnection* nodeConnection; }; -enum GetTableMetadataType { - NODE_TABLE_NAME, - REL_TABLE_NAME, - NODE_PROPERTY_NAME, - REL_PROPERTY_NAME -}; - class ConnectionExecuteAsyncWorker : public Napi::AsyncWorker { public: ConnectionExecuteAsyncWorker(Napi::Function& callback, std::shared_ptr& connection, @@ -73,7 +68,7 @@ class ConnectionExecuteAsyncWorker : public Napi::AsyncWorker { params(std::move(params)) {} ~ConnectionExecuteAsyncWorker() override = default; - inline void Execute() override { + void Execute() override { try { std::shared_ptr result = connection->executeWithParams(preparedStatement.get(), std::move(params)); @@ -87,9 +82,9 @@ class ConnectionExecuteAsyncWorker : public Napi::AsyncWorker { } } - inline void OnOK() override { Callback().Call({Env().Null()}); } + void OnOK() override { Callback().Call({Env().Null()}); } - inline void OnError(Napi::Error const& error) override { Callback().Call({error.Value()}); } + void OnError(Napi::Error const& error) override { Callback().Call({error.Value()}); } private: std::shared_ptr connection; @@ -97,3 +92,35 @@ class ConnectionExecuteAsyncWorker : public Napi::AsyncWorker { NodeQueryResult* nodeQueryResult; std::unordered_map> params; }; + +class ConnectionQueryAsyncWorker : public Napi::AsyncWorker { +public: + ConnectionQueryAsyncWorker(Napi::Function& callback, std::shared_ptr& connection, + std::string statement, NodeQueryResult* nodeQueryResult) + : Napi::AsyncWorker(callback), connection(connection), statement(std::move(statement)), + nodeQueryResult(nodeQueryResult) {} + + ~ConnectionQueryAsyncWorker() override = default; + + void Execute() override { + try { + std::shared_ptr result = connection->query(statement); + nodeQueryResult->SetQueryResult(result); + if (!result->isSuccess()) { + SetError(result->getErrorMessage()); + return; + } + } catch (const std::exception& exc) { + SetError(std::string(exc.what())); + } + } + + void OnOK() override { Callback().Call({Env().Null()}); } + + void OnError(Napi::Error const& error) override { Callback().Call({error.Value()}); } + +private: + std::shared_ptr connection; + std::string statement; + NodeQueryResult* nodeQueryResult; +}; diff --git a/tools/nodejs_api/src_cpp/node_connection.cpp b/tools/nodejs_api/src_cpp/node_connection.cpp index 98c4bbcf46..d7b978de9a 100644 --- a/tools/nodejs_api/src_cpp/node_connection.cpp +++ b/tools/nodejs_api/src_cpp/node_connection.cpp @@ -14,6 +14,7 @@ Napi::Object NodeConnection::Init(Napi::Env env, Napi::Object exports) { { InstanceMethod("initAsync", &NodeConnection::InitAsync), InstanceMethod("executeAsync", &NodeConnection::ExecuteAsync), + InstanceMethod("queryAsync", &NodeConnection::QueryAsync), InstanceMethod("setMaxNumThreadForExec", &NodeConnection::SetMaxNumThreadForExec), InstanceMethod("setQueryTimeout", &NodeConnection::SetQueryTimeout), }); @@ -85,3 +86,15 @@ Napi::Value NodeConnection::ExecuteAsync(const Napi::CallbackInfo& info) { } return info.Env().Undefined(); } + +Napi::Value NodeConnection::QueryAsync(const Napi::CallbackInfo& info) { + Napi::Env env = info.Env(); + Napi::HandleScope scope(env); + auto statement = info[0].As().Utf8Value(); + auto nodeQueryResult = Napi::ObjectWrap::Unwrap(info[1].As()); + auto callback = info[2].As(); + auto asyncWorker = + new ConnectionQueryAsyncWorker(callback, connection, statement, nodeQueryResult); + asyncWorker->Queue(); + return info.Env().Undefined(); +} diff --git a/tools/nodejs_api/src_js/connection.js b/tools/nodejs_api/src_js/connection.js index 05ff41ccf1..c1528f741f 100644 --- a/tools/nodejs_api/src_js/connection.js +++ b/tools/nodejs_api/src_js/connection.js @@ -179,13 +179,25 @@ class Connection { * @param {String} statement the statement to execute. * @returns {Promise} a promise that resolves to the query result. The promise is rejected if there is an error. */ - async query(statement) { - if (typeof statement !== "string") { - throw new Error("statement must be a string."); - } - const preparedStatement = await this.prepare(statement); - const queryResult = await this.execute(preparedStatement); - return queryResult; + query(statement) { + return new Promise((resolve, reject) => { + if (typeof statement !== "string") { + return reject(new Error("statement must be a string.")); + } + this._getConnection().then((connection) => { + const nodeQueryResult = new KuzuNative.NodeQueryResult(); + try { + connection.queryAsync(statement, nodeQueryResult, (err) => { + if (err) { + return reject(err); + } + return resolve(new QueryResult(this, nodeQueryResult)); + }); + } catch (e) { + return reject(e); + } + }); + }); } /** diff --git a/tools/python_api/src_cpp/include/py_connection.h b/tools/python_api/src_cpp/include/py_connection.h index 0841447496..57446dd89e 100644 --- a/tools/python_api/src_cpp/include/py_connection.h +++ b/tools/python_api/src_cpp/include/py_connection.h @@ -19,6 +19,8 @@ class PyConnection { std::unique_ptr execute(PyPreparedStatement* preparedStatement, const py::dict& params); + std::unique_ptr query(const std::string& statement); + void setMaxNumThreadForExec(uint64_t numThreads); PyPreparedStatement prepare(const std::string& query); @@ -36,4 +38,7 @@ class PyConnection { private: std::unique_ptr storageDriver; std::unique_ptr conn; + + static std::unique_ptr checkAndWrapQueryResult( + std::unique_ptr& queryResult); }; diff --git a/tools/python_api/src_cpp/py_connection.cpp b/tools/python_api/src_cpp/py_connection.cpp index 7a208e2b4d..ff4a17486d 100644 --- a/tools/python_api/src_cpp/py_connection.cpp +++ b/tools/python_api/src_cpp/py_connection.cpp @@ -21,6 +21,7 @@ void PyConnection::initialize(py::handle& m) { .def(py::init(), py::arg("database"), py::arg("num_threads") = 0) .def("execute", &PyConnection::execute, py::arg("prepared_statement"), py::arg("parameters") = py::dict()) + .def("query", &PyConnection::query, py::arg("statement")) .def("set_max_threads_for_exec", &PyConnection::setMaxNumThreadForExec, py::arg("num_threads")) .def("prepare", &PyConnection::prepare, py::arg("query")) @@ -56,12 +57,14 @@ std::unique_ptr PyConnection::execute(PyPreparedStatement* prepar auto queryResult = conn->executeWithParams(preparedStatement->preparedStatement.get(), std::move(parameters)); py::gil_scoped_acquire acquire; - if (!queryResult->isSuccess()) { - throw std::runtime_error(queryResult->getErrorMessage()); - } - auto pyQueryResult = std::make_unique(); - pyQueryResult->queryResult = std::move(queryResult); - return pyQueryResult; + return checkAndWrapQueryResult(queryResult); +} + +std::unique_ptr PyConnection::query(const std::string& statement) { + py::gil_scoped_release release; + auto queryResult = conn->query(statement); + py::gil_scoped_acquire acquire; + return checkAndWrapQueryResult(queryResult); } void PyConnection::setMaxNumThreadForExec(uint64_t numThreads) { @@ -387,3 +390,13 @@ Value transformPythonValue(py::handle val) { auto type = pyLogicalType(val); return transformPythonValueAs(val, type.get()); } + +std::unique_ptr PyConnection::checkAndWrapQueryResult( + std::unique_ptr& queryResult) { + if (!queryResult->isSuccess()) { + throw std::runtime_error(queryResult->getErrorMessage()); + } + auto pyQueryResult = std::make_unique(); + pyQueryResult->queryResult = std::move(queryResult); + return pyQueryResult; +} diff --git a/tools/python_api/src_py/connection.py b/tools/python_api/src_py/connection.py index 50f8f8cff1..2ac372cd80 100644 --- a/tools/python_api/src_py/connection.py +++ b/tools/python_api/src_py/connection.py @@ -32,7 +32,11 @@ def __init__(self, database: Database, num_threads: int = 0): self.init_connection() def __getstate__(self) -> dict[str, Any]: - state = {"database": self.database, "num_threads": self.num_threads, "_connection": None} + state = { + "database": self.database, + "num_threads": self.num_threads, + "_connection": None, + } return state def init_connection(self) -> None: @@ -88,8 +92,11 @@ def execute( msg = f"Parameters must be a dict; found {type(parameters)}." raise RuntimeError(msg) # noqa: TRY004 - prepared_statement = self.prepare(query) if isinstance(query, str) else query - _query_result = self._connection.execute(prepared_statement._prepared_statement, parameters) + if len(parameters) == 0: + _query_result = self._connection.query(query) + else: + prepared_statement = self.prepare(query) if isinstance(query, str) else query + _query_result = self._connection.execute(prepared_statement._prepared_statement, parameters) if not _query_result.isSuccess(): raise RuntimeError(_query_result.getErrorMessage()) return QueryResult(self, _query_result) @@ -132,7 +139,11 @@ def _get_node_property_names(self, table_name: str) -> dict[str, Any]: if s != "": shape.append(int(s)) prop_type = splitted[0] - results[prop_name] = {"type": prop_type, "dimension": dimension, "is_primary_key": is_primary_key} + results[prop_name] = { + "type": prop_type, + "dimension": dimension, + "is_primary_key": is_primary_key, + } if len(shape) > 0: results[prop_name]["shape"] = tuple(shape) return results