diff --git a/tools/rust_api/include/kuzu_rs.h b/tools/rust_api/include/kuzu_rs.h index 9f66d0068b..8c8f0e5eb0 100644 --- a/tools/rust_api/include/kuzu_rs.h +++ b/tools/rust_api/include/kuzu_rs.h @@ -122,6 +122,10 @@ std::array node_value_get_node_id(const kuzu::common::Value& val); std::array rel_value_get_src_id(const kuzu::common::Value& val); std::array rel_value_get_dst_id(const kuzu::common::Value& val); +/* RecursiveRel */ +const kuzu::common::Value &recursive_rel_get_nodes(const kuzu::common::Value &val); +const kuzu::common::Value &recursive_rel_get_rels(const kuzu::common::Value &val); + /* FlatTuple */ const kuzu::common::Value& flat_tuple_get_value( const kuzu::processor::FlatTuple& flatTuple, uint32_t index); diff --git a/tools/rust_api/src/ffi.rs b/tools/rust_api/src/ffi.rs index 287c0c5172..066e84f594 100644 --- a/tools/rust_api/src/ffi.rs +++ b/tools/rust_api/src/ffi.rs @@ -287,6 +287,9 @@ pub(crate) mod ffi { fn rel_value_get_src_id(value: &Value) -> [u64; 2]; fn rel_value_get_dst_id(value: &Value) -> [u64; 2]; + + fn recursive_rel_get_nodes(value: &Value) -> &Value; + fn recursive_rel_get_rels(value: &Value) -> &Value; } #[namespace = "kuzu_rs"] diff --git a/tools/rust_api/src/kuzu_rs.cpp b/tools/rust_api/src/kuzu_rs.cpp index 4881030227..9d8bf95758 100644 --- a/tools/rust_api/src/kuzu_rs.cpp +++ b/tools/rust_api/src/kuzu_rs.cpp @@ -255,4 +255,11 @@ std::unique_ptr create_type_list() { return std::make_unique(); } +const kuzu::common::Value &recursive_rel_get_nodes(const kuzu::common::Value &val) { + return *kuzu::common::RecursiveRelVal::getNodes(&val); +} +const kuzu::common::Value &recursive_rel_get_rels(const kuzu::common::Value &val) { + return *kuzu::common::RecursiveRelVal::getRels(&val); +} + } // namespace kuzu_rs diff --git a/tools/rust_api/src/logical_type.rs b/tools/rust_api/src/logical_type.rs index 85ecdd050e..61a4f4ba23 100644 --- a/tools/rust_api/src/logical_type.rs +++ b/tools/rust_api/src/logical_type.rs @@ -35,18 +35,23 @@ pub enum LogicalType { /// Correponds to [Value::Blob](crate::value::Value::Blob) Blob, /// Correponds to [Value::VarList](crate::value::Value::VarList) - VarList { child_type: Box }, + VarList { + child_type: Box, + }, /// Correponds to [Value::FixedList](crate::value::Value::FixedList) FixedList { child_type: Box, num_elements: u64, }, /// Correponds to [Value::Struct](crate::value::Value::Struct) - Struct { fields: Vec<(String, LogicalType)> }, + Struct { + fields: Vec<(String, LogicalType)>, + }, /// Correponds to [Value::Node](crate::value::Value::Node) Node, /// Correponds to [Value::Rel](crate::value::Value::Rel) Rel, + RecursiveRel, } impl From<&ffi::Value> for LogicalType { @@ -96,6 +101,7 @@ impl From<&ffi::LogicalType> for LogicalType { } LogicalTypeID::NODE => LogicalType::Node, LogicalTypeID::REL => LogicalType::Rel, + LogicalTypeID::RECURSIVE_REL => LogicalType::RecursiveRel, // Should be unreachable, as cxx will check that the LogicalTypeID enum matches the one // on the C++ side. x => panic!("Unsupported type {:?}", x), @@ -121,7 +127,8 @@ impl From<&LogicalType> for cxx::UniquePtr { | LogicalType::String | LogicalType::Blob | LogicalType::Node - | LogicalType::Rel => ffi::create_logical_type(typ.id()), + | LogicalType::Rel + | LogicalType::RecursiveRel => ffi::create_logical_type(typ.id()), LogicalType::VarList { child_type } => { ffi::create_logical_type_var_list(child_type.as_ref().into()) } @@ -165,6 +172,7 @@ impl LogicalType { LogicalType::Struct { .. } => LogicalTypeID::STRUCT, LogicalType::Node => LogicalTypeID::NODE, LogicalType::Rel => LogicalTypeID::REL, + LogicalType::RecursiveRel => LogicalTypeID::RECURSIVE_REL, } } } diff --git a/tools/rust_api/src/value.rs b/tools/rust_api/src/value.rs index e9ce9e65ee..811b9b1aea 100644 --- a/tools/rust_api/src/value.rs +++ b/tools/rust_api/src/value.rs @@ -43,10 +43,10 @@ pub struct NodeVal { } impl NodeVal { - pub fn new(id: InternalID, label: String) -> Self { + pub fn new, S: Into>(id: I, label: S) -> Self { NodeVal { - id, - label, + id: id.into(), + label: label.into(), properties: vec![], } } @@ -63,8 +63,8 @@ impl NodeVal { /// # Arguments /// * `key`: The name of the property /// * `value`: The value of the property - pub fn add_property(&mut self, key: String, value: Value) { - self.properties.push((key, value)); + pub fn add_property, V: Into>(&mut self, key: S, value: V) { + self.properties.push((key.into(), value.into())); } /// Returns all properties of the NodeVal @@ -106,11 +106,11 @@ pub struct RelVal { } impl RelVal { - pub fn new(src_node: InternalID, dst_node: InternalID, label: String) -> Self { + pub fn new, S: Into>(src_node: I, dst_node: I, label: S) -> Self { RelVal { - src_node, - dst_node, - label, + src_node: src_node.into(), + dst_node: dst_node.into(), + label: label.into(), properties: vec![], } } @@ -177,6 +177,15 @@ impl Ord for InternalID { } } +impl From<(u64, u64)> for InternalID { + fn from(value: (u64, u64)) -> Self { + InternalID { + offset: value.0, + table_id: value.1, + } + } +} + /// Data types supported by Kùzu /// /// Also see @@ -222,6 +231,25 @@ pub enum Value { Struct(Vec<(String, Value)>), Node(NodeVal), Rel(RelVal), + RecursiveRel { + /// Interior nodes in the Sequence of Rels + /// + /// Does not include the starting or ending Node. + nodes: Vec, + /// Sequence of Rels which make up the RecursiveRel + rels: Vec, + }, +} + +fn display_list(f: &mut fmt::Formatter<'_>, list: &Vec) -> fmt::Result { + write!(f, "[")?; + for (i, value) in list.iter().enumerate() { + write!(f, "{}", value)?; + if i != list.len() - 1 { + write!(f, ",")?; + } + } + write!(f, "]") } impl std::fmt::Display for Value { @@ -236,16 +264,7 @@ impl std::fmt::Display for Value { Value::String(x) => write!(f, "{x}"), Value::Blob(x) => write!(f, "{x:x?}"), Value::Null(_) => write!(f, ""), - Value::VarList(_, x) | Value::FixedList(_, x) => { - write!(f, "[")?; - for (i, value) in x.iter().enumerate() { - write!(f, "{}", value)?; - if i != x.len() - 1 { - write!(f, ",")?; - } - } - write!(f, "]") - } + Value::VarList(_, x) | Value::FixedList(_, x) => display_list(f, x), // Note: These don't match kuzu's toString, but we probably don't want them to Value::Interval(x) => write!(f, "{x}"), Value::Timestamp(x) => write!(f, "{x}"), @@ -264,6 +283,14 @@ impl std::fmt::Display for Value { Value::Node(x) => write!(f, "{x}"), Value::Rel(x) => write!(f, "{x}"), Value::InternalID(x) => write!(f, "{x}"), + Value::RecursiveRel { nodes, rels } => { + write!(f, "{{")?; + write!(f, "_NODES: ")?; + display_list(f, nodes)?; + write!(f, ", _RELS: ")?; + display_list(f, rels)?; + write!(f, "}}") + } } } } @@ -302,6 +329,7 @@ impl From<&Value> for LogicalType { Value::InternalID(_) => LogicalType::InternalID, Value::Node(_) => LogicalType::Node, Value::Rel(_) => LogicalType::Rel, + Value::RecursiveRel { .. } => LogicalType::RecursiveRel, } } } @@ -400,8 +428,10 @@ impl TryFrom<&ffi::Value> for Value { let mut node_val = NodeVal::new(id, label); let properties = ffi::node_value_get_properties(value); for i in 0..properties.size() { - node_val - .add_property(properties.get_name(i), properties.get_value(i).try_into()?); + node_val.add_property( + properties.get_name(i), + TryInto::::try_into(properties.get_value(i))?, + ); } Ok(Value::Node(node_val)) } @@ -432,8 +462,38 @@ impl TryFrom<&ffi::Value> for Value { table_id: internal_id[1], })) } - // Should be unreachable, as cxx will check that the LogicalTypeID enum matches the one - // on the C++ side. + LogicalTypeID::RECURSIVE_REL => { + let nodes: Value = ffi::recursive_rel_get_nodes(value).try_into()?; + let rels: Value = ffi::recursive_rel_get_rels(value).try_into()?; + let nodes = if let Value::VarList(LogicalType::Node, nodes) = nodes { + nodes.into_iter().map(|x| { + if let Value::Node(x) = x { + x + } else { + unreachable!() + } + }) + } else { + panic!("Unexpected value in RecursiveRel's rels: {}", rels) + }; + let rels = if let Value::VarList(LogicalType::Rel, rels) = rels { + rels.into_iter().map(|x| { + if let Value::Rel(x) = x { + x + } else { + unreachable!() + } + }) + } else { + panic!("Unexpected value in RecursiveRel's rels: {}", rels) + }; + + Ok(Value::RecursiveRel { + nodes: nodes.collect(), + rels: rels.collect(), + }) + } + // TODO(bmwinger): Better error message for types which are unsupported x => panic!("Unsupported type {:?}", x), } } @@ -532,6 +592,9 @@ impl TryInto> for Value { } Value::Node(_) => Err(crate::Error::ReadOnlyType(LogicalType::Node)), Value::Rel(_) => Err(crate::Error::ReadOnlyType(LogicalType::Rel)), + Value::RecursiveRel { .. } => { + Err(crate::Error::ReadOnlyType(LogicalType::RecursiveRel)) + } } } } @@ -695,6 +758,7 @@ mod tests { convert_node_type: LogicalType::Node, convert_internal_id_type: LogicalType::InternalID, convert_rel_type: LogicalType::Rel, + convert_recursive_rel_type: LogicalType::RecursiveRel, } value_tests! { @@ -858,6 +922,53 @@ mod tests { Ok(()) } + #[test] + fn test_recursive_rel() -> Result<()> { + let temp_dir = tempfile::TempDir::new()?; + let db = Database::new(temp_dir.path(), 0)?; + let conn = Connection::new(&db)?; + conn.query("CREATE NODE TABLE Person(name STRING, age INT64, PRIMARY KEY(name));")?; + conn.query("CREATE REL TABLE knows(FROM Person TO Person);")?; + conn.query("CREATE (:Person {name: \"Alice\", age: 25});")?; + conn.query("CREATE (:Person {name: \"Bob\", age: 25});")?; + conn.query("CREATE (:Person {name: \"Eve\", age: 25});")?; + conn.query( + "MATCH (p1:Person), (p2:Person) + WHERE p1.name = \"Alice\" AND p2.name = \"Bob\" + CREATE (p1)-[:knows]->(p2);", + )?; + conn.query( + "MATCH (p1:Person), (p2:Person) + WHERE p1.name = \"Bob\" AND p2.name = \"Eve\" + CREATE (p1)-[:knows]->(p2);", + )?; + let result = conn + .query( + "MATCH (a:Person)-[e*2..2]->(b:Person) + WHERE a.name = 'Alice' + RETURN e, b.name;", + )? + .next() + .unwrap(); + assert_eq!(result[1], Value::String("Eve".to_string())); + assert_eq!( + result[0], + Value::RecursiveRel { + nodes: vec![NodeVal { + id: (1, 0).into(), + label: "Person".into(), + properties: vec![("name".into(), "Bob".into()), ("age".into(), 25i64.into())] + },], + rels: vec![ + RelVal::new((0, 0), (1, 0), "knows"), + RelVal::new((1, 0), (2, 0), "knows"), + ], + } + ); + temp_dir.close()?; + Ok(()) + } + #[test] /// Test that null values are read correctly by the API fn test_null() -> Result<()> {