Skip to content

Commit

Permalink
Merge pull request #1974 from kuzudb/fix-1959
Browse files Browse the repository at this point in the history
Implement #1959
  • Loading branch information
mewim committed Aug 29, 2023
2 parents e1cfe67 + 05e067a commit 58ee272
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 52 deletions.
9 changes: 3 additions & 6 deletions tools/python_api/src_cpp/include/py_connection.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#pragma once

#include "main/storage_driver.h"
#include "py_database.h"
#include "py_prepared_statement.h"
#include "py_query_result.h"
#include "main/storage_driver.h"

class PyConnection {

Expand All @@ -16,7 +16,7 @@ class PyConnection {

void setQueryTimeout(uint64_t timeoutInMS);

std::unique_ptr<PyQueryResult> execute(PyPreparedStatement* preparedStatement, py::list params);
std::unique_ptr<PyQueryResult> execute(PyPreparedStatement* preparedStatement, py::dict params);

void setMaxNumThreadForExec(uint64_t numThreads);

Expand All @@ -40,10 +40,7 @@ class PyConnection {

private:
std::unordered_map<std::string, std::shared_ptr<kuzu::common::Value>> transformPythonParameters(
py::list params);

std::pair<std::string, std::shared_ptr<kuzu::common::Value>> transformPythonParameter(
py::tuple param);
py::dict params);

kuzu::common::Value transformPythonValue(py::handle val);

Expand Down
35 changes: 11 additions & 24 deletions tools/python_api/src_cpp/py_connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ void PyConnection::initialize(py::handle& m) {
py::class_<PyConnection>(m, "Connection")
.def(py::init<PyDatabase*, uint64_t>(), py::arg("database"), py::arg("num_threads") = 0)
.def("execute", &PyConnection::execute, py::arg("prepared_statement"),
py::arg("parameters") = py::list())
py::arg("parameters") = py::dict())
.def("set_max_threads_for_exec", &PyConnection::setMaxNumThreadForExec,
py::arg("num_threads"))
.def("get_node_property_names", &PyConnection::getNodePropertyNames, py::arg("table_name"))
Expand Down Expand Up @@ -44,7 +44,7 @@ void PyConnection::setQueryTimeout(uint64_t timeoutInMS) {
}

std::unique_ptr<PyQueryResult> PyConnection::execute(
PyPreparedStatement* preparedStatement, py::list params) {
PyPreparedStatement* preparedStatement, py::dict params) {
auto parameters = transformPythonParameters(params);
py::gil_scoped_release release;
auto queryResult =
Expand Down Expand Up @@ -133,8 +133,7 @@ void PyConnection::getAllEdgesForTorchGeometric(py::array_t<int64_t>& npArray,
if (tableSchema->getColumn(0)->isFlat() && !tableSchema->getColumn(1)->isFlat()) {
for (auto i = 0u; i < table->getNumTuples(); ++i) {
auto tuple = table->getTuple(i);
auto overflowValue =
(overflow_value_t*)(tuple + tableSchema->getColOffset(1));
auto overflowValue = (overflow_value_t*)(tuple + tableSchema->getColOffset(1));
for (auto j = 0u; j < overflowValue->numElements; ++j) {
srcBuffer[j] = *(int64_t*)(tuple + tableSchema->getColOffset(0));
}
Expand All @@ -147,8 +146,7 @@ void PyConnection::getAllEdgesForTorchGeometric(py::array_t<int64_t>& npArray,
} else if (tableSchema->getColumn(1)->isFlat() && !tableSchema->getColumn(0)->isFlat()) {
for (auto i = 0u; i < table->getNumTuples(); ++i) {
auto tuple = table->getTuple(i);
auto overflowValue =
(overflow_value_t*)(tuple + tableSchema->getColOffset(0));
auto overflowValue = (overflow_value_t*)(tuple + tableSchema->getColOffset(0));
for (auto j = 0u; j < overflowValue->numElements; ++j) {
srcBuffer[j] = ((int64_t*)overflowValue->value)[j];
}
Expand All @@ -166,31 +164,20 @@ void PyConnection::getAllEdgesForTorchGeometric(py::array_t<int64_t>& npArray,
}

std::unordered_map<std::string, std::shared_ptr<Value>> PyConnection::transformPythonParameters(
py::list params) {
py::dict params) {
std::unordered_map<std::string, std::shared_ptr<Value>> result;
for (auto param : params) {
if (!py::isinstance<py::tuple>(param)) {
throw std::runtime_error("Each parameter must be in the form of <name, val>");
for (auto& [key, value] : params) {
if (!py::isinstance<py::str>(key)) {
throw std::runtime_error("Parameter name must be of type string but get " +
py::str(key.get_type()).cast<std::string>());
}
auto [name, val] = transformPythonParameter(param.cast<py::tuple>());
auto name = key.cast<std::string>();
auto val = std::make_shared<Value>(transformPythonValue(value));
result.insert({name, val});
}
return result;
}

std::pair<std::string, std::shared_ptr<Value>> PyConnection::transformPythonParameter(
py::tuple param) {
if (py::len(param) != 2) {
throw std::runtime_error("Each parameter must be in the form of <name, val>");
}
if (!py::isinstance<py::str>(param[0])) {
throw std::runtime_error("Parameter name must be of type string but get " +
py::str(param[0].get_type()).cast<std::string>());
}
auto val = transformPythonValue(param[1]);
return make_pair(param[0].cast<std::string>(), std::make_shared<Value>(val));
}

Value PyConnection::transformPythonValue(py::handle val) {
auto datetime_mod = py::module::import("datetime");
auto datetime_datetime = datetime_mod.attr("datetime");
Expand Down
6 changes: 4 additions & 2 deletions tools/python_api/src_py/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def set_max_threads_for_exec(self, num_threads):
self.init_connection()
self._connection.set_max_threads_for_exec(num_threads)

def execute(self, query, parameters=[]):
def execute(self, query, parameters={}):
"""
Execute a query.
Expand All @@ -73,7 +73,7 @@ def execute(self, query, parameters=[]):
A prepared statement or a query string.
If a query string is given, a prepared statement will be created
automatically.
parameters : list[tuple(str, any)]
parameters : dict[str, Any]
Parameters for the query.
Returns
Expand All @@ -82,6 +82,8 @@ def execute(self, query, parameters=[]):
Query result.
"""
self.init_connection()
if type(parameters) != dict:
raise RuntimeError("Parameters must be a dict")
prepared_statement = self.prepare(
query) if type(query) == str else query
return QueryResult(self,
Expand Down
2 changes: 1 addition & 1 deletion tools/python_api/test/test_exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def test_exception(establish_connection):
conn, db = establish_connection

with pytest.raises(RuntimeError, match="Parameter asd not found."):
conn.execute("MATCH (a:person) WHERE a.registerTime = $1 RETURN COUNT(*);", [("asd", 1)])
conn.execute("MATCH (a:person) WHERE a.registerTime = $1 RETURN COUNT(*);", {"asd": 1})

with pytest.raises(RuntimeError, match="Binder exception: Cannot find property dummy for a."):
conn.execute("MATCH (a:person) RETURN a.dummy;")
Expand Down
18 changes: 9 additions & 9 deletions tools/python_api/test/test_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
def test_bool_param(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.isStudent = $1 AND a.isWorker = $k RETURN COUNT(*)",
[("1", False), ("k", False)])
{"1": False, "k": False})
assert result.has_next()
assert result.get_next() == [1]
assert not result.has_next()
Expand All @@ -14,7 +14,7 @@ def test_bool_param(establish_connection):

def test_int_param(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.age < $AGE RETURN COUNT(*)", [("AGE", 1)])
result = conn.execute("MATCH (a:person) WHERE a.age < $AGE RETURN COUNT(*)", { "AGE": 1 })
assert result.has_next()
assert result.get_next() == [0]
assert not result.has_next()
Expand All @@ -23,7 +23,7 @@ def test_int_param(establish_connection):

def test_double_param(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.eyeSight = $E RETURN COUNT(*)", [("E", 5.0)])
result = conn.execute("MATCH (a:person) WHERE a.eyeSight = $E RETURN COUNT(*)", { "E": 5.0 })
assert result.has_next()
assert result.get_next() == [2]
assert not result.has_next()
Expand All @@ -32,7 +32,7 @@ def test_double_param(establish_connection):

def test_str_param(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.ID = 0 RETURN concat(a.fName, $S);", [("S", "HH")])
result = conn.execute("MATCH (a:person) WHERE a.ID = 0 RETURN concat(a.fName, $S);", { "S": "HH" })
assert result.has_next()
assert result.get_next() == ["AliceHH"]
assert not result.has_next()
Expand All @@ -42,7 +42,7 @@ def test_str_param(establish_connection):
def test_date_param(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.birthdate = $1 RETURN COUNT(*);",
[("1", datetime.date(1900, 1, 1))])
{ "1": datetime.date(1900, 1, 1) })
assert result.has_next()
assert result.get_next() == [2]
assert not result.has_next()
Expand All @@ -52,7 +52,7 @@ def test_date_param(establish_connection):
def test_timestamp_param(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.registerTime = $1 RETURN COUNT(*);",
[("1", datetime.datetime(2011, 8, 20, 11, 25, 30))])
{ "1": datetime.datetime(2011, 8, 20, 11, 25, 30) })
assert result.has_next()
assert result.get_next() == [1]
assert not result.has_next()
Expand All @@ -62,16 +62,16 @@ def test_timestamp_param(establish_connection):
def test_param_error1(establish_connection):
conn, db = establish_connection
with pytest.raises(RuntimeError, match="Parameter name must be of type string but get <class 'int'>"):
conn.execute("MATCH (a:person) WHERE a.registerTime = $1 RETURN COUNT(*);", [(1, 1)])
conn.execute("MATCH (a:person) WHERE a.registerTime = $1 RETURN COUNT(*);", {1: 1})


def test_param_error2(establish_connection):
conn, db = establish_connection
with pytest.raises(RuntimeError, match="Each parameter must be in the form of <name, val>"):
with pytest.raises(RuntimeError, match="Parameters must be a dict"):
conn.execute("MATCH (a:person) WHERE a.registerTime = $1 RETURN COUNT(*);", ["asd"])


def test_param_error3(establish_connection):
conn, db = establish_connection
with pytest.raises(RuntimeError, match="Each parameter must be in the form of <name, val>"):
with pytest.raises(RuntimeError, match="Parameters must be a dict"):
conn.execute("MATCH (a:person) WHERE a.registerTime = $1 RETURN COUNT(*);", [("asd", 1, 1)])
16 changes: 6 additions & 10 deletions tools/python_api/test/test_prepared_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,25 @@ def test_read(establish_connection):
assert prepared_statement.is_success()
assert prepared_statement.get_error_message() == ""

result = conn.execute(prepared_statement,
[("1", False), ("k", False)])
result = conn.execute(prepared_statement, {"1": False, "k": False})
assert result.is_success()
assert result.has_next()
assert result.get_next() == [1]
assert not result.has_next()

result = conn.execute(prepared_statement,
[("1", True), ("k", False)])
result = conn.execute(prepared_statement, {"1": True, "k": False})
assert result.is_success()
assert result.has_next()
assert result.get_next() == [3]
assert not result.has_next()

result = conn.execute(prepared_statement,
[("1", False), ("k", True)])
result = conn.execute(prepared_statement, {"1": False, "k": True})
assert result.is_success()
assert result.has_next()
assert result.get_next() == [4]
assert not result.has_next()

result = conn.execute(prepared_statement,
[("1", True), ("k", True)])
result = conn.execute(prepared_statement, {"1": True, "k": True})
assert result.is_success()
assert result.has_next()
assert result.get_next() == [0]
Expand Down Expand Up @@ -77,8 +73,8 @@ def test_write(establish_connection):
"CREATE (n:organisation {ID: $ID, name: $name, orgCode: $orgCode, mark: $mark, score: $score, history: $history, licenseValidInterval: $licenseValidInterval, rating: $rating})")
assert prepared_statement.is_success()
for org in orgs:
org_tuples = [(k, v) for k, v in org.items()]
conn.execute(prepared_statement, org_tuples)
org_dict = {str(k): v for k, v in org.items()}
conn.execute(prepared_statement, org_dict)

all_orgs_res = conn.execute("MATCH (n:organisation) RETURN n")
while all_orgs_res.has_next():
Expand Down

0 comments on commit 58ee272

Please sign in to comment.