Skip to content

Commit

Permalink
Merge pull request #1305 from kuzudb/python-prepared-statement
Browse files Browse the repository at this point in the history
Add Python API binding for prepared statement
  • Loading branch information
mewim committed Feb 21, 2023
2 parents 07f3b74 + ca409c3 commit 93140d2
Show file tree
Hide file tree
Showing 11 changed files with 268 additions and 16 deletions.
1 change: 1 addition & 0 deletions tools/python_api/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pybind11_add_module(_kuzu
src_cpp/kuzu_binding.cpp
src_cpp/py_connection.cpp
src_cpp/py_database.cpp
src_cpp/py_prepared_statement.cpp
src_cpp/py_query_result.cpp
src_cpp/py_query_result_converter.cpp)

Expand Down
5 changes: 4 additions & 1 deletion tools/python_api/src_cpp/include/py_connection.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include "py_database.h"
#include "py_prepared_statement.h"
#include "py_query_result.h"

class PyConnection {
Expand All @@ -12,12 +13,14 @@ class PyConnection {

~PyConnection() = default;

std::unique_ptr<PyQueryResult> execute(const std::string& query, py::list params);
std::unique_ptr<PyQueryResult> execute(PyPreparedStatement* preparedStatement, py::list params);

void setMaxNumThreadForExec(uint64_t numThreads);

py::str getNodePropertyNames(const std::string& tableName);

PyPreparedStatement prepare(const std::string& query);

private:
std::unordered_map<std::string, std::shared_ptr<kuzu::common::Value>> transformPythonParameters(
py::list params);
Expand Down
21 changes: 21 additions & 0 deletions tools/python_api/src_cpp/include/py_prepared_statement.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#pragma once

#include "main/kuzu.h"
#include "main/prepared_statement.h"
#include "pybind_include.h"

using namespace kuzu::main;

class PyPreparedStatement {
friend class PyConnection;

public:
static void initialize(py::handle& m);

py::str getErrorMessage() const;

bool isSuccess() const;

private:
std::unique_ptr<PreparedStatement> preparedStatement;
};
4 changes: 2 additions & 2 deletions tools/python_api/src_cpp/kuzu_binding.cpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
#include "include/py_connection.h"
#include "include/py_database.h"
#include "include/py_prepared_statement.h"
#include "spdlog/spdlog.h"

void bind(py::module& m) {
PyDatabase::initialize(m);
PyConnection::initialize(m);
PyPreparedStatement::initialize(m);
PyQueryResult::initialize(m);

m.doc() = "Kuzu is an embedded graph database";
}

PYBIND11_MODULE(_kuzu, m) {
Expand Down
32 changes: 26 additions & 6 deletions tools/python_api/src_cpp/py_connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@ using namespace kuzu::common;
void PyConnection::initialize(py::handle& m) {
py::class_<PyConnection>(m, "Connection")
.def(py::init<PyDatabase*, uint64_t>(), py::arg("database"), py::arg("num_threads") = 0)
.def(
"execute", &PyConnection::execute, py::arg("query"), py::arg("parameters") = py::list())
.def("execute", &PyConnection::execute, py::arg("prepared_statement"),
py::arg("parameters") = py::list())
.def("set_max_threads_for_exec", &PyConnection::setMaxNumThreadForExec,
py::arg("num_threads"))
.def("get_node_property_names", &PyConnection::getNodePropertyNames, py::arg("table_name"));
.def("get_node_property_names", &PyConnection::getNodePropertyNames, py::arg("table_name"))
.def("prepare", &PyConnection::prepare, py::arg("query"));
PyDateTime_IMPORT;
}

Expand All @@ -27,11 +28,12 @@ PyConnection::PyConnection(PyDatabase* pyDatabase, uint64_t numThreads) {
}
}

std::unique_ptr<PyQueryResult> PyConnection::execute(const std::string& query, py::list params) {
auto preparedStatement = conn->prepare(query);
std::unique_ptr<PyQueryResult> PyConnection::execute(
PyPreparedStatement* preparedStatement, py::list params) {
auto parameters = transformPythonParameters(params);
py::gil_scoped_release release;
auto queryResult = conn->executeWithParams(preparedStatement.get(), parameters);
auto queryResult =
conn->executeWithParams(preparedStatement->preparedStatement.get(), parameters);
py::gil_scoped_acquire acquire;
if (!queryResult->isSuccess()) {
throw std::runtime_error(queryResult->getErrorMessage());
Expand All @@ -49,6 +51,13 @@ py::str PyConnection::getNodePropertyNames(const std::string& tableName) {
return conn->getNodePropertyNames(tableName);
}

PyPreparedStatement PyConnection::prepare(const std::string& query) {
auto preparedStatement = conn->prepare(query);
PyPreparedStatement pyPreparedStatement;
pyPreparedStatement.preparedStatement = std::move(preparedStatement);
return pyPreparedStatement;
}

std::unordered_map<std::string, std::shared_ptr<Value>> PyConnection::transformPythonParameters(
py::list params) {
std::unordered_map<std::string, std::shared_ptr<Value>> result;
Expand Down Expand Up @@ -78,6 +87,7 @@ std::pair<std::string, std::shared_ptr<Value>> PyConnection::transformPythonPara
Value PyConnection::transformPythonValue(py::handle val) {
auto datetime_mod = py::module::import("datetime");
auto datetime_datetime = datetime_mod.attr("datetime");
auto time_delta = datetime_mod.attr("timedelta");
auto datetime_time = datetime_mod.attr("time");
auto datetime_date = datetime_mod.attr("date");
if (py::isinstance<py::bool_>(val)) {
Expand Down Expand Up @@ -106,6 +116,16 @@ Value PyConnection::transformPythonValue(py::handle val) {
auto month = PyDateTime_GET_MONTH(ptr);
auto day = PyDateTime_GET_DAY(ptr);
return Value::createValue<date_t>(Date::FromDate(year, month, day));
} else if (py::isinstance(val, time_delta)) {
auto ptr = val.ptr();
auto days = PyDateTime_DELTA_GET_DAYS(ptr);
auto seconds = PyDateTime_DELTA_GET_SECONDS(ptr);
auto microseconds = PyDateTime_DELTA_GET_MICROSECONDS(ptr);
interval_t interval;
Interval::addition(interval, days, "days");
Interval::addition(interval, seconds, "seconds");
Interval::addition(interval, microseconds, "microseconds");
return Value::createValue<interval_t>(interval);
} else {
throw std::runtime_error(
"Unknown parameter type " + py::str(val.get_type()).cast<std::string>());
Expand Down
18 changes: 18 additions & 0 deletions tools/python_api/src_cpp/py_prepared_statement.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#include "include/py_prepared_statement.h"

#include "binder/binder.h"
#include "planner/logical_plan/logical_plan.h"

void PyPreparedStatement::initialize(py::handle& m) {
py::class_<PyPreparedStatement>(m, "PreparedStatement")
.def("get_error_message", &PyPreparedStatement::getErrorMessage)
.def("is_success", &PyPreparedStatement::isSuccess);
}

py::str PyPreparedStatement::getErrorMessage() const {
return preparedStatement->getErrorMessage();
}

bool PyPreparedStatement::isSuccess() const {
return preparedStatement->isSuccess();
}
39 changes: 34 additions & 5 deletions tools/python_api/src_py/connection.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .query_result import QueryResult
from .prepared_statement import PreparedStatement
from . import _kuzu


Expand All @@ -10,11 +11,14 @@ class Connection:
-------
set_max_threads_for_exec(num_threads)
Set the maximum number of threads for executing queries.
execute(query, parameters=[])
Execute a query.
prepare(query)
Create a prepared statement for a query.
"""

def __init__(self, database, num_threads=0):
"""
Parameters
Expand Down Expand Up @@ -46,8 +50,10 @@ def execute(self, query, parameters=[]):
Parameters
----------
query : str
Query to execute.
query : str | PreparedStatement
A prepared statement or a query string.
If a query string is given, a prepared statement will be created
automatically.
parameters : list
Parameters for the query.
Expand All @@ -57,7 +63,30 @@ def execute(self, query, parameters=[]):
Query result.
"""

return QueryResult(self, self._connection.execute(query, parameters))
prepared_statement = self.prepare(
query) if type(query) == str else query
return QueryResult(self,
self._connection.execute(
prepared_statement._prepared_statement,
parameters)
)

def prepare(self, query):
"""
Create a prepared statement for a query.
Parameters
----------
query : str
Query to prepare.
Returns
-------
PreparedStatement
Prepared statement.
"""

return PreparedStatement(self, query)

def _get_node_property_names(self, table_name):
PRIMARY_KEY_SYMBOL = "(PRIMARY KEY)"
Expand Down
51 changes: 51 additions & 0 deletions tools/python_api/src_py/prepared_statement.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
class PreparedStatement:
"""
A prepared statement is a parameterized query which can avoid planning the
same query for repeated execution.
Methods
-------
is_success()
Check if the prepared statement is successfully prepared.
get_error_message()
Get the error message if the query is not prepared successfully.
"""

def __init__(self, connection, query):
"""
Parameters
----------
connection : Connection
Connection to a database.
query : str
Query to prepare.
"""

self._prepared_statement = connection._connection.prepare(query)

def is_success(self):
"""
Check if the prepared statement is successfully prepared.
Returns
-------
bool
True if the prepared statement is successfully prepared.
"""

return self._prepared_statement.is_success()

def get_error_message(self):
"""
Get the error message if the query is not prepared successfully.
Returns
-------
str
Error message.
"""

return self._prepared_statement.get_error_message()
4 changes: 2 additions & 2 deletions tools/python_api/src_py/query_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

class QueryResult:
"""
Query result.
QueryResult stores the result of a query execution.
Methods
-------
check_for_query_result_close()
Expand Down
1 change: 1 addition & 0 deletions tools/python_api/test/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from test_get_header import *
from test_networkx import *
from test_parameter import *
from test_prepared_statement import *
from test_query_result_close import *
from test_torch_geometric import *
from test_write_to_csv import *
Expand Down
Loading

0 comments on commit 93140d2

Please sign in to comment.