Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Python API binding for prepared statement #1305

Merged
merged 1 commit into from
Feb 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tools/python_api/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pybind11_add_module(_kuzu
src_cpp/kuzu_binding.cpp
src_cpp/py_connection.cpp
src_cpp/py_database.cpp
src_cpp/py_prepared_statement.cpp
src_cpp/py_query_result.cpp
src_cpp/py_query_result_converter.cpp)

Expand Down
5 changes: 4 additions & 1 deletion tools/python_api/src_cpp/include/py_connection.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include "py_database.h"
#include "py_prepared_statement.h"
#include "py_query_result.h"

class PyConnection {
Expand All @@ -12,12 +13,14 @@ class PyConnection {

~PyConnection() = default;

std::unique_ptr<PyQueryResult> execute(const std::string& query, py::list params);
std::unique_ptr<PyQueryResult> execute(PyPreparedStatement* preparedStatement, py::list params);

void setMaxNumThreadForExec(uint64_t numThreads);

py::str getNodePropertyNames(const std::string& tableName);

PyPreparedStatement prepare(const std::string& query);

private:
std::unordered_map<std::string, std::shared_ptr<kuzu::common::Value>> transformPythonParameters(
py::list params);
Expand Down
21 changes: 21 additions & 0 deletions tools/python_api/src_cpp/include/py_prepared_statement.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#pragma once

#include "main/kuzu.h"
#include "main/prepared_statement.h"
#include "pybind_include.h"

using namespace kuzu::main;

class PyPreparedStatement {
friend class PyConnection;

public:
static void initialize(py::handle& m);

py::str getErrorMessage() const;

bool isSuccess() const;

private:
std::unique_ptr<PreparedStatement> preparedStatement;
};
4 changes: 2 additions & 2 deletions tools/python_api/src_cpp/kuzu_binding.cpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
#include "include/py_connection.h"
#include "include/py_database.h"
#include "include/py_prepared_statement.h"
#include "spdlog/spdlog.h"

void bind(py::module& m) {
PyDatabase::initialize(m);
PyConnection::initialize(m);
PyPreparedStatement::initialize(m);
PyQueryResult::initialize(m);

m.doc() = "Kuzu is an embedded graph database";
}

PYBIND11_MODULE(_kuzu, m) {
Expand Down
32 changes: 26 additions & 6 deletions tools/python_api/src_cpp/py_connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@ using namespace kuzu::common;
void PyConnection::initialize(py::handle& m) {
py::class_<PyConnection>(m, "Connection")
.def(py::init<PyDatabase*, uint64_t>(), py::arg("database"), py::arg("num_threads") = 0)
.def(
"execute", &PyConnection::execute, py::arg("query"), py::arg("parameters") = py::list())
.def("execute", &PyConnection::execute, py::arg("prepared_statement"),
py::arg("parameters") = py::list())
.def("set_max_threads_for_exec", &PyConnection::setMaxNumThreadForExec,
py::arg("num_threads"))
.def("get_node_property_names", &PyConnection::getNodePropertyNames, py::arg("table_name"));
.def("get_node_property_names", &PyConnection::getNodePropertyNames, py::arg("table_name"))
.def("prepare", &PyConnection::prepare, py::arg("query"));
PyDateTime_IMPORT;
}

Expand All @@ -27,11 +28,12 @@ PyConnection::PyConnection(PyDatabase* pyDatabase, uint64_t numThreads) {
}
}

std::unique_ptr<PyQueryResult> PyConnection::execute(const std::string& query, py::list params) {
auto preparedStatement = conn->prepare(query);
std::unique_ptr<PyQueryResult> PyConnection::execute(
PyPreparedStatement* preparedStatement, py::list params) {
auto parameters = transformPythonParameters(params);
py::gil_scoped_release release;
auto queryResult = conn->executeWithParams(preparedStatement.get(), parameters);
auto queryResult =
conn->executeWithParams(preparedStatement->preparedStatement.get(), parameters);
py::gil_scoped_acquire acquire;
if (!queryResult->isSuccess()) {
throw std::runtime_error(queryResult->getErrorMessage());
Expand All @@ -49,6 +51,13 @@ py::str PyConnection::getNodePropertyNames(const std::string& tableName) {
return conn->getNodePropertyNames(tableName);
}

PyPreparedStatement PyConnection::prepare(const std::string& query) {
auto preparedStatement = conn->prepare(query);
PyPreparedStatement pyPreparedStatement;
pyPreparedStatement.preparedStatement = std::move(preparedStatement);
return pyPreparedStatement;
}

std::unordered_map<std::string, std::shared_ptr<Value>> PyConnection::transformPythonParameters(
py::list params) {
std::unordered_map<std::string, std::shared_ptr<Value>> result;
Expand Down Expand Up @@ -78,6 +87,7 @@ std::pair<std::string, std::shared_ptr<Value>> PyConnection::transformPythonPara
Value PyConnection::transformPythonValue(py::handle val) {
auto datetime_mod = py::module::import("datetime");
auto datetime_datetime = datetime_mod.attr("datetime");
auto time_delta = datetime_mod.attr("timedelta");
auto datetime_time = datetime_mod.attr("time");
auto datetime_date = datetime_mod.attr("date");
if (py::isinstance<py::bool_>(val)) {
Expand Down Expand Up @@ -106,6 +116,16 @@ Value PyConnection::transformPythonValue(py::handle val) {
auto month = PyDateTime_GET_MONTH(ptr);
auto day = PyDateTime_GET_DAY(ptr);
return Value::createValue<date_t>(Date::FromDate(year, month, day));
} else if (py::isinstance(val, time_delta)) {
auto ptr = val.ptr();
auto days = PyDateTime_DELTA_GET_DAYS(ptr);
auto seconds = PyDateTime_DELTA_GET_SECONDS(ptr);
auto microseconds = PyDateTime_DELTA_GET_MICROSECONDS(ptr);
interval_t interval;
Interval::addition(interval, days, "days");
Interval::addition(interval, seconds, "seconds");
Interval::addition(interval, microseconds, "microseconds");
return Value::createValue<interval_t>(interval);
} else {
throw std::runtime_error(
"Unknown parameter type " + py::str(val.get_type()).cast<std::string>());
Expand Down
18 changes: 18 additions & 0 deletions tools/python_api/src_cpp/py_prepared_statement.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#include "include/py_prepared_statement.h"

#include "binder/binder.h"
#include "planner/logical_plan/logical_plan.h"

void PyPreparedStatement::initialize(py::handle& m) {
py::class_<PyPreparedStatement>(m, "PreparedStatement")
.def("get_error_message", &PyPreparedStatement::getErrorMessage)
.def("is_success", &PyPreparedStatement::isSuccess);
}

py::str PyPreparedStatement::getErrorMessage() const {
return preparedStatement->getErrorMessage();
}

bool PyPreparedStatement::isSuccess() const {
return preparedStatement->isSuccess();
}
39 changes: 34 additions & 5 deletions tools/python_api/src_py/connection.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .query_result import QueryResult
from .prepared_statement import PreparedStatement
from . import _kuzu


Expand All @@ -10,11 +11,14 @@ class Connection:
-------
set_max_threads_for_exec(num_threads)
Set the maximum number of threads for executing queries.
execute(query, parameters=[])
Execute a query.
prepare(query)
Create a prepared statement for a query.
"""

def __init__(self, database, num_threads=0):
"""
Parameters
Expand Down Expand Up @@ -46,8 +50,10 @@ def execute(self, query, parameters=[]):
Parameters
----------
query : str
Query to execute.
query : str | PreparedStatement
A prepared statement or a query string.
If a query string is given, a prepared statement will be created
automatically.
parameters : list
Parameters for the query.
Expand All @@ -57,7 +63,30 @@ def execute(self, query, parameters=[]):
Query result.
"""

return QueryResult(self, self._connection.execute(query, parameters))
prepared_statement = self.prepare(
query) if type(query) == str else query
return QueryResult(self,
self._connection.execute(
prepared_statement._prepared_statement,
parameters)
)

def prepare(self, query):
"""
Create a prepared statement for a query.
Parameters
----------
query : str
Query to prepare.
Returns
-------
PreparedStatement
Prepared statement.
"""

return PreparedStatement(self, query)

def _get_node_property_names(self, table_name):
PRIMARY_KEY_SYMBOL = "(PRIMARY KEY)"
Expand Down
51 changes: 51 additions & 0 deletions tools/python_api/src_py/prepared_statement.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
class PreparedStatement:
"""
A prepared statement is a parameterized query which can avoid planning the
same query for repeated execution.
Methods
-------
is_success()
Check if the prepared statement is successfully prepared.
get_error_message()
Get the error message if the query is not prepared successfully.
"""

def __init__(self, connection, query):
"""
Parameters
----------
connection : Connection
Connection to a database.
query : str
Query to prepare.
"""

self._prepared_statement = connection._connection.prepare(query)

def is_success(self):
"""
Check if the prepared statement is successfully prepared.
Returns
-------
bool
True if the prepared statement is successfully prepared.
"""

return self._prepared_statement.is_success()

def get_error_message(self):
"""
Get the error message if the query is not prepared successfully.
Returns
-------
str
Error message.
"""

return self._prepared_statement.get_error_message()
4 changes: 2 additions & 2 deletions tools/python_api/src_py/query_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

class QueryResult:
"""
Query result.
QueryResult stores the result of a query execution.
Methods
-------
check_for_query_result_close()
Expand Down
1 change: 1 addition & 0 deletions tools/python_api/test/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from test_get_header import *
from test_networkx import *
from test_parameter import *
from test_prepared_statement import *
from test_query_result_close import *
from test_torch_geometric import *
from test_write_to_csv import *
Expand Down
Loading