Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Node and Rel in the Rust API to use the new interface #1766

Merged
merged 1 commit into from
Jul 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 16 additions & 22 deletions tools/rust_api/include/kuzu_rs.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,33 +91,34 @@ std::unique_ptr<std::vector<kuzu::common::LogicalType>> query_result_column_data
rust::Vec<rust::String> query_result_column_names(const kuzu::main::QueryResult& query_result);

/* NodeVal/RelVal */
template <typename T>
struct PropertyList {
const std::vector<std::pair<std::string, std::unique_ptr<kuzu::common::Value>>>& properties;
const kuzu::common::Value& value;

size_t size() const { return properties.size(); }
size_t size() const { return T::getNumProperties(&value); }
rust::String get_name(size_t index) const {
return rust::String(this->properties[index].first);
return rust::String(T::getPropertyName(&value, index));
}
const kuzu::common::Value& get_value(size_t index) const {
return *this->properties[index].second.get();
return *T::getPropertyValueReference(&value, index);
}
};

template<typename T>
rust::String value_get_label_name(const T& val) {
return val.getLabelName();
}
template<typename T>
std::unique_ptr<PropertyList> value_get_properties(const T& val) {
return std::make_unique<PropertyList>(val.getProperties());
}
using NodeValuePropertyList = PropertyList<kuzu::common::NodeVal>;
using RelValuePropertyList = PropertyList<kuzu::common::RelVal>;

rust::String node_value_get_label_name(const kuzu::common::Value& val);
rust::String rel_value_get_label_name(const kuzu::common::Value& val);

std::unique_ptr<NodeValuePropertyList> node_value_get_properties(const kuzu::common::Value& val);
std::unique_ptr<RelValuePropertyList> rel_value_get_properties(const kuzu::common::Value& val);

/* NodeVal */
std::array<uint64_t, 2> node_value_get_node_id(const kuzu::common::NodeVal& val);
std::array<uint64_t, 2> node_value_get_node_id(const kuzu::common::Value& val);

/* RelVal */
std::array<uint64_t, 2> rel_value_get_src_id(const kuzu::common::RelVal& val);
std::array<uint64_t, 2> rel_value_get_dst_id(const kuzu::common::RelVal& val);
std::array<uint64_t, 2> rel_value_get_src_id(const kuzu::common::Value& val);
std::array<uint64_t, 2> rel_value_get_dst_id(const kuzu::common::Value& val);

/* FlatTuple */
const kuzu::common::Value& flat_tuple_get_value(
Expand Down Expand Up @@ -149,10 +150,6 @@ std::unique_ptr<kuzu::common::Value> create_value_interval(
std::unique_ptr<kuzu::common::Value> create_value_null(
std::unique_ptr<kuzu::common::LogicalType> typ);
std::unique_ptr<kuzu::common::Value> create_value_internal_id(uint64_t offset, uint64_t table);
std::unique_ptr<kuzu::common::Value> create_value_node(
std::unique_ptr<kuzu::common::Value> id_val, std::unique_ptr<kuzu::common::Value> label_val);
std::unique_ptr<kuzu::common::Value> create_value_rel(std::unique_ptr<kuzu::common::Value> src_id,
std::unique_ptr<kuzu::common::Value> dst_id, std::unique_ptr<kuzu::common::Value> label_val);

template<typename T>
std::unique_ptr<kuzu::common::Value> create_value(const T value) {
Expand All @@ -169,7 +166,4 @@ std::unique_ptr<kuzu::common::Value> get_list_value(
std::unique_ptr<kuzu::common::LogicalType> typ, std::unique_ptr<ValueListBuilder> value);
std::unique_ptr<ValueListBuilder> create_list();

void value_add_property(kuzu::common::Value& val, const rust::String& name,
std::unique_ptr<kuzu::common::Value> property);

} // namespace kuzu_rs
7 changes: 5 additions & 2 deletions tools/rust_api/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::logical_type::LogicalType;
use std::fmt;

pub enum Error {
Expand All @@ -7,6 +8,8 @@ pub enum Error {
FailedQuery(String),
/// Message produced by kuzu when a query fails to prepare
FailedPreparedStatement(String),
/// Message produced when you attempt to pass read-only types over the FFI boundary
ReadOnlyType(LogicalType),
}

impl std::fmt::Display for Error {
Expand All @@ -16,6 +19,7 @@ impl std::fmt::Display for Error {
CxxException(cxx) => write!(f, "{cxx}"),
FailedQuery(message) => write!(f, "Query execution failed: {message}"),
FailedPreparedStatement(message) => write!(f, "Query execution failed: {message}"),
ReadOnlyType(typ) => write!(f, "Attempted to pass read only type {:?} over ffi!", typ),
}
}
}
Expand All @@ -31,8 +35,7 @@ impl std::error::Error for Error {
use Error::*;
match self {
CxxException(cxx) => Some(cxx),
FailedQuery(_) => None,
FailedPreparedStatement(_) => None,
_ => None,
}
}
}
Expand Down
50 changes: 14 additions & 36 deletions tools/rust_api/src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,10 +246,6 @@ pub(crate) mod ffi {
fn getValue(&self) -> f64;

fn value_get_string(value: &Value) -> String;
#[rust_name = "value_get_node_val"]
fn value_get_unique(value: &Value) -> UniquePtr<NodeVal>;
#[rust_name = "value_get_rel_val"]
fn value_get_unique(value: &Value) -> UniquePtr<RelVal>;
fn value_get_interval_secs(value: &Value) -> i64;
fn value_get_interval_micros(value: &Value) -> i32;
fn value_get_timestamp_micros(value: &Value) -> i64;
Expand Down Expand Up @@ -281,22 +277,21 @@ pub(crate) mod ffi {
fn create_value_date(value: i64) -> UniquePtr<Value>;
fn create_value_interval(months: i32, days: i32, micros: i64) -> UniquePtr<Value>;
fn create_value_internal_id(offset: u64, table: u64) -> UniquePtr<Value>;
fn create_value_node(
id_val: UniquePtr<Value>,
label_val: UniquePtr<Value>,
) -> UniquePtr<Value>;
fn create_value_rel(
src_id: UniquePtr<Value>,
dst_id: UniquePtr<Value>,
label_val: UniquePtr<Value>,
) -> UniquePtr<Value>;

fn value_add_property(value: Pin<&mut Value>, name: &String, property: UniquePtr<Value>);
fn node_value_get_properties(node_value: &Value) -> UniquePtr<NodeValuePropertyList>;
fn node_value_get_node_id(value: &Value) -> [u64; 2];
fn node_value_get_label_name(value: &Value) -> String;

fn rel_value_get_properties(node_value: &Value) -> UniquePtr<RelValuePropertyList>;
fn rel_value_get_label_name(value: &Value) -> String;

fn rel_value_get_src_id(value: &Value) -> [u64; 2];
fn rel_value_get_dst_id(value: &Value) -> [u64; 2];
}

#[namespace = "kuzu_rs"]
unsafe extern "C++" {
type PropertyList<'a>;
type NodeValuePropertyList<'a>;

fn size<'a>(&'a self) -> usize;
fn get_name<'a>(&'a self, index: usize) -> String;
Expand All @@ -305,27 +300,10 @@ pub(crate) mod ffi {

#[namespace = "kuzu_rs"]
unsafe extern "C++" {
#[namespace = "kuzu::common"]
type NodeVal;

#[rust_name = "node_value_get_properties"]
fn value_get_properties(node_value: &NodeVal) -> UniquePtr<PropertyList>;
fn node_value_get_node_id(value: &NodeVal) -> [u64; 2];
#[rust_name = "node_value_get_label_name"]
fn value_get_label_name(value: &NodeVal) -> String;
}
type RelValuePropertyList<'a>;

#[namespace = "kuzu_rs"]
unsafe extern "C++" {
#[namespace = "kuzu::common"]
type RelVal;

#[rust_name = "rel_value_get_properties"]
fn value_get_properties(node_value: &RelVal) -> UniquePtr<PropertyList>;
#[rust_name = "rel_value_get_label_name"]
fn value_get_label_name(value: &RelVal) -> String;

fn rel_value_get_src_id(value: &RelVal) -> [u64; 2];
fn rel_value_get_dst_id(value: &RelVal) -> [u64; 2];
fn size<'a>(&'a self) -> usize;
fn get_name<'a>(&'a self, index: usize) -> String;
fn get_value<'a>(&'a self, index: usize) -> &'a Value;
}
}
51 changes: 21 additions & 30 deletions tools/rust_api/src/kuzu_rs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,17 +142,32 @@ rust::Vec<rust::String> query_result_column_names(const kuzu::main::QueryResult&
return names;
}

std::array<uint64_t, 2> node_value_get_node_id(const kuzu::common::NodeVal& val) {
auto internalID = val.getNodeID();
std::unique_ptr<NodeValuePropertyList> node_value_get_properties(const Value& val) {
return std::make_unique<NodeValuePropertyList>(val);
}
std::unique_ptr<RelValuePropertyList> rel_value_get_properties(const Value& val) {
return std::make_unique<RelValuePropertyList>(val);
}

rust::String node_value_get_label_name(const kuzu::common::Value& val) {
return rust::String(kuzu::common::NodeVal::getLabelName(&val));
}

rust::String rel_value_get_label_name(const kuzu::common::Value& val) {
return rust::String(kuzu::common::RelVal::getLabelName(&val));
}

std::array<uint64_t, 2> node_value_get_node_id(const kuzu::common::Value& val) {
auto internalID = kuzu::common::NodeVal::getNodeID(&val);
return std::array{internalID.offset, internalID.tableID};
}

std::array<uint64_t, 2> rel_value_get_src_id(const kuzu::common::RelVal& val) {
auto internalID = val.getSrcNodeID();
std::array<uint64_t, 2> rel_value_get_src_id(const kuzu::common::Value& val) {
auto internalID = kuzu::common::RelVal::getSrcNodeID(&val);
return std::array{internalID.offset, internalID.tableID};
}
std::array<uint64_t, 2> rel_value_get_dst_id(const kuzu::common::RelVal& val) {
auto internalID = val.getDstNodeID();
std::array<uint64_t, 2> rel_value_get_dst_id(const kuzu::common::Value& val) {
auto internalID = kuzu::common::RelVal::getDstNodeID(&val);
return std::array{internalID.offset, internalID.tableID};
}

Expand Down Expand Up @@ -225,17 +240,6 @@ std::unique_ptr<kuzu::common::Value> create_value_null(
std::unique_ptr<kuzu::common::Value> create_value_internal_id(uint64_t offset, uint64_t table) {
return std::make_unique<kuzu::common::Value>(kuzu::common::internalID_t(offset, table));
}
std::unique_ptr<Value> create_value_node(
std::unique_ptr<Value> id_val, std::unique_ptr<Value> label_val) {
return std::make_unique<Value>(
std::make_unique<kuzu::common::NodeVal>(std::move(id_val), std::move(label_val)));
}

std::unique_ptr<kuzu::common::Value> create_value_rel(std::unique_ptr<kuzu::common::Value> src_id,
std::unique_ptr<kuzu::common::Value> dst_id, std::unique_ptr<kuzu::common::Value> label_val) {
return std::make_unique<Value>(std::make_unique<kuzu::common::RelVal>(
std::move(src_id), std::move(dst_id), std::move(label_val)));
}

std::unique_ptr<kuzu::common::Value> get_list_value(
std::unique_ptr<kuzu::common::LogicalType> typ, std::unique_ptr<ValueListBuilder> value) {
Expand All @@ -250,17 +254,4 @@ std::unique_ptr<TypeListBuilder> create_type_list() {
return std::make_unique<TypeListBuilder>();
}

void value_add_property(kuzu::common::Value& val, const rust::String& name,
std::unique_ptr<kuzu::common::Value> property) {
if (val.getDataType().getLogicalTypeID() == kuzu::common::LogicalTypeID::NODE) {
kuzu::common::NodeVal& nodeVal = val.getValueReference<kuzu::common::NodeVal>();
nodeVal.addProperty(std::string(name), std::move(property));
} else if (val.getDataType().getLogicalTypeID() == kuzu::common::LogicalTypeID::REL) {
kuzu::common::RelVal& relVal = val.getValueReference<kuzu::common::RelVal>();
relVal.addProperty(std::string(name), std::move(property));
} else {
throw std::runtime_error("Internal Error! Adding property to type without properties!");
}
}

} // namespace kuzu_rs
48 changes: 12 additions & 36 deletions tools/rust_api/src/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -381,25 +381,23 @@ impl TryFrom<&ffi::Value> for Value {
Ok(Value::Struct(result))
}
LogicalTypeID::NODE => {
let ffi_node_val = ffi::value_get_node_val(value);
let id = ffi::node_value_get_node_id(ffi_node_val.as_ref().unwrap());
let id = ffi::node_value_get_node_id(value);
let id = InternalID {
offset: id[0],
table_id: id[1],
};
let label = ffi::node_value_get_label_name(ffi_node_val.as_ref().unwrap());
let label = ffi::node_value_get_label_name(value);
let mut node_val = NodeVal::new(id, label);
let properties = ffi::node_value_get_properties(ffi_node_val.as_ref().unwrap());
let properties = ffi::node_value_get_properties(value);
for i in 0..properties.size() {
node_val
.add_property(properties.get_name(i), properties.get_value(i).try_into()?);
}
Ok(Value::Node(node_val))
}
LogicalTypeID::REL => {
let ffi_rel_val = ffi::value_get_rel_val(value);
let src_node = ffi::rel_value_get_src_id(ffi_rel_val.as_ref().unwrap());
let dst_node = ffi::rel_value_get_dst_id(ffi_rel_val.as_ref().unwrap());
let src_node = ffi::rel_value_get_src_id(value);
let dst_node = ffi::rel_value_get_dst_id(value);
let src_node = InternalID {
offset: src_node[0],
table_id: src_node[1],
Expand All @@ -408,9 +406,9 @@ impl TryFrom<&ffi::Value> for Value {
offset: dst_node[0],
table_id: dst_node[1],
};
let label = ffi::rel_value_get_label_name(ffi_rel_val.as_ref().unwrap());
let label = ffi::rel_value_get_label_name(value);
let mut rel_val = RelVal::new(src_node, dst_node, label);
let properties = ffi::rel_value_get_properties(ffi_rel_val.as_ref().unwrap());
let properties = ffi::rel_value_get_properties(value);
for i in 0..properties.size() {
rel_val
.add_property(properties.get_name(i), properties.get_value(i).try_into()?);
Expand Down Expand Up @@ -518,27 +516,8 @@ impl TryInto<cxx::UniquePtr<ffi::Value>> for Value {
Value::InternalID(value) => {
Ok(ffi::create_value_internal_id(value.offset, value.table_id))
}
Value::Node(value) => {
let mut node = ffi::create_value_node(
Value::InternalID(value.id).try_into()?,
Value::String(value.label).try_into()?,
);
for (name, property) in value.properties {
ffi::value_add_property(node.pin_mut(), &name, property.try_into()?);
}
Ok(node)
}
Value::Rel(value) => {
let mut rel = ffi::create_value_rel(
Value::InternalID(value.src_node).try_into()?,
Value::InternalID(value.dst_node).try_into()?,
Value::String(value.label).try_into()?,
);
for (name, property) in value.properties {
ffi::value_add_property(rel.pin_mut(), &name, property.try_into()?);
}
Ok(rel)
}
Value::Node(_) => Err(crate::Error::ReadOnlyType(LogicalType::Node)),
Value::Rel(_) => Err(crate::Error::ReadOnlyType(LogicalType::Rel)),
}
}
}
Expand Down Expand Up @@ -640,7 +619,7 @@ mod tests {
($($name:ident: $value:expr,)*) => {
$(
#[test]
/// Tests that the values are correctly converted into kuzu::common::Value and back
/// Tests that the values display the same via the rust API as via the C++ API
fn $name() -> Result<()> {
let rust_value: Value = $value;
let value: cxx::UniquePtr<ffi::Value> = rust_value.clone().try_into()?;
Expand Down Expand Up @@ -722,8 +701,6 @@ mod tests {
}),
convert_struct: Value::Struct(vec![("NAME".to_string(), "Alice".into()), ("AGE".to_string(), 25.into())]),
convert_internal_id: Value::InternalID(InternalID { table_id: 0, offset: 0 }),
convert_node: Value::Node(NodeVal::new(InternalID { table_id: 0, offset: 0 }, "Test Label".to_string())),
convert_rel: Value::Rel(RelVal::new(InternalID { table_id: 0, offset: 0 }, InternalID { table_id: 1, offset: 0 }, "Test Label".to_string())),
}

display_tests! {
Expand All @@ -733,7 +710,7 @@ mod tests {
display_int16: Value::Int16(1),
display_int32: Value::Int32(2),
display_int64: Value::Int64(3),
// Float, doble, interval and timestamp have display differences which we probably don't want to
// Float, double, interval and timestamp have display differences which we probably don't want to
// reconcile
display_date: Value::Date(date!(2023-06-13)),
display_string: Value::String("Hello World".to_string()),
Expand All @@ -743,8 +720,7 @@ mod tests {
}),
display_struct: Value::Struct(vec![("NAME".to_string(), "Alice".into()), ("AGE".to_string(), 25.into())]),
display_internal_id: Value::InternalID(InternalID { table_id: 0, offset: 0 }),
display_node: Value::Node(NodeVal::new(InternalID { table_id: 0, offset: 0 }, "Test Label".to_string())),
display_rel: Value::Rel(RelVal::new(InternalID { table_id: 0, offset: 0 }, InternalID { table_id: 1, offset: 0 }, "Test Label".to_string())),
// Node and Rel Cannot be easily created on the C++ side
}

database_tests! {
Expand Down
Loading