Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return unconverted properties for PyG converter #1213

Merged
merged 1 commit into from
Jan 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 46 additions & 23 deletions tools/python_api/src_py/torch_geometric_result_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init__(self, query_result):
self.internal_id_to_pos_dict = {}
self.pos_to_primary_key_dict = {}
self.warning_messages = set()
self.ignored_properties = set()
self.unconverted_properties = {}
self.properties_to_extract = self.query_result._get_properties_to_extract()

def __get_node_property_names(self, table_name):
Expand Down Expand Up @@ -89,27 +89,30 @@ def __populate_nodes_dict_and_deduplicte_edges(self):
def __extract_properties_from_node(self, node, label, node_property_names):
import torch
for prop_name in node_property_names:
# Ignore properties that are marked as ignored
if (label, prop_name) in self.ignored_properties:
# If property is already marked as unconverted, then add it directly without further checks
if label in self.unconverted_properties and prop_name in self.unconverted_properties[label]:
self.__add_unconverted_property(node, label, prop_name)
continue

# Read primary key but do not add it to the node properties
if node_property_names[prop_name]["is_primary_key"]:
primary_key = node[prop_name]
continue

# Ignore properties that are not supported by torch_geometric
# Mark properties that are not supported by torch_geometric as unconverted
if node_property_names[prop_name]["type"] not in [Type.INT64.value, Type.DOUBLE.value, Type.BOOL.value]:
self.warning_messages.add(
"Property {}.{} of type {} is not supported by torch_geometric. The property is ignored."
"Property {}.{} of type {} is not supported by torch_geometric. The property is marked as unconverted."
.format(label, prop_name, node_property_names[prop_name]["type"]))
self.__ignore_property(label, prop_name)
self.__mark_property_unconverted(label, prop_name)
self.__add_unconverted_property(node, label, prop_name)
continue
if node[prop_name] is None:
self.warning_messages.add(
"Property {}.{} has a null value. torch_geometric does not support null values. The property is ignored."
"Property {}.{} has a null value. torch_geometric does not support null values. The property is marked as unconverted."
.format(label, prop_name))
self.__ignore_property(label, prop_name)
self.__mark_property_unconverted(label, prop_name)
self.__add_unconverted_property(node, label, prop_name)
continue

if node_property_names[prop_name]['dimension'] == 0:
Expand All @@ -124,18 +127,20 @@ def __extract_properties_from_node(self, node, label, node_property_names):
curr_value = torch.BoolTensor(node[prop_name])
except ValueError:
self.warning_messages.add(
"Property {}.{} cannot be converted to Tensor (likely due to nested list of variable length). The property is ignored."
"Property {}.{} cannot be converted to Tensor (likely due to nested list of variable length). The property is marked as unconverted."
.format(label, prop_name))
self.__ignore_property(label, prop_name)
self.__mark_property_unconverted(label, prop_name)
self.__add_unconverted_property(node, label, prop_name)
continue
# Check if the shape of the property is consistent
if label in self.nodes_dict and prop_name in self.nodes_dict[label]:
# If the shape is inconsistent, then ignore the property
# If the shape is inconsistent, then mark the property as unconverted
if curr_value.shape != self.nodes_dict[label][prop_name][0].shape:
self.warning_messages.add(
"Property {}.{} has an inconsistent shape. The property is ignored."
"Property {}.{} has an inconsistent shape. The property is marked as unconverted."
.format(label, prop_name))
self.__ignore_property(label, prop_name)
self.__mark_property_unconverted(label, prop_name)
self.__add_unconverted_property(node, label, prop_name)
continue

# Create the dictionary for the label if it does not exist
Expand All @@ -152,12 +157,27 @@ def __extract_properties_from_node(self, node, label, node_property_names):
pos = len(self.nodes_dict[label][prop_name]) - 1
return pos, primary_key

def __ignore_property(self, label, prop_name):
self.ignored_properties.add((label, prop_name))
if label in self.nodes_dict and prop_name in self.nodes_dict[label]:
del self.nodes_dict[label][prop_name]
if len(self.nodes_dict[label]) == 0:
del self.nodes_dict[label]
def __add_unconverted_property(self, node, label, prop_name):
self.unconverted_properties[label][prop_name].append(
node[prop_name])

def __mark_property_unconverted(self, label, prop_name):
import torch
if label not in self.unconverted_properties:
self.unconverted_properties[label] = {}
if prop_name not in self.unconverted_properties[label]:
if label in self.nodes_dict and prop_name in self.nodes_dict[label]:
self.unconverted_properties[label][prop_name] = self.nodes_dict[label][prop_name]
del self.nodes_dict[label][prop_name]
if len(self.nodes_dict[label]) == 0:
del self.nodes_dict[label]
for i in range(len(self.unconverted_properties[label][prop_name])):
# If the property is a tensor, convert it back to list (consistent with the original type)
if torch.is_tensor(self.unconverted_properties[label][prop_name][i]):
self.unconverted_properties[label][prop_name][i] = self.unconverted_properties[label][prop_name][i].tolist(
)
else:
self.unconverted_properties[label][prop_name] = []

def __populate_edges_dict(self):
# Post-process edges, map internal ids to positions
Expand All @@ -181,7 +201,7 @@ def __convert_to_torch_geometric(self):
import torch_geometric
if len(self.nodes_dict) == 0:
self.warning_messages.add(
"No nodes found or all nodes were ignored. Returning None.")
"No nodes found or all node properties are not converted. Returning None.")
return None

# If there is only one node type, then convert to torch_geometric.data.Data
Expand Down Expand Up @@ -227,11 +247,14 @@ def __convert_to_torch_geometric(self):
data.edge_index = edge_idx
pos_to_primary_key_dict = self.pos_to_primary_key_dict[
label] if not is_hetero else self.pos_to_primary_key_dict
return data, pos_to_primary_key_dict

unconverted_properties = self.unconverted_properties if is_hetero else self.unconverted_properties[next(
iter(self.unconverted_properties))]
return data, pos_to_primary_key_dict, unconverted_properties

def get_as_torch_geometric(self):
self.__populate_nodes_dict_and_deduplicte_edges()
self.__populate_edges_dict()
data, pos_to_primary_key_dict = self.__convert_to_torch_geometric()
data, pos_to_primary_key_dict, unconverted_properties = self.__convert_to_torch_geometric()
self.__print_warnings()
return data, pos_to_primary_key_dict
return data, pos_to_primary_key_dict, unconverted_properties
Loading