Skip to content

Commit

Permalink
Support use of QueryResult as a context manager (#3009)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie committed Mar 9, 2024
1 parent c554a20 commit 45c5aa9
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 46 deletions.
113 changes: 86 additions & 27 deletions tools/python_api/src_py/query_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ def __init__(self, connection, query_result):
self._query_result = query_result
self.is_closed = False

def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, traceback):
self.close()

def __del__(self):
self.close()

Expand Down Expand Up @@ -73,18 +79,22 @@ def close(self):
Close the query result.
"""

if self.is_closed:
return
self._query_result.close()
# Allows the connection to be garbage collected if the query result
# is closed manually by the user.
self.connection = None
self.is_closed = True
if not self.is_closed:
# Allows the connection to be garbage collected if the query result
# is closed manually by the user.
self._query_result.close()
self.connection = None
self.is_closed = True

def get_as_df(self):
"""
Get the query result as a Pandas DataFrame.
See Also
--------
get_as_pl : Get the query result as a Polars DataFrame.
get_as_arrow : Get the query result as a PyArrow Table.
Returns
-------
pandas.DataFrame
Expand All @@ -102,14 +112,26 @@ def get_as_pl(self):
"""
Get the query result as a Polars DataFrame.
See Also
--------
get_as_df : Get the query result as a Pandas DataFrame.
get_as_arrow : Get the query result as a PyArrow Table.
Returns
-------
polars.DataFrame
Query result as a Polars DataFrame.
"""

import polars as pl
return pl.from_arrow(data=self.get_as_arrow(10_000))

target_n_elems = (
10_000_000 # adaptive chunk_size; target 10m elements per chunk
)
target_chunk_size = max(target_n_elems // len(self.get_column_names()), 10)
return pl.from_arrow(
data=self.get_as_arrow(chunk_size=target_chunk_size),
)

def get_as_arrow(self, chunk_size):
"""
Expand All @@ -120,6 +142,11 @@ def get_as_arrow(self, chunk_size):
chunk_size : int
Number of rows to include in each chunk.
See Also
--------
get_as_pl : Get the query result as a Polars DataFrame.
get_as_df : Get the query result as a Pandas DataFrame.
Returns
-------
pyarrow.Table
Expand Down Expand Up @@ -159,6 +186,26 @@ def get_column_names(self):
self.check_for_query_result_close()
return self._query_result.getColumnNames()

def get_schema(self):
"""
Get the column schema of the query result.
Returns
-------
dict
Schema of the query result.
"""

self.check_for_query_result_close()
return {
name: dtype
for name, dtype in zip(
self._query_result.getColumnNames(),
self._query_result.getColumnDataTypes(),
)
}

def reset_iterator(self):
"""
Reset the iterator of the query result.
Expand Down Expand Up @@ -203,7 +250,9 @@ def get_as_networkx(self, directed=True):
table_primary_key_dict = {}

def encode_node_id(node, table_primary_key_dict):
return node['_label'] + "_" + str(node[table_primary_key_dict[node['_label']]])
return (
node["_label"] + "_" + str(node[table_primary_key_dict[node["_label"]]])
)

# De-duplicate nodes and rels
while self.has_next():
Expand All @@ -218,36 +267,42 @@ def encode_node_id(node, table_primary_key_dict):
elif column_type == Type.REL.value:
_src = row[i]["_src"]
_dst = row[i]["_dst"]
rels[(_src["table"], _src["offset"], _dst["table"],
_dst["offset"])] = row[i]
rels[
(_src["table"], _src["offset"], _dst["table"], _dst["offset"])
] = row[i]

elif column_type == Type.RECURSIVE_REL.value:
for node in row[i]['_nodes']:
for node in row[i]["_nodes"]:
_id = node["_id"]
nodes[(_id["table"], _id["offset"])] = node
table_to_label_dict[_id["table"]] = node["_label"]
for rel in row[i]['_rels']:
for rel in row[i]["_rels"]:
for key in rel:
if rel[key] is None:
del rel[key]
_src = rel["_src"]
_dst = rel["_dst"]
rels[(_src["table"], _src["offset"], _dst["table"],
_dst["offset"])] = rel
rels[
(
_src["table"],
_src["offset"],
_dst["table"],
_dst["offset"],
)
] = rel

# Add nodes
for node in nodes.values():
_id = node["_id"]
node_id = node['_label'] + "_" + str(_id["offset"])
if node['_label'] not in table_primary_key_dict:
props = self.connection._get_node_property_names(
node['_label'])
node_id = node["_label"] + "_" + str(_id["offset"])
if node["_label"] not in table_primary_key_dict:
props = self.connection._get_node_property_names(node["_label"])
for prop_name in props:
if props[prop_name]['is_primary_key']:
table_primary_key_dict[node['_label']] = prop_name
if props[prop_name]["is_primary_key"]:
table_primary_key_dict[node["_label"]] = prop_name
break
node_id = encode_node_id(node, table_primary_key_dict)
node[node['_label']] = True
node[node["_label"]] = True
nx_graph.add_node(node_id, **node)

# Add rels
Expand All @@ -270,7 +325,11 @@ def _get_properties_to_extract(self):
for i in range(len(column_names)):
column_name = column_names[i]
column_type = column_types[i]
if column_type in [Type.NODE.value, Type.REL.value, Type.RECURSIVE_REL.value]:
if column_type in [
Type.NODE.value,
Type.REL.value,
Type.RECURSIVE_REL.value,
]:
properties_to_extract[i] = (column_type, column_name)
return properties_to_extract

Expand All @@ -280,7 +339,7 @@ def get_as_torch_geometric(self):
torch_geometric.data.Data or torch_geometric.data.HeteroData.
For node conversion, numerical and boolean properties are directly converted into tensor and
stored in Data/HeteroData. For properties cannot be converted into tensor automatically
stored in Data/HeteroData. For properties cannot be converted into tensor automatically
(please refer to the notes below for more detail), they are returned as unconverted_properties.
For rel conversion, rel is converted into edge_index tensor director. Edge properties are returned
Expand All @@ -290,9 +349,9 @@ def get_as_torch_geometric(self):
- If the type of a node property is not one of INT64, DOUBLE, or BOOL, it cannot be converted
automatically.
- If a node property contains a null value, it cannot be converted automatically.
- If a node property contains a nested list of variable length (e.g. [[1,2],[3]]), it cannot be
- If a node property contains a nested list of variable length (e.g. [[1,2],[3]]), it cannot be
converted automatically.
- If a node property is a list or nested list, but the shape is inconsistent (e.g. the list length
- If a node property is a list or nested list, but the shape is inconsistent (e.g. the list length
is 6 for one node but 5 for another node), it cannot be converted automatically.
Additional conversion rules:
Expand Down Expand Up @@ -363,7 +422,7 @@ def get_num_tuples(self):
-------
int
Number of tuples.
"""
self.check_for_query_result_close()
return self._query_result.getNumTuples()
58 changes: 39 additions & 19 deletions tools/python_api/test/test_get_header.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,46 @@
def test_get_column_names(establish_connection):
conn, db = establish_connection
result = conn.execute("MATCH (a:person)-[e:knows]->(b:person) RETURN a.fName, e.date, b.ID;")
column_names = result.get_column_names()
assert column_names[0] == 'a.fName'
assert column_names[1] == 'e.date'
assert column_names[2] == 'b.ID'
result.close()
with conn.execute(
"MATCH (a:person)-[e:knows]->(b:person) RETURN a.fName, e.date, b.ID;"
) as result:
column_names = result.get_column_names()
assert column_names[0] == 'a.fName'
assert column_names[1] == 'e.date'
assert column_names[2] == 'b.ID'


def test_get_column_data_types(establish_connection):
conn, db = establish_connection
result = conn.execute(
with conn.execute(
"MATCH (p:person) RETURN p.ID, p.fName, p.isStudent, p.eyeSight, p.birthdate, p.registerTime, "
"p.lastJobDuration, p.workedHours, p.courseScoresPerTerm;")
column_data_types = result.get_column_data_types()
assert column_data_types[0] == 'INT64'
assert column_data_types[1] == 'STRING'
assert column_data_types[2] == 'BOOL'
assert column_data_types[3] == 'DOUBLE'
assert column_data_types[4] == 'DATE'
assert column_data_types[5] == 'TIMESTAMP'
assert column_data_types[6] == 'INTERVAL'
assert column_data_types[7] == 'INT64[]'
assert column_data_types[8] == 'INT64[][]'
result.close()
"p.lastJobDuration, p.workedHours, p.courseScoresPerTerm;"
) as result:
column_data_types = result.get_column_data_types()
assert column_data_types[0] == 'INT64'
assert column_data_types[1] == 'STRING'
assert column_data_types[2] == 'BOOL'
assert column_data_types[3] == 'DOUBLE'
assert column_data_types[4] == 'DATE'
assert column_data_types[5] == 'TIMESTAMP'
assert column_data_types[6] == 'INTERVAL'
assert column_data_types[7] == 'INT64[]'
assert column_data_types[8] == 'INT64[][]'


def test_get_schema(establish_connection):
conn, db = establish_connection
with conn.execute(
"MATCH (p:person) RETURN p.ID, p.fName, p.isStudent, p.eyeSight, p.birthdate, p.registerTime, "
"p.lastJobDuration, p.workedHours, p.courseScoresPerTerm;"
) as result:
assert result.get_schema() == {
'p.ID': 'INT64',
'p.fName': 'STRING',
'p.isStudent': 'BOOL',
'p.eyeSight': 'DOUBLE',
'p.birthdate': 'DATE',
'p.registerTime': 'TIMESTAMP',
'p.lastJobDuration': 'INTERVAL',
'p.workedHours': 'INT64[]',
'p.courseScoresPerTerm': 'INT64[][]'
}
10 changes: 10 additions & 0 deletions tools/python_api/test/test_query_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,13 @@ def test_explain(establish_connection):
result = conn.execute("EXPLAIN MATCH (a:person) WHERE a.ID = 0 RETURN a")
assert result.get_num_tuples() == 1
result.close()

def test_context_manager(establish_connection):
conn, db = establish_connection
with conn.execute("MATCH (a:person) WHERE a.ID = 0 RETURN a") as result:
assert result.get_num_tuples() == 1
assert result.get_compiling_time() > 0

# context exit guarantees immediately 'close' of the underlying QueryResult
# (don't have to wait for __del__, which may not ever actually get called)
assert result.is_closed

0 comments on commit 45c5aa9

Please sign in to comment.