Skip to content

Commit

Permalink
fix:initial
Browse files Browse the repository at this point in the history
  • Loading branch information
vibhatha committed Jul 13, 2023
1 parent ec84f3e commit 81c2b98
Show file tree
Hide file tree
Showing 6 changed files with 195 additions and 5 deletions.
36 changes: 36 additions & 0 deletions cpp/src/arrow/acero/plan_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,16 @@
#include "arrow/testing/gtest_util.h"
#include "arrow/testing/matchers.h"
#include "arrow/testing/random.h"
#include "arrow/type_fwd.h"
#include "arrow/util/async_generator.h"
#include "arrow/util/logging.h"
#include "arrow/util/macros.h"
#include "arrow/util/thread_pool.h"
#include "arrow/util/vector.h"

// TODO: remove after debug
#include "arrow/api.h"

using testing::Contains;
using testing::ElementsAre;
using testing::ElementsAreArray;
Expand Down Expand Up @@ -1440,6 +1444,38 @@ TEST(ExecPlanExecution, ScalarSourceScalarAggSink) {
AssertExecBatchesEqualIgnoringOrder(result.schema, result.batches, exp_batches);
}

TEST(ExecPlanExecution, ScalarSourceScalarDistinctAggSink) {
//////
BatchesWithSchema scalar_data;
scalar_data.batches = {
ExecBatchFromJSON({int32(), boolean(), arrow::timestamp(arrow::TimeUnit::SECOND), arrow::time32(arrow::TimeUnit::SECOND)}, {ArgShape::ARRAY, ArgShape::SCALAR, ArgShape::ARRAY, ArgShape::ARRAY},
"[[5, false, 1609459200, 1672444800], [5, false, 1609545600, 1679884800], [5, false, 1609632000, 1672444800]]"),
ExecBatchFromJSON({int32(), boolean(), arrow::timestamp(arrow::TimeUnit::SECOND), arrow::time32(arrow::TimeUnit::SECOND)}, "[[5, true, 1609545600, 1679884800], [6, false, 1609459200, 1690406400], [7, true, 1609459200, 1704067200]]")
};
scalar_data.schema = schema({field("a", int32()), field("b", boolean()),
field("c", arrow::timestamp(arrow::TimeUnit::SECOND)), field("d", arrow::time32(arrow::TimeUnit::SECOND))});
std::cout << "Data OK" << std::endl;
// index can't be tested as it's order-dependent
// mode/quantile can't be tested as they're technically vector kernels
Declaration plan = Declaration::Sequence(
{{"source", SourceNodeOptions{scalar_data.schema,
scalar_data.gen(/*parallel=*/false, /*slow=*/false)}},
{"aggregate", AggregateNodeOptions{
/*aggregates=*/{{"distinct", nullptr, "d", "distinct(b)"}}}}});

// auto exp_batches = {
// ExecBatchFromJSON(
// {boolean(), boolean(), int64(), int64(), float64(), int64(), float64(), int64(),
// float64(), float64()},
// {ArgShape::SCALAR, ArgShape::SCALAR, ArgShape::SCALAR, ArgShape::SCALAR,
// ArgShape::SCALAR, ArgShape::SCALAR, ArgShape::SCALAR, ArgShape::SCALAR,
// ArgShape::ARRAY, ArgShape::SCALAR},
// R"([[false, true, 6, 6, 5.5, 26250, 0.7637626158259734, 33, 5.0, 0.5833333333333334]])"),
// };
ASSERT_OK_AND_ASSIGN(auto res, DeclarationToTable(std::move(plan)));
std::cout << res->ToString() << std::endl;
}

TEST(ExecPlanExecution, ScalarSourceStandaloneNullaryScalarAggSink) {
BatchesWithSchema scalar_data;
scalar_data.batches = {
Expand Down
129 changes: 129 additions & 0 deletions cpp/src/arrow/compute/kernels/aggregate_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,22 @@
// specific language governing permissions and limitations
// under the License.

#include "arrow/array/concatenate.h"
#include "arrow/compute/api_aggregate.h"
#include "arrow/compute/api_vector.h"
#include "arrow/compute/kernels/aggregate_basic_internal.h"
#include "arrow/compute/kernels/aggregate_internal.h"
#include "arrow/compute/kernels/common_internal.h"
#include "arrow/compute/kernels/util_internal.h"
#include "arrow/result.h"
#include "arrow/util/cpu_info.h"
#include "arrow/util/hashing.h"

#include <memory>

// TODO: remove after debug
#include<iostream>

namespace arrow {
namespace compute {
namespace internal {
Expand Down Expand Up @@ -150,6 +156,121 @@ Result<std::unique_ptr<KernelState>> CountInit(KernelContext*,
return std::make_unique<CountImpl>(static_cast<const CountOptions&>(*args.options));
}

// ----------------------------------------------------------------------
// Distinct implementations

/// TODO: like in DistinctCountImpl we have to template this one
/// the reason is handling scalars would need a ArrayBuilder
/// and that need the ArrowType to be passed

/// TODO: also check whether can we take a vector of arrays and concatenate them
/// at the end and find the unique. This could reduce the complicated logic in the merge
/// function.
struct DistinctImpl : public ScalarAggregator {
explicit DistinctImpl() = default ;

Status Consume(KernelContext*, const ExecSpan& batch) override {

if (batch[0].is_array()) {
std::cout << "Array" << std::endl;
const ArraySpan& input = batch[0].array;
ARROW_ASSIGN_OR_RAISE(auto unique_array, arrow::compute::Unique(input.ToArray()))
std::cout << "unique array " << unique_array->ToString() << std::endl;
this->array = std::move(unique_array);
std::cout << "this array " << this->array->ToString() << std::endl;
} else {
std::cout << "Scalar" << std::endl;
const Scalar& input = *batch[0].scalar;
std::cout << input.ToString() << std::endl;
return Status::NotImplemented("Scalar support not implemented");
}
return Status::OK();
}

Status MergeFrom(KernelContext* ctx, KernelState&& src) override {
const auto& other_state = checked_cast<const DistinctImpl&>(src);
auto this_array = this->array;
auto other_array = other_state.array;
if(other_array && this_array) {
ARROW_ASSIGN_OR_RAISE(auto merged_array, arrow::Concatenate({this_array, other_array},ctx->memory_pool()));
this->array = std::move(merged_array);
std::cout << "Merge Array(other_array && this_array): " << this->array->ToString() << std::endl;
} else if(other_array) {
ARROW_ASSIGN_OR_RAISE(auto unique_array, arrow::compute::Unique(other_array))
this->array = std::move(unique_array);
std::cout << "Merge Array(other_array): " << this->array->ToString() << std::endl;
} else if(this_array) {
ARROW_ASSIGN_OR_RAISE(auto unique_array, arrow::compute::Unique(this_array))
this->array = std::move(unique_array);
std::cout << "Merge Array(this_array): " << this->array->ToString() << std::endl;
}
return Status::OK();
}

Status Finalize(KernelContext* ctx, Datum* out) override {
std::cout << "Finalize" << std::endl;
const auto& state = checked_cast<const DistinctImpl&>(*ctx->state());
//std::cout << "Finalize Array: " << state.array->ToString() << std::endl;
if(state.array) {
*out = Datum(state.array);
}
return Status::OK();
}
std::shared_ptr<Array> array;
};

Result<std::unique_ptr<KernelState>> DistinctInit(KernelContext*,
const KernelInitArgs& args) {
return std::make_unique<DistinctImpl>();
}


void AddDistinctKernel(InputType type, OutputType out_type, ScalarAggregateFunction* func) {
AddAggKernel(KernelSignature::Make({type}, out_type),
DistinctInit, func);

}

void AddDistinctKernels(ScalarAggregateFunction* func) {
// Boolean
AddDistinctKernel(boolean(), boolean(), func);
// Number
AddDistinctKernel(int8(), int8(), func);
AddDistinctKernel(int16(), int16(), func);
AddDistinctKernel(int32(), int32(), func);
AddDistinctKernel(int64(), int64(), func);
AddDistinctKernel(uint8(), uint8(), func);
AddDistinctKernel(uint16(), uint16(), func);
AddDistinctKernel(uint32(), uint32(), func);
AddDistinctKernel(uint64(), uint64(), func);
AddDistinctKernel(float16(), float16(), func);
AddDistinctKernel(float32(), float32(), func);
AddDistinctKernel(float64(), float64(), func);
// Date
AddDistinctKernel(date32(), date32(), func);
AddDistinctKernel(date64(), date64(), func);
// Time
// AddDistinctKernel(match::SameTypeId(Type::TIME32), arrow::time32(arrow::TimeUnit::SECOND), func);
// AddDistinctKernel(match::SameTypeId(Type::TIME32), arrow::time32(arrow::TimeUnit::MILLI), func);
// AddDistinctKernel(match::SameTypeId(Type::TIME64), arrow::time64(arrow::TimeUnit::SECOND), func);
// AddDistinctKernel(match::SameTypeId(Type::TIME64), arrow::time64(arrow::TimeUnit::MILLI), func);
// // Timestamp & Duration
AddDistinctKernel(match::SameTypeId(Type::TIMESTAMP), arrow::timestamp(arrow::TimeUnit::SECOND), func);
AddDistinctKernel(match::SameTypeId(Type::TIMESTAMP), arrow::timestamp(arrow::TimeUnit::MILLI), func);
AddDistinctKernel(match::SameTypeId(Type::TIMESTAMP), arrow::timestamp(arrow::TimeUnit::MICRO), func);
AddDistinctKernel(match::SameTypeId(Type::TIMESTAMP), arrow::timestamp(arrow::TimeUnit::NANO), func);
// AddDistinctKernel(match::SameTypeId(Type::DURATION), match::SameTypeId(Type::DURATION), func);
// // Interval
// AddDistinctKernel(month_interval(), month_interval(), func);
// AddDistinctKernel(day_time_interval(), day_time_interval(), func);
// AddDistinctKernel(month_day_nano_interval(), month_day_nano_interval(), func);
// // Binary & String
// AddDistinctKernel(match::BinaryLike(), match::BinaryLike(), func);
// AddDistinctKernel(match::LargeBinaryLike(), match::LargeBinaryLike(), func);
// // Fixed binary & Decimal
// AddDistinctKernel(match::FixedSizeBinaryLike(), match::FixedSizeBinaryLike(), func);
}

// ----------------------------------------------------------------------
// Distinct Count implementation

Expand Down Expand Up @@ -1012,6 +1133,9 @@ const FunctionDoc index_doc{"Find the index of the first occurrence of a given v
{"array"},
"IndexOptions",
/*options_required=*/true};
const FunctionDoc distinct_doc{"Select unique values",
("All unique values are returned"),
{"array"}};

} // namespace

Expand Down Expand Up @@ -1183,6 +1307,11 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) {
{fixed_size_binary(1), decimal128(1, 0), decimal256(1, 0), null()},
int64(), func.get());
DCHECK_OK(registry->AddFunction(std::move(func)));
// distinct
func = std::make_shared<ScalarAggregateFunction>("distinct", Arity::Unary(), distinct_doc);

AddDistinctKernels(func.get());
DCHECK_OK(registry->AddFunction(std::move(func)));
}

} // namespace internal
Expand Down
1 change: 0 additions & 1 deletion python/pyarrow/acero.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,6 @@ def _sort_source(table_or_dataset, sort_keys, output_type=Table, **kwargs):


def _group_by(table, aggregates, keys):

decl = Declaration.from_sequence([
Declaration("table_source", TableSourceNodeOptions(table)),
Declaration("aggregate", AggregateNodeOptions(aggregates, keys=keys))
Expand Down
8 changes: 7 additions & 1 deletion python/pyarrow/table.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -5294,8 +5294,14 @@ list[tuple(str, str, FunctionOptions)]
# Ensure aggregate function is hash_ if needed
if len(self.keys) > 0 and not func.startswith("hash_"):
func = "hash_" + func
import pyarrow.compute as pc
if len(self.keys) == 0 and func.startswith("hash_"):
func = func[5:]
scalar_func = func[5:]
try:
pc.get_function(scalar_func)
func = scalar_func
except:
pass
# Determine output field name
func_nohash = func if not func.startswith("hash_") else func[5:]
if len(target) == 0:
Expand Down
5 changes: 5 additions & 0 deletions python/pyarrow/tests/test_acero.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,11 @@ def test_aggregate_scalar(table_source):
)
with pytest.raises(ValueError, match="is a hash aggregate function"):
_ = decl.to_table()

aggr_opts = AggregateNodeOptions([("a", "hash_list", None, "a_list")])
decl = Declaration.from_sequence(
[table_source, Declaration("aggregate", aggr_opts)]
)


def test_aggregate_hash():
Expand Down
21 changes: 18 additions & 3 deletions python/pyarrow/tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2426,9 +2426,24 @@ def test_numpy_asarray(constructor):

def test_aggregate_hash_functions():
table = pa.table({'key': ['a', 'a', 'b', 'b', 'a'], 'value': [11, 112, 0, 1, 2]})
res = table.group_by(['key']).aggregate([('value', 'one')])
res = table.group_by([]).aggregate([('value', 'distinct')])
print(res)

aggregates = [(["value"], "hash_list", None, "value_list")]
keys = []

with pytest.raises(pa.lib.ArrowInvalid) as excinfo:
res = table.group_by(keys).aggregate([(['value'], 'hash_list')])
assert "The provided function (hash_list) is a hash aggregate function." in str(excinfo.value)

def test_scalar_aggregate_distinct_functions():
table = pa.table({'key': ['a', 'a', 'b', 'b', 'a'], 'value': [11, 11, 10, 12, 12]})

aggregates = [(["value"], "distinct", None, "value_distinct")]
keys = []

func = pc.get_function("distinct")
print(func)
res = table.group_by(keys).aggregate([(['value'], 'distinct')])




Expand Down

0 comments on commit 81c2b98

Please sign in to comment.