Skip to content

Commit

Permalink
Add tensor tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mewim committed Jan 25, 2023
1 parent f2e8d28 commit 0a32933
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def __extract_properties_from_node(self, node, label, node_property_names):
if node_property_names[prop_name]['type'] == Type.INT64.value:
curr_value = torch.LongTensor(node[prop_name])
elif node_property_names[prop_name]['type'] == Type.DOUBLE.value:
curr_value = torch.DoubleTensor(node[prop_name])
curr_value = torch.FloatTensor(node[prop_name])
elif node_property_names[prop_name]['type'] == Type.BOOL.value:
curr_value = torch.BoolTensor(node[prop_name])
except ValueError:
Expand Down
70 changes: 70 additions & 0 deletions tools/python_api/test/test_torch_geometric.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,39 @@
7: [6],
}

TENSOR_LIST_GROUND_TRUTH = {
0: {
'boolTensor': [True, False],
'doubleTensor': [[0.1, 0.2], [0.3, 0.4]],
'intTensor': [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]
},
3: {
'boolTensor': [True, False],
'doubleTensor': [[0.1, 0.2], [0.3, 0.4]],
'intTensor': [[[3, 4], [5, 6]], [[7, 8], [9, 10]]]
},
4: {
'boolTensor': [False, True],
'doubleTensor': [[0.4, 0.8], [0.7, 0.6]],
'intTensor': [[[5, 6], [7, 8]], [[9, 10], [11, 12]]]
},
5: {
'boolTensor': [True, True],
'doubleTensor': [[0.4, 0.9], [0.5, 0.2]],
'intTensor': [[[7, 8], [9, 10]], [[11, 12], [13, 14]]]
},
6: {
'boolTensor': [False, True],
'doubleTensor': [[0.2, 0.4], [0.5, 0.1]],
'intTensor': [[[9, 10], [11, 12]], [[13, 14], [15, 16]]]
},
8: {
'boolTensor': [False, True],
'doubleTensor': [[0.6, 0.4], [0.6, 0.1]],
'intTensor': [[[11, 12], [13, 14]], [[15, 16], [17, 18]]]
}
}


def test_to_torch_geometric_nodes_only(establish_connection):
conn, _ = establish_connection
Expand Down Expand Up @@ -378,3 +411,40 @@ def test_to_torch_geometric_heterogeneous_graph(establish_connection):
assert dst in pos_to_idx['organisation']
assert src != dst
assert pos_to_idx['organisation'][dst] in TINY_SNB_WORKS_AT_GROUND_TRUTH[pos_to_idx['person'][src]]


def test_multi_dimensonal_lists_to_torch_geometric(establish_connection):
conn, _ = establish_connection
query = "MATCH (t:tensor) RETURN t"

res = conn.execute(query)
with warnings.catch_warnings(record=True) as ws:
torch_geometric_data, pos_to_idx = res.get_as_torch_geometric()
assert len(ws) == 1
assert str(ws[0].message) == "Property tensor.oneDimInt has a null value. torch_geometric does not support null values. The property is ignored."

bool_list = []
float_list = []
int_list = []

for i in range(len(pos_to_idx)):
idx = pos_to_idx[i]
bool_list.append(TENSOR_LIST_GROUND_TRUTH[idx]['boolTensor'])
float_list.append(TENSOR_LIST_GROUND_TRUTH[idx]['doubleTensor'])
int_list.append(TENSOR_LIST_GROUND_TRUTH[idx]['intTensor'])

bool_tensor = torch.tensor(bool_list, dtype=torch.bool)
float_tensor = torch.tensor(float_list, dtype=torch.float32)
int_tensor = torch.tensor(int_list, dtype=torch.int64)

assert torch_geometric_data.boolTensor.shape == bool_tensor.shape
assert torch_geometric_data.boolTensor.dtype == bool_tensor.dtype
assert torch.all(torch_geometric_data.boolTensor == bool_tensor)

assert torch_geometric_data.doubleTensor.shape == float_tensor.shape
assert torch_geometric_data.doubleTensor.dtype == float_tensor.dtype
assert torch.all(torch_geometric_data.doubleTensor == float_tensor)

assert torch_geometric_data.intTensor.shape == int_tensor.shape
assert torch_geometric_data.intTensor.dtype == int_tensor.dtype
assert torch.all(torch_geometric_data.intTensor == int_tensor)

0 comments on commit 0a32933

Please sign in to comment.