Skip to content

Commit

Permalink
Add Python binding for NODE & REL types; output query results to Netw…
Browse files Browse the repository at this point in the history
…orkX
  • Loading branch information
mewim committed Jan 18, 2023
1 parent c9653c0 commit 0b82dfe
Show file tree
Hide file tree
Showing 12 changed files with 560 additions and 7 deletions.
14 changes: 14 additions & 0 deletions src/include/common/types/value.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,13 @@ class NodeVal {
properties.emplace_back(key, std::move(value));
}

inline const vector<pair<std::string, unique_ptr<Value>>>& getProperties() const {
return properties;
}

inline const Value& getNodeID() const { return *idVal; }
inline const Value& getLabelName() const { return *labelVal; }

inline unique_ptr<NodeVal> copy() const { return make_unique<NodeVal>(*this); }

string toString() const;
Expand All @@ -129,10 +136,17 @@ class RelVal {
properties.emplace_back(key, std::move(value));
}

inline const vector<pair<std::string, unique_ptr<Value>>>& getProperties() const {
return properties;
}

inline unique_ptr<RelVal> copy() const { return make_unique<RelVal>(*this); }

string toString() const;

inline const Value& getSrcNodeID() const { return *srcNodeIDVal; }
inline const Value& getDstNodeID() const { return *dstNodeIDVal; }

private:
unique_ptr<Value> srcNodeIDVal;
unique_ptr<Value> dstNodeIDVal;
Expand Down
3 changes: 2 additions & 1 deletion test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ add_kuzu_python_api_test(DataType test_datatype.py)
add_kuzu_python_api_test(PandaAPI test_df.py)
add_kuzu_python_api_test(Exception test_exception.py)
add_kuzu_python_api_test(GetHeader test_get_header.py)
add_kuzu_python_api_test(Main test_main.py)
add_kuzu_python_api_test(NetworkXAPI test_networkx.py)
add_kuzu_python_api_test(Parameter test_parameter.py)
add_kuzu_python_api_test(QueryResultClose test_query_result_close.py)
add_kuzu_python_api_test(WriteToCSV test_write_to_csv.py)
1 change: 1 addition & 0 deletions tools/python_api/requirements_dev.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pybind11>=2.6.0
pytest
pandas
networkx~=3.0.0
numpy
7 changes: 7 additions & 0 deletions tools/python_api/src_cpp/include/py_query_result.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ class PyQueryResult {

py::list getColumnNames();

void resetIterator();

private:
static py::dict getPyDictFromProperties(
const vector<pair<std::string, unique_ptr<Value>>>& properties);

static py::dict convertNodeIdToPyDict(const Value& nodeIdVal);

unique_ptr<QueryResult> queryResult;
};
44 changes: 42 additions & 2 deletions tools/python_api/src_cpp/py_query_result.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ void PyQueryResult::initialize(py::handle& m) {
.def("close", &PyQueryResult::close)
.def("getAsDF", &PyQueryResult::getAsDF)
.def("getColumnNames", &PyQueryResult::getColumnNames)
.def("getColumnDataTypes", &PyQueryResult::getColumnDataTypes);
.def("getColumnDataTypes", &PyQueryResult::getColumnDataTypes)
.def("resetIterator", &PyQueryResult::resetIterator);
// PyDateTime_IMPORT is a macro that must be invoked before calling any other cpython datetime
// macros. One could also invoke this in a separate function like constructor. See
// https://docs.python.org/3/c-api/datetime.html for details.
Expand Down Expand Up @@ -106,8 +107,25 @@ py::object PyQueryResult::convertValueToPyObject(const Value& value) {
}
return move(list);
}
case NODE: {
auto nodeVal = value.getValue<NodeVal>();
auto dict = PyQueryResult::getPyDictFromProperties(nodeVal.getProperties());
dict["_label"] = py::cast(nodeVal.getLabelName().getValue<string>());
dict["_id"] = convertNodeIdToPyDict(nodeVal.getNodeID());
return move(dict);
}
case REL: {
auto relVal = value.getValue<RelVal>();
auto dict = PyQueryResult::getPyDictFromProperties(relVal.getProperties());
dict["_src"] = convertNodeIdToPyDict(relVal.getSrcNodeID());
dict["_dst"] = convertNodeIdToPyDict(relVal.getDstNodeID());
return move(dict);
}
case NODE_ID: {
return convertNodeIdToPyDict(value);
}
default:
throw NotImplementedException("Unsupported type2: " + Types::dataTypeToString(dataType));
throw NotImplementedException("Unsupported type: " + Types::dataTypeToString(dataType));
}
}

Expand All @@ -132,3 +150,25 @@ py::list PyQueryResult::getColumnNames() {
}
return move(result);
}

void PyQueryResult::resetIterator() {
queryResult->resetIterator();
}

py::dict PyQueryResult::getPyDictFromProperties(
const vector<pair<std::string, unique_ptr<Value>>>& properties) {
py::dict result;
for (auto i = 0u; i < properties.size(); ++i) {
auto& [name, value] = properties[i];
result[py::cast(name)] = convertValueToPyObject(*value);
}
return result;
}

py::dict PyQueryResult::convertNodeIdToPyDict(const Value& nodeIdVal) {
py::dict idDict;
auto nodeId = nodeIdVal.getValue<nodeID_t>();
idDict["offset"] = py::cast(nodeId.offset);
idDict["table"] = py::cast(nodeId.tableID);
return idDict;
}
8 changes: 8 additions & 0 deletions tools/python_api/src_cpp/py_query_result_converter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ void NPArrayWrapper::appendElement(Value* value) {
((PyObject**)dataBuffer)[numElements] = result;
break;
}
case NODE:
case REL: {
((py::dict*)dataBuffer)[numElements] = PyQueryResult::convertValueToPyObject(*value);
break;
}
case LIST: {
((py::list*)dataBuffer)[numElements] = PyQueryResult::convertValueToPyObject(*value);
break;
Expand Down Expand Up @@ -76,6 +81,8 @@ py::dtype NPArrayWrapper::convertToArrayType(const DataType& type) {
dtype = "bool";
break;
}
case NODE:
case REL:
case LIST:
case STRING: {
dtype = "object";
Expand Down Expand Up @@ -104,6 +111,7 @@ QueryResultConverter::QueryResultConverter(QueryResult* queryResult) : queryResu
}

py::object QueryResultConverter::toDF() {
queryResult->resetIterator();
while (queryResult->hasNext()) {
auto flatTuple = queryResult->getNext();
for (auto i = 0u; i < columns.size(); i++) {
Expand Down
67 changes: 67 additions & 0 deletions tools/python_api/src_py/query_result.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
NODE_TYPE = "NODE"
REL_TYPE = "REL"


class QueryResult:
def __init__(self, connection, query_result):
self.connection = connection
Expand Down Expand Up @@ -44,3 +48,66 @@ def get_column_data_types(self):
def get_column_names(self):
self.check_for_query_result_close()
return self._query_result.getColumnNames()

def reset_iterator(self):
self.check_for_query_result_close()
self._query_result.resetIterator()

def get_as_networkx(self, directed=True):
self.check_for_query_result_close()
import networkx as nx

if directed:
nx_graph = nx.DiGraph()
else:
nx_graph = nx.Graph()
column_names = self.get_column_names()
column_types = self.get_column_data_types()
column_to_extract = {}

# Iterate over columns and extract nodes and rels, ignoring other columns
for i in range(len(column_names)):
column_name = column_names[i]
column_type = column_types[i]
if column_type in [NODE_TYPE, REL_TYPE]:
column_to_extract[i] = (column_type, column_name)

self.reset_iterator()

nodes = {}
rels = {}
table_to_label_dict = {}

# De-duplicate nodes and rels
while self.has_next():
row = self.get_next()
for i in column_to_extract:
column_type, _ = column_to_extract[i]
if column_type == NODE_TYPE:
_id = row[i]["_id"]
nodes[(_id["table"], _id["offset"])] = row[i]
table_to_label_dict[_id["table"]] = row[i]["_label"]

elif column_type == REL_TYPE:
_src = row[i]["_src"]
_dst = row[i]["_dst"]
rels[(_src["table"], _src["offset"], _dst["table"],
_dst["offset"])] = row[i]

# Add nodes
for node in nodes.values():
_id = node["_id"]
node_id = node['_label'] + "_" + str(_id["offset"])
node[node['_label']] = True
nx_graph.add_node(node_id, **node)

# Add rels
for rel in rels.values():
_src = rel["_src"]
_dst = rel["_dst"]
src_id = str(
table_to_label_dict[_src["table"]]) + "_" + str(_src["offset"])
dst_id = str(
table_to_label_dict[_dst["table"]]) + "_" + str(_dst["offset"])
nx_graph.add_edge(src_id, dst_id, **rel)
return nx_graph
13 changes: 12 additions & 1 deletion tools/python_api/test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import sys
import pytest
import shutil

sys.path.append('../build/')
import kuzu

Expand All @@ -18,19 +19,29 @@ def init_tiny_snb(tmp_path):
"age INT64, eyeSight DOUBLE, birthdate DATE, registerTime TIMESTAMP, lastJobDuration "
"INTERVAL, workedHours INT64[], usedNames STRING[], courseScoresPerTerm INT64[][], PRIMARY "
"KEY (ID))")
conn.execute("COPY person FROM \"../../../dataset/tinysnb/vPerson.csv\" (HEADER=true)")
conn.execute(
"COPY person FROM \"../../../dataset/tinysnb/vPerson.csv\" (HEADER=true)")
conn.execute(
"create rel table knows (FROM person TO person, date DATE, meetTime TIMESTAMP, validInterval INTERVAL, "
"comments STRING[], MANY_MANY);")
conn.execute("COPY knows FROM \"../../../dataset/tinysnb/eKnows.csv\"")
conn.execute("create node table organisation (ID INT64, name STRING, orgCode INT64, mark DOUBLE, score INT64, history STRING, licenseValidInterval INTERVAL, rating DOUBLE, PRIMARY KEY (ID))")
conn.execute(
'COPY organisation FROM "../../../dataset/tinysnb/vOrganisation.csv" (HEADER=true)')
conn.execute(
'create rel table workAt (FROM person TO organisation, year INT64, MANY_ONE)')
conn.execute(
'COPY workAt FROM "../../../dataset/tinysnb/eWorkAt.csv" (HEADER=true)')
return output_path


@pytest.fixture
def establish_connection(init_tiny_snb):
db = kuzu.Database(init_tiny_snb, buffer_pool_size=256 * 1024 * 1024)
conn = kuzu.Connection(db, num_threads=4)
return conn, db


@pytest.fixture
def get_tmp_path(tmp_path):
return str(tmp_path)
43 changes: 43 additions & 0 deletions tools/python_api/test/test_datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,46 @@ def test_list_wrap(establish_connection):
assert not result.has_next()
result.close()


def test_node(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person) WHERE a.ID = 0 RETURN a")
assert result.has_next()
n = result.get_next()
assert(len(n) == 1)
n = n[0]
assert (n['_label'] == 'person')
assert (n['ID'] == 0)
assert (n['fName'] == 'Alice')
assert (n['gender'] == 1)
assert (n['isStudent'] == True)
assert (n['eyeSight'] == 5.0)
assert (n['birthdate'] == datetime.date(1900, 1, 1))
assert (n['registerTime'] == datetime.datetime(2011, 8, 20, 11, 25, 30))
assert (n['lastJobDuration'] == datetime.timedelta(
days=1082, seconds=46920))
assert (n['courseScoresPerTerm'] == [[10, 8], [6, 7, 8]])
assert (n['usedNames'] == ['Aida'])
assert not result.has_next()
result.close()


def test_rel(establish_connection):
conn, db = establish_connection
result = conn.execute(
"MATCH (p:person)-[r:workAt]->(o:organisation) WHERE p.ID = 5 RETURN p, r, o")
assert result.has_next()
n = result.get_next()
assert (len(n) == 3)
p = n[0]
r = n[1]
o = n[2]
assert (p['_label'] == 'person')
assert (p['ID'] == 5)
assert (o['_label'] == 'organisation')
assert (o['ID'] == 6)
assert (r['year'] == 2010)
assert (r['_src'] == p['_id'])
assert (r['_dst'] == o['_id'])
assert not result.has_next()
result.close()
Loading

0 comments on commit 0b82dfe

Please sign in to comment.