Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement #1959 #1974

Merged
merged 1 commit into from
Aug 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions tools/python_api/src_cpp/include/py_connection.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#pragma once

#include "main/storage_driver.h"
#include "py_database.h"
#include "py_prepared_statement.h"
#include "py_query_result.h"
#include "main/storage_driver.h"

class PyConnection {

Expand All @@ -16,7 +16,7 @@ class PyConnection {

void setQueryTimeout(uint64_t timeoutInMS);

std::unique_ptr<PyQueryResult> execute(PyPreparedStatement* preparedStatement, py::list params);
std::unique_ptr<PyQueryResult> execute(PyPreparedStatement* preparedStatement, py::dict params);

void setMaxNumThreadForExec(uint64_t numThreads);

Expand All @@ -40,10 +40,7 @@ class PyConnection {

private:
std::unordered_map<std::string, std::shared_ptr<kuzu::common::Value>> transformPythonParameters(
py::list params);

std::pair<std::string, std::shared_ptr<kuzu::common::Value>> transformPythonParameter(
py::tuple param);
py::dict params);

kuzu::common::Value transformPythonValue(py::handle val);

Expand Down
35 changes: 11 additions & 24 deletions tools/python_api/src_cpp/py_connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
py::class_<PyConnection>(m, "Connection")
.def(py::init<PyDatabase*, uint64_t>(), py::arg("database"), py::arg("num_threads") = 0)
.def("execute", &PyConnection::execute, py::arg("prepared_statement"),
py::arg("parameters") = py::list())
py::arg("parameters") = py::dict())
.def("set_max_threads_for_exec", &PyConnection::setMaxNumThreadForExec,
py::arg("num_threads"))
.def("get_node_property_names", &PyConnection::getNodePropertyNames, py::arg("table_name"))
Expand Down Expand Up @@ -44,7 +44,7 @@
}

std::unique_ptr<PyQueryResult> PyConnection::execute(
PyPreparedStatement* preparedStatement, py::list params) {
PyPreparedStatement* preparedStatement, py::dict params) {
auto parameters = transformPythonParameters(params);
py::gil_scoped_release release;
auto queryResult =
Expand Down Expand Up @@ -133,8 +133,7 @@
if (tableSchema->getColumn(0)->isFlat() && !tableSchema->getColumn(1)->isFlat()) {
for (auto i = 0u; i < table->getNumTuples(); ++i) {
auto tuple = table->getTuple(i);
auto overflowValue =
(overflow_value_t*)(tuple + tableSchema->getColOffset(1));
auto overflowValue = (overflow_value_t*)(tuple + tableSchema->getColOffset(1));

Check warning on line 136 in tools/python_api/src_cpp/py_connection.cpp

View check run for this annotation

Codecov / codecov/patch

tools/python_api/src_cpp/py_connection.cpp#L136

Added line #L136 was not covered by tests
for (auto j = 0u; j < overflowValue->numElements; ++j) {
srcBuffer[j] = *(int64_t*)(tuple + tableSchema->getColOffset(0));
}
Expand All @@ -147,8 +146,7 @@
} else if (tableSchema->getColumn(1)->isFlat() && !tableSchema->getColumn(0)->isFlat()) {
for (auto i = 0u; i < table->getNumTuples(); ++i) {
auto tuple = table->getTuple(i);
auto overflowValue =
(overflow_value_t*)(tuple + tableSchema->getColOffset(0));
auto overflowValue = (overflow_value_t*)(tuple + tableSchema->getColOffset(0));
for (auto j = 0u; j < overflowValue->numElements; ++j) {
srcBuffer[j] = ((int64_t*)overflowValue->value)[j];
}
Expand All @@ -166,31 +164,20 @@
}

std::unordered_map<std::string, std::shared_ptr<Value>> PyConnection::transformPythonParameters(
py::list params) {
py::dict params) {
std::unordered_map<std::string, std::shared_ptr<Value>> result;
for (auto param : params) {
if (!py::isinstance<py::tuple>(param)) {
throw std::runtime_error("Each parameter must be in the form of <name, val>");
for (auto& [key, value] : params) {
if (!py::isinstance<py::str>(key)) {
throw std::runtime_error("Parameter name must be of type string but get " +
py::str(key.get_type()).cast<std::string>());
}
auto [name, val] = transformPythonParameter(param.cast<py::tuple>());
auto name = key.cast<std::string>();
auto val = std::make_shared<Value>(transformPythonValue(value));
result.insert({name, val});
}
return result;
}

std::pair<std::string, std::shared_ptr<Value>> PyConnection::transformPythonParameter(
py::tuple param) {
if (py::len(param) != 2) {
throw std::runtime_error("Each parameter must be in the form of <name, val>");
}
if (!py::isinstance<py::str>(param[0])) {
throw std::runtime_error("Parameter name must be of type string but get " +
py::str(param[0].get_type()).cast<std::string>());
}
auto val = transformPythonValue(param[1]);
return make_pair(param[0].cast<std::string>(), std::make_shared<Value>(val));
}

Value PyConnection::transformPythonValue(py::handle val) {
auto datetime_mod = py::module::import("datetime");
auto datetime_datetime = datetime_mod.attr("datetime");
Expand Down
6 changes: 4 additions & 2 deletions tools/python_api/src_py/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def set_max_threads_for_exec(self, num_threads):
self.init_connection()
self._connection.set_max_threads_for_exec(num_threads)

def execute(self, query, parameters=[]):
def execute(self, query, parameters={}):
"""
Execute a query.

Expand All @@ -73,7 +73,7 @@ def execute(self, query, parameters=[]):
A prepared statement or a query string.
If a query string is given, a prepared statement will be created
automatically.
parameters : list[tuple(str, any)]
parameters : dict[str, Any]
Parameters for the query.

Returns
Expand All @@ -82,6 +82,8 @@ def execute(self, query, parameters=[]):
Query result.
"""
self.init_connection()
if type(parameters) != dict:
raise RuntimeError("Parameters must be a dict")
prepared_statement = self.prepare(
query) if type(query) == str else query
return QueryResult(self,
Expand Down
2 changes: 1 addition & 1 deletion tools/python_api/test/test_exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def test_exception(establish_connection):
conn, db = establish_connection

with pytest.raises(RuntimeError, match="Parameter asd not found."):
conn.execute("MATCH (a:person) WHERE a.registerTime = $1 RETURN COUNT(*);", [("asd", 1)])
conn.execute("MATCH (a:person) WHERE a.registerTime = $1 RETURN COUNT(*);", {"asd": 1})

with pytest.raises(RuntimeError, match="Binder exception: Cannot find property dummy for a."):
conn.execute("MATCH (a:person) RETURN a.dummy;")
Expand Down
18 changes: 9 additions & 9 deletions tools/python_api/test/test_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
def test_bool_param(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.isStudent = $1 AND a.isWorker = $k RETURN COUNT(*)",
[("1", False), ("k", False)])
{"1": False, "k": False})
assert result.has_next()
assert result.get_next() == [1]
assert not result.has_next()
Expand All @@ -14,7 +14,7 @@ def test_bool_param(establish_connection):

def test_int_param(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.age < $AGE RETURN COUNT(*)", [("AGE", 1)])
result = conn.execute("MATCH (a:person) WHERE a.age < $AGE RETURN COUNT(*)", { "AGE": 1 })
assert result.has_next()
assert result.get_next() == [0]
assert not result.has_next()
Expand All @@ -23,7 +23,7 @@ def test_int_param(establish_connection):

def test_double_param(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.eyeSight = $E RETURN COUNT(*)", [("E", 5.0)])
result = conn.execute("MATCH (a:person) WHERE a.eyeSight = $E RETURN COUNT(*)", { "E": 5.0 })
assert result.has_next()
assert result.get_next() == [2]
assert not result.has_next()
Expand All @@ -32,7 +32,7 @@ def test_double_param(establish_connection):

def test_str_param(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.ID = 0 RETURN concat(a.fName, $S);", [("S", "HH")])
result = conn.execute("MATCH (a:person) WHERE a.ID = 0 RETURN concat(a.fName, $S);", { "S": "HH" })
assert result.has_next()
assert result.get_next() == ["AliceHH"]
assert not result.has_next()
Expand All @@ -42,7 +42,7 @@ def test_str_param(establish_connection):
def test_date_param(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.birthdate = $1 RETURN COUNT(*);",
[("1", datetime.date(1900, 1, 1))])
{ "1": datetime.date(1900, 1, 1) })
assert result.has_next()
assert result.get_next() == [2]
assert not result.has_next()
Expand All @@ -52,7 +52,7 @@ def test_date_param(establish_connection):
def test_timestamp_param(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.registerTime = $1 RETURN COUNT(*);",
[("1", datetime.datetime(2011, 8, 20, 11, 25, 30))])
{ "1": datetime.datetime(2011, 8, 20, 11, 25, 30) })
assert result.has_next()
assert result.get_next() == [1]
assert not result.has_next()
Expand All @@ -62,16 +62,16 @@ def test_timestamp_param(establish_connection):
def test_param_error1(establish_connection):
conn, db = establish_connection
with pytest.raises(RuntimeError, match="Parameter name must be of type string but get <class 'int'>"):
conn.execute("MATCH (a:person) WHERE a.registerTime = $1 RETURN COUNT(*);", [(1, 1)])
conn.execute("MATCH (a:person) WHERE a.registerTime = $1 RETURN COUNT(*);", {1: 1})


def test_param_error2(establish_connection):
conn, db = establish_connection
with pytest.raises(RuntimeError, match="Each parameter must be in the form of <name, val>"):
with pytest.raises(RuntimeError, match="Parameters must be a dict"):
conn.execute("MATCH (a:person) WHERE a.registerTime = $1 RETURN COUNT(*);", ["asd"])


def test_param_error3(establish_connection):
conn, db = establish_connection
with pytest.raises(RuntimeError, match="Each parameter must be in the form of <name, val>"):
with pytest.raises(RuntimeError, match="Parameters must be a dict"):
conn.execute("MATCH (a:person) WHERE a.registerTime = $1 RETURN COUNT(*);", [("asd", 1, 1)])
16 changes: 6 additions & 10 deletions tools/python_api/test/test_prepared_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,25 @@ def test_read(establish_connection):
assert prepared_statement.is_success()
assert prepared_statement.get_error_message() == ""

result = conn.execute(prepared_statement,
[("1", False), ("k", False)])
result = conn.execute(prepared_statement, {"1": False, "k": False})
assert result.is_success()
assert result.has_next()
assert result.get_next() == [1]
assert not result.has_next()

result = conn.execute(prepared_statement,
[("1", True), ("k", False)])
result = conn.execute(prepared_statement, {"1": True, "k": False})
assert result.is_success()
assert result.has_next()
assert result.get_next() == [3]
assert not result.has_next()

result = conn.execute(prepared_statement,
[("1", False), ("k", True)])
result = conn.execute(prepared_statement, {"1": False, "k": True})
assert result.is_success()
assert result.has_next()
assert result.get_next() == [4]
assert not result.has_next()

result = conn.execute(prepared_statement,
[("1", True), ("k", True)])
result = conn.execute(prepared_statement, {"1": True, "k": True})
assert result.is_success()
assert result.has_next()
assert result.get_next() == [0]
Expand Down Expand Up @@ -77,8 +73,8 @@ def test_write(establish_connection):
"CREATE (n:organisation {ID: $ID, name: $name, orgCode: $orgCode, mark: $mark, score: $score, history: $history, licenseValidInterval: $licenseValidInterval, rating: $rating})")
assert prepared_statement.is_success()
for org in orgs:
org_tuples = [(k, v) for k, v in org.items()]
conn.execute(prepared_statement, org_tuples)
org_dict = {str(k): v for k, v in org.items()}
conn.execute(prepared_statement, org_dict)

all_orgs_res = conn.execute("MATCH (n:organisation) RETURN n")
while all_orgs_res.has_next():
Expand Down