Skip to content

Commit

Permalink
Tidy up function results with nulls present.
Browse files Browse the repository at this point in the history
  • Loading branch information
lriggs committed Nov 8, 2023
1 parent be455ab commit 7c7a939
Showing 1 changed file with 31 additions and 22 deletions.
53 changes: 31 additions & 22 deletions cpp/src/gandiva/array_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
template <typename Type>
Type* array_remove_template(int64_t context_ptr, const Type* entry_buf,
int32_t entry_len, const int32_t* entry_validity, bool combined_row_validity,
Type remove_data,
Type remove_data, bool remove_data_valid,
int64_t loop_var, int64_t validity_index_var,
bool* valid_row, int32_t* out_len, int32_t** valid_ptr)
{
Expand All @@ -45,7 +45,7 @@ Type* array_remove_template(int64_t context_ptr, const Type* entry_buf,
std::vector<bool> outValid;
for (int i = 0; i < entry_len; i++) {
Type entry_item = *(entry_buf + i);
if (entry_item == remove_data) {
if (remove_data_valid && entry_item == remove_data) {
//Do not add the item to remove.
} else if (!arrow::bit_util::GetBit(reinterpret_cast<const uint8_t*>(entry_validityAdjusted), validityBitIndex + i)) {
outValid.push_back(false);
Expand All @@ -72,21 +72,24 @@ Type* array_remove_template(int64_t context_ptr, const Type* entry_buf,
uint8_t* ret = gdv_fn_context_arena_malloc(context_ptr, outBufferLength);
memcpy(ret, newInts.data(), outBufferLength);
*valid_row = true;
if (!combined_row_validity) {

//Return null if the input array is null or the data to remove is null.
if (!combined_row_validity || !remove_data_valid) {
*out_len = 0;
*valid_row = false; //this one is what works for the top level validity.
}

*valid_ptr = reinterpret_cast<int32_t*>(validRet);
return reinterpret_cast<Type*>(ret);
}

template <typename Type>
bool array_contains_template(const Type* entry_buf,
int32_t entry_len, const int32_t* entry_validity, bool combined_row_validity,
Type contains_data,
Type contains_data, bool contains_data_valid,
int64_t loop_var, int64_t validity_index_var,
bool* valid_row) {
if (!combined_row_validity) {
if (!combined_row_validity || !contains_data_valid) {
*valid_row = false;
return false;
}
Expand All @@ -95,107 +98,113 @@ bool array_contains_template(const Type* entry_buf,
const int32_t* entry_validityAdjusted = entry_validity - (loop_var );
int64_t validityBitIndex = validity_index_var - entry_len;

bool found_null_in_data = false;
for (int i = 0; i < entry_len; i++) {
if (!arrow::bit_util::GetBit(reinterpret_cast<const uint8_t*>(entry_validityAdjusted), validityBitIndex + i)) {
found_null_in_data = true;
continue;
}
Type entry_item = *(entry_buf + i);
if (entry_item == contains_data) {
if (contains_data_valid && entry_item == contains_data) {
return true;
}
}
//If there is null in the input and the item is not found the result is null.
if (found_null_in_data) {
*valid_row = false;
}
return false;
}

extern "C" {

bool array_int32_contains_int32(int64_t context_ptr, const int32_t* entry_buf,
int32_t entry_len, const int32_t* entry_validity, bool combined_row_validity,
int32_t contains_data, bool entry_validWhat,
int32_t contains_data, bool contains_data_valid,
int64_t loop_var, int64_t validity_index_var,
bool* valid_row) {
return array_contains_template<int32_t>(entry_buf, entry_len, entry_validity,
combined_row_validity, contains_data,
combined_row_validity, contains_data, contains_data_valid,
loop_var, validity_index_var, valid_row);
}

bool array_int64_contains_int64(int64_t context_ptr, const int64_t* entry_buf,
int32_t entry_len, const int32_t* entry_validity, bool combined_row_validity,
int64_t contains_data, bool entry_validWhat,
int64_t contains_data, bool contains_data_valid,
int64_t loop_var, int64_t validity_index_var,
bool* valid_row) {
return array_contains_template<int64_t>(entry_buf, entry_len, entry_validity,
combined_row_validity, contains_data,
combined_row_validity, contains_data, contains_data_valid,
loop_var, validity_index_var, valid_row);
}

bool array_float32_contains_float32(int64_t context_ptr, const float* entry_buf,
int32_t entry_len, const int32_t* entry_validity, bool combined_row_validity,
float contains_data, bool entry_validWhat,
float contains_data, bool contains_data_valid,
int64_t loop_var, int64_t validity_index_var,
bool* valid_row) {
return array_contains_template<float>(entry_buf, entry_len, entry_validity,
combined_row_validity, contains_data,
combined_row_validity, contains_data, contains_data_valid,
loop_var, validity_index_var, valid_row);
}

bool array_float64_contains_float64(int64_t context_ptr, const double* entry_buf,
int32_t entry_len, const int32_t* entry_validity, bool combined_row_validity,
double contains_data, bool entry_validWhat,
double contains_data, bool contains_data_valid,
int64_t loop_var, int64_t validity_index_var,
bool* valid_row) {
return array_contains_template<double>(entry_buf, entry_len, entry_validity,
combined_row_validity, contains_data,
combined_row_validity, contains_data, contains_data_valid,
loop_var, validity_index_var, valid_row);
}



int32_t* array_int32_remove(int64_t context_ptr, const int32_t* entry_buf,
int32_t entry_len, const int32_t* entry_validity, bool combined_row_validity,
int32_t remove_data, bool entry_validWhat,
int32_t remove_data, bool remove_data_valid,
int64_t loop_var, int64_t validity_index_var,
bool* valid_row, int32_t* out_len, int32_t** valid_ptr) {
return array_remove_template<int32_t>(context_ptr, entry_buf,
entry_len, entry_validity, combined_row_validity,
remove_data,
remove_data, remove_data_valid,
loop_var, validity_index_var,
valid_row, out_len, valid_ptr);
}

int64_t* array_int64_remove(int64_t context_ptr, const int64_t* entry_buf,
int32_t entry_len, const int32_t* entry_validity, bool combined_row_validity,
int64_t remove_data, bool entry_validWhat,
int64_t remove_data, bool remove_data_valid,
int64_t loop_var, int64_t validity_index_var,
bool* valid_row, int32_t* out_len, int32_t** valid_ptr){
return array_remove_template<int64_t>(context_ptr, entry_buf,
entry_len, entry_validity, combined_row_validity,
remove_data,
remove_data, remove_data_valid,
loop_var, validity_index_var,
valid_row, out_len, valid_ptr);
}

float* array_float32_remove(int64_t context_ptr, const float* entry_buf,
int32_t entry_len, const int32_t* entry_validity, bool combined_row_validity,
float remove_data, bool entry_validWhat,
float remove_data, bool remove_data_valid,
int64_t loop_var, int64_t validity_index_var,
bool* valid_row, int32_t* out_len, int32_t** valid_ptr){
return array_remove_template<float>(context_ptr, entry_buf,
entry_len, entry_validity, combined_row_validity,
remove_data,
remove_data, remove_data_valid,
loop_var, validity_index_var,
valid_row, out_len, valid_ptr);
}


double* array_float64_remove(int64_t context_ptr, const double* entry_buf,
int32_t entry_len, const int32_t* entry_validity, bool combined_row_validity,
double remove_data, bool entry_validWhat,
double remove_data, bool remove_data_valid,
int64_t loop_var, int64_t validity_index_var,
bool* valid_row, int32_t* out_len, int32_t** valid_ptr){
return array_remove_template<double>(context_ptr, entry_buf,
entry_len, entry_validity, combined_row_validity,
remove_data,
remove_data, remove_data_valid,
loop_var, validity_index_var,
valid_row, out_len, valid_ptr);
}
Expand Down

0 comments on commit 7c7a939

Please sign in to comment.