Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrap pybind11 API and and Fix #1106 #1124

Merged
merged 1 commit into from
Dec 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions tools/python_api/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,20 @@ project(_kuzu)

set(CMAKE_CXX_STANDARD 20)

file(GLOB SOURCE_PY
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rather not use GLOB, but list all necessary files here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have changed the C++ target to use the explicit path instead of using GLOB. But I think we can keep the GLOB for Python files since CMake is not actually running any build script against the Python files. It is just copying the files to the build directory.

"src_py/*")

pybind11_add_module(_kuzu
SHARED
kuzu_binding.cpp
py_connection.cpp
py_database.cpp
py_query_result.cpp
py_query_result_converter.cpp)
src_cpp/kuzu_binding.cpp
src_cpp/py_connection.cpp
src_cpp/py_database.cpp
src_cpp/py_query_result.cpp
src_cpp/py_query_result_converter.cpp)

set_target_properties(_kuzu
PROPERTIES
LIBRARY_OUTPUT_DIRECTORY "${PROJECT_SOURCE_DIR}/build/")
LIBRARY_OUTPUT_DIRECTORY "${PROJECT_SOURCE_DIR}/build/kuzu")

target_link_libraries(_kuzu
PRIVATE
Expand All @@ -22,5 +25,8 @@ target_link_libraries(_kuzu
target_include_directories(
_kuzu
PUBLIC
../../src/include
)
../../src/include)

get_target_property(PYTHON_DEST _kuzu LIBRARY_OUTPUT_DIRECTORY)

file(COPY ${SOURCE_PY} DESTINATION ${PYTHON_DEST})
Empty file removed tools/python_api/__init__.py
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#include "datetime.h" // from Python

void PyConnection::initialize(py::handle& m) {
py::class_<PyConnection>(m, "connection")
py::class_<PyConnection>(m, "Connection")
.def(py::init<PyDatabase*, uint64_t>(), py::arg("database"), py::arg("num_threads") = 0)
.def(
"execute", &PyConnection::execute, py::arg("query"), py::arg("parameters") = py::list())
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include "include/py_database.h"

void PyDatabase::initialize(py::handle& m) {
py::class_<PyDatabase>(m, "database")
py::class_<PyDatabase>(m, "Database")
.def(py::init<const string&, uint64_t>(), py::arg("database_path"),
py::arg("buffer_pool_size") = 0)
.def("resize_buffer_manager", &PyDatabase::resizeBufferManager, py::arg("new_size"))
Expand Down
5 changes: 5 additions & 0 deletions tools/python_api/src_py/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from ._kuzu import *
# The following imports will override C++ implementations with Python
# implementations.
from .connection import *
from .query_result import *
14 changes: 14 additions & 0 deletions tools/python_api/src_py/connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from .query_result import QueryResult
from . import _kuzu


class Connection:
def __init__(self, database, num_threads=0):
self.database = database
self._connection = _kuzu.Connection(database, num_threads)

def set_max_threads_for_exec(self, num_threads):
self._connection.set_max_threads_for_exec(num_threads)

def execute(self, query, parameters=[]):
return QueryResult(self, self._connection.execute(query, parameters))
46 changes: 46 additions & 0 deletions tools/python_api/src_py/query_result.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
class QueryResult:
def __init__(self, connection, query_result):
self.connection = connection
self._query_result = query_result
self.is_closed = False

def __del__(self):
self.close()

def check_for_query_result_close(self):
if self.is_closed:
raise Exception("Query result is closed")

def has_next(self):
self.check_for_query_result_close()
return self._query_result.hasNext()

def get_next(self):
self.check_for_query_result_close()
return self._query_result.getNext()

def write_to_csv(self, filename, delimiter=',', escapeCharacter='"', newline='\n'):
self.check_for_query_result_close()
self._query_result.writeToCSV(
filename, delimiter, escapeCharacter, newline)

def close(self):
if self.is_closed:
return
self._query_result.close()
# Allows the connection to be garbage collected if the query result
# is closed manually by the user.
self.connection = None
self.is_closed = True

def get_as_df(self):
self.check_for_query_result_close()
return self._query_result.getAsDF()

def get_column_data_types(self):
self.check_for_query_result_close()
return self._query_result.getColumnDataTypes()

def get_column_names(self):
self.check_for_query_result_close()
return self._query_result.getColumnNames()
17 changes: 11 additions & 6 deletions tools/python_api/test/conftest.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
import os
import sys
import pytest
import shutil
sys.path.append('../build/')
import _kuzu as kuzu
import kuzu


# Note conftest is the default file name for sharing fixture through multiple test files. Do not change file name.
@pytest.fixture
def init_tiny_snb(tmp_path):
if os.path.exists(tmp_path):
os.rmdir(tmp_path)
shutil.rmtree(tmp_path)
output_path = str(tmp_path)
db = kuzu.database(output_path)
conn = kuzu.connection(db)
db = kuzu.Database(output_path)
conn = kuzu.Connection(db)
conn.execute("CREATE NODE TABLE person (ID INT64, fName STRING, gender INT64, isStudent BOOLEAN, isWorker BOOLEAN, "
"age INT64, eyeSight DOUBLE, birthdate DATE, registerTime TIMESTAMP, lastJobDuration "
"INTERVAL, workedHours INT64[], usedNames STRING[], courseScoresPerTerm INT64[][], PRIMARY "
Expand All @@ -26,6 +27,10 @@ def init_tiny_snb(tmp_path):

@pytest.fixture
def establish_connection(init_tiny_snb):
db = kuzu.database(init_tiny_snb, buffer_pool_size=256 * 1024 * 1024)
conn = kuzu.connection(db, num_threads=4)
db = kuzu.Database(init_tiny_snb, buffer_pool_size=256 * 1024 * 1024)
conn = kuzu.Connection(db, num_threads=4)
return conn, db

@pytest.fixture
def get_tmp_path(tmp_path):
return str(tmp_path)
12 changes: 7 additions & 5 deletions tools/python_api/test/example.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from tools.python_api import _kuzu as kuzu
import sys
sys.path.append('../build/')
import kuzu

databaseDir = "path to database file"
db = kuzu.database(databaseDir)
conn = kuzu.connection(db)
db = kuzu.Database(databaseDir)
conn = kuzu.Connection(db)
query = "MATCH (a:person) RETURN *;"
result = conn.execute(query)
while result.hasNext():
print(result.getNext())
while result.has_next():
print(result.get_next())
result.close()
48 changes: 24 additions & 24 deletions tools/python_api/test/test_datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,71 +4,71 @@
def test_bool_wrap(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.ID = 0 RETURN a.isStudent;")
assert result.hasNext()
assert result.getNext() == [True]
assert not result.hasNext()
assert result.has_next()
assert result.get_next() == [True]
assert not result.has_next()
result.close()


def test_int_wrap(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.ID = 0 RETURN a.age;")
assert result.hasNext()
assert result.getNext() == [35]
assert not result.hasNext()
assert result.has_next()
assert result.get_next() == [35]
assert not result.has_next()
result.close()


def test_double_wrap(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.ID = 0 RETURN a.eyeSight;")
assert result.hasNext()
assert result.getNext() == [5.0]
assert not result.hasNext()
assert result.has_next()
assert result.get_next() == [5.0]
assert not result.has_next()
result.close()


def test_string_wrap(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.ID = 0 RETURN a.fName;")
assert result.hasNext()
assert result.getNext() == ['Alice']
assert not result.hasNext()
assert result.has_next()
assert result.get_next() == ['Alice']
assert not result.has_next()
result.close()


def test_date_wrap(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.ID = 0 RETURN a.birthdate;")
assert result.hasNext()
assert result.getNext() == [datetime.date(1900, 1, 1)]
assert not result.hasNext()
assert result.has_next()
assert result.get_next() == [datetime.date(1900, 1, 1)]
assert not result.has_next()
result.close()


def test_timestamp_wrap(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.ID = 0 RETURN a.registerTime;")
assert result.hasNext()
assert result.getNext() == [datetime.datetime(2011, 8, 20, 11, 25, 30)]
assert not result.hasNext()
assert result.has_next()
assert result.get_next() == [datetime.datetime(2011, 8, 20, 11, 25, 30)]
assert not result.has_next()
result.close()


def test_interval_wrap(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.ID = 0 RETURN a.lastJobDuration;")
assert result.hasNext()
assert result.getNext() == [datetime.timedelta(days=1082, seconds=46920)]
assert not result.hasNext()
assert result.has_next()
assert result.get_next() == [datetime.timedelta(days=1082, seconds=46920)]
assert not result.has_next()
result.close()


def test_list_wrap(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.ID = 0 RETURN a.courseScoresPerTerm;")
assert result.hasNext()
assert result.getNext() == [[[10, 8], [6, 7, 8]]]
assert not result.hasNext()
assert result.has_next()
assert result.get_next() == [[[10, 8], [6, 7, 8]]]
assert not result.has_next()
result.close()

4 changes: 2 additions & 2 deletions tools/python_api/test/test_df.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import numpy as np
import sys
sys.path.append('../build/')
import _kuzu as kuzu
import kuzu
from pandas import Timestamp, Timedelta, isna

def test_to_df(establish_connection):
conn, db = establish_connection

def _test_to_df(conn):
query = "MATCH (p:person) return * ORDER BY p.ID"
pd = conn.execute(query).getAsDF()
pd = conn.execute(query).get_as_df()
assert pd['p.ID'].tolist() == [0, 2, 3, 5, 7, 8, 9, 10]
assert str(pd['p.ID'].dtype) == "int64"
assert pd['p.fName'].tolist() == ["Alice", "Bob", "Carol", "Dan", "Elizabeth", "Farooq", "Greg",
Expand Down
4 changes: 2 additions & 2 deletions tools/python_api/test/test_get_header.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
def test_get_column_names(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person)-[e:knows]->(b:person) RETURN a.fName, e.date, b.ID;")
column_names = result.getColumnNames()
column_names = result.get_column_names()
assert column_names[0] == 'a.fName'
assert column_names[1] == 'e.date'
assert column_names[2] == 'b.ID'
Expand All @@ -13,7 +13,7 @@ def test_get_column_data_types(establish_connection):
result = conn.execute(
"MATCH (p:person) RETURN p.ID, p.fName, p.isStudent, p.eyeSight, p.birthdate, p.registerTime, "
"p.lastJobDuration, p.workedHours, p.courseScoresPerTerm;")
column_data_types = result.getColumnDataTypes()
column_data_types = result.get_column_data_types()
assert column_data_types[0] == 'INT64'
assert column_data_types[1] == 'STRING'
assert column_data_types[2] == 'BOOL'
Expand Down
1 change: 1 addition & 0 deletions tools/python_api/test/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from test_df import *
from test_write_to_csv import *
from test_get_header import *
from test_query_result_close import *

if __name__ == "__main__":
raise SystemExit(pytest.main([__file__]))
36 changes: 18 additions & 18 deletions tools/python_api/test/test_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,56 +6,56 @@ def test_bool_param(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.isStudent = $1 AND a.isWorker = $k RETURN COUNT(*)",
[("1", False), ("k", False)])
assert result.hasNext()
assert result.getNext() == [1]
assert not result.hasNext()
assert result.has_next()
assert result.get_next() == [1]
assert not result.has_next()
result.close()


def test_int_param(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.age < $AGE RETURN COUNT(*)", [("AGE", 1)])
assert result.hasNext()
assert result.getNext() == [0]
assert not result.hasNext()
assert result.has_next()
assert result.get_next() == [0]
assert not result.has_next()
result.close()


def test_double_param(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.eyeSight = $E RETURN COUNT(*)", [("E", 5.0)])
assert result.hasNext()
assert result.getNext() == [2]
assert not result.hasNext()
assert result.has_next()
assert result.get_next() == [2]
assert not result.has_next()
result.close()


def test_str_param(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.ID = 0 RETURN concat(a.fName, $S);", [("S", "HH")])
assert result.hasNext()
assert result.getNext() == ["AliceHH"]
assert not result.hasNext()
assert result.has_next()
assert result.get_next() == ["AliceHH"]
assert not result.has_next()
result.close()


def test_date_param(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.birthdate = $1 RETURN COUNT(*);",
[("1", datetime.date(1900, 1, 1))])
assert result.hasNext()
assert result.getNext() == [2]
assert not result.hasNext()
assert result.has_next()
assert result.get_next() == [2]
assert not result.has_next()
result.close()


def test_timestamp_param(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.registerTime = $1 RETURN COUNT(*);",
[("1", datetime.datetime(2011, 8, 20, 11, 25, 30))])
assert result.hasNext()
assert result.getNext() == [1]
assert not result.hasNext()
assert result.has_next()
assert result.get_next() == [1]
assert not result.has_next()
result.close()


Expand Down
Loading