Skip to content

Commit

Permalink
Merge pull request #1213 from kuzudb/get-as-pyg
Browse files Browse the repository at this point in the history
Return unconverted properties for PyG converter
  • Loading branch information
mewim committed Jan 30, 2023
2 parents 89fba7c + 6795214 commit 339b703
Show file tree
Hide file tree
Showing 2 changed files with 184 additions and 51 deletions.
69 changes: 46 additions & 23 deletions tools/python_api/src_py/torch_geometric_result_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init__(self, query_result):
self.internal_id_to_pos_dict = {}
self.pos_to_primary_key_dict = {}
self.warning_messages = set()
self.ignored_properties = set()
self.unconverted_properties = {}
self.properties_to_extract = self.query_result._get_properties_to_extract()

def __get_node_property_names(self, table_name):
Expand Down Expand Up @@ -89,27 +89,30 @@ def __populate_nodes_dict_and_deduplicte_edges(self):
def __extract_properties_from_node(self, node, label, node_property_names):
import torch
for prop_name in node_property_names:
# Ignore properties that are marked as ignored
if (label, prop_name) in self.ignored_properties:
# If property is already marked as unconverted, then add it directly without further checks
if label in self.unconverted_properties and prop_name in self.unconverted_properties[label]:
self.__add_unconverted_property(node, label, prop_name)
continue

# Read primary key but do not add it to the node properties
if node_property_names[prop_name]["is_primary_key"]:
primary_key = node[prop_name]
continue

# Ignore properties that are not supported by torch_geometric
# Mark properties that are not supported by torch_geometric as unconverted
if node_property_names[prop_name]["type"] not in [Type.INT64.value, Type.DOUBLE.value, Type.BOOL.value]:
self.warning_messages.add(
"Property {}.{} of type {} is not supported by torch_geometric. The property is ignored."
"Property {}.{} of type {} is not supported by torch_geometric. The property is marked as unconverted."
.format(label, prop_name, node_property_names[prop_name]["type"]))
self.__ignore_property(label, prop_name)
self.__mark_property_unconverted(label, prop_name)
self.__add_unconverted_property(node, label, prop_name)
continue
if node[prop_name] is None:
self.warning_messages.add(
"Property {}.{} has a null value. torch_geometric does not support null values. The property is ignored."
"Property {}.{} has a null value. torch_geometric does not support null values. The property is marked as unconverted."
.format(label, prop_name))
self.__ignore_property(label, prop_name)
self.__mark_property_unconverted(label, prop_name)
self.__add_unconverted_property(node, label, prop_name)
continue

if node_property_names[prop_name]['dimension'] == 0:
Expand All @@ -124,18 +127,20 @@ def __extract_properties_from_node(self, node, label, node_property_names):
curr_value = torch.BoolTensor(node[prop_name])
except ValueError:
self.warning_messages.add(
"Property {}.{} cannot be converted to Tensor (likely due to nested list of variable length). The property is ignored."
"Property {}.{} cannot be converted to Tensor (likely due to nested list of variable length). The property is marked as unconverted."
.format(label, prop_name))
self.__ignore_property(label, prop_name)
self.__mark_property_unconverted(label, prop_name)
self.__add_unconverted_property(node, label, prop_name)
continue
# Check if the shape of the property is consistent
if label in self.nodes_dict and prop_name in self.nodes_dict[label]:
# If the shape is inconsistent, then ignore the property
# If the shape is inconsistent, then mark the property as unconverted
if curr_value.shape != self.nodes_dict[label][prop_name][0].shape:
self.warning_messages.add(
"Property {}.{} has an inconsistent shape. The property is ignored."
"Property {}.{} has an inconsistent shape. The property is marked as unconverted."
.format(label, prop_name))
self.__ignore_property(label, prop_name)
self.__mark_property_unconverted(label, prop_name)
self.__add_unconverted_property(node, label, prop_name)
continue

# Create the dictionary for the label if it does not exist
Expand All @@ -152,12 +157,27 @@ def __extract_properties_from_node(self, node, label, node_property_names):
pos = len(self.nodes_dict[label][prop_name]) - 1
return pos, primary_key

def __ignore_property(self, label, prop_name):
self.ignored_properties.add((label, prop_name))
if label in self.nodes_dict and prop_name in self.nodes_dict[label]:
del self.nodes_dict[label][prop_name]
if len(self.nodes_dict[label]) == 0:
del self.nodes_dict[label]
def __add_unconverted_property(self, node, label, prop_name):
self.unconverted_properties[label][prop_name].append(
node[prop_name])

def __mark_property_unconverted(self, label, prop_name):
import torch
if label not in self.unconverted_properties:
self.unconverted_properties[label] = {}
if prop_name not in self.unconverted_properties[label]:
if label in self.nodes_dict and prop_name in self.nodes_dict[label]:
self.unconverted_properties[label][prop_name] = self.nodes_dict[label][prop_name]
del self.nodes_dict[label][prop_name]
if len(self.nodes_dict[label]) == 0:
del self.nodes_dict[label]
for i in range(len(self.unconverted_properties[label][prop_name])):
# If the property is a tensor, convert it back to list (consistent with the original type)
if torch.is_tensor(self.unconverted_properties[label][prop_name][i]):
self.unconverted_properties[label][prop_name][i] = self.unconverted_properties[label][prop_name][i].tolist(
)
else:
self.unconverted_properties[label][prop_name] = []

def __populate_edges_dict(self):
# Post-process edges, map internal ids to positions
Expand All @@ -181,7 +201,7 @@ def __convert_to_torch_geometric(self):
import torch_geometric
if len(self.nodes_dict) == 0:
self.warning_messages.add(
"No nodes found or all nodes were ignored. Returning None.")
"No nodes found or all node properties are not converted. Returning None.")
return None

# If there is only one node type, then convert to torch_geometric.data.Data
Expand Down Expand Up @@ -227,11 +247,14 @@ def __convert_to_torch_geometric(self):
data.edge_index = edge_idx
pos_to_primary_key_dict = self.pos_to_primary_key_dict[
label] if not is_hetero else self.pos_to_primary_key_dict
return data, pos_to_primary_key_dict

unconverted_properties = self.unconverted_properties if is_hetero else self.unconverted_properties[next(
iter(self.unconverted_properties))]
return data, pos_to_primary_key_dict, unconverted_properties

def get_as_torch_geometric(self):
self.__populate_nodes_dict_and_deduplicte_edges()
self.__populate_edges_dict()
data, pos_to_primary_key_dict = self.__convert_to_torch_geometric()
data, pos_to_primary_key_dict, unconverted_properties = self.__convert_to_torch_geometric()
self.__print_warnings()
return data, pos_to_primary_key_dict
return data, pos_to_primary_key_dict, unconverted_properties
Loading

0 comments on commit 339b703

Please sign in to comment.