diff --git a/tools/python_api/src_py/query_result.py b/tools/python_api/src_py/query_result.py index 0b34342ef7..fb14ed7346 100644 --- a/tools/python_api/src_py/query_result.py +++ b/tools/python_api/src_py/query_result.py @@ -22,7 +22,7 @@ class QueryResult: close() Close the query result. - + get_as_df() Get the query result as a Pandas DataFrame. @@ -35,10 +35,14 @@ class QueryResult: get_column_names() Get the names of the columns in the query result. + get_as_networkx(directed=True) + Converts the nodes and rels in query result into a NetworkX graph representation. + get_as_torch_geometric() - Get the query result as a PyTorch Geometric Data object. + Converts the nodes and rels in query result into a PyTorch Geometric graph representation + torch_geometric.data.Data or torch_geometric.data.HeteroData. """ - + def __init__(self, connection, query_result): """ Parameters @@ -200,19 +204,22 @@ def reset_iterator(self): def get_as_networkx(self, directed=True): """ - Get the query result as a NetworkX graph. + Convert the nodes and rels in query result into a NetworkX directed or undirected graph + with the following rules: + Columns with data type other than node or rel will be ignored. + Duplicated nodes and rels will be converted only once. Parameters ---------- directed : bool Whether the graph should be directed. Defaults to True. - + Returns ------- networkx.DiGraph or networkx.Graph Query result as a NetworkX graph. """ - + self.check_for_query_result_close() import networkx as nx @@ -253,7 +260,8 @@ def encode_node_id(node, table_primary_key_dict): _id = node["_id"] node_id = node['_label'] + "_" + str(_id["offset"]) if node['_label'] not in table_primary_key_dict: - props = self.connection._get_node_property_names(node['_label']) + props = self.connection._get_node_property_names( + node['_label']) for prop_name in props: if props[prop_name]['is_primary_key']: table_primary_key_dict[node['_label']] = prop_name @@ -288,12 +296,43 @@ def _get_properties_to_extract(self): def get_as_torch_geometric(self): """ - Get the query result as a PyTorch Geometric graph. + Converts the nodes and rels in query result into a PyTorch Geometric graph representation + torch_geometric.data.Data or torch_geometric.data.HeteroData. + + For node conversion, numerical and boolean properties are directly converted into tensor and + stored in Data/HeteroData. For properties cannot be converted into tensor automatically + (please refer to the notes below for more detail), they are returned as unconverted_properties. + + For rel conversion, rel is converted into edge_index tensor director. Edge properties are returned + as edge_properties. + + Node properties that cannot be converted into tensor automatically: + - If the type of a node property is not one of INT64, DOUBLE, or BOOL, it cannot be converted + automatically. + - If a node property contains a null value, it cannot be converted automatically. + - If a node property contains a nested list of variable length (e.g. [[1,2],[3]]), it cannot be + converted automatically. + - If a node property is a list or nested list, but the shape is inconsistent (e.g. the list length + is 6 for one node but 5 for another node), it cannot be converted automatically. + + Additional conversion rules: + - Columns with data type other than node or rel will be ignored. + - Duplicated nodes and rels will be converted only once. Returns ------- torch_geometric.data.Data or torch_geometric.data.HeteroData - Query result as a PyTorch Geometric graph. + Query result as a PyTorch Geometric graph. Containing numeric or boolean node properties + and edge_index tensor. + dict + A dictionary that maps the positional offset of each node in Data/HeteroData to its primary + key in the database. + dict + A dictionary contains node properties that cannot be converted into tensor automatically. The + order of values for each property is aligned with nodes in Data/HeteroData. + dict + A dictionary contains edge properties. The order of values for each property is aligned with + edge_index in Data/HeteroData. """ self.check_for_query_result_close()