Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix PyG test errors #1239

Merged
merged 1 commit into from
Feb 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions tools/python_api/src_py/torch_geometric_result_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,15 @@ def __extract_properties_from_node(self, node, label, node_property_names):
pos = None
import torch
for prop_name in node_property_names:
# Read primary key
if node_property_names[prop_name]["is_primary_key"]:
primary_key = node[prop_name]

# If property is already marked as unconverted, then add it directly without further checks
if label in self.unconverted_properties and prop_name in self.unconverted_properties[label]:
pos = self.__add_unconverted_property(node, label, prop_name)
continue

# Read primary key
if node_property_names[prop_name]["is_primary_key"]:
primary_key = node[prop_name]

# Mark properties that are not supported by torch_geometric as unconverted
if node_property_names[prop_name]["type"] not in [Type.INT64.value, Type.DOUBLE.value, Type.BOOL.value]:
self.warning_messages.add(
Expand Down
40 changes: 38 additions & 2 deletions tools/python_api/test/test_torch_geometric.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,12 @@ def test_to_torch_geometric_nodes_only(establish_connection):
for w in ws:
assert str(w.message) in warnings_ground_truth

assert torch_geometric_data.ID.shape == torch.Size([8])
assert torch_geometric_data.ID.dtype == torch.int64
for i in range(8):
assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i]
]['ID'] == torch_geometric_data.ID[i].item()

assert torch_geometric_data.gender.shape == torch.Size([8])
assert torch_geometric_data.gender.dtype == torch.int64
for i in range(8):
Expand Down Expand Up @@ -319,6 +325,12 @@ def test_to_torch_geometric_homogeneous_graph(establish_connection):
for w in ws:
assert str(w.message) in warnings_ground_truth

assert torch_geometric_data.ID.shape == torch.Size([7])
assert torch_geometric_data.ID.dtype == torch.int64
for i in range(7):
assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx[i]
]['ID'] == torch_geometric_data.ID[i].item()

assert torch_geometric_data.gender.shape == torch.Size([7])
assert torch_geometric_data.gender.dtype == torch.int64
for i in range(7):
Expand Down Expand Up @@ -414,6 +426,12 @@ def test_to_torch_geometric_heterogeneous_graph(establish_connection):
for w in ws:
assert str(w.message) in warnings_ground_truth

assert torch_geometric_data['person'].ID.shape == torch.Size([4])
assert torch_geometric_data['person'].ID.dtype == torch.int64
for i in range(4):
assert TINY_SNB_PERSONS_GROUND_TRUTH[pos_to_idx['person'][i]
]['ID'] == torch_geometric_data['person'].ID[i].item()

assert torch_geometric_data['person'].gender.shape == torch.Size([4])
assert torch_geometric_data['person'].gender.dtype == torch.int64
for i in range(4):
Expand Down Expand Up @@ -481,6 +499,12 @@ def test_to_torch_geometric_heterogeneous_graph(establish_connection):
assert src != dst
assert pos_to_idx['person'][dst] in TINY_SNB_KNOWS_GROUND_TRUTH[pos_to_idx['person'][src]]

assert torch_geometric_data['organisation'].ID.shape == torch.Size([2])
assert torch_geometric_data['organisation'].ID.dtype == torch.int64
for i in range(2):
assert TINY_SNB_ORGANISATIONS_GROUND_TRUTH[pos_to_idx['organisation'][i]
]['ID'] == torch_geometric_data['organisation'].ID[i].item()

assert torch_geometric_data['organisation'].orgCode.shape == torch.Size([
2])
assert torch_geometric_data['organisation'].orgCode.dtype == torch.int64
Expand Down Expand Up @@ -558,6 +582,11 @@ def test_to_torch_geometric_multi_dimensonal_lists(establish_connection):
float_tensor = torch.tensor(float_list, dtype=torch.float32)
int_tensor = torch.tensor(int_list, dtype=torch.int64)

assert torch_geometric_data.ID.shape == torch.Size([len(pos_to_idx)])
assert torch_geometric_data.ID.dtype == torch.int64
for i in range(len(pos_to_idx)):
assert torch_geometric_data.ID[i].item() == pos_to_idx[i]

assert torch_geometric_data.boolTensor.shape == bool_tensor.shape
assert torch_geometric_data.boolTensor.dtype == bool_tensor.dtype
assert torch.all(torch_geometric_data.boolTensor == bool_tensor)
Expand All @@ -583,8 +612,9 @@ def test_to_torch_geometric_no_properties_converted(establish_connection):
res = conn.execute(query)
with warnings.catch_warnings(record=True) as ws:
torch_geometric_data, pos_to_idx, unconverted_properties = res.get_as_torch_geometric()
assert len(ws) == 2
assert len(ws) == 3
warnings_ground_truth = set([
"Property personLongString.name of type STRING is not supported by torch_geometric. The property is marked as unconverted.",
"Property personLongString.spouse of type STRING is not supported by torch_geometric. The property is marked as unconverted.",
"No nodes found or all node properties are not converted."])
for w in ws:
Expand All @@ -602,9 +632,15 @@ def test_to_torch_geometric_no_properties_converted(establish_connection):
assert pos_to_idx['personLongString'][dst] in PERSONLONGSTRING_KNOWS_GROUND_TRUTH[pos_to_idx['personLongString'][src]]

assert len(unconverted_properties) == 1
assert len(unconverted_properties['personLongString']) == 1
assert len(unconverted_properties['personLongString']) == 2

assert 'spouse' in unconverted_properties['personLongString']
assert len(unconverted_properties['personLongString']['spouse']) == 2
for i in range(2):
assert PERSONLONGSTRING_GROUND_TRUTH[pos_to_idx['personLongString'][i]
]['spouse'] == unconverted_properties['personLongString']['spouse'][i]

assert 'name' in unconverted_properties['personLongString']
for i in range(2):
assert PERSONLONGSTRING_GROUND_TRUTH[pos_to_idx['personLongString'][i]
]['name'] == unconverted_properties['personLongString']['name'][i]