Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add API bindings for recursive rel type #1777

Merged
merged 1 commit into from
Jul 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/c_api/value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,14 @@ kuzu_value* kuzu_value_get_struct_field_value(kuzu_value* value, uint64_t index)
return kuzu_value_get_list_element(value, index);
}

kuzu_value* kuzu_value_get_recursive_rel_node_list(kuzu_value* value) {
return kuzu_value_get_list_element(value, 0);
}

kuzu_value* kuzu_value_get_recursive_rel_rel_list(kuzu_value* value) {
return kuzu_value_get_list_element(value, 1);
}

kuzu_logical_type* kuzu_value_get_data_type(kuzu_value* value) {
auto* c_data_type = (kuzu_logical_type*)malloc(sizeof(kuzu_logical_type));
c_data_type->_data_type = new LogicalType(static_cast<Value*>(value->_value)->getDataType());
Expand Down
19 changes: 15 additions & 4 deletions src/include/c_api/kuzu.h
Original file line number Diff line number Diff line change
Expand Up @@ -615,23 +615,34 @@ KUZU_C_API kuzu_value* kuzu_value_get_list_element(kuzu_value* value, uint64_t i
*/
KUZU_C_API uint64_t kuzu_value_get_struct_num_fields(kuzu_value* value);
/**
* @brief Returns the field name at index of the given struct value. The value must be of type
* STRUCT.
* @brief Returns the field name at index of the given struct value. The value must be of physical
* type STRUCT (STRUCT, NODE, REL, RECURSIVE_REL, UNION).
* @param value The STRUCT value to get field name.
* @param index The index of the field name to return.
*/
KUZU_C_API char* kuzu_value_get_struct_field_name(kuzu_value* value, uint64_t index);
/**
* @brief Returns the field value at index of the given struct value. The value must be of type
* STRUCT.
* @brief Returns the field value at index of the given struct value. The value must be of physical
* type STRUCT (STRUCT, NODE, REL, RECURSIVE_REL, UNION).
* @param value The STRUCT value to get field value.
* @param index The index of the field value to return.
*/
KUZU_C_API kuzu_value* kuzu_value_get_struct_field_value(kuzu_value* value, uint64_t index);
/*
* @brief Returns the list of nodes for recursive rel value. The value must be of type
* RECURSIVE_REL.
*/
KUZU_C_API kuzu_value* kuzu_value_get_recursive_rel_node_list(kuzu_value* value);

/*
* @brief Returns the list of rels for recursive rel value. The value must be of type RECURSIVE_REL.
*/
KUZU_C_API kuzu_value* kuzu_value_get_recursive_rel_rel_list(kuzu_value* value);
/**
* @brief Returns internal type of the given value.
* @param value The value to return.
*/

KUZU_C_API kuzu_logical_type* kuzu_value_get_data_type(kuzu_value* value);
/**
* @brief Returns the boolean value of the given value. The value must be of type BOOL.
Expand Down
2 changes: 1 addition & 1 deletion tools/java_api/src/jni/kuzu_java.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,7 @@ JNIEXPORT jlong JNICALL Java_com_kuzudb_KuzuNative_kuzu_1value_1create_1value(
v = new Value(str);
env->ReleaseStringUTFChars(value, str);
} else if (env->IsInstanceOf(val, env->FindClass("com/kuzudb/KuzuInternalID"))) {
jfieldID fieldID = env->GetFieldID(val_class, "table_id", "J");
jfieldID fieldID = env->GetFieldID(val_class, "tableId", "J");
long table_id = static_cast<long>(env->GetLongField(val, fieldID));
fieldID = env->GetFieldID(val_class, "offset", "J");
long offset = static_cast<long>(env->GetLongField(val, fieldID));
Expand Down
6 changes: 3 additions & 3 deletions tools/java_api/src/main/java/com/kuzudb/KuzuInternalID.java
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
package com.kuzudb;

public class KuzuInternalID {
public long table_id;
public long tableId;
public long offset;

public KuzuInternalID(long table_id, long offset) {
this.table_id = table_id;
public KuzuInternalID(long tableId, long offset) {
this.tableId = tableId;
this.offset = offset;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package com.kuzudb;

public class KuzuValueRecursiveRelUtil {
public static KuzuValue getNodeList(KuzuValue value) throws KuzuObjectRefDestroyedException {
return KuzuValueStructUtil.getValueByIndex(value, 0);
}

public static KuzuValue getRelList(KuzuValue value) throws KuzuObjectRefDestroyedException {
return KuzuValueStructUtil.getValueByIndex(value, 1);
}
}
45 changes: 40 additions & 5 deletions tools/java_api/src/test/java/com/kuzudb/test/ValueTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ void ValueCreateInternalID() throws KuzuObjectRefDestroyedException {
assertFalse(value.isOwnedByCPP());
assertEquals(value.getDataType().getID(), KuzuDataTypeID.INTERNAL_ID);
KuzuInternalID id = value.getValue();
assertEquals(id.table_id, 1);
assertEquals(id.tableId, 1);
assertEquals(id.offset, 123);
value.destroy();
}
Expand Down Expand Up @@ -539,7 +539,7 @@ void ValueGetBlob() throws KuzuObjectRefDestroyedException {
assertTrue(value.isOwnedByCPP());
assertFalse(value.isNull());

byte[] bytes = value.getValue();
byte[] bytes = value.getValue();
assertTrue(bytes.length == 4);
assertTrue(bytes[0] == (byte) 0xAA);
assertTrue(bytes[1] == (byte) 0xBB);
Expand Down Expand Up @@ -584,7 +584,7 @@ void NodeValGetID() throws KuzuObjectRefDestroyedException {
assertFalse(value.isNull());

KuzuInternalID id = KuzuValueNodeUtil.getID(value);
assertEquals(id.table_id, 0);
assertEquals(id.tableId, 0);
assertEquals(id.offset, 0);
value.destroy();
flatTuple.destroy();
Expand Down Expand Up @@ -678,11 +678,11 @@ void RelValGetIDsAndLabel() throws KuzuObjectRefDestroyedException {
assertFalse(value.isNull());

KuzuInternalID srcId = KuzuValueRelUtil.getSrcID(value);
assertEquals(srcId.table_id, 0);
assertEquals(srcId.tableId, 0);
assertEquals(srcId.offset, 0);

KuzuInternalID dstId = KuzuValueRelUtil.getDstID(value);
assertEquals(dstId.table_id, 0);
assertEquals(dstId.tableId, 0);
assertEquals(dstId.offset, 1);

String label = KuzuValueRelUtil.getLabelName(value);
Expand Down Expand Up @@ -829,4 +829,39 @@ void StructValGetValueByIndex() throws KuzuObjectRefDestroyedException {
flatTuple.destroy();
result.destroy();
}

@Test
void RecursiveRelGetNodeAndRelList() throws KuzuObjectRefDestroyedException {
KuzuQueryResult result = conn.query("MATCH (a:person)-[e*1..1]->(b:organisation) WHERE a.fName = 'Alice' RETURN e;");
assertTrue(result.isSuccess());
assertTrue(result.hasNext());
KuzuFlatTuple flatTuple = result.getNext();
KuzuValue value = flatTuple.getValue(0);
assertTrue(value.isOwnedByCPP());

KuzuValue nodeList = KuzuValueRecursiveRelUtil.getNodeList(value);
assertTrue(nodeList.isOwnedByCPP());
assertEquals(KuzuValueListUtil.getListSize(nodeList), 0);
nodeList.destroy();

KuzuValue relList = KuzuValueRecursiveRelUtil.getRelList(value);
assertTrue(relList.isOwnedByCPP());
assertEquals(KuzuValueListUtil.getListSize(relList), 1);

KuzuValue rel = KuzuValueListUtil.getListElement(relList, 0);
assertTrue(rel.isOwnedByCPP());
KuzuInternalID srcId = KuzuValueRelUtil.getSrcID(rel);
assertEquals(srcId.tableId, 0);
assertEquals(srcId.offset, 0);

KuzuInternalID dstId = KuzuValueRelUtil.getDstID(rel);
assertEquals(dstId.tableId, 1);
assertEquals(dstId.offset, 0);

rel.destroy();
relList.destroy();
value.destroy();
flatTuple.destroy();
result.destroy();
}
}
9 changes: 8 additions & 1 deletion tools/nodejs_api/src_cpp/node_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ Napi::Value Util::ConvertToNapiObject(const Value& value, Napi::Env env) {
}
return napiArray;
}
case LogicalTypeID::STRUCT: {
case LogicalTypeID::STRUCT:
case LogicalTypeID::UNION: {
auto childrenNames = StructType::getFieldNames(&dataType);
auto napiObj = Napi::Object::New(env);
auto& structVal = value.getListValReference();
Expand All @@ -77,6 +78,12 @@ Napi::Value Util::ConvertToNapiObject(const Value& value, Napi::Env env) {
}
return napiObj;
}
case LogicalTypeID::RECURSIVE_REL: {
auto napiObj = Napi::Object::New(env);
napiObj.Set("_nodes", ConvertToNapiObject(*value.getListValReference()[0], env));
napiObj.Set("_rels", ConvertToNapiObject(*value.getListValReference()[1], env));
return napiObj;
}
case LogicalTypeID::NODE: {
Napi::Object napiObj = Napi::Object::New(env);
auto numProperties = NodeVal::getNumProperties(&value);
Expand Down
42 changes: 42 additions & 0 deletions tools/nodejs_api/test/test_data_type.js
Original file line number Diff line number Diff line change
Expand Up @@ -297,3 +297,45 @@ describe("REL", function () {
assert.approximately(rel.rating, 7.6, EPSILON);
});
});

describe("RECURSIVE_REL", function () {
it("should convert RECURSIVE_REL type", async function () {
const queryResult = await conn.query(
"MATCH (a:person)-[e*1..1]->(b:organisation) WHERE a.fName = 'Alice' RETURN e;"
);
const result = await queryResult.getAll();
assert.equal(result.length, 1);
assert.exists(result[0]["e"]);
const e = result[0]["e"];
assert.deepEqual(e, {
_nodes: [],
_rels: [
{
date: null,
meetTime: null,
validInterval: null,
comments: null,
year: 2021,
places: ["wwAewsdndweusd", "wek"],
length: 5,
grading: null,
rating: null,
location: null,
times: null,
data: null,
usedAddress: null,
address: null,
note: null,
_src: {
offset: 0,
table: 0,
},
_dst: {
offset: 0,
table: 1,
},
},
],
});
});
});
10 changes: 9 additions & 1 deletion tools/python_api/src_cpp/py_query_result.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ py::object PyQueryResult::convertValueToPyObject(const Value& value) {
}
return std::move(list);
}
case LogicalTypeID::STRUCT: {
case LogicalTypeID::STRUCT:
case LogicalTypeID::UNION: {
auto fieldNames = StructType::getFieldNames(&dataType);
py::dict dict;
auto& structVals = value.getListValReference();
Expand All @@ -142,6 +143,13 @@ py::object PyQueryResult::convertValueToPyObject(const Value& value) {
}
return dict;
}
case LogicalTypeID::RECURSIVE_REL: {
py::dict dict;
auto& structVals = value.getListValReference();
dict["_nodes"] = convertValueToPyObject(*structVals[0]);
dict["_rels"] = convertValueToPyObject(*structVals[1]);
return dict;
}
case LogicalTypeID::NODE: {
py::dict dict;
dict["_label"] = py::cast(NodeVal::getLabelName(&value));
Expand Down
27 changes: 27 additions & 0 deletions tools/python_api/test/test_datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,30 @@ def test_struct(establish_connection):
assert (description['film'] == datetime.date(2013, 2, 22))
assert not result.has_next()
result.close()


def test_recursive_rel(establish_connection):
conn, db = establish_connection
result = conn.execute(
"MATCH (a:person)-[e*1..1]->(b:organisation) WHERE a.fName = 'Alice' RETURN e;"
)
assert result.has_next()
n = result.get_next()
assert (len(n) == 1)
e = n[0]
assert ("_nodes" in e)
assert ("_rels" in e)
assert (len(e["_nodes"]) == 0)
assert (len(e["_rels"]) == 1)
rel = e["_rels"][0]
excepted_rel = {'_src': {'offset': 0, 'table': 0},
'_dst': {'offset': 0, 'table': 1},
'date': None, 'meetTime': None, 'validInterval': None,
'comments': None, 'year': 2021,
'places': ['wwAewsdndweusd', 'wek'],
'length': 5, 'grading': None, 'rating': None,
'location': None, 'times': None, 'data': None,
'usedAddress': None, 'address': None, 'note': None}
assert (rel == excepted_rel)
assert not result.has_next()
result.close()
Loading