Skip to content

Commit

Permalink
fix: initial scalar distinct version
Browse files Browse the repository at this point in the history
  • Loading branch information
vibhatha committed Jul 18, 2023
1 parent 0a6671a commit b18056e
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 145 deletions.
122 changes: 40 additions & 82 deletions cpp/src/arrow/acero/plan_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include <gmock/gmock-matchers.h>

#include <algorithm>
#include <functional>
#include <memory>

Expand Down Expand Up @@ -1444,20 +1445,21 @@ TEST(ExecPlanExecution, ScalarSourceScalarAggSink) {
AssertExecBatchesEqualIgnoringOrder(result.schema, result.batches, exp_batches);
}

/// TODO: parametrize this test case for various types and data
/// 1. Parameterize the types
/// 2. Parameterize the input data

class ScalarAggregateDistinctTest
class ScalarAggregateDistinctArrayTest
: public ::testing::TestWithParam<
std::tuple<std::shared_ptr<arrow::DataType>, std::shared_ptr<arrow::DataType>,
std::string, std::string, std::string>> {};

class ScalarAggregateDistinctScalarTest
: public ::testing::TestWithParam<
std::tuple<std::shared_ptr<arrow::DataType>, std::shared_ptr<arrow::DataType>,
std::string, std::string, std::string, std::string, std::string>> {};

// Order of parameters: Key Type, Value Type, Batch 1, Batch 2, Output
static std::vector<
std::tuple<std::shared_ptr<arrow::DataType>, std::shared_ptr<arrow::DataType>,
std::string, std::string, std::string>>
GetDistinctBatch() {
GetDistinctBatchOfArray() {
return std::vector<
std::tuple<std::shared_ptr<arrow::DataType>, std::shared_ptr<arrow::DataType>,
std::string, std::string, std::string>>{
Expand All @@ -1484,38 +1486,35 @@ GetDistinctBatch() {
// Order of parameters: Key Type, Value Type, Batch 1, Batch 2, Output
static std::vector<
std::tuple<std::shared_ptr<arrow::DataType>, std::shared_ptr<arrow::DataType>,
std::string, std::string, std::string>>
GetDistinctBatchScalar() {
std::string, std::string, std::string, std::string, std::string>>
GetDistinctBatchOfScalar() {
return std::vector<
std::tuple<std::shared_ptr<arrow::DataType>, std::shared_ptr<arrow::DataType>,
std::string, std::string, std::string>>{
{arrow::int32(), arrow::boolean(), "[[1, false], [2, false], [3, false]]",
"[[4, true], [5, true], [6, false]]", "[[false, true]]"},
{arrow::int32(), arrow::int32(), "[[1, 10], [2, 20], [3, 10]]",
"[[4, 10], [5, 30], [6, 20]]", "[[10, 20, 30]]"},
{arrow::int32(), arrow::timestamp(arrow::TimeUnit::SECOND),
"[[1, 1609459200], [2, 1609545600], [3, 1609632000]]",
"[[4, 1609545600], [5, 1609459200], [6, 1609459200]]",
std::string, std::string, std::string, std::string, std::string>>{
{arrow::int32(), arrow::boolean(), "[[1, false]]", "[[2, false]]", "[[3, true]]",
"[[4, false]]", "[[false], [true]]"},
{arrow::int32(), arrow::int32(), "[[1, 10]]", "[[2, 20]]", "[[3, 10]]", "[[4, 30]]",
"[[10], [20], [30]]"},
{arrow::int32(), arrow::timestamp(arrow::TimeUnit::SECOND), "[[1, 1609459200]]",
"[[2, 1609545600]]", "[[3, 1609632000]]", "[[4, 1609545600]]",
"[[1609459200], [1609545600], [1609632000]]"},
{arrow::int32(), arrow::time32(arrow::TimeUnit::SECOND),
"[[1, 48615], [2, 48615], [3, 48735]]", "[[4, 48615], [5, 48735], [6, 48675]]",
"[[48615], [48675], [48735]]"},
{arrow::int32(), arrow::duration(arrow::TimeUnit::SECOND),
"[[1, 20], [2, 40], [3, 20]]", "[[4, 60], [5, 40], [6, 30]]",
"[[20], [30], [40], [60]]"},
{arrow::int32(), arrow::utf8(),
std::string(R"([[1, "abcd"], [2, "efgh"], [3, "abcd"]])"),
std::string(R"([[4, "efgh"], [5, "hijk"], [6, "abcd"]])"),
{arrow::int32(), arrow::time32(arrow::TimeUnit::SECOND), "[[1, 48615]]",
"[[2, 48675]]", "[[3, 48675]]", "[[4, 48735]]", "[[48615], [48675], [48735]]"},
{arrow::int32(), arrow::duration(arrow::TimeUnit::SECOND), "[[1, 20]]", "[[2, 60]]",
"[[3, 20]]", "[[4, 60]]", "[[20], [60]]"},
{arrow::int32(), arrow::utf8(), std::string(R"([[1, "abcd"]])"),
std::string(R"([[2, "efgh"]])"), std::string(R"([[3, "hijk"]])"),
std::string(R"([[4, "efgh"]])"),
std::string(R"([["abcd"], ["efgh"], ["hijk"]])")}};
}

INSTANTIATE_TEST_SUITE_P(ScalarAggregatorArrayTest, ScalarAggregateDistinctTest,
::testing::ValuesIn(GetDistinctBatch()));
INSTANTIATE_TEST_SUITE_P(ScalarAggregatorArrayTest, ScalarAggregateDistinctArrayTest,
::testing::ValuesIn(GetDistinctBatchOfArray()));

INSTANTIATE_TEST_SUITE_P(ScalarAggregatorScalarTest, ScalarAggregateDistinctTest,
::testing::ValuesIn(GetDistinctBatchScalar()));
INSTANTIATE_TEST_SUITE_P(ScalarAggregatorScalarTest, ScalarAggregateDistinctScalarTest,
::testing::ValuesIn(GetDistinctBatchOfScalar()));

TEST_P(ScalarAggregateDistinctTest, WithArray) {
TEST_P(ScalarAggregateDistinctArrayTest, WithArray) {
BatchesWithSchema scalar_data;
auto param = GetParam();
std::shared_ptr<arrow::DataType> key_type = std::get<0>(param);
Expand All @@ -1542,77 +1541,36 @@ TEST_P(ScalarAggregateDistinctTest, WithArray) {
AssertExecBatchesEqualIgnoringOrder(result.schema, result.batches, exp_batches);
}

TEST_P(ScalarAggregateDistinctTest, WithScalar) {
TEST_P(ScalarAggregateDistinctScalarTest, WithScalar) {
BatchesWithSchema scalar_data;
auto param = GetParam();
std::shared_ptr<arrow::DataType> key_type = std::get<0>(param);
std::shared_ptr<arrow::DataType> val_type = std::get<1>(param);
std::string first_batch = std::get<2>(param);
std::string second_batch = std::get<3>(param);
std::string output_batch = std::get<4>(param);
std::string third_batch = std::get<4>(param);
std::string fourth_batch = std::get<5>(param);
std::string output_batch = std::get<6>(param);

scalar_data.batches = {
ExecBatchFromJSON({key_type, val_type}, {ArgShape::SCALAR, ArgShape::SCALAR},
first_batch),
ExecBatchFromJSON({key_type, val_type}, {ArgShape::SCALAR, ArgShape::SCALAR},
second_batch),
ExecBatchFromJSON({key_type, val_type}, {ArgShape::SCALAR, ArgShape::SCALAR},
third_batch),
ExecBatchFromJSON({key_type, val_type}, {ArgShape::SCALAR, ArgShape::SCALAR},
fourth_batch),
};

scalar_data.schema = schema({field("key", key_type), field("value", val_type)});
Declaration plan = Declaration::Sequence(
{{"source", SourceNodeOptions{scalar_data.schema,
scalar_data.gen(/*parallel=*/false, /*slow=*/false)}},
{"aggregate", AggregateNodeOptions{/*aggregates=*/{
{"distinct", nullptr, "value", "distinct(value)"}}}}});
// ASSERT_OK_AND_ASSIGN(auto result, DeclarationToExecBatches(std::move(plan)));
// auto exp_batches = {ExecBatchFromJSON({val_type}, {ArgShape::ARRAY}, output_batch)};
// AssertExecBatchesEqualIgnoringOrder(result.schema, result.batches, exp_batches);
ASSERT_OK_AND_ASSIGN(auto result, DeclarationToTable(std::move(plan)));
std::cout << result->ToString() << std::endl;
}

TEST(ExecPlanExecution, ScalarSourceScalarDistinctAggSink) {
BatchesWithSchema scalar_data;
scalar_data.batches = {
ExecBatchFromJSON({int32(), boolean(), arrow::timestamp(arrow::TimeUnit::SECOND),
arrow::time32(arrow::TimeUnit::SECOND),
arrow::duration(arrow::TimeUnit::SECOND), int32()},
{ArgShape::SCALAR, ArgShape::SCALAR, ArgShape::SCALAR,
ArgShape::SCALAR, ArgShape::SCALAR, ArgShape::SCALAR},
"[[5, false, 1609459200, 48615, 60, 10], [5, false, 1609545600, "
"48615, 20, 20], [5, false, 1609632000, 48735, 20, 10]]"),
ExecBatchFromJSON({int32(), boolean(), arrow::timestamp(arrow::TimeUnit::SECOND),
arrow::time32(arrow::TimeUnit::SECOND),
arrow::duration(arrow::TimeUnit::SECOND), int32()},
"[[5, true, 1609545600, 48735, 20, 10], [6, false, 1609459200, "
"48855, 40, 30], [7, true, 1609459200, 48735, 60, 20]]")};
scalar_data.schema =
schema({field("a", int32()), field("b", boolean()),
field("c", arrow::timestamp(arrow::TimeUnit::SECOND)),
field("d", arrow::time32(arrow::TimeUnit::SECOND)),
field("e", arrow::duration(arrow::TimeUnit::SECOND)), field("f", int32())});
std::cout << "Data OK" << std::endl;
// index can't be tested as it's order-dependent
// mode/quantile can't be tested as they're technically vector kernels
Declaration plan = Declaration::Sequence(
{{"source", SourceNodeOptions{scalar_data.schema,
scalar_data.gen(/*parallel=*/false, /*slow=*/false)}},
{"aggregate",
AggregateNodeOptions{/*aggregates=*/{{"distinct", nullptr, "f", "distinct"}}}}});

// auto exp_batches = {
// ExecBatchFromJSON(
// {boolean(), boolean(), int64(), int64(), float64(), int64(), float64(),
// int64(),
// float64(), float64()},
// {ArgShape::SCALAR, ArgShape::SCALAR, ArgShape::SCALAR, ArgShape::SCALAR,
// ArgShape::SCALAR, ArgShape::SCALAR, ArgShape::SCALAR, ArgShape::SCALAR,
// ArgShape::ARRAY, ArgShape::SCALAR},
// R"([[false, true, 6, 6, 5.5, 26250, 0.7637626158259734, 33, 5.0,
// 0.5833333333333334]])"),
// };
ASSERT_OK_AND_ASSIGN(auto res, DeclarationToTable(std::move(plan)));
std::cout << res->ToString() << std::endl;
ASSERT_OK_AND_ASSIGN(auto result, DeclarationToExecBatches(std::move(plan)));
auto exp_batches = {ExecBatchFromJSON({val_type}, {ArgShape::ARRAY}, output_batch)};
AssertExecBatchesEqualIgnoringOrder(result.schema, result.batches, exp_batches);
}

TEST(ExecPlanExecution, ScalarSourceStandaloneNullaryScalarAggSink) {
Expand Down
76 changes: 13 additions & 63 deletions cpp/src/arrow/compute/kernels/aggregate_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,52 +163,38 @@ Result<std::unique_ptr<KernelState>> CountInit(KernelContext*,
// ----------------------------------------------------------------------
// Distinct implementations

/// TODO: like in DistinctCountImpl we have to template this one
/// the reason is handling scalars would need a ArrayBuilder
/// and that need the ArrowType to be passed

/// TODO: also check whether can we take a vector of arrays and concatenate them
/// at the end and find the unique. This could reduce the complicated logic in the merge
/// function.
struct DistinctImpl : public ScalarAggregator {
Status Consume(KernelContext* ctx, const ExecSpan& batch) override {
if (batch[0].is_array()) {
const ArraySpan& input = batch[0].array;
ARROW_ASSIGN_OR_RAISE(auto unique_array, arrow::compute::Unique(input.ToArray()))
this->array = std::move(unique_array);
this->arrays.push_back(input.ToArray());
} else {
/// TODO: complete this feature
return Status::NotImplemented("Distinct aggregate doesn't support scalar values");
const Scalar& input = *batch[0].scalar;
std::shared_ptr<arrow::Array> scalar_array;
ARROW_ASSIGN_OR_RAISE(scalar_array,
arrow::MakeArrayFromScalar(input, 1, ctx->memory_pool()));
this->arrays.push_back(scalar_array);
}
return Status::OK();
}

Status MergeFrom(KernelContext* ctx, KernelState&& src) override {
const auto& other_state = checked_cast<const DistinctImpl&>(src);
auto this_array = this->array;
auto other_array = other_state.array;
if (other_array && this_array) {
ARROW_ASSIGN_OR_RAISE(
auto merged_array,
arrow::Concatenate({this_array, other_array}, ctx->memory_pool()));
this->array = std::move(merged_array);
} else {
auto target_array = other_array ? other_array : this_array;
ARROW_ASSIGN_OR_RAISE(auto unique_array, arrow::compute::Unique(target_array));
this->array = std::move(unique_array);
for (const auto& array : other_state.arrays) {
this->arrays.push_back(array);
}
return Status::OK();
}

Status Finalize(KernelContext* ctx, Datum* out) override {
const auto& state = checked_cast<const DistinctImpl&>(*ctx->state());
if (state.array) {
*out = Datum(state.array);
}
ARROW_ASSIGN_OR_RAISE(auto concatenated,
arrow::Concatenate(this->arrays, ctx->memory_pool()));
ARROW_ASSIGN_OR_RAISE(auto unique_array, arrow::compute::Unique(concatenated));
*out = Datum(unique_array);
return Status::OK();
}

std::shared_ptr<Array> array;
std::vector<std::shared_ptr<Array>> arrays;
};

Result<std::unique_ptr<KernelState>> DistinctInit(KernelContext*,
Expand Down Expand Up @@ -236,42 +222,6 @@ void AddDistinctKernel(const std::vector<std::shared_ptr<DataType>> types,
}
}

// void AddDistinctKernels(ScalarAggregateFunction* func) {
// // Boolean
// AddDistinctKernel(boolean(), func);
// // Number
// AddDistinctKernel(int8(), func);
// AddDistinctKernel(int16(), func);
// AddDistinctKernel(int32(), func);
// AddDistinctKernel(int64(), func);
// AddDistinctKernel(uint8(), func);
// AddDistinctKernel(uint16(), func);
// AddDistinctKernel(uint32(), func);
// AddDistinctKernel(uint64(), func);
// AddDistinctKernel(float16(), func);
// AddDistinctKernel(float32(), func);
// AddDistinctKernel(float64(), func);
// // Date
// AddDistinctKernel(date32(), func);
// AddDistinctKernel(date64(), func);
// // Time
// AddDistinctKernel(match::SameTypeId(Type::TIME32), func);
// AddDistinctKernel(match::SameTypeId(Type::TIME64), func);
// // // Timestamp
// AddDistinctKernel(match::SameTypeId(Type::TIMESTAMP), func);
// // Duration
// AddDistinctKernel(match::SameTypeId(Type::DURATION), func);
// // // Interval
// AddDistinctKernel(month_interval(), func);
// AddDistinctKernel(day_time_interval(), func);
// AddDistinctKernel(month_day_nano_interval(), func);
// // Binary & String
// AddDistinctKernel(match::BinaryLike(), func);
// AddDistinctKernel(match::LargeBinaryLike(), func);
// // Fixed binary & Decimal
// AddDistinctKernel(match::FixedSizeBinaryLike(), func);
// }

void AddDistinctKernels(ScalarAggregateFunction* func) {
AddDistinctKernel({null(), boolean()}, func);
AddDistinctKernel(NumericTypes(), func);
Expand Down

0 comments on commit b18056e

Please sign in to comment.