Skip to content

Commit

Permalink
Wrap pybind11 API and and Fix #1106
Browse files Browse the repository at this point in the history
  • Loading branch information
mewim committed Dec 19, 2022
1 parent 08fba27 commit b3b8d75
Show file tree
Hide file tree
Showing 25 changed files with 178 additions and 72 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ bazel-*

test_temp/
build/
cmake-build-debug/

### Python
# Byte-compiled / optimized / DLL files
Expand Down
20 changes: 14 additions & 6 deletions tools/python_api/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,21 @@ project(_kuzu)

set(CMAKE_CXX_STANDARD 20)

file(GLOB SOURCE_PY
"src_py/*"
)

file(GLOB SOURCE_CPP
"src_cpp/*"
)

pybind11_add_module(_kuzu
SHARED
kuzu_binding.cpp
py_connection.cpp
py_database.cpp
py_query_result.cpp
py_query_result_converter.cpp)
${SOURCE_CPP})

set_target_properties(_kuzu
PROPERTIES
LIBRARY_OUTPUT_DIRECTORY "${PROJECT_SOURCE_DIR}/build/")
LIBRARY_OUTPUT_DIRECTORY "${PROJECT_SOURCE_DIR}/build/kuzu")

target_link_libraries(_kuzu
PRIVATE
Expand All @@ -24,3 +28,7 @@ target_include_directories(
PUBLIC
../../src/include
)

get_target_property(PYTHON_DEST _kuzu LIBRARY_OUTPUT_DIRECTORY)

file(COPY ${SOURCE_PY} DESTINATION ${PYTHON_DEST})
Empty file removed tools/python_api/__init__.py
Empty file.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
5 changes: 5 additions & 0 deletions tools/python_api/src_py/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from ._kuzu import *
# The following imports will override C++ implementations with Python
# implementations.
from .connection import *
from .query_result import *
14 changes: 14 additions & 0 deletions tools/python_api/src_py/connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from .query_result import QueryResult
from . import _kuzu


class connection:
def __init__(self, database, num_threads=0):
self.database = database
self._connection = _kuzu.connection(database, num_threads)

def set_max_threads_for_exec(self, num_threads):
self._connection.set_max_threads_for_exec(num_threads)

def execute(self, query, parameters=[]):
return QueryResult(self, self._connection.execute(query, parameters))
29 changes: 29 additions & 0 deletions tools/python_api/src_py/query_result.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
class QueryResult:
def __init__(self, connection, query_result):
self.connection = connection
self._query_result = query_result

def __del__(self):
self._query_result.close()

def has_next(self):
return self._query_result.hasNext()

def get_next(self):
return self._query_result.getNext()

def write_to_csv(self, filename, delimiter=',', escapeCharacter='"', newline='\n'):
self._query_result.writeToCSV(
filename, delimiter, escapeCharacter, newline)

def close(self):
self._query_result.close()

def get_as_df(self):
return self._query_result.getAsDF()

def get_column_data_types(self):
return self._query_result.getColumnDataTypes()

def get_column_names(self):
return self._query_result.getColumnNames()
9 changes: 7 additions & 2 deletions tools/python_api/test/conftest.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import os
import sys
import pytest
import shutil
sys.path.append('../build/')
import _kuzu as kuzu
import kuzu


# Note conftest is the default file name for sharing fixture through multiple test files. Do not change file name.
@pytest.fixture
def init_tiny_snb(tmp_path):
if os.path.exists(tmp_path):
os.rmdir(tmp_path)
shutil.rmtree(tmp_path)
output_path = str(tmp_path)
db = kuzu.database(output_path)
conn = kuzu.connection(db)
Expand All @@ -29,3 +30,7 @@ def establish_connection(init_tiny_snb):
db = kuzu.database(init_tiny_snb, buffer_pool_size=256 * 1024 * 1024)
conn = kuzu.connection(db, num_threads=4)
return conn, db

@pytest.fixture
def get_tmp_path(tmp_path):
return str(tmp_path)
8 changes: 5 additions & 3 deletions tools/python_api/test/example.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from tools.python_api import _kuzu as kuzu
import sys
sys.path.append('../build/')
import kuzu

databaseDir = "path to database file"
db = kuzu.database(databaseDir)
conn = kuzu.connection(db)
query = "MATCH (a:person) RETURN *;"
result = conn.execute(query)
while result.hasNext():
print(result.getNext())
while result.has_next():
print(result.get_next())
result.close()
48 changes: 24 additions & 24 deletions tools/python_api/test/test_datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,71 +4,71 @@
def test_bool_wrap(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.ID = 0 RETURN a.isStudent;")
assert result.hasNext()
assert result.getNext() == [True]
assert not result.hasNext()
assert result.has_next()
assert result.get_next() == [True]
assert not result.has_next()
result.close()


def test_int_wrap(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.ID = 0 RETURN a.age;")
assert result.hasNext()
assert result.getNext() == [35]
assert not result.hasNext()
assert result.has_next()
assert result.get_next() == [35]
assert not result.has_next()
result.close()


def test_double_wrap(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.ID = 0 RETURN a.eyeSight;")
assert result.hasNext()
assert result.getNext() == [5.0]
assert not result.hasNext()
assert result.has_next()
assert result.get_next() == [5.0]
assert not result.has_next()
result.close()


def test_string_wrap(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.ID = 0 RETURN a.fName;")
assert result.hasNext()
assert result.getNext() == ['Alice']
assert not result.hasNext()
assert result.has_next()
assert result.get_next() == ['Alice']
assert not result.has_next()
result.close()


def test_date_wrap(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.ID = 0 RETURN a.birthdate;")
assert result.hasNext()
assert result.getNext() == [datetime.date(1900, 1, 1)]
assert not result.hasNext()
assert result.has_next()
assert result.get_next() == [datetime.date(1900, 1, 1)]
assert not result.has_next()
result.close()


def test_timestamp_wrap(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.ID = 0 RETURN a.registerTime;")
assert result.hasNext()
assert result.getNext() == [datetime.datetime(2011, 8, 20, 11, 25, 30)]
assert not result.hasNext()
assert result.has_next()
assert result.get_next() == [datetime.datetime(2011, 8, 20, 11, 25, 30)]
assert not result.has_next()
result.close()


def test_interval_wrap(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.ID = 0 RETURN a.lastJobDuration;")
assert result.hasNext()
assert result.getNext() == [datetime.timedelta(days=1082, seconds=46920)]
assert not result.hasNext()
assert result.has_next()
assert result.get_next() == [datetime.timedelta(days=1082, seconds=46920)]
assert not result.has_next()
result.close()


def test_list_wrap(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.ID = 0 RETURN a.courseScoresPerTerm;")
assert result.hasNext()
assert result.getNext() == [[[10, 8], [6, 7, 8]]]
assert not result.hasNext()
assert result.has_next()
assert result.get_next() == [[[10, 8], [6, 7, 8]]]
assert not result.has_next()
result.close()

4 changes: 2 additions & 2 deletions tools/python_api/test/test_df.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import numpy as np
import sys
sys.path.append('../build/')
import _kuzu as kuzu
import kuzu
from pandas import Timestamp, Timedelta, isna

def test_to_df(establish_connection):
conn, db = establish_connection

def _test_to_df(conn):
query = "MATCH (p:person) return * ORDER BY p.ID"
pd = conn.execute(query).getAsDF()
pd = conn.execute(query).get_as_df()
assert pd['p.ID'].tolist() == [0, 2, 3, 5, 7, 8, 9, 10]
assert str(pd['p.ID'].dtype) == "int64"
assert pd['p.fName'].tolist() == ["Alice", "Bob", "Carol", "Dan", "Elizabeth", "Farooq", "Greg",
Expand Down
4 changes: 2 additions & 2 deletions tools/python_api/test/test_get_header.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
def test_get_column_names(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person)-[e:knows]->(b:person) RETURN a.fName, e.date, b.ID;")
column_names = result.getColumnNames()
column_names = result.get_column_names()
assert column_names[0] == 'a.fName'
assert column_names[1] == 'e.date'
assert column_names[2] == 'b.ID'
Expand All @@ -13,7 +13,7 @@ def test_get_column_data_types(establish_connection):
result = conn.execute(
"MATCH (p:person) RETURN p.ID, p.fName, p.isStudent, p.eyeSight, p.birthdate, p.registerTime, "
"p.lastJobDuration, p.workedHours, p.courseScoresPerTerm;")
column_data_types = result.getColumnDataTypes()
column_data_types = result.get_column_data_types()
assert column_data_types[0] == 'INT64'
assert column_data_types[1] == 'STRING'
assert column_data_types[2] == 'BOOL'
Expand Down
1 change: 1 addition & 0 deletions tools/python_api/test/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from test_df import *
from test_write_to_csv import *
from test_get_header import *
from test_query_result_close import *

if __name__ == "__main__":
raise SystemExit(pytest.main([__file__]))
36 changes: 18 additions & 18 deletions tools/python_api/test/test_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,56 +6,56 @@ def test_bool_param(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.isStudent = $1 AND a.isWorker = $k RETURN COUNT(*)",
[("1", False), ("k", False)])
assert result.hasNext()
assert result.getNext() == [1]
assert not result.hasNext()
assert result.has_next()
assert result.get_next() == [1]
assert not result.has_next()
result.close()


def test_int_param(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.age < $AGE RETURN COUNT(*)", [("AGE", 1)])
assert result.hasNext()
assert result.getNext() == [0]
assert not result.hasNext()
assert result.has_next()
assert result.get_next() == [0]
assert not result.has_next()
result.close()


def test_double_param(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.eyeSight = $E RETURN COUNT(*)", [("E", 5.0)])
assert result.hasNext()
assert result.getNext() == [2]
assert not result.hasNext()
assert result.has_next()
assert result.get_next() == [2]
assert not result.has_next()
result.close()


def test_str_param(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.ID = 0 RETURN concat(a.fName, $S);", [("S", "HH")])
assert result.hasNext()
assert result.getNext() == ["AliceHH"]
assert not result.hasNext()
assert result.has_next()
assert result.get_next() == ["AliceHH"]
assert not result.has_next()
result.close()


def test_date_param(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.birthdate = $1 RETURN COUNT(*);",
[("1", datetime.date(1900, 1, 1))])
assert result.hasNext()
assert result.getNext() == [2]
assert not result.hasNext()
assert result.has_next()
assert result.get_next() == [2]
assert not result.has_next()
result.close()


def test_timestamp_param(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.registerTime = $1 RETURN COUNT(*);",
[("1", datetime.datetime(2011, 8, 20, 11, 25, 30))])
assert result.hasNext()
assert result.getNext() == [1]
assert not result.hasNext()
assert result.has_next()
assert result.get_next() == [1]
assert not result.has_next()
result.close()


Expand Down
23 changes: 23 additions & 0 deletions tools/python_api/test/test_query_result_close.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import subprocess
import sys


def test_query_result_close(get_tmp_path):
code = [
'import sys',
'sys.path.append("../build/")',
'import kuzu',
'db = kuzu.database("' + get_tmp_path + '")',
'conn = kuzu.connection(db)',
'conn.execute(\'CREATE NODE TABLE person (ID INT64, fName STRING, gender INT64,\
isStudent BOOLEAN, isWorker BOOLEAN, age INT64, eyeSight DOUBLE,\
birthdate DATE, registerTime TIMESTAMP, lastJobDuration INTERVAL,\
workedHours INT64[], usedNames STRING[], courseScoresPerTerm INT64[][],\
PRIMARY KEY (ID))\')',
'conn.execute(\'COPY person FROM \"../../../dataset/tinysnb/vPerson.csv\" (HEADER=true)\')',
'result = conn.execute("MATCH (a:person) WHERE a.ID = 0 RETURN a.isStudent;")',
# 'result.close()',
]
code = ';'.join(code)
result = subprocess.run([sys.executable, '-c', code])
assert result.returncode == 0
Loading

0 comments on commit b3b8d75

Please sign in to comment.