diff --git a/tools/rust_api/include/kuzu_rs.h b/tools/rust_api/include/kuzu_rs.h index 32d6f81450..b52b1bdf6a 100644 --- a/tools/rust_api/include/kuzu_rs.h +++ b/tools/rust_api/include/kuzu_rs.h @@ -91,33 +91,34 @@ std::unique_ptr> query_result_column_data rust::Vec query_result_column_names(const kuzu::main::QueryResult& query_result); /* NodeVal/RelVal */ +template struct PropertyList { - const std::vector>>& properties; + const kuzu::common::Value& value; - size_t size() const { return properties.size(); } + size_t size() const { return T::getNumProperties(&value); } rust::String get_name(size_t index) const { - return rust::String(this->properties[index].first); + return rust::String(T::getPropertyName(&value, index)); } const kuzu::common::Value& get_value(size_t index) const { - return *this->properties[index].second.get(); + return *T::getPropertyValueReference(&value, index); } }; -template -rust::String value_get_label_name(const T& val) { - return val.getLabelName(); -} -template -std::unique_ptr value_get_properties(const T& val) { - return std::make_unique(val.getProperties()); -} +using NodeValuePropertyList = PropertyList; +using RelValuePropertyList = PropertyList; + +rust::String node_value_get_label_name(const kuzu::common::Value& val); +rust::String rel_value_get_label_name(const kuzu::common::Value& val); + +std::unique_ptr node_value_get_properties(const kuzu::common::Value& val); +std::unique_ptr rel_value_get_properties(const kuzu::common::Value& val); /* NodeVal */ -std::array node_value_get_node_id(const kuzu::common::NodeVal& val); +std::array node_value_get_node_id(const kuzu::common::Value& val); /* RelVal */ -std::array rel_value_get_src_id(const kuzu::common::RelVal& val); -std::array rel_value_get_dst_id(const kuzu::common::RelVal& val); +std::array rel_value_get_src_id(const kuzu::common::Value& val); +std::array rel_value_get_dst_id(const kuzu::common::Value& val); /* FlatTuple */ const kuzu::common::Value& flat_tuple_get_value( @@ -149,10 +150,6 @@ std::unique_ptr create_value_interval( std::unique_ptr create_value_null( std::unique_ptr typ); std::unique_ptr create_value_internal_id(uint64_t offset, uint64_t table); -std::unique_ptr create_value_node( - std::unique_ptr id_val, std::unique_ptr label_val); -std::unique_ptr create_value_rel(std::unique_ptr src_id, - std::unique_ptr dst_id, std::unique_ptr label_val); template std::unique_ptr create_value(const T value) { @@ -169,7 +166,4 @@ std::unique_ptr get_list_value( std::unique_ptr typ, std::unique_ptr value); std::unique_ptr create_list(); -void value_add_property(kuzu::common::Value& val, const rust::String& name, - std::unique_ptr property); - } // namespace kuzu_rs diff --git a/tools/rust_api/src/error.rs b/tools/rust_api/src/error.rs index cd8701f21e..bfbc10ec4d 100644 --- a/tools/rust_api/src/error.rs +++ b/tools/rust_api/src/error.rs @@ -1,3 +1,4 @@ +use crate::logical_type::LogicalType; use std::fmt; pub enum Error { @@ -7,6 +8,8 @@ pub enum Error { FailedQuery(String), /// Message produced by kuzu when a query fails to prepare FailedPreparedStatement(String), + /// Message produced when you attempt to pass read-only types over the FFI boundary + ReadOnlyType(LogicalType), } impl std::fmt::Display for Error { @@ -16,6 +19,7 @@ impl std::fmt::Display for Error { CxxException(cxx) => write!(f, "{cxx}"), FailedQuery(message) => write!(f, "Query execution failed: {message}"), FailedPreparedStatement(message) => write!(f, "Query execution failed: {message}"), + ReadOnlyType(typ) => write!(f, "Attempted to pass read only type {:?} over ffi!", typ), } } } @@ -31,8 +35,7 @@ impl std::error::Error for Error { use Error::*; match self { CxxException(cxx) => Some(cxx), - FailedQuery(_) => None, - FailedPreparedStatement(_) => None, + _ => None, } } } diff --git a/tools/rust_api/src/ffi.rs b/tools/rust_api/src/ffi.rs index 319c42af1d..4366d4b4cd 100644 --- a/tools/rust_api/src/ffi.rs +++ b/tools/rust_api/src/ffi.rs @@ -246,10 +246,6 @@ pub(crate) mod ffi { fn getValue(&self) -> f64; fn value_get_string(value: &Value) -> String; - #[rust_name = "value_get_node_val"] - fn value_get_unique(value: &Value) -> UniquePtr; - #[rust_name = "value_get_rel_val"] - fn value_get_unique(value: &Value) -> UniquePtr; fn value_get_interval_secs(value: &Value) -> i64; fn value_get_interval_micros(value: &Value) -> i32; fn value_get_timestamp_micros(value: &Value) -> i64; @@ -281,22 +277,21 @@ pub(crate) mod ffi { fn create_value_date(value: i64) -> UniquePtr; fn create_value_interval(months: i32, days: i32, micros: i64) -> UniquePtr; fn create_value_internal_id(offset: u64, table: u64) -> UniquePtr; - fn create_value_node( - id_val: UniquePtr, - label_val: UniquePtr, - ) -> UniquePtr; - fn create_value_rel( - src_id: UniquePtr, - dst_id: UniquePtr, - label_val: UniquePtr, - ) -> UniquePtr; - fn value_add_property(value: Pin<&mut Value>, name: &String, property: UniquePtr); + fn node_value_get_properties(node_value: &Value) -> UniquePtr; + fn node_value_get_node_id(value: &Value) -> [u64; 2]; + fn node_value_get_label_name(value: &Value) -> String; + + fn rel_value_get_properties(node_value: &Value) -> UniquePtr; + fn rel_value_get_label_name(value: &Value) -> String; + + fn rel_value_get_src_id(value: &Value) -> [u64; 2]; + fn rel_value_get_dst_id(value: &Value) -> [u64; 2]; } #[namespace = "kuzu_rs"] unsafe extern "C++" { - type PropertyList<'a>; + type NodeValuePropertyList<'a>; fn size<'a>(&'a self) -> usize; fn get_name<'a>(&'a self, index: usize) -> String; @@ -305,27 +300,10 @@ pub(crate) mod ffi { #[namespace = "kuzu_rs"] unsafe extern "C++" { - #[namespace = "kuzu::common"] - type NodeVal; - - #[rust_name = "node_value_get_properties"] - fn value_get_properties(node_value: &NodeVal) -> UniquePtr; - fn node_value_get_node_id(value: &NodeVal) -> [u64; 2]; - #[rust_name = "node_value_get_label_name"] - fn value_get_label_name(value: &NodeVal) -> String; - } + type RelValuePropertyList<'a>; - #[namespace = "kuzu_rs"] - unsafe extern "C++" { - #[namespace = "kuzu::common"] - type RelVal; - - #[rust_name = "rel_value_get_properties"] - fn value_get_properties(node_value: &RelVal) -> UniquePtr; - #[rust_name = "rel_value_get_label_name"] - fn value_get_label_name(value: &RelVal) -> String; - - fn rel_value_get_src_id(value: &RelVal) -> [u64; 2]; - fn rel_value_get_dst_id(value: &RelVal) -> [u64; 2]; + fn size<'a>(&'a self) -> usize; + fn get_name<'a>(&'a self, index: usize) -> String; + fn get_value<'a>(&'a self, index: usize) -> &'a Value; } } diff --git a/tools/rust_api/src/kuzu_rs.cpp b/tools/rust_api/src/kuzu_rs.cpp index 15788d7c2a..18db538c5f 100644 --- a/tools/rust_api/src/kuzu_rs.cpp +++ b/tools/rust_api/src/kuzu_rs.cpp @@ -142,17 +142,32 @@ rust::Vec query_result_column_names(const kuzu::main::QueryResult& return names; } -std::array node_value_get_node_id(const kuzu::common::NodeVal& val) { - auto internalID = val.getNodeID(); +std::unique_ptr node_value_get_properties(const Value& val) { + return std::make_unique(val); +} +std::unique_ptr rel_value_get_properties(const Value& val) { + return std::make_unique(val); +} + +rust::String node_value_get_label_name(const kuzu::common::Value& val) { + return rust::String(kuzu::common::NodeVal::getLabelName(&val)); +} + +rust::String rel_value_get_label_name(const kuzu::common::Value& val) { + return rust::String(kuzu::common::RelVal::getLabelName(&val)); +} + +std::array node_value_get_node_id(const kuzu::common::Value& val) { + auto internalID = kuzu::common::NodeVal::getNodeID(&val); return std::array{internalID.offset, internalID.tableID}; } -std::array rel_value_get_src_id(const kuzu::common::RelVal& val) { - auto internalID = val.getSrcNodeID(); +std::array rel_value_get_src_id(const kuzu::common::Value& val) { + auto internalID = kuzu::common::RelVal::getSrcNodeID(&val); return std::array{internalID.offset, internalID.tableID}; } -std::array rel_value_get_dst_id(const kuzu::common::RelVal& val) { - auto internalID = val.getDstNodeID(); +std::array rel_value_get_dst_id(const kuzu::common::Value& val) { + auto internalID = kuzu::common::RelVal::getDstNodeID(&val); return std::array{internalID.offset, internalID.tableID}; } @@ -225,17 +240,6 @@ std::unique_ptr create_value_null( std::unique_ptr create_value_internal_id(uint64_t offset, uint64_t table) { return std::make_unique(kuzu::common::internalID_t(offset, table)); } -std::unique_ptr create_value_node( - std::unique_ptr id_val, std::unique_ptr label_val) { - return std::make_unique( - std::make_unique(std::move(id_val), std::move(label_val))); -} - -std::unique_ptr create_value_rel(std::unique_ptr src_id, - std::unique_ptr dst_id, std::unique_ptr label_val) { - return std::make_unique(std::make_unique( - std::move(src_id), std::move(dst_id), std::move(label_val))); -} std::unique_ptr get_list_value( std::unique_ptr typ, std::unique_ptr value) { @@ -250,17 +254,4 @@ std::unique_ptr create_type_list() { return std::make_unique(); } -void value_add_property(kuzu::common::Value& val, const rust::String& name, - std::unique_ptr property) { - if (val.getDataType().getLogicalTypeID() == kuzu::common::LogicalTypeID::NODE) { - kuzu::common::NodeVal& nodeVal = val.getValueReference(); - nodeVal.addProperty(std::string(name), std::move(property)); - } else if (val.getDataType().getLogicalTypeID() == kuzu::common::LogicalTypeID::REL) { - kuzu::common::RelVal& relVal = val.getValueReference(); - relVal.addProperty(std::string(name), std::move(property)); - } else { - throw std::runtime_error("Internal Error! Adding property to type without properties!"); - } -} - } // namespace kuzu_rs diff --git a/tools/rust_api/src/value.rs b/tools/rust_api/src/value.rs index 3eccf4dbcd..b73c556d33 100644 --- a/tools/rust_api/src/value.rs +++ b/tools/rust_api/src/value.rs @@ -381,15 +381,14 @@ impl TryFrom<&ffi::Value> for Value { Ok(Value::Struct(result)) } LogicalTypeID::NODE => { - let ffi_node_val = ffi::value_get_node_val(value); - let id = ffi::node_value_get_node_id(ffi_node_val.as_ref().unwrap()); + let id = ffi::node_value_get_node_id(value); let id = InternalID { offset: id[0], table_id: id[1], }; - let label = ffi::node_value_get_label_name(ffi_node_val.as_ref().unwrap()); + let label = ffi::node_value_get_label_name(value); let mut node_val = NodeVal::new(id, label); - let properties = ffi::node_value_get_properties(ffi_node_val.as_ref().unwrap()); + let properties = ffi::node_value_get_properties(value); for i in 0..properties.size() { node_val .add_property(properties.get_name(i), properties.get_value(i).try_into()?); @@ -397,9 +396,8 @@ impl TryFrom<&ffi::Value> for Value { Ok(Value::Node(node_val)) } LogicalTypeID::REL => { - let ffi_rel_val = ffi::value_get_rel_val(value); - let src_node = ffi::rel_value_get_src_id(ffi_rel_val.as_ref().unwrap()); - let dst_node = ffi::rel_value_get_dst_id(ffi_rel_val.as_ref().unwrap()); + let src_node = ffi::rel_value_get_src_id(value); + let dst_node = ffi::rel_value_get_dst_id(value); let src_node = InternalID { offset: src_node[0], table_id: src_node[1], @@ -408,9 +406,9 @@ impl TryFrom<&ffi::Value> for Value { offset: dst_node[0], table_id: dst_node[1], }; - let label = ffi::rel_value_get_label_name(ffi_rel_val.as_ref().unwrap()); + let label = ffi::rel_value_get_label_name(value); let mut rel_val = RelVal::new(src_node, dst_node, label); - let properties = ffi::rel_value_get_properties(ffi_rel_val.as_ref().unwrap()); + let properties = ffi::rel_value_get_properties(value); for i in 0..properties.size() { rel_val .add_property(properties.get_name(i), properties.get_value(i).try_into()?); @@ -518,27 +516,8 @@ impl TryInto> for Value { Value::InternalID(value) => { Ok(ffi::create_value_internal_id(value.offset, value.table_id)) } - Value::Node(value) => { - let mut node = ffi::create_value_node( - Value::InternalID(value.id).try_into()?, - Value::String(value.label).try_into()?, - ); - for (name, property) in value.properties { - ffi::value_add_property(node.pin_mut(), &name, property.try_into()?); - } - Ok(node) - } - Value::Rel(value) => { - let mut rel = ffi::create_value_rel( - Value::InternalID(value.src_node).try_into()?, - Value::InternalID(value.dst_node).try_into()?, - Value::String(value.label).try_into()?, - ); - for (name, property) in value.properties { - ffi::value_add_property(rel.pin_mut(), &name, property.try_into()?); - } - Ok(rel) - } + Value::Node(_) => Err(crate::Error::ReadOnlyType(LogicalType::Node)), + Value::Rel(_) => Err(crate::Error::ReadOnlyType(LogicalType::Rel)), } } } @@ -640,7 +619,7 @@ mod tests { ($($name:ident: $value:expr,)*) => { $( #[test] - /// Tests that the values are correctly converted into kuzu::common::Value and back + /// Tests that the values display the same via the rust API as via the C++ API fn $name() -> Result<()> { let rust_value: Value = $value; let value: cxx::UniquePtr = rust_value.clone().try_into()?; @@ -722,8 +701,6 @@ mod tests { }), convert_struct: Value::Struct(vec![("NAME".to_string(), "Alice".into()), ("AGE".to_string(), 25.into())]), convert_internal_id: Value::InternalID(InternalID { table_id: 0, offset: 0 }), - convert_node: Value::Node(NodeVal::new(InternalID { table_id: 0, offset: 0 }, "Test Label".to_string())), - convert_rel: Value::Rel(RelVal::new(InternalID { table_id: 0, offset: 0 }, InternalID { table_id: 1, offset: 0 }, "Test Label".to_string())), } display_tests! { @@ -733,7 +710,7 @@ mod tests { display_int16: Value::Int16(1), display_int32: Value::Int32(2), display_int64: Value::Int64(3), - // Float, doble, interval and timestamp have display differences which we probably don't want to + // Float, double, interval and timestamp have display differences which we probably don't want to // reconcile display_date: Value::Date(date!(2023-06-13)), display_string: Value::String("Hello World".to_string()), @@ -743,8 +720,7 @@ mod tests { }), display_struct: Value::Struct(vec![("NAME".to_string(), "Alice".into()), ("AGE".to_string(), 25.into())]), display_internal_id: Value::InternalID(InternalID { table_id: 0, offset: 0 }), - display_node: Value::Node(NodeVal::new(InternalID { table_id: 0, offset: 0 }, "Test Label".to_string())), - display_rel: Value::Rel(RelVal::new(InternalID { table_id: 0, offset: 0 }, InternalID { table_id: 1, offset: 0 }, "Test Label".to_string())), + // Node and Rel Cannot be easily created on the C++ side } database_tests! {