Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Map to rust API #2176

Merged
merged 1 commit into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions src/common/types/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -810,23 +810,30 @@ std::unique_ptr<LogicalType> LogicalTypeUtils::parseStructType(const std::string
LogicalTypeID::STRUCT, std::make_unique<StructTypeInfo>(parseStructTypeInfo(trimmedStr)));
}

std::unique_ptr<LogicalType> MapType::createMapType(
std::unique_ptr<LogicalType> keyType, std::unique_ptr<LogicalType> valueType) {
std::vector<std::unique_ptr<StructField>> structFields;
structFields.push_back(
std::make_unique<StructField>(InternalKeyword::MAP_KEY, std::move(keyType)));
structFields.push_back(
std::make_unique<StructField>(InternalKeyword::MAP_VALUE, std::move(valueType)));
auto childType = std::make_unique<LogicalType>(
LogicalTypeID::STRUCT, std::make_unique<StructTypeInfo>(std::move(structFields)));
return std::make_unique<LogicalType>(
LogicalTypeID::MAP, std::make_unique<VarListTypeInfo>(std::move(childType)));
}

std::unique_ptr<LogicalType> LogicalTypeUtils::parseMapType(const std::string& trimmedStr) {
auto leftBracketPos = trimmedStr.find('(');
auto rightBracketPos = trimmedStr.find_last_of(')');
if (leftBracketPos == std::string::npos || rightBracketPos == std::string::npos) {
throw Exception("Cannot parse struct type: " + trimmedStr);
throw Exception("Cannot parse map type: " + trimmedStr);
}
auto mapTypeStr = trimmedStr.substr(leftBracketPos + 1, rightBracketPos - leftBracketPos - 1);
auto keyValueTypes = StringUtils::split(mapTypeStr, ",");
std::vector<std::unique_ptr<StructField>> structFields;
structFields.emplace_back(std::make_unique<StructField>(InternalKeyword::MAP_KEY,
std::make_unique<LogicalType>(dataTypeFromString(keyValueTypes[0]))));
structFields.emplace_back(std::make_unique<StructField>(InternalKeyword::MAP_VALUE,
std::make_unique<LogicalType>(dataTypeFromString(keyValueTypes[1]))));
auto childType = std::make_unique<LogicalType>(
LogicalTypeID::STRUCT, std::make_unique<StructTypeInfo>(std::move(structFields)));
return std::make_unique<LogicalType>(
LogicalTypeID::MAP, std::make_unique<VarListTypeInfo>(std::move(childType)));
return MapType::createMapType(
std::make_unique<LogicalType>(dataTypeFromString(keyValueTypes[0])),
std::make_unique<LogicalType>(dataTypeFromString(keyValueTypes[1])));
}

std::unique_ptr<LogicalType> LogicalTypeUtils::parseUnionType(const std::string& trimmedStr) {
Expand Down
2 changes: 2 additions & 0 deletions src/include/common/types/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,8 @@ struct StructType {
};

struct MapType {
static std::unique_ptr<LogicalType> createMapType(
std::unique_ptr<LogicalType> keyType, std::unique_ptr<LogicalType> valueType);
static inline LogicalType* getKeyType(const LogicalType* type) {
assert(type->getLogicalTypeID() == LogicalTypeID::MAP);
return StructType::getFieldTypes(VarListType::getChildType(type))[0];
Expand Down
3 changes: 3 additions & 0 deletions tools/rust_api/include/kuzu_rs.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ std::unique_ptr<kuzu::common::LogicalType> create_logical_type_fixed_list(
std::unique_ptr<kuzu::common::LogicalType> childType, uint64_t numElements);
std::unique_ptr<kuzu::common::LogicalType> create_logical_type_struct(
const rust::Vec<rust::String>& fieldNames, std::unique_ptr<TypeListBuilder> fieldTypes);
std::unique_ptr<kuzu::common::LogicalType> create_logical_type_map(
std::unique_ptr<kuzu::common::LogicalType> keyType,
std::unique_ptr<kuzu::common::LogicalType> valueType);

const kuzu::common::LogicalType& logical_type_get_var_list_child_type(
const kuzu::common::LogicalType& logicalType);
Expand Down
4 changes: 4 additions & 0 deletions tools/rust_api/src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,10 @@ pub(crate) mod ffi {
field_names: &Vec<String>,
types: UniquePtr<TypeListBuilder>,
) -> UniquePtr<LogicalType>;
fn create_logical_type_map(
keyType: UniquePtr<LogicalType>,
valueType: UniquePtr<LogicalType>,
) -> UniquePtr<LogicalType>;

fn logical_type_get_var_list_child_type(value: &LogicalType) -> &LogicalType;
fn logical_type_get_fixed_list_child_type(value: &LogicalType) -> &LogicalType;
Expand Down
5 changes: 5 additions & 0 deletions tools/rust_api/src/kuzu_rs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ std::unique_ptr<kuzu::common::LogicalType> create_logical_type_struct(
LogicalTypeID::STRUCT, std::make_unique<kuzu::common::StructTypeInfo>(std::move(fields)));
}

std::unique_ptr<kuzu::common::LogicalType> create_logical_type_map(
std::unique_ptr<LogicalType> keyType, std::unique_ptr<LogicalType> valueType) {
return kuzu::common::MapType::createMapType(std::move(keyType), std::move(valueType));
}

const LogicalType& logical_type_get_var_list_child_type(const LogicalType& logicalType) {
return *kuzu::common::VarListType::getChildType(&logicalType);
}
Expand Down
29 changes: 29 additions & 0 deletions tools/rust_api/src/logical_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ pub enum LogicalType {
/// Correponds to [Value::Rel](crate::value::Value::Rel)
Rel,
RecursiveRel,
Map {
key_type: Box<LogicalType>,
value_type: Box<LogicalType>,
},
}

impl From<&ffi::Value> for LogicalType {
Expand Down Expand Up @@ -117,6 +121,26 @@ impl From<&ffi::LogicalType> for LogicalType {
LogicalTypeID::NODE => LogicalType::Node,
LogicalTypeID::REL => LogicalType::Rel,
LogicalTypeID::RECURSIVE_REL => LogicalType::RecursiveRel,
LogicalTypeID::MAP => {
let child_types = ffi::logical_type_get_var_list_child_type(logical_type);
let types = ffi::logical_type_get_struct_field_types(child_types);
let key_type = types
.as_ref()
.unwrap()
.get(0)
.expect(
"First element of map type list should be the key type, but list was empty",
)
.into();
let value_type = types.as_ref().unwrap()
.get(1)
.expect("Second element of map type list should be the value type, but list did not have two elements")
.into();
LogicalType::Map {
key_type: Box::new(key_type),
value_type: Box::new(value_type),
}
}
// Should be unreachable, as cxx will check that the LogicalTypeID enum matches the one
// on the C++ side.
x => panic!("Unsupported type {:?}", x),
Expand Down Expand Up @@ -165,6 +189,10 @@ impl From<&LogicalType> for cxx::UniquePtr<ffi::LogicalType> {
}
ffi::create_logical_type_struct(&names, builder)
}
LogicalType::Map {
key_type,
value_type,
} => ffi::create_logical_type_map(key_type.as_ref().into(), value_type.as_ref().into()),
}
}
}
Expand Down Expand Up @@ -198,6 +226,7 @@ impl LogicalType {
LogicalType::Node => LogicalTypeID::NODE,
LogicalType::Rel => LogicalTypeID::REL,
LogicalType::RecursiveRel => LogicalTypeID::RECURSIVE_REL,
LogicalType::Map { .. } => LogicalTypeID::MAP,
}
}
}
80 changes: 71 additions & 9 deletions tools/rust_api/src/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ pub enum Value {
/// Sequence of Rels which make up the RecursiveRel
rels: Vec<RelVal>,
},
Map((LogicalType, LogicalType), Vec<(Value, Value)>),
}

fn display_list<T: std::fmt::Display>(f: &mut fmt::Formatter<'_>, list: &Vec<T>) -> fmt::Result {
Expand Down Expand Up @@ -290,6 +291,16 @@ impl std::fmt::Display for Value {
}
write!(f, "}}")
}
Value::Map(_, x) => {
write!(f, "{{")?;
for (i, (name, value)) in x.iter().enumerate() {
write!(f, "{}={}", name, value)?;
if i != x.len() - 1 {
write!(f, ", ")?;
}
}
write!(f, "}}")
}
Value::Node(x) => write!(f, "{x}"),
Value::Rel(x) => write!(f, "{x}"),
Value::InternalID(x) => write!(f, "{x}"),
Expand Down Expand Up @@ -345,6 +356,10 @@ impl From<&Value> for LogicalType {
Value::Node(_) => LogicalType::Node,
Value::Rel(_) => LogicalType::Rel,
Value::RecursiveRel { .. } => LogicalType::RecursiveRel,
Value::Map((key_type, value_type), _) => LogicalType::Map {
key_type: Box::new(key_type.clone()),
value_type: Box::new(value_type.clone()),
},
}
}
}
Expand Down Expand Up @@ -433,6 +448,25 @@ impl TryFrom<&ffi::Value> for Value {
}
Ok(Value::Struct(result))
}
LogicalTypeID::MAP => {
let mut result = vec![];
for index in 0..ffi::value_get_children_size(value) {
let pair = ffi::value_get_child(value, index);
result.push((
ffi::value_get_child(pair, 0).try_into()?,
ffi::value_get_child(pair, 1).try_into()?,
));
}
if let LogicalType::Map {
key_type,
value_type,
} = value.into()
{
Ok(Value::Map((*key_type, *value_type), result))
} else {
unreachable!()
}
}
LogicalTypeID::NODE => {
let id = ffi::node_value_get_node_id(value);
let id = InternalID {
Expand Down Expand Up @@ -573,6 +607,30 @@ impl TryInto<cxx::UniquePtr<ffi::Value>> for Value {
builder,
))
}
Value::Map((key_type, value_type), values) => {
let mut builder = ffi::create_list();
let list_type = LogicalType::Struct {
fields: vec![
("KEY".to_string(), key_type.clone()),
("VALUE".to_string(), value_type.clone()),
],
};
for (key, value) in values {
let mut pair = ffi::create_list();
pair.pin_mut().insert(key.try_into()?);
pair.pin_mut().insert(value.try_into()?);
let pair_value = ffi::get_list_value((&list_type).into(), pair);
builder.pin_mut().insert(pair_value);
}
Ok(ffi::get_list_value(
(&LogicalType::Map {
key_type: Box::new(key_type),
value_type: Box::new(value_type),
})
.into(),
builder,
))
}
Value::FixedList(typ, value) => {
let mut builder = ffi::create_list();
let len = value.len();
Expand Down Expand Up @@ -811,6 +869,7 @@ mod tests {
convert_internal_id_type: LogicalType::InternalID,
convert_rel_type: LogicalType::Rel,
convert_recursive_rel_type: LogicalType::RecursiveRel,
convert_map_type: LogicalType::Map { key_type: Box::new(LogicalType::Interval), value_type: Box::new(LogicalType::Rel) },
}

value_tests! {
Expand Down Expand Up @@ -838,6 +897,7 @@ mod tests {
}),
convert_struct: Value::Struct(vec![("NAME".to_string(), "Alice".into()), ("AGE".to_string(), 25.into())]),
convert_internal_id: Value::InternalID(InternalID { table_id: 0, offset: 0 }),
convert_map: Value::Map((LogicalType::String, LogicalType::Int64), vec![(Value::String("key".to_string()), Value::Int64(24))]),
}

display_tests! {
Expand Down Expand Up @@ -865,21 +925,23 @@ mod tests {
display_struct: Value::Struct(vec![("NAME".to_string(), "Alice".into()), ("AGE".to_string(), 25.into())]),
display_internal_id: Value::InternalID(InternalID { table_id: 0, offset: 0 }),
// Node and Rel Cannot be easily created on the C++ side
display_map: Value::Map((LogicalType::String, LogicalType::Int64), vec![(Value::String("key".to_string()), Value::Int64(24))]),
}

database_tests! {
// Passing these values as arguments is not yet implemented in kuzu:
// db_struct:
// Value::Struct(vec![("item".to_string(), "Knife".into()), ("count".to_string(), 1.into())]),
// "STRUCT(item STRING, count INT32)",
// db_fixed_list: Value::FixedList(LogicalType::String, vec!["Alice".into(), "Bob".into()]), "STRING[2]",
// db_null_string: Value::Null(LogicalType::String), "STRING",
// db_null_int: Value::Null(LogicalType::Int64), "INT64",
// db_null_list: Value::Null(LogicalType::VarList {
// child_type: Box::new(LogicalType::FixedList { child_type: Box::new(LogicalType::Int16), num_elements: 3 })
// }), "INT16[3][]",
// db_var_list_string: Value::VarList(LogicalType::String, vec!["Alice".into(), "Bob".into()]), "STRING[]",
// db_var_list_int: Value::VarList(LogicalType::Int64, vec![0i64.into(), 1i64.into(), 2i64.into()]), "INT64[]",
// db_map: Value::Map((LogicalType::String, LogicalType::Int64), vec![(Value::String("key".to_string()), Value::Int64(24))]), "MAP(STRING,INT64)",
// db_fixed_list: Value::FixedList(LogicalType::Int64, vec![1i64.into(), 2i64.into(), 3i64.into()]), "INT64[3]",
db_struct:
Value::Struct(vec![("item".to_string(), "Knife".into()), ("count".to_string(), 1.into())]),
"STRUCT(item STRING, count INT32)",
db_null_string: Value::Null(LogicalType::String), "STRING",
db_null_int: Value::Null(LogicalType::Int64), "INT64",
db_null_list: Value::Null(LogicalType::VarList {
child_type: Box::new(LogicalType::FixedList { child_type: Box::new(LogicalType::Int16), num_elements: 3 })
}), "INT16[3][]",
db_int8: Value::Int8(0), "INT8",
db_int16: Value::Int16(1), "INT16",
db_int32: Value::Int32(2), "INT32",
Expand Down
Loading