From 35b64b907f2fd29e2a90707cf6a194f1a4c0e69d Mon Sep 17 00:00:00 2001 From: Chang Liu Date: Wed, 20 Sep 2023 17:54:18 +0800 Subject: [PATCH] Expose rel label for Python and Node.js API --- tools/nodejs_api/src_cpp/node_util.cpp | 3 ++- tools/nodejs_api/test/test_data_type.js | 2 ++ tools/python_api/src_cpp/py_query_result.cpp | 1 + tools/python_api/test/test_datatype.py | 2 ++ tools/python_api/test/test_torch_geometric.py | 10 +++++++--- 5 files changed, 14 insertions(+), 4 deletions(-) diff --git a/tools/nodejs_api/src_cpp/node_util.cpp b/tools/nodejs_api/src_cpp/node_util.cpp index 6a0617f650..e3dbcdc368 100644 --- a/tools/nodejs_api/src_cpp/node_util.cpp +++ b/tools/nodejs_api/src_cpp/node_util.cpp @@ -1,11 +1,11 @@ #include "include/node_util.h" #include "common/exception/exception.h" -#include "common/types/value/value.h" #include "common/types/value/nested.h" #include "common/types/value/node.h" #include "common/types/value/recursive_rel.h" #include "common/types/value/rel.h" +#include "common/types/value/value.h" Napi::Value Util::ConvertToNapiObject(const Value& value, Napi::Env env) { if (value.isNull()) { @@ -119,6 +119,7 @@ Napi::Value Util::ConvertToNapiObject(const Value& value, Napi::Env env) { } napiObj.Set("_src", ConvertNodeIdToNapiObject(RelVal::getSrcNodeID(&value), env)); napiObj.Set("_dst", ConvertNodeIdToNapiObject(RelVal::getDstNodeID(&value), env)); + napiObj.Set("_label", Napi::String::New(env, RelVal::getLabelName(&value))); return napiObj; } case LogicalTypeID::INTERNAL_ID: { diff --git a/tools/nodejs_api/test/test_data_type.js b/tools/nodejs_api/test/test_data_type.js index a341345256..781c8a3846 100644 --- a/tools/nodejs_api/test/test_data_type.js +++ b/tools/nodejs_api/test/test_data_type.js @@ -294,6 +294,7 @@ describe("REL", function () { assert.equal(rel.grading.length, 2); assert.equal(rel.grading[0], 2.1); assert.equal(rel.grading[1], 4.4); + assert.equal(rel._label, "workAt"); assert.approximately(rel.rating, 7.6, EPSILON); }); }); @@ -327,6 +328,7 @@ describe("RECURSIVE_REL", function () { usedAddress: null, address: null, note: null, + _label: "studyAt", _src: { offset: 0, table: 0, diff --git a/tools/python_api/src_cpp/py_query_result.cpp b/tools/python_api/src_cpp/py_query_result.cpp index 53e0ff78a7..2fc3b59ad3 100644 --- a/tools/python_api/src_cpp/py_query_result.cpp +++ b/tools/python_api/src_cpp/py_query_result.cpp @@ -169,6 +169,7 @@ py::object PyQueryResult::convertValueToPyObject(const Value& value) { py::dict dict; dict["_src"] = convertNodeIdToPyDict(RelVal::getSrcNodeID(&value)); dict["_dst"] = convertNodeIdToPyDict(RelVal::getDstNodeID(&value)); + dict["_label"] = py::cast(RelVal::getLabelName(&value)); auto numProperties = RelVal::getNumProperties(&value); for (auto i = 0u; i < numProperties; ++i) { auto key = py::str(RelVal::getPropertyName(&value, i)); diff --git a/tools/python_api/test/test_datatype.py b/tools/python_api/test/test_datatype.py index 0ef2fd46b1..17b410157b 100644 --- a/tools/python_api/test/test_datatype.py +++ b/tools/python_api/test/test_datatype.py @@ -141,6 +141,7 @@ def test_rel(establish_connection): assert (r['year'] == 2010) assert (r['_src'] == p['_id']) assert (r['_dst'] == o['_id']) + assert (r['_label'] == 'workAt') assert not result.has_next() result.close() @@ -179,6 +180,7 @@ def test_recursive_rel(establish_connection): rel = e["_rels"][0] excepted_rel = {'_src': {'offset': 0, 'table': 0}, '_dst': {'offset': 0, 'table': 1}, + '_label': 'studyAt', 'date': None, 'meetTime': None, 'validInterval': None, 'comments': None, 'year': 2021, 'places': ['wwAewsdndweusd', 'wek'], diff --git a/tools/python_api/test/test_torch_geometric.py b/tools/python_api/test/test_torch_geometric.py index 2f6426fddd..f5e5c807bb 100644 --- a/tools/python_api/test/test_torch_geometric.py +++ b/tools/python_api/test/test_torch_geometric.py @@ -188,7 +188,7 @@ def test_to_torch_geometric_homogeneous_graph(establish_connection): assert src != dst assert pos_to_idx[dst] in ground_truth.TINY_SNB_KNOWS_GROUND_TRUTH[pos_to_idx[src]] - assert len(edge_properties) == 4 + assert len(edge_properties) == 5 assert 'date' in edge_properties assert 'meetTime' in edge_properties assert 'validInterval' in edge_properties @@ -310,11 +310,12 @@ def test_to_torch_geometric_heterogeneous_graph(establish_connection): assert src != dst assert pos_to_idx['person'][dst] in ground_truth.TINY_SNB_KNOWS_GROUND_TRUTH[pos_to_idx['person'][src]] - assert len(edge_properties['person', 'person']) == 4 + assert len(edge_properties['person', 'person']) == 5 assert 'date' in edge_properties['person', 'person'] assert 'meetTime' in edge_properties['person', 'person'] assert 'validInterval' in edge_properties['person', 'person'] assert 'comments' in edge_properties['person', 'person'] + assert '_label' in edge_properties['person', 'person'] for i in range(3): src, dst = torch_geometric_data['person', 'person'].edge_index[0][i].item( ), torch_geometric_data['person', 'person'].edge_index[1][i].item() @@ -328,6 +329,7 @@ def test_to_torch_geometric_heterogeneous_graph(establish_connection): original_src, original_dst)]['validInterval'] == edge_properties['person', 'person']['validInterval'][i] assert ground_truth.TINY_SNB_KNOWS_PROPERTIES_GROUND_TRUTH[( original_src, original_dst)]['comments'] == edge_properties['person', 'person']['comments'][i] + assert edge_properties['person', 'person']['_label'][i] == 'knows' assert torch_geometric_data['organisation'].ID.shape == torch.Size([2]) assert torch_geometric_data['organisation'].ID.dtype == torch.int64 @@ -386,14 +388,16 @@ def test_to_torch_geometric_heterogeneous_graph(establish_connection): assert dst in pos_to_idx['organisation'] assert src != dst assert pos_to_idx['organisation'][dst] in ground_truth.TINY_SNB_WORKS_AT_GROUND_TRUTH[pos_to_idx['person'][src]] - assert len(edge_properties['person', 'organisation']) == 3 + assert len(edge_properties['person', 'organisation']) == 4 assert 'year' in edge_properties['person', 'organisation'] + assert '_label' in edge_properties['person', 'organisation'] for i in range(2): src, dst = torch_geometric_data['person', 'organisation'].edge_index[0][i].item( ), torch_geometric_data['person', 'organisation'].edge_index[1][i].item() original_src, original_dst = pos_to_idx['person'][src], pos_to_idx['organisation'][dst] assert ground_truth.TINY_SNB_WORKS_AT_PROPERTIES_GROUND_TRUTH[( original_src, original_dst)]['year'] == edge_properties['person', 'organisation']['year'][i] + assert edge_properties['person', 'organisation']['_label'][i] == 'workAt' def test_to_torch_geometric_multi_dimensional_lists(establish_connection):