Skip to content

Commit

Permalink
Merge pull request #1777 from kuzudb/recursive-rel-api-binding
Browse files Browse the repository at this point in the history
Add API bindings for recursive rel type
  • Loading branch information
mewim committed Jul 7, 2023
2 parents d6527aa + 266918c commit 7ffffbd
Show file tree
Hide file tree
Showing 10 changed files with 164 additions and 15 deletions.
8 changes: 8 additions & 0 deletions src/c_api/value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,14 @@ kuzu_value* kuzu_value_get_struct_field_value(kuzu_value* value, uint64_t index)
return kuzu_value_get_list_element(value, index);
}

kuzu_value* kuzu_value_get_recursive_rel_node_list(kuzu_value* value) {
return kuzu_value_get_list_element(value, 0);
}

kuzu_value* kuzu_value_get_recursive_rel_rel_list(kuzu_value* value) {
return kuzu_value_get_list_element(value, 1);
}

kuzu_logical_type* kuzu_value_get_data_type(kuzu_value* value) {
auto* c_data_type = (kuzu_logical_type*)malloc(sizeof(kuzu_logical_type));
c_data_type->_data_type = new LogicalType(static_cast<Value*>(value->_value)->getDataType());
Expand Down
19 changes: 15 additions & 4 deletions src/include/c_api/kuzu.h
Original file line number Diff line number Diff line change
Expand Up @@ -615,23 +615,34 @@ KUZU_C_API kuzu_value* kuzu_value_get_list_element(kuzu_value* value, uint64_t i
*/
KUZU_C_API uint64_t kuzu_value_get_struct_num_fields(kuzu_value* value);
/**
* @brief Returns the field name at index of the given struct value. The value must be of type
* STRUCT.
* @brief Returns the field name at index of the given struct value. The value must be of physical
* type STRUCT (STRUCT, NODE, REL, RECURSIVE_REL, UNION).
* @param value The STRUCT value to get field name.
* @param index The index of the field name to return.
*/
KUZU_C_API char* kuzu_value_get_struct_field_name(kuzu_value* value, uint64_t index);
/**
* @brief Returns the field value at index of the given struct value. The value must be of type
* STRUCT.
* @brief Returns the field value at index of the given struct value. The value must be of physical
* type STRUCT (STRUCT, NODE, REL, RECURSIVE_REL, UNION).
* @param value The STRUCT value to get field value.
* @param index The index of the field value to return.
*/
KUZU_C_API kuzu_value* kuzu_value_get_struct_field_value(kuzu_value* value, uint64_t index);
/*
* @brief Returns the list of nodes for recursive rel value. The value must be of type
* RECURSIVE_REL.
*/
KUZU_C_API kuzu_value* kuzu_value_get_recursive_rel_node_list(kuzu_value* value);

/*
* @brief Returns the list of rels for recursive rel value. The value must be of type RECURSIVE_REL.
*/
KUZU_C_API kuzu_value* kuzu_value_get_recursive_rel_rel_list(kuzu_value* value);
/**
* @brief Returns internal type of the given value.
* @param value The value to return.
*/

KUZU_C_API kuzu_logical_type* kuzu_value_get_data_type(kuzu_value* value);
/**
* @brief Returns the boolean value of the given value. The value must be of type BOOL.
Expand Down
2 changes: 1 addition & 1 deletion tools/java_api/src/jni/kuzu_java.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,7 @@ JNIEXPORT jlong JNICALL Java_com_kuzudb_KuzuNative_kuzu_1value_1create_1value(
v = new Value(str);
env->ReleaseStringUTFChars(value, str);
} else if (env->IsInstanceOf(val, env->FindClass("com/kuzudb/KuzuInternalID"))) {
jfieldID fieldID = env->GetFieldID(val_class, "table_id", "J");
jfieldID fieldID = env->GetFieldID(val_class, "tableId", "J");
long table_id = static_cast<long>(env->GetLongField(val, fieldID));
fieldID = env->GetFieldID(val_class, "offset", "J");
long offset = static_cast<long>(env->GetLongField(val, fieldID));
Expand Down
6 changes: 3 additions & 3 deletions tools/java_api/src/main/java/com/kuzudb/KuzuInternalID.java
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
package com.kuzudb;

public class KuzuInternalID {
public long table_id;
public long tableId;
public long offset;

public KuzuInternalID(long table_id, long offset) {
this.table_id = table_id;
public KuzuInternalID(long tableId, long offset) {
this.tableId = tableId;
this.offset = offset;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package com.kuzudb;

public class KuzuValueRecursiveRelUtil {
public static KuzuValue getNodeList(KuzuValue value) throws KuzuObjectRefDestroyedException {
return KuzuValueStructUtil.getValueByIndex(value, 0);
}

public static KuzuValue getRelList(KuzuValue value) throws KuzuObjectRefDestroyedException {
return KuzuValueStructUtil.getValueByIndex(value, 1);
}
}
45 changes: 40 additions & 5 deletions tools/java_api/src/test/java/com/kuzudb/test/ValueTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ void ValueCreateInternalID() throws KuzuObjectRefDestroyedException {
assertFalse(value.isOwnedByCPP());
assertEquals(value.getDataType().getID(), KuzuDataTypeID.INTERNAL_ID);
KuzuInternalID id = value.getValue();
assertEquals(id.table_id, 1);
assertEquals(id.tableId, 1);
assertEquals(id.offset, 123);
value.destroy();
}
Expand Down Expand Up @@ -539,7 +539,7 @@ void ValueGetBlob() throws KuzuObjectRefDestroyedException {
assertTrue(value.isOwnedByCPP());
assertFalse(value.isNull());

byte[] bytes = value.getValue();
byte[] bytes = value.getValue();
assertTrue(bytes.length == 4);
assertTrue(bytes[0] == (byte) 0xAA);
assertTrue(bytes[1] == (byte) 0xBB);
Expand Down Expand Up @@ -584,7 +584,7 @@ void NodeValGetID() throws KuzuObjectRefDestroyedException {
assertFalse(value.isNull());

KuzuInternalID id = KuzuValueNodeUtil.getID(value);
assertEquals(id.table_id, 0);
assertEquals(id.tableId, 0);
assertEquals(id.offset, 0);
value.destroy();
flatTuple.destroy();
Expand Down Expand Up @@ -678,11 +678,11 @@ void RelValGetIDsAndLabel() throws KuzuObjectRefDestroyedException {
assertFalse(value.isNull());

KuzuInternalID srcId = KuzuValueRelUtil.getSrcID(value);
assertEquals(srcId.table_id, 0);
assertEquals(srcId.tableId, 0);
assertEquals(srcId.offset, 0);

KuzuInternalID dstId = KuzuValueRelUtil.getDstID(value);
assertEquals(dstId.table_id, 0);
assertEquals(dstId.tableId, 0);
assertEquals(dstId.offset, 1);

String label = KuzuValueRelUtil.getLabelName(value);
Expand Down Expand Up @@ -829,4 +829,39 @@ void StructValGetValueByIndex() throws KuzuObjectRefDestroyedException {
flatTuple.destroy();
result.destroy();
}

@Test
void RecursiveRelGetNodeAndRelList() throws KuzuObjectRefDestroyedException {
KuzuQueryResult result = conn.query("MATCH (a:person)-[e*1..1]->(b:organisation) WHERE a.fName = 'Alice' RETURN e;");
assertTrue(result.isSuccess());
assertTrue(result.hasNext());
KuzuFlatTuple flatTuple = result.getNext();
KuzuValue value = flatTuple.getValue(0);
assertTrue(value.isOwnedByCPP());

KuzuValue nodeList = KuzuValueRecursiveRelUtil.getNodeList(value);
assertTrue(nodeList.isOwnedByCPP());
assertEquals(KuzuValueListUtil.getListSize(nodeList), 0);
nodeList.destroy();

KuzuValue relList = KuzuValueRecursiveRelUtil.getRelList(value);
assertTrue(relList.isOwnedByCPP());
assertEquals(KuzuValueListUtil.getListSize(relList), 1);

KuzuValue rel = KuzuValueListUtil.getListElement(relList, 0);
assertTrue(rel.isOwnedByCPP());
KuzuInternalID srcId = KuzuValueRelUtil.getSrcID(rel);
assertEquals(srcId.tableId, 0);
assertEquals(srcId.offset, 0);

KuzuInternalID dstId = KuzuValueRelUtil.getDstID(rel);
assertEquals(dstId.tableId, 1);
assertEquals(dstId.offset, 0);

rel.destroy();
relList.destroy();
value.destroy();
flatTuple.destroy();
result.destroy();
}
}
9 changes: 8 additions & 1 deletion tools/nodejs_api/src_cpp/node_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ Napi::Value Util::ConvertToNapiObject(const Value& value, Napi::Env env) {
}
return napiArray;
}
case LogicalTypeID::STRUCT: {
case LogicalTypeID::STRUCT:
case LogicalTypeID::UNION: {
auto childrenNames = StructType::getFieldNames(&dataType);
auto napiObj = Napi::Object::New(env);
auto& structVal = value.getListValReference();
Expand All @@ -77,6 +78,12 @@ Napi::Value Util::ConvertToNapiObject(const Value& value, Napi::Env env) {
}
return napiObj;
}
case LogicalTypeID::RECURSIVE_REL: {
auto napiObj = Napi::Object::New(env);
napiObj.Set("_nodes", ConvertToNapiObject(*value.getListValReference()[0], env));
napiObj.Set("_rels", ConvertToNapiObject(*value.getListValReference()[1], env));
return napiObj;
}
case LogicalTypeID::NODE: {
Napi::Object napiObj = Napi::Object::New(env);
auto numProperties = NodeVal::getNumProperties(&value);
Expand Down
42 changes: 42 additions & 0 deletions tools/nodejs_api/test/test_data_type.js
Original file line number Diff line number Diff line change
Expand Up @@ -297,3 +297,45 @@ describe("REL", function () {
assert.approximately(rel.rating, 7.6, EPSILON);
});
});

describe("RECURSIVE_REL", function () {
it("should convert RECURSIVE_REL type", async function () {
const queryResult = await conn.query(
"MATCH (a:person)-[e*1..1]->(b:organisation) WHERE a.fName = 'Alice' RETURN e;"
);
const result = await queryResult.getAll();
assert.equal(result.length, 1);
assert.exists(result[0]["e"]);
const e = result[0]["e"];
assert.deepEqual(e, {
_nodes: [],
_rels: [
{
date: null,
meetTime: null,
validInterval: null,
comments: null,
year: 2021,
places: ["wwAewsdndweusd", "wek"],
length: 5,
grading: null,
rating: null,
location: null,
times: null,
data: null,
usedAddress: null,
address: null,
note: null,
_src: {
offset: 0,
table: 0,
},
_dst: {
offset: 0,
table: 1,
},
},
],
});
});
});
10 changes: 9 additions & 1 deletion tools/python_api/src_cpp/py_query_result.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ py::object PyQueryResult::convertValueToPyObject(const Value& value) {
}
return std::move(list);
}
case LogicalTypeID::STRUCT: {
case LogicalTypeID::STRUCT:
case LogicalTypeID::UNION: {
auto fieldNames = StructType::getFieldNames(&dataType);
py::dict dict;
auto& structVals = value.getListValReference();
Expand All @@ -142,6 +143,13 @@ py::object PyQueryResult::convertValueToPyObject(const Value& value) {
}
return dict;
}
case LogicalTypeID::RECURSIVE_REL: {
py::dict dict;
auto& structVals = value.getListValReference();
dict["_nodes"] = convertValueToPyObject(*structVals[0]);
dict["_rels"] = convertValueToPyObject(*structVals[1]);
return dict;
}
case LogicalTypeID::NODE: {
py::dict dict;
dict["_label"] = py::cast(NodeVal::getLabelName(&value));
Expand Down
27 changes: 27 additions & 0 deletions tools/python_api/test/test_datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,30 @@ def test_struct(establish_connection):
assert (description['film'] == datetime.date(2013, 2, 22))
assert not result.has_next()
result.close()


def test_recursive_rel(establish_connection):
conn, db = establish_connection
result = conn.execute(
"MATCH (a:person)-[e*1..1]->(b:organisation) WHERE a.fName = 'Alice' RETURN e;"
)
assert result.has_next()
n = result.get_next()
assert (len(n) == 1)
e = n[0]
assert ("_nodes" in e)
assert ("_rels" in e)
assert (len(e["_nodes"]) == 0)
assert (len(e["_rels"]) == 1)
rel = e["_rels"][0]
excepted_rel = {'_src': {'offset': 0, 'table': 0},
'_dst': {'offset': 0, 'table': 1},
'date': None, 'meetTime': None, 'validInterval': None,
'comments': None, 'year': 2021,
'places': ['wwAewsdndweusd', 'wek'],
'length': 5, 'grading': None, 'rating': None,
'location': None, 'times': None, 'data': None,
'usedAddress': None, 'address': None, 'note': None}
assert (rel == excepted_rel)
assert not result.has_next()
result.close()

0 comments on commit 7ffffbd

Please sign in to comment.