Skip to content

Commit

Permalink
Add support for Map to rust API
Browse files Browse the repository at this point in the history
Also re-enabled some database tests which are now supported.
  • Loading branch information
benjaminwinger committed Oct 11, 2023
1 parent 3cf3539 commit 5f2615b
Show file tree
Hide file tree
Showing 7 changed files with 131 additions and 19 deletions.
27 changes: 17 additions & 10 deletions src/common/types/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -810,23 +810,30 @@ std::unique_ptr<LogicalType> LogicalTypeUtils::parseStructType(const std::string
LogicalTypeID::STRUCT, std::make_unique<StructTypeInfo>(parseStructTypeInfo(trimmedStr)));
}

std::unique_ptr<LogicalType> MapType::createMapType(
std::unique_ptr<LogicalType> keyType, std::unique_ptr<LogicalType> valueType) {
std::vector<std::unique_ptr<StructField>> structFields;
structFields.push_back(
std::make_unique<StructField>(InternalKeyword::MAP_KEY, std::move(keyType)));
structFields.push_back(
std::make_unique<StructField>(InternalKeyword::MAP_VALUE, std::move(valueType)));
auto childType = std::make_unique<LogicalType>(
LogicalTypeID::STRUCT, std::make_unique<StructTypeInfo>(std::move(structFields)));
return std::make_unique<LogicalType>(
LogicalTypeID::MAP, std::make_unique<VarListTypeInfo>(std::move(childType)));
}

std::unique_ptr<LogicalType> LogicalTypeUtils::parseMapType(const std::string& trimmedStr) {
auto leftBracketPos = trimmedStr.find('(');
auto rightBracketPos = trimmedStr.find_last_of(')');
if (leftBracketPos == std::string::npos || rightBracketPos == std::string::npos) {
throw Exception("Cannot parse struct type: " + trimmedStr);
throw Exception("Cannot parse map type: " + trimmedStr);
}
auto mapTypeStr = trimmedStr.substr(leftBracketPos + 1, rightBracketPos - leftBracketPos - 1);
auto keyValueTypes = StringUtils::split(mapTypeStr, ",");
std::vector<std::unique_ptr<StructField>> structFields;
structFields.emplace_back(std::make_unique<StructField>(InternalKeyword::MAP_KEY,
std::make_unique<LogicalType>(dataTypeFromString(keyValueTypes[0]))));
structFields.emplace_back(std::make_unique<StructField>(InternalKeyword::MAP_VALUE,
std::make_unique<LogicalType>(dataTypeFromString(keyValueTypes[1]))));
auto childType = std::make_unique<LogicalType>(
LogicalTypeID::STRUCT, std::make_unique<StructTypeInfo>(std::move(structFields)));
return std::make_unique<LogicalType>(
LogicalTypeID::MAP, std::make_unique<VarListTypeInfo>(std::move(childType)));
return MapType::createMapType(
std::make_unique<LogicalType>(dataTypeFromString(keyValueTypes[0])),
std::make_unique<LogicalType>(dataTypeFromString(keyValueTypes[1])));
}

std::unique_ptr<LogicalType> LogicalTypeUtils::parseUnionType(const std::string& trimmedStr) {
Expand Down
2 changes: 2 additions & 0 deletions src/include/common/types/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,8 @@ struct StructType {
};

struct MapType {
static std::unique_ptr<LogicalType> createMapType(
std::unique_ptr<LogicalType> keyType, std::unique_ptr<LogicalType> valueType);
static inline LogicalType* getKeyType(const LogicalType* type) {
assert(type->getLogicalTypeID() == LogicalTypeID::MAP);
return StructType::getFieldTypes(VarListType::getChildType(type))[0];
Expand Down
3 changes: 3 additions & 0 deletions tools/rust_api/include/kuzu_rs.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ std::unique_ptr<kuzu::common::LogicalType> create_logical_type_fixed_list(
std::unique_ptr<kuzu::common::LogicalType> childType, uint64_t numElements);
std::unique_ptr<kuzu::common::LogicalType> create_logical_type_struct(
const rust::Vec<rust::String>& fieldNames, std::unique_ptr<TypeListBuilder> fieldTypes);
std::unique_ptr<kuzu::common::LogicalType> create_logical_type_map(
std::unique_ptr<kuzu::common::LogicalType> keyType,
std::unique_ptr<kuzu::common::LogicalType> valueType);

const kuzu::common::LogicalType& logical_type_get_var_list_child_type(
const kuzu::common::LogicalType& logicalType);
Expand Down
4 changes: 4 additions & 0 deletions tools/rust_api/src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,10 @@ pub(crate) mod ffi {
field_names: &Vec<String>,
types: UniquePtr<TypeListBuilder>,
) -> UniquePtr<LogicalType>;
fn create_logical_type_map(
keyType: UniquePtr<LogicalType>,
valueType: UniquePtr<LogicalType>,
) -> UniquePtr<LogicalType>;

fn logical_type_get_var_list_child_type(value: &LogicalType) -> &LogicalType;
fn logical_type_get_fixed_list_child_type(value: &LogicalType) -> &LogicalType;
Expand Down
5 changes: 5 additions & 0 deletions tools/rust_api/src/kuzu_rs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ std::unique_ptr<kuzu::common::LogicalType> create_logical_type_struct(
LogicalTypeID::STRUCT, std::make_unique<kuzu::common::StructTypeInfo>(std::move(fields)));
}

std::unique_ptr<kuzu::common::LogicalType> create_logical_type_map(
std::unique_ptr<LogicalType> keyType, std::unique_ptr<LogicalType> valueType) {
return kuzu::common::MapType::createMapType(std::move(keyType), std::move(valueType));
}

const LogicalType& logical_type_get_var_list_child_type(const LogicalType& logicalType) {
return *kuzu::common::VarListType::getChildType(&logicalType);
}
Expand Down
29 changes: 29 additions & 0 deletions tools/rust_api/src/logical_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ pub enum LogicalType {
/// Correponds to [Value::Rel](crate::value::Value::Rel)
Rel,
RecursiveRel,
Map {
key_type: Box<LogicalType>,
value_type: Box<LogicalType>,
},
}

impl From<&ffi::Value> for LogicalType {
Expand Down Expand Up @@ -117,6 +121,26 @@ impl From<&ffi::LogicalType> for LogicalType {
LogicalTypeID::NODE => LogicalType::Node,
LogicalTypeID::REL => LogicalType::Rel,
LogicalTypeID::RECURSIVE_REL => LogicalType::RecursiveRel,
LogicalTypeID::MAP => {
let child_types = ffi::logical_type_get_var_list_child_type(logical_type);
let types = ffi::logical_type_get_struct_field_types(child_types);
let key_type = types
.as_ref()
.unwrap()
.get(0)
.expect(
"First element of map type list should be the key type, but list was empty",
)
.into();
let value_type = types.as_ref().unwrap()
.get(1)
.expect("Second element of map type list should be the value type, but list did not have two elements")
.into();
LogicalType::Map {
key_type: Box::new(key_type),
value_type: Box::new(value_type),
}
}
// Should be unreachable, as cxx will check that the LogicalTypeID enum matches the one
// on the C++ side.
x => panic!("Unsupported type {:?}", x),
Expand Down Expand Up @@ -165,6 +189,10 @@ impl From<&LogicalType> for cxx::UniquePtr<ffi::LogicalType> {
}
ffi::create_logical_type_struct(&names, builder)
}
LogicalType::Map {
key_type,
value_type,
} => ffi::create_logical_type_map(key_type.as_ref().into(), value_type.as_ref().into()),
}
}
}
Expand Down Expand Up @@ -198,6 +226,7 @@ impl LogicalType {
LogicalType::Node => LogicalTypeID::NODE,
LogicalType::Rel => LogicalTypeID::REL,
LogicalType::RecursiveRel => LogicalTypeID::RECURSIVE_REL,
LogicalType::Map { .. } => LogicalTypeID::MAP,
}
}
}
80 changes: 71 additions & 9 deletions tools/rust_api/src/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ pub enum Value {
/// Sequence of Rels which make up the RecursiveRel
rels: Vec<RelVal>,
},
Map((LogicalType, LogicalType), Vec<(Value, Value)>),
}

fn display_list<T: std::fmt::Display>(f: &mut fmt::Formatter<'_>, list: &Vec<T>) -> fmt::Result {
Expand Down Expand Up @@ -290,6 +291,16 @@ impl std::fmt::Display for Value {
}
write!(f, "}}")
}
Value::Map(_, x) => {
write!(f, "{{")?;
for (i, (name, value)) in x.iter().enumerate() {
write!(f, "{}={}", name, value)?;
if i != x.len() - 1 {
write!(f, ", ")?;
}
}
write!(f, "}}")
}
Value::Node(x) => write!(f, "{x}"),
Value::Rel(x) => write!(f, "{x}"),
Value::InternalID(x) => write!(f, "{x}"),
Expand Down Expand Up @@ -345,6 +356,10 @@ impl From<&Value> for LogicalType {
Value::Node(_) => LogicalType::Node,
Value::Rel(_) => LogicalType::Rel,
Value::RecursiveRel { .. } => LogicalType::RecursiveRel,
Value::Map((key_type, value_type), _) => LogicalType::Map {
key_type: Box::new(key_type.clone()),
value_type: Box::new(value_type.clone()),
},
}
}
}
Expand Down Expand Up @@ -433,6 +448,25 @@ impl TryFrom<&ffi::Value> for Value {
}
Ok(Value::Struct(result))
}
LogicalTypeID::MAP => {
let mut result = vec![];
for index in 0..ffi::value_get_children_size(value) {
let pair = ffi::value_get_child(value, index);
result.push((
ffi::value_get_child(pair, 0).try_into()?,
ffi::value_get_child(pair, 1).try_into()?,
));
}
if let LogicalType::Map {
key_type,
value_type,
} = value.into()
{
Ok(Value::Map((*key_type, *value_type), result))
} else {
unreachable!()
}
}
LogicalTypeID::NODE => {
let id = ffi::node_value_get_node_id(value);
let id = InternalID {
Expand Down Expand Up @@ -573,6 +607,30 @@ impl TryInto<cxx::UniquePtr<ffi::Value>> for Value {
builder,
))
}
Value::Map((key_type, value_type), values) => {
let mut builder = ffi::create_list();
let list_type = LogicalType::Struct {
fields: vec![
("KEY".to_string(), key_type.clone()),
("VALUE".to_string(), value_type.clone()),
],
};
for (key, value) in values {
let mut pair = ffi::create_list();
pair.pin_mut().insert(key.try_into()?);
pair.pin_mut().insert(value.try_into()?);
let pair_value = ffi::get_list_value((&list_type).into(), pair);
builder.pin_mut().insert(pair_value);
}
Ok(ffi::get_list_value(
(&LogicalType::Map {
key_type: Box::new(key_type),
value_type: Box::new(value_type),
})
.into(),
builder,
))
}
Value::FixedList(typ, value) => {
let mut builder = ffi::create_list();
let len = value.len();
Expand Down Expand Up @@ -811,6 +869,7 @@ mod tests {
convert_internal_id_type: LogicalType::InternalID,
convert_rel_type: LogicalType::Rel,
convert_recursive_rel_type: LogicalType::RecursiveRel,
convert_map_type: LogicalType::Map { key_type: Box::new(LogicalType::Interval), value_type: Box::new(LogicalType::Rel) },
}

value_tests! {
Expand Down Expand Up @@ -838,6 +897,7 @@ mod tests {
}),
convert_struct: Value::Struct(vec![("NAME".to_string(), "Alice".into()), ("AGE".to_string(), 25.into())]),
convert_internal_id: Value::InternalID(InternalID { table_id: 0, offset: 0 }),
convert_map: Value::Map((LogicalType::String, LogicalType::Int64), vec![(Value::String("key".to_string()), Value::Int64(24))]),
}

display_tests! {
Expand Down Expand Up @@ -865,21 +925,23 @@ mod tests {
display_struct: Value::Struct(vec![("NAME".to_string(), "Alice".into()), ("AGE".to_string(), 25.into())]),
display_internal_id: Value::InternalID(InternalID { table_id: 0, offset: 0 }),
// Node and Rel Cannot be easily created on the C++ side
display_map: Value::Map((LogicalType::String, LogicalType::Int64), vec![(Value::String("key".to_string()), Value::Int64(24))]),
}

database_tests! {
// Passing these values as arguments is not yet implemented in kuzu:
// db_struct:
// Value::Struct(vec![("item".to_string(), "Knife".into()), ("count".to_string(), 1.into())]),
// "STRUCT(item STRING, count INT32)",
// db_fixed_list: Value::FixedList(LogicalType::String, vec!["Alice".into(), "Bob".into()]), "STRING[2]",
// db_null_string: Value::Null(LogicalType::String), "STRING",
// db_null_int: Value::Null(LogicalType::Int64), "INT64",
// db_null_list: Value::Null(LogicalType::VarList {
// child_type: Box::new(LogicalType::FixedList { child_type: Box::new(LogicalType::Int16), num_elements: 3 })
// }), "INT16[3][]",
// db_var_list_string: Value::VarList(LogicalType::String, vec!["Alice".into(), "Bob".into()]), "STRING[]",
// db_var_list_int: Value::VarList(LogicalType::Int64, vec![0i64.into(), 1i64.into(), 2i64.into()]), "INT64[]",
// db_map: Value::Map((LogicalType::String, LogicalType::Int64), vec![(Value::String("key".to_string()), Value::Int64(24))]), "MAP(STRING,INT64)",
// db_fixed_list: Value::FixedList(LogicalType::Int64, vec![1i64.into(), 2i64.into(), 3i64.into()]), "INT64[3]",
db_struct:
Value::Struct(vec![("item".to_string(), "Knife".into()), ("count".to_string(), 1.into())]),
"STRUCT(item STRING, count INT32)",
db_null_string: Value::Null(LogicalType::String), "STRING",
db_null_int: Value::Null(LogicalType::Int64), "INT64",
db_null_list: Value::Null(LogicalType::VarList {
child_type: Box::new(LogicalType::FixedList { child_type: Box::new(LogicalType::Int16), num_elements: 3 })
}), "INT16[3][]",
db_int8: Value::Int8(0), "INT8",
db_int16: Value::Int16(1), "INT16",
db_int32: Value::Int32(2), "INT32",
Expand Down

0 comments on commit 5f2615b

Please sign in to comment.