Skip to content

Commit

Permalink
Fix nested extension types (jorgecarleitao#1334)
Browse files Browse the repository at this point in the history
Various fixes for nested Extension types
  • Loading branch information
John Hughes authored and ritchie46 committed Mar 29, 2023
1 parent 49e3c59 commit 54b3ca1
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 13 deletions.
11 changes: 4 additions & 7 deletions src/array/growable/structure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use std::sync::Arc;
use crate::{
array::{Array, StructArray},
bitmap::MutableBitmap,
datatypes::DataType,
};

use super::{
Expand All @@ -25,6 +24,8 @@ impl<'a> GrowableStruct<'a> {
/// # Panics
/// If `arrays` is empty.
pub fn new(arrays: Vec<&'a StructArray>, mut use_validity: bool, capacity: usize) -> Self {
assert!(!arrays.is_empty());

// if any of the arrays has nulls, insertions from any array requires setting bits
// as there is at least one array with nulls.
if arrays.iter().any(|array| array.null_count() > 0) {
Expand Down Expand Up @@ -68,11 +69,7 @@ impl<'a> GrowableStruct<'a> {
let values = std::mem::take(&mut self.values);
let values = values.into_iter().map(|mut x| x.as_box()).collect();

StructArray::new(
DataType::Struct(self.arrays[0].fields().to_vec()),
values,
validity.into(),
)
StructArray::new(self.arrays[0].data_type().clone(), values, validity.into())
}
}

Expand Down Expand Up @@ -121,7 +118,7 @@ impl<'a> From<GrowableStruct<'a>> for StructArray {
let values = val.values.into_iter().map(|mut x| x.as_box()).collect();

StructArray::new(
DataType::Struct(val.arrays[0].fields().to_vec()),
val.arrays[0].data_type().clone(),
values,
val.validity.into(),
)
Expand Down
2 changes: 1 addition & 1 deletion src/array/struct_/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ impl StructArray {

/// Creates an empty [`StructArray`].
pub fn new_empty(data_type: DataType) -> Self {
if let DataType::Struct(fields) = &data_type {
if let DataType::Struct(fields) = &data_type.to_logical_type() {
let values = fields
.iter()
.map(|field| new_empty_array(field.data_type().clone()))
Expand Down
2 changes: 1 addition & 1 deletion src/array/union/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ impl UnionArray {

/// Creates a new empty [`UnionArray`].
pub fn new_empty(data_type: DataType) -> Self {
if let DataType::Union(f, _, mode) = &data_type {
if let DataType::Union(f, _, mode) = data_type.to_logical_type() {
let fields = f
.iter()
.map(|x| new_empty_array(x.data_type().clone()))
Expand Down
36 changes: 33 additions & 3 deletions tests/it/array/growable/list.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use arrow2::array::{
growable::{Growable, GrowableList},
ListArray, MutableListArray, MutablePrimitiveArray, TryExtend,
use arrow2::{
array::{
growable::{Growable, GrowableList},
Array, ListArray, MutableListArray, MutablePrimitiveArray, TryExtend,
},
datatypes::DataType,
};

fn create_list_array(data: Vec<Option<Vec<Option<i32>>>>) -> ListArray<i32> {
Expand All @@ -9,6 +12,33 @@ fn create_list_array(data: Vec<Option<Vec<Option<i32>>>>) -> ListArray<i32> {
array.into()
}

#[test]
fn extension() {
let data = vec![
Some(vec![Some(1i32), Some(2), Some(3)]),
Some(vec![Some(4), Some(5)]),
Some(vec![Some(6i32), Some(7), Some(8)]),
];

let array = create_list_array(data);

let data_type =
DataType::Extension("ext".to_owned(), Box::new(array.data_type().clone()), None);
let array_ext = ListArray::new(
data_type,
array.offsets().clone(),
array.values().clone(),
array.validity().cloned(),
);

let mut a = GrowableList::new(vec![&array_ext], false, 0);
a.extend(0, 0, 1);

let result: ListArray<i32> = a.into();
assert_eq!(array_ext.data_type(), result.data_type());
dbg!(result);
}

#[test]
fn basic() {
let data = vec![
Expand Down
27 changes: 26 additions & 1 deletion tests/it/array/growable/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ mod utf8;

use arrow2::array::growable::make_growable;
use arrow2::array::*;
use arrow2::datatypes::DataType;
use arrow2::datatypes::{DataType, Field};

#[test]
fn test_make_growable() {
Expand All @@ -37,11 +37,36 @@ fn test_make_growable() {
let array =
FixedSizeBinaryArray::new(DataType::FixedSizeBinary(2), b"abcd".to_vec().into(), None);
make_growable(&[&array], false, 2);
}

#[test]
fn test_make_growable_extension() {
let array = DictionaryArray::try_from_keys(
Int32Array::from_slice([1, 0]),
Int32Array::from_slice([1, 2]).boxed(),
)
.unwrap();
make_growable(&[&array], false, 2);

let data_type = DataType::Extension("ext".to_owned(), Box::new(DataType::Int32), None);
let array = Int32Array::from_slice([1, 2]).to(data_type.clone());
let array_grown = make_growable(&[&array], false, 2).as_box();
assert_eq!(array_grown.data_type(), &data_type);

let data_type = DataType::Extension(
"ext".to_owned(),
Box::new(DataType::Struct(vec![Field::new(
"a",
DataType::Int32,
false,
)])),
None,
);
let array = StructArray::new(
data_type.clone(),
vec![Int32Array::from_slice([1, 2]).boxed()],
None,
);
let array_grown = make_growable(&[&array], false, 2).as_box();
assert_eq!(array_grown.data_type(), &data_type);
}
36 changes: 36 additions & 0 deletions tests/it/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ fn empty() {
DataType::Utf8,
DataType::Binary,
DataType::List(Box::new(Field::new("a", DataType::Binary, true))),
DataType::List(Box::new(Field::new(
"a",
DataType::Extension("ext".to_owned(), Box::new(DataType::Int32), None),
true,
))),
DataType::Union(
vec![Field::new("a", DataType::Binary, true)],
None,
Expand All @@ -68,11 +73,42 @@ fn empty() {
None,
UnionMode::Dense,
),
DataType::Struct(vec![Field::new("a", DataType::Int32, true)]),
];
let a = datatypes.into_iter().all(|x| new_empty_array(x).len() == 0);
assert!(a);
}

#[test]
fn empty_extension() {
let datatypes = vec![
DataType::Int32,
DataType::Float64,
DataType::Utf8,
DataType::Binary,
DataType::List(Box::new(Field::new("a", DataType::Binary, true))),
DataType::Union(
vec![Field::new("a", DataType::Binary, true)],
None,
UnionMode::Sparse,
),
DataType::Union(
vec![Field::new("a", DataType::Binary, true)],
None,
UnionMode::Dense,
),
DataType::Struct(vec![Field::new("a", DataType::Int32, true)]),
];
let a = datatypes
.into_iter()
.map(|dt| DataType::Extension("ext".to_owned(), Box::new(dt), None))
.all(|x| {
let a = new_empty_array(x);
a.len() == 0 && matches!(a.data_type(), DataType::Extension(_, _, _))
});
assert!(a);
}

#[test]
fn test_clone() {
let datatypes = vec![
Expand Down

0 comments on commit 54b3ca1

Please sign in to comment.