Skip to content

Commit

Permalink
Rust Union support
Browse files Browse the repository at this point in the history
  • Loading branch information
benjaminwinger committed Oct 12, 2023
1 parent a5e8e62 commit 4b06c5e
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 13 deletions.
2 changes: 1 addition & 1 deletion src/storage/store/struct_node_column.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ void StructNodeColumn::writeInternal(

void StructNodeColumn::append(ColumnChunk* columnChunk, uint64_t nodeGroupIdx) {
NodeColumn::append(columnChunk, nodeGroupIdx);
assert(columnChunk->getDataType().getLogicalTypeID() == LogicalTypeID::STRUCT);
assert(columnChunk->getDataType().getPhysicalType() == PhysicalTypeID::STRUCT);
auto structColumnChunk = static_cast<StructColumnChunk*>(columnChunk);
for (auto i = 0u; i < childColumns.size(); i++) {
childColumns[i]->append(structColumnChunk->getChild(i), nodeGroupIdx);
Expand Down
3 changes: 2 additions & 1 deletion tools/rust_api/include/kuzu_rs.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ std::unique_ptr<kuzu::common::LogicalType> create_logical_type_var_list(
std::unique_ptr<kuzu::common::LogicalType> create_logical_type_fixed_list(
std::unique_ptr<kuzu::common::LogicalType> childType, uint64_t numElements);
std::unique_ptr<kuzu::common::LogicalType> create_logical_type_struct(
const rust::Vec<rust::String>& fieldNames, std::unique_ptr<TypeListBuilder> fieldTypes);
kuzu::common::LogicalTypeID typeID, const rust::Vec<rust::String>& fieldNames,
std::unique_ptr<TypeListBuilder> fieldTypes);
std::unique_ptr<kuzu::common::LogicalType> create_logical_type_map(
std::unique_ptr<kuzu::common::LogicalType> keyType,
std::unique_ptr<kuzu::common::LogicalType> valueType);
Expand Down
1 change: 1 addition & 0 deletions tools/rust_api/src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ pub(crate) mod ffi {
num_elements: u64,
) -> UniquePtr<LogicalType>;
fn create_logical_type_struct(
type_id: LogicalTypeID,
field_names: &Vec<String>,
types: UniquePtr<TypeListBuilder>,
) -> UniquePtr<LogicalType>;
Expand Down
4 changes: 2 additions & 2 deletions tools/rust_api/src/kuzu_rs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,15 @@ std::unique_ptr<LogicalType> create_logical_type_fixed_list(
std::make_unique<FixedListTypeInfo>(std::move(childType), numElements));
}

std::unique_ptr<kuzu::common::LogicalType> create_logical_type_struct(
std::unique_ptr<kuzu::common::LogicalType> create_logical_type_struct(LogicalTypeID typeID,
const rust::Vec<rust::String>& fieldNames, std::unique_ptr<TypeListBuilder> fieldTypes) {
std::vector<std::unique_ptr<StructField>> fields;
for (auto i = 0u; i < fieldNames.size(); i++) {
fields.push_back(std::make_unique<StructField>(
std::string(fieldNames[i]), std::move(fieldTypes->types[i])));
}
return std::make_unique<LogicalType>(
LogicalTypeID::STRUCT, std::make_unique<kuzu::common::StructTypeInfo>(std::move(fields)));
typeID, std::make_unique<kuzu::common::StructTypeInfo>(std::move(fields)));
}

std::unique_ptr<kuzu::common::LogicalType> create_logical_type_map(
Expand Down
30 changes: 29 additions & 1 deletion tools/rust_api/src/logical_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,14 @@ pub enum LogicalType {
/// Correponds to [Value::Rel](crate::value::Value::Rel)
Rel,
RecursiveRel,
/// Correponds to [Value::Map](crate::value::Value::Map)
Map {
key_type: Box<LogicalType>,
value_type: Box<LogicalType>,
},
Union {
types: Vec<(String, LogicalType)>,
},
}

impl From<&ffi::Value> for LogicalType {
Expand Down Expand Up @@ -141,6 +145,18 @@ impl From<&ffi::LogicalType> for LogicalType {
value_type: Box::new(value_type),
}
}
LogicalTypeID::UNION => {
let names = ffi::logical_type_get_struct_field_names(logical_type);
let types = ffi::logical_type_get_struct_field_types(logical_type);
LogicalType::Union {
types: names
.into_iter()
// Skip the tag field
.skip(1)
.zip(types.into_iter().skip(1).map(Into::<LogicalType>::into))
.collect(),
}
}
// Should be unreachable, as cxx will check that the LogicalTypeID enum matches the one
// on the C++ side.
x => panic!("Unsupported type {:?}", x),
Expand Down Expand Up @@ -187,7 +203,18 @@ impl From<&LogicalType> for cxx::UniquePtr<ffi::LogicalType> {
names.push(name.clone());
builder.pin_mut().insert(typ.into());
}
ffi::create_logical_type_struct(&names, builder)
ffi::create_logical_type_struct(ffi::LogicalTypeID::STRUCT, &names, builder)
}
LogicalType::Union { types } => {
let mut builder = ffi::create_type_list();
let mut names = vec![];
names.push("tag".to_string());
builder.pin_mut().insert((&LogicalType::Int64).into());
for (name, typ) in types {
names.push(name.clone());
builder.pin_mut().insert(typ.into());
}
ffi::create_logical_type_struct(ffi::LogicalTypeID::UNION, &names, builder)
}
LogicalType::Map {
key_type,
Expand Down Expand Up @@ -227,6 +254,7 @@ impl LogicalType {
LogicalType::Rel => LogicalTypeID::REL,
LogicalType::RecursiveRel => LogicalTypeID::RECURSIVE_REL,
LogicalType::Map { .. } => LogicalTypeID::MAP,
LogicalType::Union { .. } => LogicalTypeID::UNION,
}
}
}
109 changes: 101 additions & 8 deletions tools/rust_api/src/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ impl From<(u64, u64)> for InternalID {

/// Data types supported by Kùzu
///
/// Also see <https://kuzudb.com/docs/cypher/data-types/overview.html>
/// Also see <https://kuzudb.com/docusaurus/cypher/data-types/overview.html>
#[derive(Clone, Debug, PartialEq)]
pub enum Value {
Null(LogicalType),
Expand All @@ -206,33 +206,33 @@ pub enum Value {
/// Stored internally as the number of days since 1970-01-01 as a 32-bit signed integer, which
/// allows for a wider range of dates to be stored than can be represented by time::Date
///
/// <https://kuzudb.com/docs/cypher/data-types/date.html>
/// <https://kuzudb.com/docusaurus/cypher/data-types/date.html>
Date(time::Date),
/// May be signed or unsigned.
///
/// Nanosecond precision of time::Duration (if available) will not be preserved when passed to
/// queries, and results will always have at most microsecond precision.
///
/// <https://kuzudb.com/docs/cypher/data-types/interval.html>
/// <https://kuzudb.com/docusaurus/cypher/data-types/interval.html>
Interval(time::Duration),
/// Stored internally as the number of microseconds since 1970-01-01
/// Nanosecond precision of SystemTime (if available) will not be preserved when used.
///
/// <https://kuzudb.com/docs/cypher/data-types/timestamp.html>
/// <https://kuzudb.com/docusaurus/cypher/data-types/timestamp.html>
Timestamp(time::OffsetDateTime),
InternalID(InternalID),
/// <https://kuzudb.com/docs/cypher/data-types/string.html>
/// <https://kuzudb.com/docusaurus/cypher/data-types/string.html>
String(String),
Blob(Vec<u8>),
// TODO: Enforce type of contents
// LogicalType is necessary so that we can pass the correct type to the C++ API if the list is empty.
/// These must contain elements which are all the given type.
/// <https://kuzudb.com/docs/cypher/data-types/list.html>
/// <https://kuzudb.com/docusaurus/cypher/data-types/list.html>
VarList(LogicalType, Vec<Value>),
/// These must contain elements which are all the same type.
/// <https://kuzudb.com/docs/cypher/data-types/list.html>
/// <https://kuzudb.com/docusaurus/cypher/data-types/list.html>
FixedList(LogicalType, Vec<Value>),
/// <https://kuzudb.com/docs/cypher/data-types/struct.html>
/// <https://kuzudb.com/docusaurus/cypher/data-types/struct.html>
Struct(Vec<(String, Value)>),
Node(NodeVal),
Rel(RelVal),
Expand All @@ -244,7 +244,13 @@ pub enum Value {
/// Sequence of Rels which make up the RecursiveRel
rels: Vec<RelVal>,
},
/// <https://kuzudb.com/docusaurus/cypher/data-types/map>
Map((LogicalType, LogicalType), Vec<(Value, Value)>),
/// <https://kuzudb.com/docusaurus/cypher/data-types/union>
Union {
types: Vec<(String, LogicalType)>,
value: Box<Value>,
},
}

fn display_list<T: std::fmt::Display>(f: &mut fmt::Formatter<'_>, list: &Vec<T>) -> fmt::Result {
Expand Down Expand Up @@ -312,6 +318,7 @@ impl std::fmt::Display for Value {
display_list(f, rels)?;
write!(f, "}}")
}
Value::Union { types: _, value } => write!(f, "{value}"),
}
}
}
Expand Down Expand Up @@ -360,6 +367,9 @@ impl From<&Value> for LogicalType {
key_type: Box::new(key_type.clone()),
value_type: Box::new(value_type.clone()),
},
Value::Union { types, value: _ } => LogicalType::Union {
types: types.clone(),
},
}
}
}
Expand Down Expand Up @@ -542,6 +552,19 @@ impl TryFrom<&ffi::Value> for Value {
rels: rels.collect(),
})
}
LogicalTypeID::UNION => {
let types =
if let LogicalType::Union { types } = ffi::value_get_data_type(value).into() {
types
} else {
unreachable!()
};
let value: Value = ffi::value_get_child(value, 0).try_into()?;
Ok(Value::Union {
types,
value: Box::new(value),
})
}
// TODO(bmwinger): Better error message for types which are unsupported
x => panic!("Unsupported type {:?}", x),
}
Expand Down Expand Up @@ -673,6 +696,15 @@ impl TryInto<cxx::UniquePtr<ffi::Value>> for Value {
Value::RecursiveRel { .. } => {
Err(crate::Error::ReadOnlyType(LogicalType::RecursiveRel))
}
Value::Union { types, value } => {
let mut builder = ffi::create_list();
builder.pin_mut().insert((*value).try_into()?);

Ok(ffi::get_list_value(
(&LogicalType::Union { types }).into(),
builder,
))
}
}
}
}
Expand Down Expand Up @@ -870,6 +902,7 @@ mod tests {
convert_rel_type: LogicalType::Rel,
convert_recursive_rel_type: LogicalType::RecursiveRel,
convert_map_type: LogicalType::Map { key_type: Box::new(LogicalType::Interval), value_type: Box::new(LogicalType::Rel) },
convert_union_type: LogicalType::Union { types: vec![("Num".to_string(), LogicalType::Int8), ("duration".to_string(), LogicalType::Interval), ("string".to_string(), LogicalType::String)] },
}

value_tests! {
Expand Down Expand Up @@ -898,6 +931,10 @@ mod tests {
convert_struct: Value::Struct(vec![("NAME".to_string(), "Alice".into()), ("AGE".to_string(), 25.into())]),
convert_internal_id: Value::InternalID(InternalID { table_id: 0, offset: 0 }),
convert_map: Value::Map((LogicalType::String, LogicalType::Int64), vec![(Value::String("key".to_string()), Value::Int64(24))]),
convert_union: Value::Union {
types: vec![("Num".to_string(), LogicalType::Int8), ("duration".to_string(), LogicalType::Interval)],
value: Box::new(Value::Int8(-127))
},
}

display_tests! {
Expand Down Expand Up @@ -926,6 +963,10 @@ mod tests {
display_internal_id: Value::InternalID(InternalID { table_id: 0, offset: 0 }),
// Node and Rel Cannot be easily created on the C++ side
display_map: Value::Map((LogicalType::String, LogicalType::Int64), vec![(Value::String("key".to_string()), Value::Int64(24))]),
display_union: Value::Union {
types: vec![("Num".to_string(), LogicalType::Int8), ("duration".to_string(), LogicalType::Interval)],
value: Box::new(Value::Int8(-127))
},
}

database_tests! {
Expand All @@ -934,6 +975,10 @@ mod tests {
// db_var_list_int: Value::VarList(LogicalType::Int64, vec![0i64.into(), 1i64.into(), 2i64.into()]), "INT64[]",
// db_map: Value::Map((LogicalType::String, LogicalType::Int64), vec![(Value::String("key".to_string()), Value::Int64(24))]), "MAP(STRING,INT64)",
// db_fixed_list: Value::FixedList(LogicalType::Int64, vec![1i64.into(), 2i64.into(), 3i64.into()]), "INT64[3]",
// db_union: Value::Union {
// types: vec![("Num".to_string(), LogicalType::Int8), ("duration".to_string(), LogicalType::Interval)],
// value: Box::new(Value::Int8(-127))
// }, "UNION(Num INT8, duration INTERVAL)",
db_struct:
Value::Struct(vec![("item".to_string(), "Knife".into()), ("count".to_string(), 1.into())]),
"STRUCT(item STRING, count INT32)",
Expand Down Expand Up @@ -1138,4 +1183,52 @@ mod tests {
temp_dir.close()?;
Ok(())
}

#[test]
/// Tests that passing the values through the database returns what we put in
fn test_union() -> Result<()> {
use std::fs::File;
use std::io::Write;
let temp_dir = tempfile::tempdir()?;
let db = Database::new(temp_dir.path(), SystemConfig::default())?;
let conn = Connection::new(&db)?;
conn.query(
"CREATE NODE TABLE demo(a SERIAL, b UNION(num INT64, str STRING), PRIMARY KEY(a));",
)?;
let mut file = File::create(temp_dir.path().join("demo.csv"))?;
file.write_all(b"1\naa\n")?;
conn.query(&format!(
"COPY demo from '{}/demo.csv';",
// Use forward-slashes instead of backslashes on windows, as thmay not be supported by
// the query parser
temp_dir.path().display().replace("\\", "/")
))?;
let result = conn.query("MATCH (d:demo) RETURN d.b;")?;
let types = vec![
("num".to_string(), LogicalType::Int64),
("str".to_string(), LogicalType::String),
];
assert_eq!(
result.get_column_data_types(),
vec![LogicalType::Union {
types: types.clone()
}],
);
let results: Vec<Value> = result.map(|mut x| x.pop().unwrap()).collect();
assert_eq!(
results,
vec![
Value::Union {
types: types.clone(),
value: Box::new(Value::Int64(1))
},
Value::Union {
types: types.clone(),
value: Box::new(Value::String("aa".to_string()))
},
]
);
temp_dir.close()?;
Ok(())
}
}

0 comments on commit 4b06c5e

Please sign in to comment.