Skip to content

Commit

Permalink
Merge pull request #1124 from kuzudb/py-api-wrapper
Browse files Browse the repository at this point in the history
Wrap pybind11 API and and Fix #1106
  • Loading branch information
mewim committed Dec 19, 2022
2 parents 08fba27 + b94aca9 commit 448b38f
Show file tree
Hide file tree
Showing 24 changed files with 202 additions and 82 deletions.
22 changes: 14 additions & 8 deletions tools/python_api/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,20 @@ project(_kuzu)

set(CMAKE_CXX_STANDARD 20)

file(GLOB SOURCE_PY
"src_py/*")

pybind11_add_module(_kuzu
SHARED
kuzu_binding.cpp
py_connection.cpp
py_database.cpp
py_query_result.cpp
py_query_result_converter.cpp)
src_cpp/kuzu_binding.cpp
src_cpp/py_connection.cpp
src_cpp/py_database.cpp
src_cpp/py_query_result.cpp
src_cpp/py_query_result_converter.cpp)

set_target_properties(_kuzu
PROPERTIES
LIBRARY_OUTPUT_DIRECTORY "${PROJECT_SOURCE_DIR}/build/")
LIBRARY_OUTPUT_DIRECTORY "${PROJECT_SOURCE_DIR}/build/kuzu")

target_link_libraries(_kuzu
PRIVATE
Expand All @@ -22,5 +25,8 @@ target_link_libraries(_kuzu
target_include_directories(
_kuzu
PUBLIC
../../src/include
)
../../src/include)

get_target_property(PYTHON_DEST _kuzu LIBRARY_OUTPUT_DIRECTORY)

file(COPY ${SOURCE_PY} DESTINATION ${PYTHON_DEST})
Empty file removed tools/python_api/__init__.py
Empty file.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#include "datetime.h" // from Python

void PyConnection::initialize(py::handle& m) {
py::class_<PyConnection>(m, "connection")
py::class_<PyConnection>(m, "Connection")
.def(py::init<PyDatabase*, uint64_t>(), py::arg("database"), py::arg("num_threads") = 0)
.def(
"execute", &PyConnection::execute, py::arg("query"), py::arg("parameters") = py::list())
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include "include/py_database.h"

void PyDatabase::initialize(py::handle& m) {
py::class_<PyDatabase>(m, "database")
py::class_<PyDatabase>(m, "Database")
.def(py::init<const string&, uint64_t>(), py::arg("database_path"),
py::arg("buffer_pool_size") = 0)
.def("resize_buffer_manager", &PyDatabase::resizeBufferManager, py::arg("new_size"))
Expand Down
File renamed without changes.
5 changes: 5 additions & 0 deletions tools/python_api/src_py/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from ._kuzu import *
# The following imports will override C++ implementations with Python
# implementations.
from .connection import *
from .query_result import *
14 changes: 14 additions & 0 deletions tools/python_api/src_py/connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from .query_result import QueryResult
from . import _kuzu


class Connection:
def __init__(self, database, num_threads=0):
self.database = database
self._connection = _kuzu.Connection(database, num_threads)

def set_max_threads_for_exec(self, num_threads):
self._connection.set_max_threads_for_exec(num_threads)

def execute(self, query, parameters=[]):
return QueryResult(self, self._connection.execute(query, parameters))
46 changes: 46 additions & 0 deletions tools/python_api/src_py/query_result.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
class QueryResult:
def __init__(self, connection, query_result):
self.connection = connection
self._query_result = query_result
self.is_closed = False

def __del__(self):
self.close()

def check_for_query_result_close(self):
if self.is_closed:
raise Exception("Query result is closed")

def has_next(self):
self.check_for_query_result_close()
return self._query_result.hasNext()

def get_next(self):
self.check_for_query_result_close()
return self._query_result.getNext()

def write_to_csv(self, filename, delimiter=',', escapeCharacter='"', newline='\n'):
self.check_for_query_result_close()
self._query_result.writeToCSV(
filename, delimiter, escapeCharacter, newline)

def close(self):
if self.is_closed:
return
self._query_result.close()
# Allows the connection to be garbage collected if the query result
# is closed manually by the user.
self.connection = None
self.is_closed = True

def get_as_df(self):
self.check_for_query_result_close()
return self._query_result.getAsDF()

def get_column_data_types(self):
self.check_for_query_result_close()
return self._query_result.getColumnDataTypes()

def get_column_names(self):
self.check_for_query_result_close()
return self._query_result.getColumnNames()
17 changes: 11 additions & 6 deletions tools/python_api/test/conftest.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
import os
import sys
import pytest
import shutil
sys.path.append('../build/')
import _kuzu as kuzu
import kuzu


# Note conftest is the default file name for sharing fixture through multiple test files. Do not change file name.
@pytest.fixture
def init_tiny_snb(tmp_path):
if os.path.exists(tmp_path):
os.rmdir(tmp_path)
shutil.rmtree(tmp_path)
output_path = str(tmp_path)
db = kuzu.database(output_path)
conn = kuzu.connection(db)
db = kuzu.Database(output_path)
conn = kuzu.Connection(db)
conn.execute("CREATE NODE TABLE person (ID INT64, fName STRING, gender INT64, isStudent BOOLEAN, isWorker BOOLEAN, "
"age INT64, eyeSight DOUBLE, birthdate DATE, registerTime TIMESTAMP, lastJobDuration "
"INTERVAL, workedHours INT64[], usedNames STRING[], courseScoresPerTerm INT64[][], PRIMARY "
Expand All @@ -26,6 +27,10 @@ def init_tiny_snb(tmp_path):

@pytest.fixture
def establish_connection(init_tiny_snb):
db = kuzu.database(init_tiny_snb, buffer_pool_size=256 * 1024 * 1024)
conn = kuzu.connection(db, num_threads=4)
db = kuzu.Database(init_tiny_snb, buffer_pool_size=256 * 1024 * 1024)
conn = kuzu.Connection(db, num_threads=4)
return conn, db

@pytest.fixture
def get_tmp_path(tmp_path):
return str(tmp_path)
12 changes: 7 additions & 5 deletions tools/python_api/test/example.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from tools.python_api import _kuzu as kuzu
import sys
sys.path.append('../build/')
import kuzu

databaseDir = "path to database file"
db = kuzu.database(databaseDir)
conn = kuzu.connection(db)
db = kuzu.Database(databaseDir)
conn = kuzu.Connection(db)
query = "MATCH (a:person) RETURN *;"
result = conn.execute(query)
while result.hasNext():
print(result.getNext())
while result.has_next():
print(result.get_next())
result.close()
48 changes: 24 additions & 24 deletions tools/python_api/test/test_datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,71 +4,71 @@
def test_bool_wrap(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.ID = 0 RETURN a.isStudent;")
assert result.hasNext()
assert result.getNext() == [True]
assert not result.hasNext()
assert result.has_next()
assert result.get_next() == [True]
assert not result.has_next()
result.close()


def test_int_wrap(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.ID = 0 RETURN a.age;")
assert result.hasNext()
assert result.getNext() == [35]
assert not result.hasNext()
assert result.has_next()
assert result.get_next() == [35]
assert not result.has_next()
result.close()


def test_double_wrap(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.ID = 0 RETURN a.eyeSight;")
assert result.hasNext()
assert result.getNext() == [5.0]
assert not result.hasNext()
assert result.has_next()
assert result.get_next() == [5.0]
assert not result.has_next()
result.close()


def test_string_wrap(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.ID = 0 RETURN a.fName;")
assert result.hasNext()
assert result.getNext() == ['Alice']
assert not result.hasNext()
assert result.has_next()
assert result.get_next() == ['Alice']
assert not result.has_next()
result.close()


def test_date_wrap(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.ID = 0 RETURN a.birthdate;")
assert result.hasNext()
assert result.getNext() == [datetime.date(1900, 1, 1)]
assert not result.hasNext()
assert result.has_next()
assert result.get_next() == [datetime.date(1900, 1, 1)]
assert not result.has_next()
result.close()


def test_timestamp_wrap(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.ID = 0 RETURN a.registerTime;")
assert result.hasNext()
assert result.getNext() == [datetime.datetime(2011, 8, 20, 11, 25, 30)]
assert not result.hasNext()
assert result.has_next()
assert result.get_next() == [datetime.datetime(2011, 8, 20, 11, 25, 30)]
assert not result.has_next()
result.close()


def test_interval_wrap(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.ID = 0 RETURN a.lastJobDuration;")
assert result.hasNext()
assert result.getNext() == [datetime.timedelta(days=1082, seconds=46920)]
assert not result.hasNext()
assert result.has_next()
assert result.get_next() == [datetime.timedelta(days=1082, seconds=46920)]
assert not result.has_next()
result.close()


def test_list_wrap(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.ID = 0 RETURN a.courseScoresPerTerm;")
assert result.hasNext()
assert result.getNext() == [[[10, 8], [6, 7, 8]]]
assert not result.hasNext()
assert result.has_next()
assert result.get_next() == [[[10, 8], [6, 7, 8]]]
assert not result.has_next()
result.close()

4 changes: 2 additions & 2 deletions tools/python_api/test/test_df.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import numpy as np
import sys
sys.path.append('../build/')
import _kuzu as kuzu
import kuzu
from pandas import Timestamp, Timedelta, isna

def test_to_df(establish_connection):
conn, db = establish_connection

def _test_to_df(conn):
query = "MATCH (p:person) return * ORDER BY p.ID"
pd = conn.execute(query).getAsDF()
pd = conn.execute(query).get_as_df()
assert pd['p.ID'].tolist() == [0, 2, 3, 5, 7, 8, 9, 10]
assert str(pd['p.ID'].dtype) == "int64"
assert pd['p.fName'].tolist() == ["Alice", "Bob", "Carol", "Dan", "Elizabeth", "Farooq", "Greg",
Expand Down
4 changes: 2 additions & 2 deletions tools/python_api/test/test_get_header.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
def test_get_column_names(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person)-[e:knows]->(b:person) RETURN a.fName, e.date, b.ID;")
column_names = result.getColumnNames()
column_names = result.get_column_names()
assert column_names[0] == 'a.fName'
assert column_names[1] == 'e.date'
assert column_names[2] == 'b.ID'
Expand All @@ -13,7 +13,7 @@ def test_get_column_data_types(establish_connection):
result = conn.execute(
"MATCH (p:person) RETURN p.ID, p.fName, p.isStudent, p.eyeSight, p.birthdate, p.registerTime, "
"p.lastJobDuration, p.workedHours, p.courseScoresPerTerm;")
column_data_types = result.getColumnDataTypes()
column_data_types = result.get_column_data_types()
assert column_data_types[0] == 'INT64'
assert column_data_types[1] == 'STRING'
assert column_data_types[2] == 'BOOL'
Expand Down
1 change: 1 addition & 0 deletions tools/python_api/test/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from test_df import *
from test_write_to_csv import *
from test_get_header import *
from test_query_result_close import *

if __name__ == "__main__":
raise SystemExit(pytest.main([__file__]))
36 changes: 18 additions & 18 deletions tools/python_api/test/test_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,56 +6,56 @@ def test_bool_param(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.isStudent = $1 AND a.isWorker = $k RETURN COUNT(*)",
[("1", False), ("k", False)])
assert result.hasNext()
assert result.getNext() == [1]
assert not result.hasNext()
assert result.has_next()
assert result.get_next() == [1]
assert not result.has_next()
result.close()


def test_int_param(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.age < $AGE RETURN COUNT(*)", [("AGE", 1)])
assert result.hasNext()
assert result.getNext() == [0]
assert not result.hasNext()
assert result.has_next()
assert result.get_next() == [0]
assert not result.has_next()
result.close()


def test_double_param(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.eyeSight = $E RETURN COUNT(*)", [("E", 5.0)])
assert result.hasNext()
assert result.getNext() == [2]
assert not result.hasNext()
assert result.has_next()
assert result.get_next() == [2]
assert not result.has_next()
result.close()


def test_str_param(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.ID = 0 RETURN concat(a.fName, $S);", [("S", "HH")])
assert result.hasNext()
assert result.getNext() == ["AliceHH"]
assert not result.hasNext()
assert result.has_next()
assert result.get_next() == ["AliceHH"]
assert not result.has_next()
result.close()


def test_date_param(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.birthdate = $1 RETURN COUNT(*);",
[("1", datetime.date(1900, 1, 1))])
assert result.hasNext()
assert result.getNext() == [2]
assert not result.hasNext()
assert result.has_next()
assert result.get_next() == [2]
assert not result.has_next()
result.close()


def test_timestamp_param(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.registerTime = $1 RETURN COUNT(*);",
[("1", datetime.datetime(2011, 8, 20, 11, 25, 30))])
assert result.hasNext()
assert result.getNext() == [1]
assert not result.hasNext()
assert result.has_next()
assert result.get_next() == [1]
assert not result.has_next()
result.close()


Expand Down
Loading

0 comments on commit 448b38f

Please sign in to comment.