Skip to content

Commit

Permalink
Merge pull request #2060 from kuzudb/expose-rel-labels-api
Browse files Browse the repository at this point in the history
Expose rel label for Python and Node.js API
  • Loading branch information
mewim authored Sep 20, 2023
2 parents b576565 + 35b64b9 commit 7bf0d6e
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 4 deletions.
3 changes: 2 additions & 1 deletion tools/nodejs_api/src_cpp/node_util.cpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#include "include/node_util.h"

#include "common/exception/exception.h"
#include "common/types/value/value.h"
#include "common/types/value/nested.h"
#include "common/types/value/node.h"
#include "common/types/value/recursive_rel.h"
#include "common/types/value/rel.h"
#include "common/types/value/value.h"

Napi::Value Util::ConvertToNapiObject(const Value& value, Napi::Env env) {
if (value.isNull()) {
Expand Down Expand Up @@ -119,6 +119,7 @@ Napi::Value Util::ConvertToNapiObject(const Value& value, Napi::Env env) {
}
napiObj.Set("_src", ConvertNodeIdToNapiObject(RelVal::getSrcNodeID(&value), env));
napiObj.Set("_dst", ConvertNodeIdToNapiObject(RelVal::getDstNodeID(&value), env));
napiObj.Set("_label", Napi::String::New(env, RelVal::getLabelName(&value)));
return napiObj;
}
case LogicalTypeID::INTERNAL_ID: {
Expand Down
2 changes: 2 additions & 0 deletions tools/nodejs_api/test/test_data_type.js
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ describe("REL", function () {
assert.equal(rel.grading.length, 2);
assert.equal(rel.grading[0], 2.1);
assert.equal(rel.grading[1], 4.4);
assert.equal(rel._label, "workAt");
assert.approximately(rel.rating, 7.6, EPSILON);
});
});
Expand Down Expand Up @@ -327,6 +328,7 @@ describe("RECURSIVE_REL", function () {
usedAddress: null,
address: null,
note: null,
_label: "studyAt",
_src: {
offset: 0,
table: 0,
Expand Down
1 change: 1 addition & 0 deletions tools/python_api/src_cpp/py_query_result.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ py::object PyQueryResult::convertValueToPyObject(const Value& value) {
py::dict dict;
dict["_src"] = convertNodeIdToPyDict(RelVal::getSrcNodeID(&value));
dict["_dst"] = convertNodeIdToPyDict(RelVal::getDstNodeID(&value));
dict["_label"] = py::cast(RelVal::getLabelName(&value));
auto numProperties = RelVal::getNumProperties(&value);
for (auto i = 0u; i < numProperties; ++i) {
auto key = py::str(RelVal::getPropertyName(&value, i));
Expand Down
2 changes: 2 additions & 0 deletions tools/python_api/test/test_datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def test_rel(establish_connection):
assert (r['year'] == 2010)
assert (r['_src'] == p['_id'])
assert (r['_dst'] == o['_id'])
assert (r['_label'] == 'workAt')
assert not result.has_next()
result.close()

Expand Down Expand Up @@ -179,6 +180,7 @@ def test_recursive_rel(establish_connection):
rel = e["_rels"][0]
excepted_rel = {'_src': {'offset': 0, 'table': 0},
'_dst': {'offset': 0, 'table': 1},
'_label': 'studyAt',
'date': None, 'meetTime': None, 'validInterval': None,
'comments': None, 'year': 2021,
'places': ['wwAewsdndweusd', 'wek'],
Expand Down
10 changes: 7 additions & 3 deletions tools/python_api/test/test_torch_geometric.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def test_to_torch_geometric_homogeneous_graph(establish_connection):
assert src != dst
assert pos_to_idx[dst] in ground_truth.TINY_SNB_KNOWS_GROUND_TRUTH[pos_to_idx[src]]

assert len(edge_properties) == 4
assert len(edge_properties) == 5
assert 'date' in edge_properties
assert 'meetTime' in edge_properties
assert 'validInterval' in edge_properties
Expand Down Expand Up @@ -310,11 +310,12 @@ def test_to_torch_geometric_heterogeneous_graph(establish_connection):
assert src != dst
assert pos_to_idx['person'][dst] in ground_truth.TINY_SNB_KNOWS_GROUND_TRUTH[pos_to_idx['person'][src]]

assert len(edge_properties['person', 'person']) == 4
assert len(edge_properties['person', 'person']) == 5
assert 'date' in edge_properties['person', 'person']
assert 'meetTime' in edge_properties['person', 'person']
assert 'validInterval' in edge_properties['person', 'person']
assert 'comments' in edge_properties['person', 'person']
assert '_label' in edge_properties['person', 'person']
for i in range(3):
src, dst = torch_geometric_data['person', 'person'].edge_index[0][i].item(
), torch_geometric_data['person', 'person'].edge_index[1][i].item()
Expand All @@ -328,6 +329,7 @@ def test_to_torch_geometric_heterogeneous_graph(establish_connection):
original_src, original_dst)]['validInterval'] == edge_properties['person', 'person']['validInterval'][i]
assert ground_truth.TINY_SNB_KNOWS_PROPERTIES_GROUND_TRUTH[(
original_src, original_dst)]['comments'] == edge_properties['person', 'person']['comments'][i]
assert edge_properties['person', 'person']['_label'][i] == 'knows'

assert torch_geometric_data['organisation'].ID.shape == torch.Size([2])
assert torch_geometric_data['organisation'].ID.dtype == torch.int64
Expand Down Expand Up @@ -386,14 +388,16 @@ def test_to_torch_geometric_heterogeneous_graph(establish_connection):
assert dst in pos_to_idx['organisation']
assert src != dst
assert pos_to_idx['organisation'][dst] in ground_truth.TINY_SNB_WORKS_AT_GROUND_TRUTH[pos_to_idx['person'][src]]
assert len(edge_properties['person', 'organisation']) == 3
assert len(edge_properties['person', 'organisation']) == 4
assert 'year' in edge_properties['person', 'organisation']
assert '_label' in edge_properties['person', 'organisation']
for i in range(2):
src, dst = torch_geometric_data['person', 'organisation'].edge_index[0][i].item(
), torch_geometric_data['person', 'organisation'].edge_index[1][i].item()
original_src, original_dst = pos_to_idx['person'][src], pos_to_idx['organisation'][dst]
assert ground_truth.TINY_SNB_WORKS_AT_PROPERTIES_GROUND_TRUTH[(
original_src, original_dst)]['year'] == edge_properties['person', 'organisation']['year'][i]
assert edge_properties['person', 'organisation']['_label'][i] == 'workAt'


def test_to_torch_geometric_multi_dimensional_lists(establish_connection):
Expand Down

0 comments on commit 7bf0d6e

Please sign in to comment.