Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rust Union support #2193

Merged
merged 1 commit into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/storage/store/struct_node_column.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@

void StructNodeColumn::append(ColumnChunk* columnChunk, uint64_t nodeGroupIdx) {
NodeColumn::append(columnChunk, nodeGroupIdx);
assert(columnChunk->getDataType().getLogicalTypeID() == LogicalTypeID::STRUCT);
assert(columnChunk->getDataType().getPhysicalType() == PhysicalTypeID::STRUCT);
auto structColumnChunk = static_cast<StructColumnChunk*>(columnChunk);
for (auto i = 0u; i < childColumns.size(); i++) {
childColumns[i]->append(structColumnChunk->getChild(i), nodeGroupIdx);
Expand All @@ -74,12 +74,12 @@
}
}

void StructNodeColumn::rollbackInMemory() {
NodeColumn::rollbackInMemory();
for (const auto& childColumn : childColumns) {
childColumn->rollbackInMemory();

Check warning on line 80 in src/storage/store/struct_node_column.cpp

View check run for this annotation

Codecov / codecov/patch

src/storage/store/struct_node_column.cpp#L77-L80

Added lines #L77 - L80 were not covered by tests
}
}

Check warning on line 82 in src/storage/store/struct_node_column.cpp

View check run for this annotation

Codecov / codecov/patch

src/storage/store/struct_node_column.cpp#L82

Added line #L82 was not covered by tests

} // namespace storage
} // namespace kuzu
3 changes: 2 additions & 1 deletion tools/rust_api/include/kuzu_rs.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ std::unique_ptr<kuzu::common::LogicalType> create_logical_type_var_list(
std::unique_ptr<kuzu::common::LogicalType> create_logical_type_fixed_list(
std::unique_ptr<kuzu::common::LogicalType> childType, uint64_t numElements);
std::unique_ptr<kuzu::common::LogicalType> create_logical_type_struct(
const rust::Vec<rust::String>& fieldNames, std::unique_ptr<TypeListBuilder> fieldTypes);
kuzu::common::LogicalTypeID typeID, const rust::Vec<rust::String>& fieldNames,
std::unique_ptr<TypeListBuilder> fieldTypes);
std::unique_ptr<kuzu::common::LogicalType> create_logical_type_map(
std::unique_ptr<kuzu::common::LogicalType> keyType,
std::unique_ptr<kuzu::common::LogicalType> valueType);
Expand Down
1 change: 1 addition & 0 deletions tools/rust_api/src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ pub(crate) mod ffi {
num_elements: u64,
) -> UniquePtr<LogicalType>;
fn create_logical_type_struct(
type_id: LogicalTypeID,
field_names: &Vec<String>,
types: UniquePtr<TypeListBuilder>,
) -> UniquePtr<LogicalType>;
Expand Down
4 changes: 2 additions & 2 deletions tools/rust_api/src/kuzu_rs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,15 @@ std::unique_ptr<LogicalType> create_logical_type_fixed_list(
std::make_unique<FixedListTypeInfo>(std::move(childType), numElements));
}

std::unique_ptr<kuzu::common::LogicalType> create_logical_type_struct(
std::unique_ptr<kuzu::common::LogicalType> create_logical_type_struct(LogicalTypeID typeID,
const rust::Vec<rust::String>& fieldNames, std::unique_ptr<TypeListBuilder> fieldTypes) {
std::vector<std::unique_ptr<StructField>> fields;
for (auto i = 0u; i < fieldNames.size(); i++) {
fields.push_back(std::make_unique<StructField>(
std::string(fieldNames[i]), std::move(fieldTypes->types[i])));
}
return std::make_unique<LogicalType>(
LogicalTypeID::STRUCT, std::make_unique<kuzu::common::StructTypeInfo>(std::move(fields)));
typeID, std::make_unique<kuzu::common::StructTypeInfo>(std::move(fields)));
}

std::unique_ptr<kuzu::common::LogicalType> create_logical_type_map(
Expand Down
30 changes: 29 additions & 1 deletion tools/rust_api/src/logical_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,14 @@ pub enum LogicalType {
/// Correponds to [Value::Rel](crate::value::Value::Rel)
Rel,
RecursiveRel,
/// Correponds to [Value::Map](crate::value::Value::Map)
Map {
key_type: Box<LogicalType>,
value_type: Box<LogicalType>,
},
Union {
types: Vec<(String, LogicalType)>,
},
}

impl From<&ffi::Value> for LogicalType {
Expand Down Expand Up @@ -141,6 +145,18 @@ impl From<&ffi::LogicalType> for LogicalType {
value_type: Box::new(value_type),
}
}
LogicalTypeID::UNION => {
let names = ffi::logical_type_get_struct_field_names(logical_type);
let types = ffi::logical_type_get_struct_field_types(logical_type);
LogicalType::Union {
types: names
.into_iter()
// Skip the tag field
.skip(1)
.zip(types.into_iter().skip(1).map(Into::<LogicalType>::into))
.collect(),
}
}
// Should be unreachable, as cxx will check that the LogicalTypeID enum matches the one
// on the C++ side.
x => panic!("Unsupported type {:?}", x),
Expand Down Expand Up @@ -187,7 +203,18 @@ impl From<&LogicalType> for cxx::UniquePtr<ffi::LogicalType> {
names.push(name.clone());
builder.pin_mut().insert(typ.into());
}
ffi::create_logical_type_struct(&names, builder)
ffi::create_logical_type_struct(ffi::LogicalTypeID::STRUCT, &names, builder)
}
LogicalType::Union { types } => {
let mut builder = ffi::create_type_list();
let mut names = vec![];
names.push("tag".to_string());
builder.pin_mut().insert((&LogicalType::Int64).into());
for (name, typ) in types {
names.push(name.clone());
builder.pin_mut().insert(typ.into());
}
ffi::create_logical_type_struct(ffi::LogicalTypeID::UNION, &names, builder)
}
LogicalType::Map {
key_type,
Expand Down Expand Up @@ -227,6 +254,7 @@ impl LogicalType {
LogicalType::Rel => LogicalTypeID::REL,
LogicalType::RecursiveRel => LogicalTypeID::RECURSIVE_REL,
LogicalType::Map { .. } => LogicalTypeID::MAP,
LogicalType::Union { .. } => LogicalTypeID::UNION,
}
}
}
109 changes: 101 additions & 8 deletions tools/rust_api/src/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ impl From<(u64, u64)> for InternalID {

/// Data types supported by Kùzu
///
/// Also see <https://kuzudb.com/docs/cypher/data-types/overview.html>
/// Also see <https://kuzudb.com/docusaurus/cypher/data-types/overview.html>
#[derive(Clone, Debug, PartialEq)]
pub enum Value {
Null(LogicalType),
Expand All @@ -206,33 +206,33 @@ pub enum Value {
/// Stored internally as the number of days since 1970-01-01 as a 32-bit signed integer, which
/// allows for a wider range of dates to be stored than can be represented by time::Date
///
/// <https://kuzudb.com/docs/cypher/data-types/date.html>
/// <https://kuzudb.com/docusaurus/cypher/data-types/date.html>
Date(time::Date),
/// May be signed or unsigned.
///
/// Nanosecond precision of time::Duration (if available) will not be preserved when passed to
/// queries, and results will always have at most microsecond precision.
///
/// <https://kuzudb.com/docs/cypher/data-types/interval.html>
/// <https://kuzudb.com/docusaurus/cypher/data-types/interval.html>
Interval(time::Duration),
/// Stored internally as the number of microseconds since 1970-01-01
/// Nanosecond precision of SystemTime (if available) will not be preserved when used.
///
/// <https://kuzudb.com/docs/cypher/data-types/timestamp.html>
/// <https://kuzudb.com/docusaurus/cypher/data-types/timestamp.html>
Timestamp(time::OffsetDateTime),
InternalID(InternalID),
/// <https://kuzudb.com/docs/cypher/data-types/string.html>
/// <https://kuzudb.com/docusaurus/cypher/data-types/string.html>
String(String),
Blob(Vec<u8>),
// TODO: Enforce type of contents
// LogicalType is necessary so that we can pass the correct type to the C++ API if the list is empty.
/// These must contain elements which are all the given type.
/// <https://kuzudb.com/docs/cypher/data-types/list.html>
/// <https://kuzudb.com/docusaurus/cypher/data-types/list.html>
VarList(LogicalType, Vec<Value>),
/// These must contain elements which are all the same type.
/// <https://kuzudb.com/docs/cypher/data-types/list.html>
/// <https://kuzudb.com/docusaurus/cypher/data-types/list.html>
FixedList(LogicalType, Vec<Value>),
/// <https://kuzudb.com/docs/cypher/data-types/struct.html>
/// <https://kuzudb.com/docusaurus/cypher/data-types/struct.html>
Struct(Vec<(String, Value)>),
Node(NodeVal),
Rel(RelVal),
Expand All @@ -244,7 +244,13 @@ pub enum Value {
/// Sequence of Rels which make up the RecursiveRel
rels: Vec<RelVal>,
},
/// <https://kuzudb.com/docusaurus/cypher/data-types/map>
Map((LogicalType, LogicalType), Vec<(Value, Value)>),
/// <https://kuzudb.com/docusaurus/cypher/data-types/union>
Union {
types: Vec<(String, LogicalType)>,
value: Box<Value>,
},
}

fn display_list<T: std::fmt::Display>(f: &mut fmt::Formatter<'_>, list: &Vec<T>) -> fmt::Result {
Expand Down Expand Up @@ -312,6 +318,7 @@ impl std::fmt::Display for Value {
display_list(f, rels)?;
write!(f, "}}")
}
Value::Union { types: _, value } => write!(f, "{value}"),
}
}
}
Expand Down Expand Up @@ -360,6 +367,9 @@ impl From<&Value> for LogicalType {
key_type: Box::new(key_type.clone()),
value_type: Box::new(value_type.clone()),
},
Value::Union { types, value: _ } => LogicalType::Union {
types: types.clone(),
},
}
}
}
Expand Down Expand Up @@ -542,6 +552,19 @@ impl TryFrom<&ffi::Value> for Value {
rels: rels.collect(),
})
}
LogicalTypeID::UNION => {
let types =
if let LogicalType::Union { types } = ffi::value_get_data_type(value).into() {
types
} else {
unreachable!()
};
let value: Value = ffi::value_get_child(value, 0).try_into()?;
Ok(Value::Union {
types,
value: Box::new(value),
})
}
// TODO(bmwinger): Better error message for types which are unsupported
x => panic!("Unsupported type {:?}", x),
}
Expand Down Expand Up @@ -673,6 +696,15 @@ impl TryInto<cxx::UniquePtr<ffi::Value>> for Value {
Value::RecursiveRel { .. } => {
Err(crate::Error::ReadOnlyType(LogicalType::RecursiveRel))
}
Value::Union { types, value } => {
let mut builder = ffi::create_list();
builder.pin_mut().insert((*value).try_into()?);

Ok(ffi::get_list_value(
(&LogicalType::Union { types }).into(),
builder,
))
}
}
}
}
Expand Down Expand Up @@ -870,6 +902,7 @@ mod tests {
convert_rel_type: LogicalType::Rel,
convert_recursive_rel_type: LogicalType::RecursiveRel,
convert_map_type: LogicalType::Map { key_type: Box::new(LogicalType::Interval), value_type: Box::new(LogicalType::Rel) },
convert_union_type: LogicalType::Union { types: vec![("Num".to_string(), LogicalType::Int8), ("duration".to_string(), LogicalType::Interval), ("string".to_string(), LogicalType::String)] },
}

value_tests! {
Expand Down Expand Up @@ -898,6 +931,10 @@ mod tests {
convert_struct: Value::Struct(vec![("NAME".to_string(), "Alice".into()), ("AGE".to_string(), 25.into())]),
convert_internal_id: Value::InternalID(InternalID { table_id: 0, offset: 0 }),
convert_map: Value::Map((LogicalType::String, LogicalType::Int64), vec![(Value::String("key".to_string()), Value::Int64(24))]),
convert_union: Value::Union {
types: vec![("Num".to_string(), LogicalType::Int8), ("duration".to_string(), LogicalType::Interval)],
value: Box::new(Value::Int8(-127))
},
}

display_tests! {
Expand Down Expand Up @@ -926,6 +963,10 @@ mod tests {
display_internal_id: Value::InternalID(InternalID { table_id: 0, offset: 0 }),
// Node and Rel Cannot be easily created on the C++ side
display_map: Value::Map((LogicalType::String, LogicalType::Int64), vec![(Value::String("key".to_string()), Value::Int64(24))]),
display_union: Value::Union {
types: vec![("Num".to_string(), LogicalType::Int8), ("duration".to_string(), LogicalType::Interval)],
value: Box::new(Value::Int8(-127))
},
}

database_tests! {
Expand All @@ -934,6 +975,10 @@ mod tests {
// db_var_list_int: Value::VarList(LogicalType::Int64, vec![0i64.into(), 1i64.into(), 2i64.into()]), "INT64[]",
// db_map: Value::Map((LogicalType::String, LogicalType::Int64), vec![(Value::String("key".to_string()), Value::Int64(24))]), "MAP(STRING,INT64)",
// db_fixed_list: Value::FixedList(LogicalType::Int64, vec![1i64.into(), 2i64.into(), 3i64.into()]), "INT64[3]",
// db_union: Value::Union {
// types: vec![("Num".to_string(), LogicalType::Int8), ("duration".to_string(), LogicalType::Interval)],
// value: Box::new(Value::Int8(-127))
// }, "UNION(Num INT8, duration INTERVAL)",
db_struct:
Value::Struct(vec![("item".to_string(), "Knife".into()), ("count".to_string(), 1.into())]),
"STRUCT(item STRING, count INT32)",
Expand Down Expand Up @@ -1138,4 +1183,52 @@ mod tests {
temp_dir.close()?;
Ok(())
}

#[test]
/// Tests that passing the values through the database returns what we put in
fn test_union() -> Result<()> {
use std::fs::File;
use std::io::Write;
let temp_dir = tempfile::tempdir()?;
let db = Database::new(temp_dir.path(), SystemConfig::default())?;
let conn = Connection::new(&db)?;
conn.query(
"CREATE NODE TABLE demo(a SERIAL, b UNION(num INT64, str STRING), PRIMARY KEY(a));",
)?;
let mut file = File::create(temp_dir.path().join("demo.csv"))?;
file.write_all(b"1\naa\n")?;
conn.query(&format!(
"COPY demo from '{}/demo.csv';",
// Use forward-slashes instead of backslashes on windows, as thmay not be supported by
// the query parser
temp_dir.path().display().to_string().replace("\\", "/")
))?;
let result = conn.query("MATCH (d:demo) RETURN d.b;")?;
let types = vec![
("num".to_string(), LogicalType::Int64),
("str".to_string(), LogicalType::String),
];
assert_eq!(
result.get_column_data_types(),
vec![LogicalType::Union {
types: types.clone()
}],
);
let results: Vec<Value> = result.map(|mut x| x.pop().unwrap()).collect();
assert_eq!(
results,
vec![
Value::Union {
types: types.clone(),
value: Box::new(Value::Int64(1))
},
Value::Union {
types: types.clone(),
value: Box::new(Value::String("aa".to_string()))
},
]
);
temp_dir.close()?;
Ok(())
}
}
Loading