Skip to content

Commit

Permalink
Merge pull request #1766 from kuzudb/rust-node-update
Browse files Browse the repository at this point in the history
Update Node and Rel in the Rust API to use the new interface
  • Loading branch information
andyfengHKU committed Jul 6, 2023
2 parents 3c3e6a8 + 425dc51 commit 10f5c58
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 126 deletions.
38 changes: 16 additions & 22 deletions tools/rust_api/include/kuzu_rs.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,33 +91,34 @@ std::unique_ptr<std::vector<kuzu::common::LogicalType>> query_result_column_data
rust::Vec<rust::String> query_result_column_names(const kuzu::main::QueryResult& query_result);

/* NodeVal/RelVal */
template <typename T>
struct PropertyList {
const std::vector<std::pair<std::string, std::unique_ptr<kuzu::common::Value>>>& properties;
const kuzu::common::Value& value;

size_t size() const { return properties.size(); }
size_t size() const { return T::getNumProperties(&value); }
rust::String get_name(size_t index) const {
return rust::String(this->properties[index].first);
return rust::String(T::getPropertyName(&value, index));
}
const kuzu::common::Value& get_value(size_t index) const {
return *this->properties[index].second.get();
return *T::getPropertyValueReference(&value, index);
}
};

template<typename T>
rust::String value_get_label_name(const T& val) {
return val.getLabelName();
}
template<typename T>
std::unique_ptr<PropertyList> value_get_properties(const T& val) {
return std::make_unique<PropertyList>(val.getProperties());
}
using NodeValuePropertyList = PropertyList<kuzu::common::NodeVal>;
using RelValuePropertyList = PropertyList<kuzu::common::RelVal>;

rust::String node_value_get_label_name(const kuzu::common::Value& val);
rust::String rel_value_get_label_name(const kuzu::common::Value& val);

std::unique_ptr<NodeValuePropertyList> node_value_get_properties(const kuzu::common::Value& val);
std::unique_ptr<RelValuePropertyList> rel_value_get_properties(const kuzu::common::Value& val);

/* NodeVal */
std::array<uint64_t, 2> node_value_get_node_id(const kuzu::common::NodeVal& val);
std::array<uint64_t, 2> node_value_get_node_id(const kuzu::common::Value& val);

/* RelVal */
std::array<uint64_t, 2> rel_value_get_src_id(const kuzu::common::RelVal& val);
std::array<uint64_t, 2> rel_value_get_dst_id(const kuzu::common::RelVal& val);
std::array<uint64_t, 2> rel_value_get_src_id(const kuzu::common::Value& val);
std::array<uint64_t, 2> rel_value_get_dst_id(const kuzu::common::Value& val);

/* FlatTuple */
const kuzu::common::Value& flat_tuple_get_value(
Expand Down Expand Up @@ -149,10 +150,6 @@ std::unique_ptr<kuzu::common::Value> create_value_interval(
std::unique_ptr<kuzu::common::Value> create_value_null(
std::unique_ptr<kuzu::common::LogicalType> typ);
std::unique_ptr<kuzu::common::Value> create_value_internal_id(uint64_t offset, uint64_t table);
std::unique_ptr<kuzu::common::Value> create_value_node(
std::unique_ptr<kuzu::common::Value> id_val, std::unique_ptr<kuzu::common::Value> label_val);
std::unique_ptr<kuzu::common::Value> create_value_rel(std::unique_ptr<kuzu::common::Value> src_id,
std::unique_ptr<kuzu::common::Value> dst_id, std::unique_ptr<kuzu::common::Value> label_val);

template<typename T>
std::unique_ptr<kuzu::common::Value> create_value(const T value) {
Expand All @@ -169,7 +166,4 @@ std::unique_ptr<kuzu::common::Value> get_list_value(
std::unique_ptr<kuzu::common::LogicalType> typ, std::unique_ptr<ValueListBuilder> value);
std::unique_ptr<ValueListBuilder> create_list();

void value_add_property(kuzu::common::Value& val, const rust::String& name,
std::unique_ptr<kuzu::common::Value> property);

} // namespace kuzu_rs
7 changes: 5 additions & 2 deletions tools/rust_api/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::logical_type::LogicalType;
use std::fmt;

pub enum Error {
Expand All @@ -7,6 +8,8 @@ pub enum Error {
FailedQuery(String),
/// Message produced by kuzu when a query fails to prepare
FailedPreparedStatement(String),
/// Message produced when you attempt to pass read-only types over the FFI boundary
ReadOnlyType(LogicalType),
}

impl std::fmt::Display for Error {
Expand All @@ -16,6 +19,7 @@ impl std::fmt::Display for Error {
CxxException(cxx) => write!(f, "{cxx}"),
FailedQuery(message) => write!(f, "Query execution failed: {message}"),
FailedPreparedStatement(message) => write!(f, "Query execution failed: {message}"),
ReadOnlyType(typ) => write!(f, "Attempted to pass read only type {:?} over ffi!", typ),
}
}
}
Expand All @@ -31,8 +35,7 @@ impl std::error::Error for Error {
use Error::*;
match self {
CxxException(cxx) => Some(cxx),
FailedQuery(_) => None,
FailedPreparedStatement(_) => None,
_ => None,
}
}
}
Expand Down
50 changes: 14 additions & 36 deletions tools/rust_api/src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,10 +246,6 @@ pub(crate) mod ffi {
fn getValue(&self) -> f64;

fn value_get_string(value: &Value) -> String;
#[rust_name = "value_get_node_val"]
fn value_get_unique(value: &Value) -> UniquePtr<NodeVal>;
#[rust_name = "value_get_rel_val"]
fn value_get_unique(value: &Value) -> UniquePtr<RelVal>;
fn value_get_interval_secs(value: &Value) -> i64;
fn value_get_interval_micros(value: &Value) -> i32;
fn value_get_timestamp_micros(value: &Value) -> i64;
Expand Down Expand Up @@ -281,22 +277,21 @@ pub(crate) mod ffi {
fn create_value_date(value: i64) -> UniquePtr<Value>;
fn create_value_interval(months: i32, days: i32, micros: i64) -> UniquePtr<Value>;
fn create_value_internal_id(offset: u64, table: u64) -> UniquePtr<Value>;
fn create_value_node(
id_val: UniquePtr<Value>,
label_val: UniquePtr<Value>,
) -> UniquePtr<Value>;
fn create_value_rel(
src_id: UniquePtr<Value>,
dst_id: UniquePtr<Value>,
label_val: UniquePtr<Value>,
) -> UniquePtr<Value>;

fn value_add_property(value: Pin<&mut Value>, name: &String, property: UniquePtr<Value>);
fn node_value_get_properties(node_value: &Value) -> UniquePtr<NodeValuePropertyList>;
fn node_value_get_node_id(value: &Value) -> [u64; 2];
fn node_value_get_label_name(value: &Value) -> String;

fn rel_value_get_properties(node_value: &Value) -> UniquePtr<RelValuePropertyList>;
fn rel_value_get_label_name(value: &Value) -> String;

fn rel_value_get_src_id(value: &Value) -> [u64; 2];
fn rel_value_get_dst_id(value: &Value) -> [u64; 2];
}

#[namespace = "kuzu_rs"]
unsafe extern "C++" {
type PropertyList<'a>;
type NodeValuePropertyList<'a>;

fn size<'a>(&'a self) -> usize;
fn get_name<'a>(&'a self, index: usize) -> String;
Expand All @@ -305,27 +300,10 @@ pub(crate) mod ffi {

#[namespace = "kuzu_rs"]
unsafe extern "C++" {
#[namespace = "kuzu::common"]
type NodeVal;

#[rust_name = "node_value_get_properties"]
fn value_get_properties(node_value: &NodeVal) -> UniquePtr<PropertyList>;
fn node_value_get_node_id(value: &NodeVal) -> [u64; 2];
#[rust_name = "node_value_get_label_name"]
fn value_get_label_name(value: &NodeVal) -> String;
}
type RelValuePropertyList<'a>;

#[namespace = "kuzu_rs"]
unsafe extern "C++" {
#[namespace = "kuzu::common"]
type RelVal;

#[rust_name = "rel_value_get_properties"]
fn value_get_properties(node_value: &RelVal) -> UniquePtr<PropertyList>;
#[rust_name = "rel_value_get_label_name"]
fn value_get_label_name(value: &RelVal) -> String;

fn rel_value_get_src_id(value: &RelVal) -> [u64; 2];
fn rel_value_get_dst_id(value: &RelVal) -> [u64; 2];
fn size<'a>(&'a self) -> usize;
fn get_name<'a>(&'a self, index: usize) -> String;
fn get_value<'a>(&'a self, index: usize) -> &'a Value;
}
}
51 changes: 21 additions & 30 deletions tools/rust_api/src/kuzu_rs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,17 +142,32 @@ rust::Vec<rust::String> query_result_column_names(const kuzu::main::QueryResult&
return names;
}

std::array<uint64_t, 2> node_value_get_node_id(const kuzu::common::NodeVal& val) {
auto internalID = val.getNodeID();
std::unique_ptr<NodeValuePropertyList> node_value_get_properties(const Value& val) {
return std::make_unique<NodeValuePropertyList>(val);
}
std::unique_ptr<RelValuePropertyList> rel_value_get_properties(const Value& val) {
return std::make_unique<RelValuePropertyList>(val);
}

rust::String node_value_get_label_name(const kuzu::common::Value& val) {
return rust::String(kuzu::common::NodeVal::getLabelName(&val));
}

rust::String rel_value_get_label_name(const kuzu::common::Value& val) {
return rust::String(kuzu::common::RelVal::getLabelName(&val));
}

std::array<uint64_t, 2> node_value_get_node_id(const kuzu::common::Value& val) {
auto internalID = kuzu::common::NodeVal::getNodeID(&val);
return std::array{internalID.offset, internalID.tableID};
}

std::array<uint64_t, 2> rel_value_get_src_id(const kuzu::common::RelVal& val) {
auto internalID = val.getSrcNodeID();
std::array<uint64_t, 2> rel_value_get_src_id(const kuzu::common::Value& val) {
auto internalID = kuzu::common::RelVal::getSrcNodeID(&val);
return std::array{internalID.offset, internalID.tableID};
}
std::array<uint64_t, 2> rel_value_get_dst_id(const kuzu::common::RelVal& val) {
auto internalID = val.getDstNodeID();
std::array<uint64_t, 2> rel_value_get_dst_id(const kuzu::common::Value& val) {
auto internalID = kuzu::common::RelVal::getDstNodeID(&val);
return std::array{internalID.offset, internalID.tableID};
}

Expand Down Expand Up @@ -225,17 +240,6 @@ std::unique_ptr<kuzu::common::Value> create_value_null(
std::unique_ptr<kuzu::common::Value> create_value_internal_id(uint64_t offset, uint64_t table) {
return std::make_unique<kuzu::common::Value>(kuzu::common::internalID_t(offset, table));
}
std::unique_ptr<Value> create_value_node(
std::unique_ptr<Value> id_val, std::unique_ptr<Value> label_val) {
return std::make_unique<Value>(
std::make_unique<kuzu::common::NodeVal>(std::move(id_val), std::move(label_val)));
}

std::unique_ptr<kuzu::common::Value> create_value_rel(std::unique_ptr<kuzu::common::Value> src_id,
std::unique_ptr<kuzu::common::Value> dst_id, std::unique_ptr<kuzu::common::Value> label_val) {
return std::make_unique<Value>(std::make_unique<kuzu::common::RelVal>(
std::move(src_id), std::move(dst_id), std::move(label_val)));
}

std::unique_ptr<kuzu::common::Value> get_list_value(
std::unique_ptr<kuzu::common::LogicalType> typ, std::unique_ptr<ValueListBuilder> value) {
Expand All @@ -250,17 +254,4 @@ std::unique_ptr<TypeListBuilder> create_type_list() {
return std::make_unique<TypeListBuilder>();
}

void value_add_property(kuzu::common::Value& val, const rust::String& name,
std::unique_ptr<kuzu::common::Value> property) {
if (val.getDataType().getLogicalTypeID() == kuzu::common::LogicalTypeID::NODE) {
kuzu::common::NodeVal& nodeVal = val.getValueReference<kuzu::common::NodeVal>();
nodeVal.addProperty(std::string(name), std::move(property));
} else if (val.getDataType().getLogicalTypeID() == kuzu::common::LogicalTypeID::REL) {
kuzu::common::RelVal& relVal = val.getValueReference<kuzu::common::RelVal>();
relVal.addProperty(std::string(name), std::move(property));
} else {
throw std::runtime_error("Internal Error! Adding property to type without properties!");
}
}

} // namespace kuzu_rs
48 changes: 12 additions & 36 deletions tools/rust_api/src/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -381,25 +381,23 @@ impl TryFrom<&ffi::Value> for Value {
Ok(Value::Struct(result))
}
LogicalTypeID::NODE => {
let ffi_node_val = ffi::value_get_node_val(value);
let id = ffi::node_value_get_node_id(ffi_node_val.as_ref().unwrap());
let id = ffi::node_value_get_node_id(value);
let id = InternalID {
offset: id[0],
table_id: id[1],
};
let label = ffi::node_value_get_label_name(ffi_node_val.as_ref().unwrap());
let label = ffi::node_value_get_label_name(value);
let mut node_val = NodeVal::new(id, label);
let properties = ffi::node_value_get_properties(ffi_node_val.as_ref().unwrap());
let properties = ffi::node_value_get_properties(value);
for i in 0..properties.size() {
node_val
.add_property(properties.get_name(i), properties.get_value(i).try_into()?);
}
Ok(Value::Node(node_val))
}
LogicalTypeID::REL => {
let ffi_rel_val = ffi::value_get_rel_val(value);
let src_node = ffi::rel_value_get_src_id(ffi_rel_val.as_ref().unwrap());
let dst_node = ffi::rel_value_get_dst_id(ffi_rel_val.as_ref().unwrap());
let src_node = ffi::rel_value_get_src_id(value);
let dst_node = ffi::rel_value_get_dst_id(value);
let src_node = InternalID {
offset: src_node[0],
table_id: src_node[1],
Expand All @@ -408,9 +406,9 @@ impl TryFrom<&ffi::Value> for Value {
offset: dst_node[0],
table_id: dst_node[1],
};
let label = ffi::rel_value_get_label_name(ffi_rel_val.as_ref().unwrap());
let label = ffi::rel_value_get_label_name(value);
let mut rel_val = RelVal::new(src_node, dst_node, label);
let properties = ffi::rel_value_get_properties(ffi_rel_val.as_ref().unwrap());
let properties = ffi::rel_value_get_properties(value);
for i in 0..properties.size() {
rel_val
.add_property(properties.get_name(i), properties.get_value(i).try_into()?);
Expand Down Expand Up @@ -518,27 +516,8 @@ impl TryInto<cxx::UniquePtr<ffi::Value>> for Value {
Value::InternalID(value) => {
Ok(ffi::create_value_internal_id(value.offset, value.table_id))
}
Value::Node(value) => {
let mut node = ffi::create_value_node(
Value::InternalID(value.id).try_into()?,
Value::String(value.label).try_into()?,
);
for (name, property) in value.properties {
ffi::value_add_property(node.pin_mut(), &name, property.try_into()?);
}
Ok(node)
}
Value::Rel(value) => {
let mut rel = ffi::create_value_rel(
Value::InternalID(value.src_node).try_into()?,
Value::InternalID(value.dst_node).try_into()?,
Value::String(value.label).try_into()?,
);
for (name, property) in value.properties {
ffi::value_add_property(rel.pin_mut(), &name, property.try_into()?);
}
Ok(rel)
}
Value::Node(_) => Err(crate::Error::ReadOnlyType(LogicalType::Node)),
Value::Rel(_) => Err(crate::Error::ReadOnlyType(LogicalType::Rel)),
}
}
}
Expand Down Expand Up @@ -640,7 +619,7 @@ mod tests {
($($name:ident: $value:expr,)*) => {
$(
#[test]
/// Tests that the values are correctly converted into kuzu::common::Value and back
/// Tests that the values display the same via the rust API as via the C++ API
fn $name() -> Result<()> {
let rust_value: Value = $value;
let value: cxx::UniquePtr<ffi::Value> = rust_value.clone().try_into()?;
Expand Down Expand Up @@ -722,8 +701,6 @@ mod tests {
}),
convert_struct: Value::Struct(vec![("NAME".to_string(), "Alice".into()), ("AGE".to_string(), 25.into())]),
convert_internal_id: Value::InternalID(InternalID { table_id: 0, offset: 0 }),
convert_node: Value::Node(NodeVal::new(InternalID { table_id: 0, offset: 0 }, "Test Label".to_string())),
convert_rel: Value::Rel(RelVal::new(InternalID { table_id: 0, offset: 0 }, InternalID { table_id: 1, offset: 0 }, "Test Label".to_string())),
}

display_tests! {
Expand All @@ -733,7 +710,7 @@ mod tests {
display_int16: Value::Int16(1),
display_int32: Value::Int32(2),
display_int64: Value::Int64(3),
// Float, doble, interval and timestamp have display differences which we probably don't want to
// Float, double, interval and timestamp have display differences which we probably don't want to
// reconcile
display_date: Value::Date(date!(2023-06-13)),
display_string: Value::String("Hello World".to_string()),
Expand All @@ -743,8 +720,7 @@ mod tests {
}),
display_struct: Value::Struct(vec![("NAME".to_string(), "Alice".into()), ("AGE".to_string(), 25.into())]),
display_internal_id: Value::InternalID(InternalID { table_id: 0, offset: 0 }),
display_node: Value::Node(NodeVal::new(InternalID { table_id: 0, offset: 0 }, "Test Label".to_string())),
display_rel: Value::Rel(RelVal::new(InternalID { table_id: 0, offset: 0 }, InternalID { table_id: 1, offset: 0 }, "Test Label".to_string())),
// Node and Rel Cannot be easily created on the C++ side
}

database_tests! {
Expand Down

0 comments on commit 10f5c58

Please sign in to comment.