Skip to content

Commit

Permalink
Add NetworkX tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mewim committed Jan 18, 2023
1 parent 038e1c5 commit 4ee3ee7
Show file tree
Hide file tree
Showing 4 changed files with 242 additions and 6 deletions.
2 changes: 2 additions & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,7 @@ add_kuzu_python_api_test(DataType test_datatype.py)
add_kuzu_python_api_test(PandaAPI test_df.py)
add_kuzu_python_api_test(Exception test_exception.py)
add_kuzu_python_api_test(GetHeader test_get_header.py)
add_kuzu_python_api_test(NetworkXAPI test_networkx.py)
add_kuzu_python_api_test(Parameter test_parameter.py)
add_kuzu_python_api_test(QueryResultClose test_query_result_close.py)
add_kuzu_python_api_test(WriteToCSV test_write_to_csv.py)
11 changes: 8 additions & 3 deletions tools/python_api/src_py/query_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def get_as_networkx(self, directed=True):
column_types = self.get_column_data_types()
column_to_extract = {}

# Iterate over columns and extract nodes and rels, ignoring other columns
for i in range(len(column_names)):
column_name = column_names[i]
column_type = column_types[i]
Expand All @@ -75,6 +76,7 @@ def get_as_networkx(self, directed=True):

nodes = {}
rels = {}
table_to_label_dict = {}

# De-duplicate nodes and rels
while self.has_next():
Expand All @@ -84,6 +86,7 @@ def get_as_networkx(self, directed=True):
if column_type == NODE_TYPE:
_id = row[i]["_id"]
nodes[(_id["table"], _id["offset"])] = row[i]
table_to_label_dict[_id["table"]] = row[i]["_label"]

elif column_type == REL_TYPE:
_src = row[i]["_src"]
Expand All @@ -94,15 +97,17 @@ def get_as_networkx(self, directed=True):
# Add nodes
for node in nodes.values():
_id = node["_id"]
node_id = str(_id["table"]) + "_" + str(_id["offset"])
node_id = node['_label'] + "_" + str(_id["offset"])
node[node['_label']] = True
nx_graph.add_node(node_id, **node)

# Add rels
for rel in rels.values():
_src = rel["_src"]
_dst = rel["_dst"]
src_id = str(_src["table"]) + "_" + str(_src["offset"])
dst_id = str(_dst["table"]) + "_" + str(_dst["offset"])
src_id = str(
table_to_label_dict[_src["table"]]) + "_" + str(_src["offset"])
dst_id = str(
table_to_label_dict[_dst["table"]]) + "_" + str(_dst["offset"])
nx_graph.add_edge(src_id, dst_id, **rel)
return nx_graph
7 changes: 4 additions & 3 deletions tools/python_api/test/test_main.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import pytest

from test_datatype import *
from test_parameter import *
from test_exception import *
from test_df import *
from test_write_to_csv import *
from test_exception import *
from test_get_header import *
from test_networkx import *
from test_parameter import *
from test_query_result_close import *
from test_write_to_csv import *

if __name__ == "__main__":
raise SystemExit(pytest.main([__file__]))
228 changes: 228 additions & 0 deletions tools/python_api/test/test_networkx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
import datetime
from pandas import Timestamp, Timedelta


def test_to_networkx_node(establish_connection):
conn, _ = establish_connection
query = "MATCH (p:person) return p"

res = conn.execute(query)
nx_graph = res.get_as_networkx()
nodes = list(nx_graph.nodes(data=True))
assert len(nodes) == 8

ground_truth = {
'ID': [0, 2, 3, 5, 7, 8, 9, 10],
'fName': ["Alice", "Bob", "Carol", "Dan", "Elizabeth", "Farooq", "Greg",
"Hubert Blaine Wolfeschlegelsteinhausenbergerdorff"],
'gender': [1, 2, 1, 2, 1, 2, 2, 2],
'isStudent': [True, True, False, False, False, True, False, False],
'eyeSight': [5.0, 5.1, 5.0, 4.8, 4.7, 4.5, 4.9, 4.9],
'birthdate': [datetime.date(1900, 1, 1), datetime.date(1900, 1, 1),
datetime.date(1940, 6, 22), datetime.date(1950, 7, 23),
datetime.date(1980, 10, 26), datetime.date(1980, 10, 26),
datetime.date(1980, 10, 26), datetime.date(1990, 11, 27)],
'registerTime': [Timestamp('2011-08-20 11:25:30'),
Timestamp('2008-11-03 15:25:30.000526'),
Timestamp(
'1911-08-20 02:32:21'), Timestamp('2031-11-30 12:25:30'),
Timestamp('1976-12-23 11:21:42'),
Timestamp('1972-07-31 13:22:30.678559'),
Timestamp('1976-12-23 04:41:42'), Timestamp('2023-02-21 13:25:30')],
'lastJobDuration': [Timedelta('1082 days 13:02:00'),
Timedelta('3750 days 13:00:00.000024'),
Timedelta('2 days 00:24:11'),
Timedelta('3750 days 13:00:00.000024'),
Timedelta('2 days 00:24:11'), Timedelta(
'0 days 00:18:00.024000'),
Timedelta('3750 days 13:00:00.000024'),
Timedelta('1082 days 13:02:00')],
'workedHours': [[10, 5], [12, 8], [4, 5], [1, 9], [2], [3, 4, 5, 6, 7], [1],
[10, 11, 12, 3, 4, 5, 6, 7]],
'usedNames': [["Aida"], ['Bobby'], ['Carmen', 'Fred'],
['Wolfeschlegelstein', 'Daniel'], [
'Ein'], ['Fesdwe'], ['Grad'],
['Ad', 'De', 'Hi', 'Kye', 'Orlan']],
'courseScoresPerTerm': [[[10, 8], [6, 7, 8]], [[8, 9], [9, 10]], [[8, 10]],
[[7, 4], [8, 8], [9]], [
[6], [7], [8]], [[8]], [[10]],
[[7], [10], [6, 7]]],
'_label': ['person', 'person', 'person', 'person', 'person', 'person',
'person', 'person'],
}
for i in range(len(nodes)):
node_id, node = nodes[i]
assert node_id == "%s_%d" % (node['_label'], node['_id']['offset'])
for key in ground_truth:
assert node[key] == ground_truth[key][i]


def test_networkx_undirected(establish_connection):
conn, _ = establish_connection
res = conn.execute(
"MATCH (p1:person)-[r:knows]->(p2:person) WHERE p1.ID <= 3 RETURN p1, r, p2")

nx_graph = res.get_as_networkx(directed=False)
assert not nx_graph.is_directed()

nodes = list(nx_graph.nodes(data=True))
assert len(nodes) == 4

edges = list(nx_graph.edges(data=True))

ground_truth_p = {
'ID': [0, 2, 3, 5, 7, 8, 9, 10],
'fName': ["Alice", "Bob", "Carol", "Dan", "Elizabeth", "Farooq", "Greg",
"Hubert Blaine Wolfeschlegelsteinhausenbergerdorff"],
'gender': [1, 2, 1, 2, 1, 2, 2, 2],
'isStudent': [True, True, False, False, False, True, False, False],
'eyeSight': [5.0, 5.1, 5.0, 4.8, 4.7, 4.5, 4.9, 4.9],
'birthdate': [datetime.date(1900, 1, 1), datetime.date(1900, 1, 1),
datetime.date(1940, 6, 22), datetime.date(1950, 7, 23),
datetime.date(1980, 10, 26), datetime.date(1980, 10, 26),
datetime.date(1980, 10, 26), datetime.date(1990, 11, 27)],
'registerTime': [Timestamp('2011-08-20 11:25:30'),
Timestamp('2008-11-03 15:25:30.000526'),
Timestamp(
'1911-08-20 02:32:21'), Timestamp('2031-11-30 12:25:30'),
Timestamp('1976-12-23 11:21:42'),
Timestamp('1972-07-31 13:22:30.678559'),
Timestamp('1976-12-23 04:41:42'), Timestamp('2023-02-21 13:25:30')],
'lastJobDuration': [Timedelta('1082 days 13:02:00'),
Timedelta('3750 days 13:00:00.000024'),
Timedelta('2 days 00:24:11'),
Timedelta('3750 days 13:00:00.000024'),
Timedelta('2 days 00:24:11'), Timedelta(
'0 days 00:18:00.024000'),
Timedelta('3750 days 13:00:00.000024'),
Timedelta('1082 days 13:02:00')],
'workedHours': [[10, 5], [12, 8], [4, 5], [1, 9], [2], [3, 4, 5, 6, 7], [1],
[10, 11, 12, 3, 4, 5, 6, 7]],
'usedNames': [["Aida"], ['Bobby'], ['Carmen', 'Fred'],
['Wolfeschlegelstein', 'Daniel'], [
'Ein'], ['Fesdwe'], ['Grad'],
['Ad', 'De', 'Hi', 'Kye', 'Orlan']],
'courseScoresPerTerm': [[[10, 8], [6, 7, 8]], [[8, 9], [9, 10]], [[8, 10]],
[[7, 4], [8, 8], [9]], [
[6], [7], [8]], [[8]], [[10]],
[[7], [10], [6, 7]]],
'_label': ['person', 'person', 'person', 'person', 'person', 'person',
'person', 'person'],
}
for (node_id, node) in nodes:
assert node_id == "%s_%d" % (node['_label'], node['_id']['offset'])

for (_, node) in nodes:
found = False
for i in range(len(ground_truth_p['ID'])):
if node['ID'] != ground_truth_p['ID'][i]:
continue
found = True
for key in ground_truth_p:
assert node[key] == ground_truth_p[key][i]
assert found

assert len(edges) == 6
# This should be a complete graph, so we check if an edge exists between
# every pair of nodes and that there are no self-loops
for i in range(len(nodes)):
assert not nx_graph.has_edge(nodes[i][0], nodes[i][0])
for j in range(i + 1, len(nodes)):
assert nx_graph.has_edge(nodes[i][0], nodes[j][0])


def test_networkx_directed(establish_connection):
conn, _ = establish_connection
res = conn.execute(
"MATCH (p:person)-[r:workAt]->(o:organisation) RETURN p, r, o")

nx_graph = res.get_as_networkx(directed=True)
assert nx_graph.is_directed()

nodes = list(nx_graph.nodes(data=True))
assert len(nodes) == 3

edges = list(nx_graph.edges(data=True))

ground_truth_p = {
'ID': [5, 7],
'fName': ["Dan", "Elizabeth"],
'gender': [2, 1],
'isStudent': [False, False],
'eyeSight': [4.8, 4.7],
'birthdate': [datetime.date(1950, 7, 23),
datetime.date(1980, 10, 26)],
'registerTime': [Timestamp('2031-11-30 12:25:30'),
Timestamp('1976-12-23 11:21:42')
],
'lastJobDuration': [
Timedelta('3750 days 13:00:00.000024'),
Timedelta('2 days 00:24:11')],
'workedHours': [[1, 9], [2]],
'usedNames': [
['Wolfeschlegelstein', 'Daniel'], [
'Ein']],
'courseScoresPerTerm': [
[[7, 4], [8, 8], [9]], [
[6], [7], [8]]],
'_label': ['person', 'person'],
}

for (node_id, node) in nodes:
assert node_id == "%s_%d" % (node['_label'], node['_id']['offset'])

for (_, node) in nodes:
if 'person' not in node:
continue
found = False
for i in range(len(ground_truth_p['ID'])):
if node['ID'] != ground_truth_p['ID'][i]:
continue
found = True
for key in ground_truth_p:
assert node[key] == ground_truth_p[key][i]
assert found

ground_truth_o = {'ID': [6],
'name': ['DEsWork'],
'orgCode': [824],
'mark': [4.1],
'score': [7],
'history': ['2 years 4 hours 22 us 34 minutes'],
'licenseValidInterval': [Timedelta(days=3, seconds=36000, microseconds=100000)],
'rating': [0.52],
'_label': ['organisation'],
}

for (_, node) in nodes:
if 'organisation' not in node:
continue
found = False
for i in range(len(ground_truth_o['ID'])):
if node['ID'] != ground_truth_o['ID'][i]:
continue
found = True
for key in ground_truth_o:
assert node[key] == ground_truth_o[key][i]
assert found

nodes_dict = dict(nx_graph.nodes(data=True))
edges = list(nx_graph.edges(data=True))
assert len(edges) == 2

years_ground_truth = [2010, 2015]
for (src, dst, edge) in edges:
assert nodes_dict[dst]['_label'] == 'organisation'
assert nodes_dict[dst]['ID'] == 6

assert nodes_dict[src]['_label'] == 'person'
assert nodes_dict[src]['ID'] in [5, 7]
assert edge['year'] in years_ground_truth

# If the edge is found, remove it from ground truth so we can check
# that all edges were found and no extra edges were found
del years_ground_truth[years_ground_truth.index(edge['year'])]
del nodes_dict[src]

assert len(years_ground_truth) == 0
assert len(nodes_dict) == 1 # Only the organisation node should be left

0 comments on commit 4ee3ee7

Please sign in to comment.