Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Python binding for NODE & REL types; output query results to NetworkX #1192

Merged
merged 1 commit into from
Jan 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions .github/workflows/ci-workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ jobs:

- name: test
run: CC=gcc CXX=g++ make test NUM_THREADS=32


- name: python test
run: CC=gcc CXX=g++ make pytest NUM_THREADS=32

gcc-build-test-with-asan:
name: gcc build & test with asan
needs: [gcc-build-test]
Expand All @@ -32,7 +35,7 @@ jobs:
- run: pip install --user -r tools/python_api/requirements_dev.txt

- name: build debug
run: CC=gcc CXX=g++ make NUM_THREADS=32 alldebug
run: CC=gcc CXX=g++ make alldebug NUM_THREADS=32

- name: run test with asan
run: ctest
Expand Down Expand Up @@ -62,6 +65,9 @@ jobs:

- name: test
run: CC=clang-14 CXX=clang++-14 make test NUM_THREADS=32

- name: python test
run: CC=clang-14 CXX=clang++-14 make pytest NUM_THREADS=32

clang-formatting-check:
name: clang-formatting-check
Expand Down
5 changes: 5 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ test: arrow
cd $(ROOT_DIR)/build/release/test && \
ctest

pytest: arrow
$(MAKE) release
cd $(ROOT_DIR)/tools/python_api/test && \
python3 -m pytest -v test_main.py

clean-external:
rm -rf external/build

Expand Down
16 changes: 16 additions & 0 deletions src/common/types/value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,14 @@ NodeVal::NodeVal(const NodeVal& other) {
}
}

nodeID_t NodeVal::getNodeID() const {
return idVal->getValue<nodeID_t>();
}

string NodeVal::getLabelName() const {
return labelVal->getValue<string>();
}

string NodeVal::toString() const {
std::string result = "(";
result += idVal->toString();
Expand All @@ -281,6 +289,14 @@ RelVal::RelVal(const RelVal& other) {
}
}

nodeID_t RelVal::getSrcNodeID() const {
return srcNodeIDVal->getValue<nodeID_t>();
}

nodeID_t RelVal::getDstNodeID() const {
return dstNodeIDVal->getValue<nodeID_t>();
}

string RelVal::toString() const {
std::string result;
result += "(" + srcNodeIDVal->toString() + ")";
Expand Down
14 changes: 14 additions & 0 deletions src/include/common/types/value.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,13 @@ class NodeVal {
properties.emplace_back(key, std::move(value));
}

inline const vector<pair<std::string, unique_ptr<Value>>>& getProperties() const {
return properties;
}

nodeID_t getNodeID() const;
string getLabelName() const;

inline unique_ptr<NodeVal> copy() const { return make_unique<NodeVal>(*this); }

string toString() const;
Expand All @@ -129,8 +136,15 @@ class RelVal {
properties.emplace_back(key, std::move(value));
}

inline const vector<pair<std::string, unique_ptr<Value>>>& getProperties() const {
return properties;
}

inline unique_ptr<RelVal> copy() const { return make_unique<RelVal>(*this); }

nodeID_t getSrcNodeID() const;
nodeID_t getDstNodeID() const;

string toString() const;

private:
Expand Down
14 changes: 0 additions & 14 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,3 @@ add_subdirectory(processor)
add_subdirectory(runner)
add_subdirectory(storage)
add_subdirectory(transaction)

function(add_kuzu_python_api_test TEST_NAME FILE_NAME)
add_test(NAME PythonAPI.${TEST_NAME}
COMMAND ${PYTHON_EXECUTABLE} -m pytest ${PROJECT_SOURCE_DIR}/tools/python_api/test/${FILE_NAME}
WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}/tools/python_api/test)
endfunction()

add_kuzu_python_api_test(DataType test_datatype.py)
add_kuzu_python_api_test(PandaAPI test_df.py)
add_kuzu_python_api_test(Exception test_exception.py)
add_kuzu_python_api_test(GetHeader test_get_header.py)
add_kuzu_python_api_test(Main test_main.py)
add_kuzu_python_api_test(Parameter test_parameter.py)
add_kuzu_python_api_test(WriteToCSV test_write_to_csv.py)
1 change: 1 addition & 0 deletions tools/python_api/requirements_dev.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pybind11>=2.6.0
pytest
pandas
networkx~=3.0.0
numpy
7 changes: 7 additions & 0 deletions tools/python_api/src_cpp/include/py_query_result.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ class PyQueryResult {

py::list getColumnNames();

void resetIterator();

private:
static py::dict getPyDictFromProperties(
const vector<pair<std::string, unique_ptr<Value>>>& properties);

static py::dict convertNodeIdToPyDict(const nodeID_t& nodeId);

unique_ptr<QueryResult> queryResult;
};
43 changes: 41 additions & 2 deletions tools/python_api/src_cpp/py_query_result.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ void PyQueryResult::initialize(py::handle& m) {
.def("close", &PyQueryResult::close)
.def("getAsDF", &PyQueryResult::getAsDF)
.def("getColumnNames", &PyQueryResult::getColumnNames)
.def("getColumnDataTypes", &PyQueryResult::getColumnDataTypes);
.def("getColumnDataTypes", &PyQueryResult::getColumnDataTypes)
.def("resetIterator", &PyQueryResult::resetIterator);
// PyDateTime_IMPORT is a macro that must be invoked before calling any other cpython datetime
// macros. One could also invoke this in a separate function like constructor. See
// https://docs.python.org/3/c-api/datetime.html for details.
Expand Down Expand Up @@ -106,8 +107,25 @@ py::object PyQueryResult::convertValueToPyObject(const Value& value) {
}
return move(list);
}
case NODE: {
auto nodeVal = value.getValue<NodeVal>();
auto dict = PyQueryResult::getPyDictFromProperties(nodeVal.getProperties());
dict["_label"] = py::cast(nodeVal.getLabelName());
dict["_id"] = convertNodeIdToPyDict(nodeVal.getNodeID());
return move(dict);
}
case REL: {
auto relVal = value.getValue<RelVal>();
auto dict = PyQueryResult::getPyDictFromProperties(relVal.getProperties());
dict["_src"] = convertNodeIdToPyDict(relVal.getSrcNodeID());
dict["_dst"] = convertNodeIdToPyDict(relVal.getDstNodeID());
return move(dict);
}
case NODE_ID: {
return convertNodeIdToPyDict(value.getValue<nodeID_t>());
}
default:
throw NotImplementedException("Unsupported type2: " + Types::dataTypeToString(dataType));
throw NotImplementedException("Unsupported type: " + Types::dataTypeToString(dataType));
}
}

Expand All @@ -132,3 +150,24 @@ py::list PyQueryResult::getColumnNames() {
}
return move(result);
}

void PyQueryResult::resetIterator() {
queryResult->resetIterator();
}

py::dict PyQueryResult::getPyDictFromProperties(
const vector<pair<std::string, unique_ptr<Value>>>& properties) {
py::dict result;
for (auto i = 0u; i < properties.size(); ++i) {
auto& [name, value] = properties[i];
result[py::cast(name)] = convertValueToPyObject(*value);
}
return result;
}

py::dict PyQueryResult::convertNodeIdToPyDict(const nodeID_t& nodeId) {
py::dict idDict;
idDict["offset"] = py::cast(nodeId.offset);
idDict["table"] = py::cast(nodeId.tableID);
return idDict;
}
8 changes: 8 additions & 0 deletions tools/python_api/src_cpp/py_query_result_converter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ void NPArrayWrapper::appendElement(Value* value) {
((PyObject**)dataBuffer)[numElements] = result;
break;
}
case NODE:
case REL: {
((py::dict*)dataBuffer)[numElements] = PyQueryResult::convertValueToPyObject(*value);
break;
}
case LIST: {
((py::list*)dataBuffer)[numElements] = PyQueryResult::convertValueToPyObject(*value);
break;
Expand Down Expand Up @@ -76,6 +81,8 @@ py::dtype NPArrayWrapper::convertToArrayType(const DataType& type) {
dtype = "bool";
break;
}
case NODE:
case REL:
case LIST:
case STRING: {
dtype = "object";
Expand Down Expand Up @@ -104,6 +111,7 @@ QueryResultConverter::QueryResultConverter(QueryResult* queryResult) : queryResu
}

py::object QueryResultConverter::toDF() {
queryResult->resetIterator();
while (queryResult->hasNext()) {
auto flatTuple = queryResult->getNext();
for (auto i = 0u; i < columns.size(); i++) {
Expand Down
67 changes: 67 additions & 0 deletions tools/python_api/src_py/query_result.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
NODE_TYPE = "NODE"
REL_TYPE = "REL"


class QueryResult:
def __init__(self, connection, query_result):
self.connection = connection
Expand Down Expand Up @@ -44,3 +48,66 @@ def get_column_data_types(self):
def get_column_names(self):
self.check_for_query_result_close()
return self._query_result.getColumnNames()

def reset_iterator(self):
self.check_for_query_result_close()
self._query_result.resetIterator()

def get_as_networkx(self, directed=True):
self.check_for_query_result_close()
import networkx as nx

if directed:
nx_graph = nx.DiGraph()
else:
nx_graph = nx.Graph()
column_names = self.get_column_names()
column_types = self.get_column_data_types()
column_to_extract = {}

# Iterate over columns and extract nodes and rels, ignoring other columns
for i in range(len(column_names)):
column_name = column_names[i]
column_type = column_types[i]
if column_type in [NODE_TYPE, REL_TYPE]:
column_to_extract[i] = (column_type, column_name)

self.reset_iterator()

nodes = {}
rels = {}
table_to_label_dict = {}

# De-duplicate nodes and rels
while self.has_next():
row = self.get_next()
for i in column_to_extract:
column_type, _ = column_to_extract[i]
if column_type == NODE_TYPE:
_id = row[i]["_id"]
nodes[(_id["table"], _id["offset"])] = row[i]
table_to_label_dict[_id["table"]] = row[i]["_label"]

elif column_type == REL_TYPE:
_src = row[i]["_src"]
_dst = row[i]["_dst"]
rels[(_src["table"], _src["offset"], _dst["table"],
_dst["offset"])] = row[i]

# Add nodes
for node in nodes.values():
_id = node["_id"]
node_id = node['_label'] + "_" + str(_id["offset"])
node[node['_label']] = True
nx_graph.add_node(node_id, **node)

# Add rels
for rel in rels.values():
_src = rel["_src"]
_dst = rel["_dst"]
src_id = str(
table_to_label_dict[_src["table"]]) + "_" + str(_src["offset"])
dst_id = str(
table_to_label_dict[_dst["table"]]) + "_" + str(_dst["offset"])
nx_graph.add_edge(src_id, dst_id, **rel)
return nx_graph
13 changes: 12 additions & 1 deletion tools/python_api/test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import sys
import pytest
import shutil

sys.path.append('../build/')
import kuzu

Expand All @@ -18,19 +19,29 @@ def init_tiny_snb(tmp_path):
"age INT64, eyeSight DOUBLE, birthdate DATE, registerTime TIMESTAMP, lastJobDuration "
"INTERVAL, workedHours INT64[], usedNames STRING[], courseScoresPerTerm INT64[][], PRIMARY "
"KEY (ID))")
conn.execute("COPY person FROM \"../../../dataset/tinysnb/vPerson.csv\" (HEADER=true)")
conn.execute(
"COPY person FROM \"../../../dataset/tinysnb/vPerson.csv\" (HEADER=true)")
conn.execute(
"create rel table knows (FROM person TO person, date DATE, meetTime TIMESTAMP, validInterval INTERVAL, "
"comments STRING[], MANY_MANY);")
conn.execute("COPY knows FROM \"../../../dataset/tinysnb/eKnows.csv\"")
conn.execute("create node table organisation (ID INT64, name STRING, orgCode INT64, mark DOUBLE, score INT64, history STRING, licenseValidInterval INTERVAL, rating DOUBLE, PRIMARY KEY (ID))")
conn.execute(
'COPY organisation FROM "../../../dataset/tinysnb/vOrganisation.csv" (HEADER=true)')
conn.execute(
'create rel table workAt (FROM person TO organisation, year INT64, MANY_ONE)')
conn.execute(
'COPY workAt FROM "../../../dataset/tinysnb/eWorkAt.csv" (HEADER=true)')
return output_path


@pytest.fixture
def establish_connection(init_tiny_snb):
db = kuzu.Database(init_tiny_snb, buffer_pool_size=256 * 1024 * 1024)
conn = kuzu.Connection(db, num_threads=4)
return conn, db


@pytest.fixture
def get_tmp_path(tmp_path):
return str(tmp_path)
43 changes: 43 additions & 0 deletions tools/python_api/test/test_datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,46 @@ def test_list_wrap(establish_connection):
assert not result.has_next()
result.close()


def test_node(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.ID = 0 RETURN a")
assert result.has_next()
n = result.get_next()
assert(len(n) == 1)
n = n[0]
assert (n['_label'] == 'person')
assert (n['ID'] == 0)
assert (n['fName'] == 'Alice')
assert (n['gender'] == 1)
assert (n['isStudent'] == True)
assert (n['eyeSight'] == 5.0)
assert (n['birthdate'] == datetime.date(1900, 1, 1))
assert (n['registerTime'] == datetime.datetime(2011, 8, 20, 11, 25, 30))
assert (n['lastJobDuration'] == datetime.timedelta(
days=1082, seconds=46920))
assert (n['courseScoresPerTerm'] == [[10, 8], [6, 7, 8]])
assert (n['usedNames'] == ['Aida'])
assert not result.has_next()
result.close()


def test_rel(establish_connection):
conn, db = establish_connection
result = conn.execute(
"MATCH (p:person)-[r:workAt]->(o:organisation) WHERE p.ID = 5 RETURN p, r, o")
assert result.has_next()
n = result.get_next()
assert (len(n) == 3)
p = n[0]
r = n[1]
o = n[2]
assert (p['_label'] == 'person')
assert (p['ID'] == 5)
assert (o['_label'] == 'organisation')
assert (o['ID'] == 6)
assert (r['year'] == 2010)
assert (r['_src'] == p['_id'])
assert (r['_dst'] == o['_id'])
assert not result.has_next()
result.close()
Loading