Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented support for RecursiveRel in the Rust API #1813

Merged
merged 1 commit into from
Jul 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions tools/rust_api/include/kuzu_rs.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,10 @@ std::array<uint64_t, 2> node_value_get_node_id(const kuzu::common::Value& val);
std::array<uint64_t, 2> rel_value_get_src_id(const kuzu::common::Value& val);
std::array<uint64_t, 2> rel_value_get_dst_id(const kuzu::common::Value& val);

/* RecursiveRel */
const kuzu::common::Value &recursive_rel_get_nodes(const kuzu::common::Value &val);
const kuzu::common::Value &recursive_rel_get_rels(const kuzu::common::Value &val);

/* FlatTuple */
const kuzu::common::Value& flat_tuple_get_value(
const kuzu::processor::FlatTuple& flatTuple, uint32_t index);
Expand Down
3 changes: 3 additions & 0 deletions tools/rust_api/src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,9 @@ pub(crate) mod ffi {

fn rel_value_get_src_id(value: &Value) -> [u64; 2];
fn rel_value_get_dst_id(value: &Value) -> [u64; 2];

fn recursive_rel_get_nodes(value: &Value) -> &Value;
fn recursive_rel_get_rels(value: &Value) -> &Value;
}

#[namespace = "kuzu_rs"]
Expand Down
7 changes: 7 additions & 0 deletions tools/rust_api/src/kuzu_rs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,4 +255,11 @@ std::unique_ptr<TypeListBuilder> create_type_list() {
return std::make_unique<TypeListBuilder>();
}

const kuzu::common::Value &recursive_rel_get_nodes(const kuzu::common::Value &val) {
return *kuzu::common::RecursiveRelVal::getNodes(&val);
}
const kuzu::common::Value &recursive_rel_get_rels(const kuzu::common::Value &val) {
return *kuzu::common::RecursiveRelVal::getRels(&val);
}

} // namespace kuzu_rs
14 changes: 11 additions & 3 deletions tools/rust_api/src/logical_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,23 @@ pub enum LogicalType {
/// Correponds to [Value::Blob](crate::value::Value::Blob)
Blob,
/// Correponds to [Value::VarList](crate::value::Value::VarList)
VarList { child_type: Box<LogicalType> },
VarList {
child_type: Box<LogicalType>,
},
/// Correponds to [Value::FixedList](crate::value::Value::FixedList)
FixedList {
child_type: Box<LogicalType>,
num_elements: u64,
},
/// Correponds to [Value::Struct](crate::value::Value::Struct)
Struct { fields: Vec<(String, LogicalType)> },
Struct {
fields: Vec<(String, LogicalType)>,
},
/// Correponds to [Value::Node](crate::value::Value::Node)
Node,
/// Correponds to [Value::Rel](crate::value::Value::Rel)
Rel,
RecursiveRel,
}

impl From<&ffi::Value> for LogicalType {
Expand Down Expand Up @@ -96,6 +101,7 @@ impl From<&ffi::LogicalType> for LogicalType {
}
LogicalTypeID::NODE => LogicalType::Node,
LogicalTypeID::REL => LogicalType::Rel,
LogicalTypeID::RECURSIVE_REL => LogicalType::RecursiveRel,
// Should be unreachable, as cxx will check that the LogicalTypeID enum matches the one
// on the C++ side.
x => panic!("Unsupported type {:?}", x),
Expand All @@ -121,7 +127,8 @@ impl From<&LogicalType> for cxx::UniquePtr<ffi::LogicalType> {
| LogicalType::String
| LogicalType::Blob
| LogicalType::Node
| LogicalType::Rel => ffi::create_logical_type(typ.id()),
| LogicalType::Rel
| LogicalType::RecursiveRel => ffi::create_logical_type(typ.id()),
LogicalType::VarList { child_type } => {
ffi::create_logical_type_var_list(child_type.as_ref().into())
}
Expand Down Expand Up @@ -165,6 +172,7 @@ impl LogicalType {
LogicalType::Struct { .. } => LogicalTypeID::STRUCT,
LogicalType::Node => LogicalTypeID::NODE,
LogicalType::Rel => LogicalTypeID::REL,
LogicalType::RecursiveRel => LogicalTypeID::RECURSIVE_REL,
}
}
}
157 changes: 134 additions & 23 deletions tools/rust_api/src/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ pub struct NodeVal {
}

impl NodeVal {
pub fn new(id: InternalID, label: String) -> Self {
pub fn new<I: Into<InternalID>, S: Into<String>>(id: I, label: S) -> Self {
NodeVal {
id,
label,
id: id.into(),
label: label.into(),
properties: vec![],
}
}
Expand All @@ -63,8 +63,8 @@ impl NodeVal {
/// # Arguments
/// * `key`: The name of the property
/// * `value`: The value of the property
pub fn add_property(&mut self, key: String, value: Value) {
self.properties.push((key, value));
pub fn add_property<S: Into<String>, V: Into<Value>>(&mut self, key: S, value: V) {
self.properties.push((key.into(), value.into()));
}

/// Returns all properties of the NodeVal
Expand Down Expand Up @@ -106,11 +106,11 @@ pub struct RelVal {
}

impl RelVal {
pub fn new(src_node: InternalID, dst_node: InternalID, label: String) -> Self {
pub fn new<I: Into<InternalID>, S: Into<String>>(src_node: I, dst_node: I, label: S) -> Self {
RelVal {
src_node,
dst_node,
label,
src_node: src_node.into(),
dst_node: dst_node.into(),
label: label.into(),
properties: vec![],
}
}
Expand Down Expand Up @@ -177,6 +177,15 @@ impl Ord for InternalID {
}
}

impl From<(u64, u64)> for InternalID {
fn from(value: (u64, u64)) -> Self {
InternalID {
offset: value.0,
table_id: value.1,
}
}
}

/// Data types supported by Kùzu
///
/// Also see <https://kuzudb.com/docs/cypher/data-types/overview.html>
Expand Down Expand Up @@ -222,6 +231,25 @@ pub enum Value {
Struct(Vec<(String, Value)>),
Node(NodeVal),
Rel(RelVal),
RecursiveRel {
/// Interior nodes in the Sequence of Rels
///
/// Does not include the starting or ending Node.
nodes: Vec<NodeVal>,
/// Sequence of Rels which make up the RecursiveRel
rels: Vec<RelVal>,
},
}

fn display_list<T: std::fmt::Display>(f: &mut fmt::Formatter<'_>, list: &Vec<T>) -> fmt::Result {
write!(f, "[")?;
for (i, value) in list.iter().enumerate() {
write!(f, "{}", value)?;
if i != list.len() - 1 {
write!(f, ",")?;
}
}
write!(f, "]")
}

impl std::fmt::Display for Value {
Expand All @@ -236,16 +264,7 @@ impl std::fmt::Display for Value {
Value::String(x) => write!(f, "{x}"),
Value::Blob(x) => write!(f, "{x:x?}"),
Value::Null(_) => write!(f, ""),
Value::VarList(_, x) | Value::FixedList(_, x) => {
write!(f, "[")?;
for (i, value) in x.iter().enumerate() {
write!(f, "{}", value)?;
if i != x.len() - 1 {
write!(f, ",")?;
}
}
write!(f, "]")
}
Value::VarList(_, x) | Value::FixedList(_, x) => display_list(f, x),
// Note: These don't match kuzu's toString, but we probably don't want them to
Value::Interval(x) => write!(f, "{x}"),
Value::Timestamp(x) => write!(f, "{x}"),
Expand All @@ -264,6 +283,14 @@ impl std::fmt::Display for Value {
Value::Node(x) => write!(f, "{x}"),
Value::Rel(x) => write!(f, "{x}"),
Value::InternalID(x) => write!(f, "{x}"),
Value::RecursiveRel { nodes, rels } => {
write!(f, "{{")?;
write!(f, "_NODES: ")?;
display_list(f, nodes)?;
write!(f, ", _RELS: ")?;
display_list(f, rels)?;
write!(f, "}}")
}
}
}
}
Expand Down Expand Up @@ -302,6 +329,7 @@ impl From<&Value> for LogicalType {
Value::InternalID(_) => LogicalType::InternalID,
Value::Node(_) => LogicalType::Node,
Value::Rel(_) => LogicalType::Rel,
Value::RecursiveRel { .. } => LogicalType::RecursiveRel,
}
}
}
Expand Down Expand Up @@ -400,8 +428,10 @@ impl TryFrom<&ffi::Value> for Value {
let mut node_val = NodeVal::new(id, label);
let properties = ffi::node_value_get_properties(value);
for i in 0..properties.size() {
node_val
.add_property(properties.get_name(i), properties.get_value(i).try_into()?);
node_val.add_property(
properties.get_name(i),
TryInto::<Value>::try_into(properties.get_value(i))?,
);
}
Ok(Value::Node(node_val))
}
Expand Down Expand Up @@ -432,8 +462,38 @@ impl TryFrom<&ffi::Value> for Value {
table_id: internal_id[1],
}))
}
// Should be unreachable, as cxx will check that the LogicalTypeID enum matches the one
// on the C++ side.
LogicalTypeID::RECURSIVE_REL => {
let nodes: Value = ffi::recursive_rel_get_nodes(value).try_into()?;
let rels: Value = ffi::recursive_rel_get_rels(value).try_into()?;
let nodes = if let Value::VarList(LogicalType::Node, nodes) = nodes {
nodes.into_iter().map(|x| {
if let Value::Node(x) = x {
x
} else {
unreachable!()
}
})
} else {
panic!("Unexpected value in RecursiveRel's rels: {}", rels)
};
let rels = if let Value::VarList(LogicalType::Rel, rels) = rels {
rels.into_iter().map(|x| {
if let Value::Rel(x) = x {
x
} else {
unreachable!()
}
})
} else {
panic!("Unexpected value in RecursiveRel's rels: {}", rels)
};

Ok(Value::RecursiveRel {
nodes: nodes.collect(),
rels: rels.collect(),
})
}
// TODO(bmwinger): Better error message for types which are unsupported
x => panic!("Unsupported type {:?}", x),
}
}
Expand Down Expand Up @@ -532,6 +592,9 @@ impl TryInto<cxx::UniquePtr<ffi::Value>> for Value {
}
Value::Node(_) => Err(crate::Error::ReadOnlyType(LogicalType::Node)),
Value::Rel(_) => Err(crate::Error::ReadOnlyType(LogicalType::Rel)),
Value::RecursiveRel { .. } => {
Err(crate::Error::ReadOnlyType(LogicalType::RecursiveRel))
}
}
}
}
Expand Down Expand Up @@ -695,6 +758,7 @@ mod tests {
convert_node_type: LogicalType::Node,
convert_internal_id_type: LogicalType::InternalID,
convert_rel_type: LogicalType::Rel,
convert_recursive_rel_type: LogicalType::RecursiveRel,
}

value_tests! {
Expand Down Expand Up @@ -858,6 +922,53 @@ mod tests {
Ok(())
}

#[test]
fn test_recursive_rel() -> Result<()> {
let temp_dir = tempfile::TempDir::new()?;
let db = Database::new(temp_dir.path(), 0)?;
let conn = Connection::new(&db)?;
conn.query("CREATE NODE TABLE Person(name STRING, age INT64, PRIMARY KEY(name));")?;
conn.query("CREATE REL TABLE knows(FROM Person TO Person);")?;
conn.query("CREATE (:Person {name: \"Alice\", age: 25});")?;
conn.query("CREATE (:Person {name: \"Bob\", age: 25});")?;
conn.query("CREATE (:Person {name: \"Eve\", age: 25});")?;
conn.query(
"MATCH (p1:Person), (p2:Person)
WHERE p1.name = \"Alice\" AND p2.name = \"Bob\"
CREATE (p1)-[:knows]->(p2);",
)?;
conn.query(
"MATCH (p1:Person), (p2:Person)
WHERE p1.name = \"Bob\" AND p2.name = \"Eve\"
CREATE (p1)-[:knows]->(p2);",
)?;
let result = conn
.query(
"MATCH (a:Person)-[e*2..2]->(b:Person)
WHERE a.name = 'Alice'
RETURN e, b.name;",
)?
.next()
.unwrap();
assert_eq!(result[1], Value::String("Eve".to_string()));
assert_eq!(
result[0],
Value::RecursiveRel {
nodes: vec![NodeVal {
id: (1, 0).into(),
label: "Person".into(),
properties: vec![("name".into(), "Bob".into()), ("age".into(), 25i64.into())]
},],
rels: vec![
RelVal::new((0, 0), (1, 0), "knows"),
RelVal::new((1, 0), (2, 0), "knows"),
],
}
);
temp_dir.close()?;
Ok(())
}

#[test]
/// Test that null values are read correctly by the API
fn test_null() -> Result<()> {
Expand Down
Loading