Skip to content

Commit

Permalink
implement fast, binary-search simplification
Browse files Browse the repository at this point in the history
  • Loading branch information
larry98 committed Oct 4, 2024
1 parent 2bef6fd commit 0469a8f
Show file tree
Hide file tree
Showing 3 changed files with 244 additions and 46 deletions.
206 changes: 167 additions & 39 deletions cpp/src/arrow/compute/expression.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@
#include <unordered_map>
#include <unordered_set>

#include "arrow/array/concatenate.h"
#include "arrow/chunked_array.h"
#include "arrow/compute/api_aggregate.h"
#include "arrow/compute/api_vector.h"
#include "arrow/compute/exec_internal.h"
#include "arrow/compute/expression_internal.h"
#include "arrow/compute/function_internal.h"
#include "arrow/compute/kernels/set_lookup_internal.h"
#include "arrow/compute/util.h"
#include "arrow/io/memory.h"
#include "arrow/ipc/reader.h"
Expand All @@ -38,6 +40,7 @@
#include "arrow/util/string.h"
#include "arrow/util/value_parsing.h"
#include "arrow/util/vector.h"
#include "arrow/visit_array_inline.h"

namespace arrow {

Expand Down Expand Up @@ -1351,45 +1354,140 @@ Result<Expression> SimplifyIsValidGuarantee(Expression expr,
/// potential complications with null matching behavior. This is ok for the
/// predicate pushdown use case because the overall aim is to simplify to an
/// unsatisfiable expression.
Result<Datum> SimplifyIsInValueSet(Datum value_set, const Inequality& guarantee,
SetLookupOptions::NullMatchingBehavior null_matching) {
FilterOptions::NullSelectionBehavior null_selection;
switch (null_matching) {
case SetLookupOptions::MATCH:
null_selection =
guarantee.nullable ? FilterOptions::EMIT_NULL : FilterOptions::DROP;
break;
case SetLookupOptions::SKIP:
null_selection = FilterOptions::DROP;
break;
case SetLookupOptions::EMIT_NULL:
if (guarantee.nullable) return value_set;
null_selection = FilterOptions::DROP;
break;
case SetLookupOptions::INCONCLUSIVE:
if (guarantee.nullable) return value_set;
ARROW_ASSIGN_OR_RAISE(Datum is_null, IsNull(value_set));
ARROW_ASSIGN_OR_RAISE(Datum any_null, Any(is_null));
if (any_null.scalar_as<BooleanScalar>().value) return value_set;
null_selection = FilterOptions::DROP;
break;
}

std::string func_name = Comparison::GetName(guarantee.cmp);
DCHECK_NE(func_name, "na");
std::vector<Datum> args{value_set, guarantee.bound};
ARROW_ASSIGN_OR_RAISE(Datum filter_mask, CallFunction(func_name, args));
FilterOptions filter_options(null_selection);
ARROW_ASSIGN_OR_RAISE(Datum simplified_value_set,
Filter(value_set, filter_mask, filter_options));
return simplified_value_set;
}
struct IsInValueSetSimplifier {
template <typename T>
Status Visit(const T&) {
ARROW_ASSIGN_OR_RAISE(result, SimplifyBasic());
return Status::OK();
}

template <typename T>
enable_if_t<std::is_base_of_v<FlatArray, T> || std::is_base_of_v<BaseBinaryArray, T>,
Status>
Visit(const T&) {
auto simplified =
enable_fast_simplification ? SimplifyOptimized<T>() : Status::Invalid();
if (simplified.ok()) {
result = simplified.ValueUnsafe();
} else {
ARROW_ASSIGN_OR_RAISE(result, SimplifyBasic());
}
return Status::OK();
}

/// Simplify the value set using a linear scan filter.
Result<std::shared_ptr<Array>> SimplifyBasic() {
FilterOptions::NullSelectionBehavior null_selection;
switch (null_matching) {
case SetLookupOptions::MATCH:
null_selection =
guarantee.nullable ? FilterOptions::EMIT_NULL : FilterOptions::DROP;
break;
case SetLookupOptions::SKIP:
null_selection = FilterOptions::DROP;
break;
case SetLookupOptions::EMIT_NULL:
if (guarantee.nullable) return value_set;
null_selection = FilterOptions::DROP;
break;
case SetLookupOptions::INCONCLUSIVE:
if (guarantee.nullable) return value_set;
ARROW_ASSIGN_OR_RAISE(Datum is_null, IsNull(value_set));
ARROW_ASSIGN_OR_RAISE(Datum any_null, Any(is_null));
if (any_null.scalar_as<BooleanScalar>().value) return value_set;
null_selection = FilterOptions::DROP;
break;
}

std::string func_name = Comparison::GetName(guarantee.cmp);
DCHECK_NE(func_name, "na");
std::vector<Datum> args{value_set, guarantee.bound};
ARROW_ASSIGN_OR_RAISE(Datum filter_mask, CallFunction(func_name, args));
FilterOptions filter_options(null_selection);
ARROW_ASSIGN_OR_RAISE(Datum simplified_value_set,
Filter(value_set, filter_mask, filter_options));
return simplified_value_set.make_array();
}

/// Simplify the value set using binary search.
///
/// \pre `value_set` is sorted
/// \pre `value_set` contains no duplicates
/// \pre `value_set` contains no nulls
template <typename T>
Result<std::shared_ptr<Array>> SimplifyOptimized() {
if (guarantee.nullable) return Status::Invalid();
if (null_matching == SetLookupOptions::INCONCLUSIVE) return Status::Invalid();

ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar_bound,
guarantee.bound.scalar()->CastTo(value_set->type()));
auto bound = internal::UnboxScalar<typename T::TypeClass>::Unbox(*scalar_bound);
auto compare = [&](size_t i) -> Comparison::type {
DCHECK(value_set->IsValid(i));
auto value = checked_pointer_cast<T>(value_set)->GetView(i);
return value == bound ? Comparison::EQUAL
: value < bound ? Comparison::LESS
: Comparison::GREATER;
};

size_t lo = 0;
size_t hi = value_set->length();
while (lo + 1 < hi) {
size_t mid = (lo + hi) / 2;
Comparison::type cmp = compare(mid);
if (cmp & Comparison::LESS_EQUAL) {
lo = mid;
} else {
hi = mid;
}
}

Comparison::type cmp = compare(lo);
size_t pivot = lo + (cmp == Comparison::LESS ? 1 : 0);
bool found = cmp == Comparison::EQUAL;

switch (guarantee.cmp) {
case Comparison::EQUAL:
return value_set->Slice(pivot, found ? 1 : 0);
case Comparison::LESS:
return value_set->Slice(0, pivot);
case Comparison::LESS_EQUAL:
return value_set->Slice(0, pivot + (found ? 1 : 0));
case Comparison::GREATER:
return value_set->Slice(pivot + (found ? 1 : 0));
case Comparison::GREATER_EQUAL:
return value_set->Slice(pivot);
case Comparison::NOT_EQUAL:
case Comparison::NA:
DCHECK(false);
return Status::Invalid("Invalid comparison");
}
}

static Result<std::shared_ptr<Array>> Simplify(
std::shared_ptr<Array> value_set, const Inequality& guarantee,
SetLookupOptions::NullMatchingBehavior null_matching,
bool enable_fast_simplification) {
IsInValueSetSimplifier simplifier{value_set, guarantee, null_matching,
enable_fast_simplification, nullptr};
RETURN_NOT_OK(VisitArrayInline(*value_set, &simplifier));
return simplifier.result;
}

std::shared_ptr<Array> value_set;
const Inequality& guarantee;
SetLookupOptions::NullMatchingBehavior null_matching;
bool enable_fast_simplification;
std::shared_ptr<Array> result;
};

/// Simplify an `is_in` call against a list of inequality guarantees.
///
/// Simplification is done across all guarantee conjunction members at once to
/// avoid the cost of repeatedly binding the simplified expression, which is
/// linear in the size of the `is_in` value set.
///
/// Returns a simplified expression, or nullopt if no simfpliciation occurred.
Result<std::optional<Expression>> SimplifyIsInWithGuarantees(
const Expression::Call* is_in_call,
const std::vector<Expression>& guarantee_conjunction_members) {
Expand All @@ -1401,18 +1499,48 @@ Result<std::optional<Expression>> SimplifyIsInWithGuarantees(
const auto& lhs = Comparison::StripOrderPreservingCasts(is_in_call->arguments[0]);
if (!lhs.field_ref()) return std::nullopt;

Datum simplified_value_set = options->value_set;
std::vector<Inequality> guarantees;
for (const Expression& guarantee : guarantee_conjunction_members) {
std::optional<Inequality> inequality = Inequality::ExtractOne(guarantee);
if (!inequality) continue;
if (inequality->target != *lhs.field_ref()) continue;
ARROW_ASSIGN_OR_RAISE(simplified_value_set,
SimplifyIsInValueSet(simplified_value_set, *inequality,
options->null_matching_behavior));
guarantees.emplace_back(std::move(*inequality));
}

bool guaranteed_non_nullable =
std::any_of(guarantees.begin(), guarantees.end(),
[](const Inequality& guarantee) { return !guarantee.nullable; });

std::shared_ptr<Array> simplified_value_set;
bool enable_fast_simplification = false;
if (guaranteed_non_nullable &&
options->null_matching_behavior != SetLookupOptions::INCONCLUSIVE) {
auto state =
checked_pointer_cast<internal::SetLookupStateBase>(is_in_call->kernel_state);
simplified_value_set = state->sorted_and_unique_value_set;
enable_fast_simplification = static_cast<bool>(simplified_value_set);
}
if (!simplified_value_set) {
if (options->value_set.is_array()) {
simplified_value_set = options->value_set.make_array();
} else if (options->value_set.is_chunked_array()) {
ARROW_ASSIGN_OR_RAISE(simplified_value_set,
Concatenate(options->value_set.chunked_array()->chunks()));
} else {
return Status::Invalid("`is_in` value set must be an array or chunked array");
}
}

for (Inequality& guarantee : guarantees) {
if (guaranteed_non_nullable) guarantee.nullable = false;
ARROW_ASSIGN_OR_RAISE(simplified_value_set, IsInValueSetSimplifier::Simplify(
simplified_value_set, guarantee,
options->null_matching_behavior,
enable_fast_simplification));
}

if (simplified_value_set.length() == 0) return literal(false);
if (simplified_value_set.length() == options->value_set.length()) return std::nullopt;
if (simplified_value_set->length() == 0) return literal(false);
if (simplified_value_set->length() == options->value_set.length()) return std::nullopt;

ExecContext exec_context;
Expression::Call simplified_call;
Expand Down
48 changes: 41 additions & 7 deletions cpp/src/arrow/compute/kernels/scalar_set_lookup.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@
// under the License.

#include "arrow/array/array_base.h"
#include "arrow/array/concatenate.h"
#include "arrow/compute/api_scalar.h"
#include "arrow/compute/api_vector.h"
#include "arrow/compute/cast.h"
#include "arrow/compute/kernels/common_internal.h"
#include "arrow/compute/kernels/set_lookup_internal.h"
#include "arrow/compute/kernels/util_internal.h"
#include "arrow/type.h"
#include "arrow/util/bit_util.h"
Expand All @@ -34,10 +37,26 @@ using internal::HashTraits;
namespace compute::internal {
namespace {

// This base class enables non-templated access to the value set type
struct SetLookupStateBase : public KernelState {
std::shared_ptr<DataType> value_set_type;
};
template <typename T>
Result<std::shared_ptr<Array>> SortAndUnique(std::shared_ptr<T> value_set) {
if constexpr (std::is_same_v<T, ChunkedArray>) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Array> value_set_array,
Concatenate(value_set->chunks()));
return SortAndUnique(value_set_array);
} else {
ARROW_ASSIGN_OR_RAISE(value_set, Unique(std::move(value_set)));
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Array> sort_indices,
SortIndices(value_set, SortOptions({}, NullPlacement::AtEnd)));
ARROW_ASSIGN_OR_RAISE(
value_set, Take(*value_set, *sort_indices, TakeOptions(/*bounds_check=*/false)));
if (value_set->length() > 0 && value_set->IsNull(value_set->length() - 1)) {
// If the last one is null we know it's the only one because of the call
// to `Unique` above.
value_set = value_set->Slice(0, value_set->length() - 1);
}
return value_set;
}
}

template <typename Type>
struct SetLookupState : public SetLookupStateBase {
Expand Down Expand Up @@ -209,6 +228,21 @@ struct InitStateVisitor {
}

Result<std::unique_ptr<KernelState>> GetResult() {
if (!options.value_set.is_arraylike()) {
return Status::Invalid("Set lookup value set must be Array or ChunkedArray");
}

// The sorted and unique value set needs to be derived from the value set
// before casting occurs.
std::shared_ptr<Array> sorted_and_unique_value_set;
if (options.value_set.is_chunked_array()) {
sorted_and_unique_value_set =
SortAndUnique(options.value_set.chunked_array()).ValueOr(nullptr);
} else {
sorted_and_unique_value_set =
SortAndUnique(options.value_set.make_array()).ValueOr(nullptr);
}

if (arg_type.id() == Type::TIMESTAMP &&
options.value_set.type()->id() == Type::TIMESTAMP) {
// Other types will fail when casting, so no separate check is needed
Expand All @@ -228,9 +262,7 @@ struct InitStateVisitor {
" vs ", *options.value_set.type());
}

if (!options.value_set.is_arraylike()) {
return Status::Invalid("Set lookup value set must be Array or ChunkedArray");
} else if (!options.value_set.type()->Equals(*arg_type)) {
if (!options.value_set.type()->Equals(*arg_type)) {
auto cast_result =
Cast(options.value_set, CastOptions::Safe(arg_type.GetSharedPtr()),
ctx->exec_context());
Expand All @@ -252,6 +284,8 @@ struct InitStateVisitor {
}

RETURN_NOT_OK(VisitTypeInline(*options.value_set.type(), this));
checked_cast<SetLookupStateBase*>(result.get())->sorted_and_unique_value_set =
sorted_and_unique_value_set;
return std::move(result);
}
};
Expand Down
36 changes: 36 additions & 0 deletions cpp/src/arrow/compute/kernels/set_lookup_internal.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#pragma once

#include "arrow/compute/kernel.h"

namespace arrow::compute::internal {

/// Base class for `is_in` and `index_in` kernel states.
struct SetLookupStateBase : public KernelState {
/// Enables non-templated access to the value set type.
std::shared_ptr<DataType> value_set_type;
/// Enables fast simplification for `is_in` expressions.
///
/// This field may be null.
///
/// \invariant sorted and contains no null or duplicate values
std::shared_ptr<Array> sorted_and_unique_value_set;
};

} // namespace arrow::compute::internal

0 comments on commit 0469a8f

Please sign in to comment.