diff --git a/tools/python_api/src_py/torch_geometric_result_converter.py b/tools/python_api/src_py/torch_geometric_result_converter.py index 339f60f3920..99874c15061 100644 --- a/tools/python_api/src_py/torch_geometric_result_converter.py +++ b/tools/python_api/src_py/torch_geometric_result_converter.py @@ -119,7 +119,7 @@ def __extract_properties_from_node(self, node, label, node_property_names): if node_property_names[prop_name]['type'] == Type.INT64.value: curr_value = torch.LongTensor(node[prop_name]) elif node_property_names[prop_name]['type'] == Type.DOUBLE.value: - curr_value = torch.DoubleTensor(node[prop_name]) + curr_value = torch.FloatTensor(node[prop_name]) elif node_property_names[prop_name]['type'] == Type.BOOL.value: curr_value = torch.BoolTensor(node[prop_name]) except ValueError: diff --git a/tools/python_api/test/test_torch_geometric.py b/tools/python_api/test/test_torch_geometric.py index f1062c8ec6c..fb9780e4831 100644 --- a/tools/python_api/test/test_torch_geometric.py +++ b/tools/python_api/test/test_torch_geometric.py @@ -168,6 +168,39 @@ 7: [6], } +TENSOR_LIST_GROUND_TRUTH = { + 0: { + 'boolTensor': [True, False], + 'doubleTensor': [[0.1, 0.2], [0.3, 0.4]], + 'intTensor': [[[1, 2], [3, 4]], [[5, 6], [7, 8]]] + }, + 3: { + 'boolTensor': [True, False], + 'doubleTensor': [[0.1, 0.2], [0.3, 0.4]], + 'intTensor': [[[3, 4], [5, 6]], [[7, 8], [9, 10]]] + }, + 4: { + 'boolTensor': [False, True], + 'doubleTensor': [[0.4, 0.8], [0.7, 0.6]], + 'intTensor': [[[5, 6], [7, 8]], [[9, 10], [11, 12]]] + }, + 5: { + 'boolTensor': [True, True], + 'doubleTensor': [[0.4, 0.9], [0.5, 0.2]], + 'intTensor': [[[7, 8], [9, 10]], [[11, 12], [13, 14]]] + }, + 6: { + 'boolTensor': [False, True], + 'doubleTensor': [[0.2, 0.4], [0.5, 0.1]], + 'intTensor': [[[9, 10], [11, 12]], [[13, 14], [15, 16]]] + }, + 8: { + 'boolTensor': [False, True], + 'doubleTensor': [[0.6, 0.4], [0.6, 0.1]], + 'intTensor': [[[11, 12], [13, 14]], [[15, 16], [17, 18]]] + } +} + def test_to_torch_geometric_nodes_only(establish_connection): conn, _ = establish_connection @@ -378,3 +411,40 @@ def test_to_torch_geometric_heterogeneous_graph(establish_connection): assert dst in pos_to_idx['organisation'] assert src != dst assert pos_to_idx['organisation'][dst] in TINY_SNB_WORKS_AT_GROUND_TRUTH[pos_to_idx['person'][src]] + + +def test_multi_dimensonal_lists_to_torch_geometric(establish_connection): + conn, _ = establish_connection + query = "MATCH (t:tensor) RETURN t" + + res = conn.execute(query) + with warnings.catch_warnings(record=True) as ws: + torch_geometric_data, pos_to_idx = res.get_as_torch_geometric() + assert len(ws) == 1 + assert str(ws[0].message) == "Property tensor.oneDimInt has a null value. torch_geometric does not support null values. The property is ignored." + + bool_list = [] + float_list = [] + int_list = [] + + for i in range(len(pos_to_idx)): + idx = pos_to_idx[i] + bool_list.append(TENSOR_LIST_GROUND_TRUTH[idx]['boolTensor']) + float_list.append(TENSOR_LIST_GROUND_TRUTH[idx]['doubleTensor']) + int_list.append(TENSOR_LIST_GROUND_TRUTH[idx]['intTensor']) + + bool_tensor = torch.tensor(bool_list, dtype=torch.bool) + float_tensor = torch.tensor(float_list, dtype=torch.float32) + int_tensor = torch.tensor(int_list, dtype=torch.int64) + + assert torch_geometric_data.boolTensor.shape == bool_tensor.shape + assert torch_geometric_data.boolTensor.dtype == bool_tensor.dtype + assert torch.all(torch_geometric_data.boolTensor == bool_tensor) + + assert torch_geometric_data.doubleTensor.shape == float_tensor.shape + assert torch_geometric_data.doubleTensor.dtype == float_tensor.dtype + assert torch.all(torch_geometric_data.doubleTensor == float_tensor) + + assert torch_geometric_data.intTensor.shape == int_tensor.shape + assert torch_geometric_data.intTensor.dtype == int_tensor.dtype + assert torch.all(torch_geometric_data.intTensor == int_tensor)