Skip to content

Commit

Permalink
support multi distinct sum/min/max (apache#37)
Browse files Browse the repository at this point in the history
  • Loading branch information
stdpain authored and HappenLee committed Jul 1, 2021
1 parent 6aa454c commit cd795ac
Show file tree
Hide file tree
Showing 8 changed files with 363 additions and 0 deletions.
1 change: 1 addition & 0 deletions be/src/vec/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ set(EXECUTABLE_OUTPUT_PATH "${BUILD_DIR}/src/vec")
set(VEC_FILES
aggregate_functions/aggregate_function_avg.cpp
aggregate_functions/aggregate_function_count.cpp
aggregate_functions/aggregate_function_distinct.cpp
aggregate_functions/aggregate_function_sum.cpp
aggregate_functions/aggregate_function_min_max.cpp
aggregate_functions/aggregate_function_null.cpp
Expand Down
79 changes: 79 additions & 0 deletions be/src/vec/aggregate_functions/aggregate_function_distinct.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#include "vec/aggregate_functions/aggregate_function_distinct.h"

#include <algorithm>

#include "boost/algorithm/string.hpp"
#include "vec/aggregate_functions/aggregate_function_combinator.h"
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
#include "vec/aggregate_functions/helpers.h"
#include "vec/common/typeid_cast.h"
#include "vec/data_types/data_type_nullable.h"
// #include "registerAggregateFunctions.h"

namespace doris::vectorized {
namespace ErrorCodes {
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}

class AggregateFunctionCombinatorDistinct final : public IAggregateFunctionCombinator {
public:
String getName() const override { return "Distinct"; }

DataTypes transformArguments(const DataTypes& arguments) const override {
if (arguments.empty())
throw Exception(
"Incorrect number of arguments for aggregate function with Distinct suffix",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);

return arguments;
}

AggregateFunctionPtr transformAggregateFunction(const AggregateFunctionPtr& nested_function,
const DataTypes& arguments,
const Array& params) const override {
AggregateFunctionPtr res;
if (arguments.size() == 1) {
res.reset(createWithNumericType<AggregateFunctionDistinct,
AggregateFunctionDistinctSingleNumericData>(
*arguments[0], nested_function, arguments));

if (res) return res;

if (arguments[0]->isValueUnambiguouslyRepresentedInContiguousMemoryRegion())
return std::make_shared<AggregateFunctionDistinct<
AggregateFunctionDistinctSingleGenericData<true>>>(nested_function,
arguments);
else
return std::make_shared<AggregateFunctionDistinct<
AggregateFunctionDistinctSingleGenericData<false>>>(nested_function,
arguments);
}

return std::make_shared<
AggregateFunctionDistinct<AggregateFunctionDistinctMultipleGenericData>>(
nested_function, arguments);
}
};

const std::string DISTINCT_FUNCTION_PREFIX = "multi_distinct_";

void registerAggregateFunctionCombinatorDistinct(AggregateFunctionSimpleFactory& factory) {
AggregateFunctionCreator creator = [&](const std::string& name, const DataTypes& types,
const Array& params) {
// 1. we should get not nullable types;
DataTypes nested_types(types.size());
std::transform(types.begin(), types.end(), nested_types.begin(),
[](const auto& e) { return removeNullable(e); });
auto function_combinator = std::make_shared<AggregateFunctionCombinatorDistinct>();
auto transformArguments = function_combinator->transformArguments(nested_types);
if (!boost::algorithm::starts_with(name, DISTINCT_FUNCTION_PREFIX)) {
return AggregateFunctionPtr();
}
auto nested_function_name = name.substr(DISTINCT_FUNCTION_PREFIX.size());
auto nested_function = factory.get(nested_function_name, transformArguments, params);
return function_combinator->transformAggregateFunction(nested_function, types, params);
};
factory.registerDistinctFunctionCombinator(creator, DISTINCT_FUNCTION_PREFIX);
// factory.registerCombinator(std::make_shared<AggregateFunctionCombinatorDistinct>());
}
} // namespace doris::vectorized
203 changes: 203 additions & 0 deletions be/src/vec/aggregate_functions/aggregate_function_distinct.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
#pragma once

#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/aggregate_functions/key_holder_helpers.h"

// #include <Columns/ColumnArray.h>
// #include <DataTypes/DataTypeArray.h>
#include "vec/common/aggregation_common.h"
#include "vec/common/assert_cast.h"
#include "vec/common/field_visitors.h"
#include "vec/common/hash_table/hash_set.h"
#include "vec/common/hash_table/hash_table.h"
#include "vec/common/sip_hash.h"
#include "vec/io/io_helper.h"

namespace doris::vectorized {

template <typename T>
struct AggregateFunctionDistinctSingleNumericData {
/// When creating, the hash table must be small.
using Set = HashSetWithStackMemory<T, DefaultHash<T>, 4>;
using Self = AggregateFunctionDistinctSingleNumericData<T>;
Set set;

void add(const IColumn** columns, size_t /* columns_num */, size_t row_num, Arena*) {
const auto& vec = assert_cast<const ColumnVector<T>&>(*columns[0]).getData();
set.insert(vec[row_num]);
}

void merge(const Self& rhs, Arena*) { set.merge(rhs.set); }

void serialize(std::ostream& buf) const { set.write(buf); }

void deserialize(std::istream& buf, Arena*) { set.read(buf); }

MutableColumns getArguments(const DataTypes& argument_types) const {
MutableColumns argument_columns;
argument_columns.emplace_back(argument_types[0]->createColumn());
for (const auto& elem : set) argument_columns[0]->insert(elem.getValue());

return argument_columns;
}
};

struct AggregateFunctionDistinctGenericData {
/// When creating, the hash table must be small.
using Set = HashSetWithSavedHashWithStackMemory<StringRef, StringRefHash, 4>;
using Self = AggregateFunctionDistinctGenericData;
Set set;

void merge(const Self& rhs, Arena* arena) {
Set::LookupResult it;
bool inserted;
for (const auto& elem : rhs.set)
set.emplace(ArenaKeyHolder{elem.getValue(), *arena}, it, inserted);
}

void serialize(std::ostream& buf) const {
writeVarUInt(set.size(), buf);
for (const auto& elem : set) writeStringBinary(elem.getValue(), buf);
}

void deserialize(std::istream& buf, Arena* arena) {
size_t size;
readVarUInt(size, buf);
for (size_t i = 0; i < size; ++i) set.insert(readStringBinaryInto(*arena, buf));
}
};

template <bool is_plain_column>
struct AggregateFunctionDistinctSingleGenericData : public AggregateFunctionDistinctGenericData {
void add(const IColumn** columns, size_t /* columns_num */, size_t row_num, Arena* arena) {
Set::LookupResult it;
bool inserted;
auto key_holder = getKeyHolder<is_plain_column>(*columns[0], row_num, *arena);
set.emplace(key_holder, it, inserted);
}

MutableColumns getArguments(const DataTypes& argument_types) const {
MutableColumns argument_columns;
argument_columns.emplace_back(argument_types[0]->createColumn());
for (const auto& elem : set)
deserializeAndInsert<is_plain_column>(elem.getValue(), *argument_columns[0]);

return argument_columns;
}
};

struct AggregateFunctionDistinctMultipleGenericData : public AggregateFunctionDistinctGenericData {
void add(const IColumn** columns, size_t columns_num, size_t row_num, Arena* arena) {
const char* begin = nullptr;
StringRef value(begin, 0);
for (size_t i = 0; i < columns_num; ++i) {
auto cur_ref = columns[i]->serializeValueIntoArena(row_num, *arena, begin);
value.data = cur_ref.data - value.size;
value.size += cur_ref.size;
}

Set::LookupResult it;
bool inserted;
auto key_holder = SerializedKeyHolder{value, *arena};
set.emplace(key_holder, it, inserted);
}

MutableColumns getArguments(const DataTypes& argument_types) const {
MutableColumns argument_columns(argument_types.size());
for (size_t i = 0; i < argument_types.size(); ++i)
argument_columns[i] = argument_types[i]->createColumn();

for (const auto& elem : set) {
const char* begin = elem.getValue().data;
for (auto& column : argument_columns)
begin = column->deserializeAndInsertFromArena(begin);
}

return argument_columns;
}
};

/** Adaptor for aggregate functions.
* Adding -Distinct suffix to aggregate function
**/
template <typename Data>
class AggregateFunctionDistinct
: public IAggregateFunctionDataHelper<Data, AggregateFunctionDistinct<Data>> {
private:
static constexpr auto prefix_size = sizeof(Data);
AggregateFunctionPtr nested_func;
size_t arguments_num;

AggregateDataPtr getNestedPlace(AggregateDataPtr __restrict place) const noexcept {
return place + prefix_size;
}

ConstAggregateDataPtr getNestedPlace(ConstAggregateDataPtr __restrict place) const noexcept {
return place + prefix_size;
}

public:
AggregateFunctionDistinct(AggregateFunctionPtr nested_func_, const DataTypes& arguments)
: IAggregateFunctionDataHelper<Data, AggregateFunctionDistinct>(
arguments, nested_func_->getParameters()),
nested_func(nested_func_),
arguments_num(arguments.size()) {}

void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num,
Arena* arena) const override {
this->data(place).add(columns, arguments_num, row_num, arena);
}

void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
Arena* arena) const override {
this->data(place).merge(this->data(rhs), arena);
}

void serialize(ConstAggregateDataPtr place, std::ostream& buf) const override {
this->data(place).serialize(buf);
}

void deserialize(AggregateDataPtr place, std::istream& buf, Arena* arena) const override {
this->data(place).deserialize(buf, arena);
}

// void insertResultInto(AggregateDataPtr place, IColumn & to, Arena * arena) const override
void insertResultInto(ConstAggregateDataPtr targetplace, IColumn& to) const override {
auto place = const_cast<AggregateDataPtr>(targetplace);
auto arguments = this->data(place).getArguments(this->argument_types);
ColumnRawPtrs arguments_raw(arguments.size());
for (size_t i = 0; i < arguments.size(); ++i) arguments_raw[i] = arguments[i].get();

assert(!arguments.empty());
// nested_func->addBatchSinglePlace(arguments[0]->size(), getNestedPlace(place), arguments_raw.data(), arena);
// nested_func->insertResultInto(getNestedPlace(place), to, arena);

nested_func->addBatchSinglePlace(arguments[0]->size(), getNestedPlace(place),
arguments_raw.data(), nullptr);
nested_func->insertResultInto(getNestedPlace(place), to);
}

size_t sizeOfData() const override { return prefix_size + nested_func->sizeOfData(); }

void create(AggregateDataPtr place) const override {
new (place) Data;
nested_func->create(getNestedPlace(place));
}

void destroy(AggregateDataPtr place) const noexcept override {
this->data(place).~Data();
nested_func->destroy(getNestedPlace(place));
}

String getName() const override { return nested_func->getName() + "Distinct"; }

DataTypePtr getReturnType() const override { return nested_func->getReturnType(); }

bool allocatesMemoryInArena() const override { return true; }

const char* getHeaderFilePath() const override { return __FILE__; }

// AggregateFunctionPtr getNestedFunction() const override { return nested_func; }
};

} // namespace doris::vectorized
15 changes: 15 additions & 0 deletions be/src/vec/aggregate_functions/aggregate_function_simple_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ void registerAggregateFunctionMinMax(AggregateFunctionSimpleFactory& factory);
void registerAggregateFunctionAvg(AggregateFunctionSimpleFactory& factory);
void registerAggregateFunctionCount(AggregateFunctionSimpleFactory& factory);
void registerAggregateFunctionsUniq(AggregateFunctionSimpleFactory& factory);
void registerAggregateFunctionCombinatorDistinct(AggregateFunctionSimpleFactory& factory);

using DataTypePtr = std::shared_ptr<const IDataType>;
using DataTypes = std::vector<DataTypePtr>;
Expand All @@ -61,6 +62,19 @@ class AggregateFunctionSimpleFactory {
}
}

void registerDistinctFunctionCombinator(Creator creator, const std::string& prefix) {
std::vector<std::string> need_insert;
for (auto entity : aggregate_functions) {
std::string target_value = prefix + entity.first;
if (aggregate_functions[target_value] == nullptr) {
need_insert.emplace_back(std::move(target_value));
}
}
for (const auto& function_name : need_insert) {
aggregate_functions[function_name] = creator;
}
}

void registerFunction(const std::string& name, Creator creator, bool nullable = false) {
if (nullable) {
nullable_aggregate_functions[name] = creator;
Expand Down Expand Up @@ -98,6 +112,7 @@ class AggregateFunctionSimpleFactory {
registerAggregateFunctionAvg(instance);
registerAggregateFunctionCount(instance);
registerAggregateFunctionsUniq(instance);
registerAggregateFunctionCombinatorDistinct(instance);
registerAggregateFunctionCombinatorNull(instance);
});
return instance;
Expand Down
14 changes: 14 additions & 0 deletions be/src/vec/aggregate_functions/helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,20 @@ static IAggregateFunction* createWithNumericType(const IDataType& argument_type,
return nullptr;
}

template <template <typename> class AggregateFunctionTemplate, template <typename> class Data,
typename... TArgs>
static IAggregateFunction* createWithNumericType(const IDataType& argument_type, TArgs&&... args) {
WhichDataType which(argument_type);
#define DISPATCH(TYPE) \
if (which.idx == TypeIndex::TYPE) \
return new AggregateFunctionTemplate<Data<TYPE>>(std::forward<TArgs>(args)...);
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
// if (which.idx == TypeIndex::Enum8) return new AggregateFunctionTemplate<Data<Int8>>(std::forward<TArgs>(args)...);
// if (which.idx == TypeIndex::Enum16) return new AggregateFunctionTemplate<Data<Int16>>(std::forward<TArgs>(args)...);
return nullptr;
}

template <template <typename, typename> class AggregateFunctionTemplate,
template <typename> class Data, typename... TArgs>
static IAggregateFunction* createWithUnsignedIntegerType(const IDataType& argument_type,
Expand Down
28 changes: 28 additions & 0 deletions be/src/vec/aggregate_functions/key_holder_helpers.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#pragma once

#include "vec/columns/column.h"
#include "vec/common/hash_table/hash_table_key_holder.h"

namespace doris::vectorized {

template <bool is_plain_column = false>
static auto getKeyHolder(const IColumn& column, size_t row_num, Arena& arena) {
if constexpr (is_plain_column) {
return ArenaKeyHolder{column.getDataAt(row_num), arena};
} else {
const char* begin = nullptr;
StringRef serialized = column.serializeValueIntoArena(row_num, arena, begin);
assert(serialized.data != nullptr);
return SerializedKeyHolder{serialized, arena};
}
}

template <bool is_plain_column>
static void deserializeAndInsert(StringRef str, IColumn& data_to) {
if constexpr (is_plain_column)
data_to.insertData(str.data, str.size);
else
data_to.deserializeAndInsertFromArena(str.data);
}

} // namespace doris::vectorized
12 changes: 12 additions & 0 deletions be/src/vec/common/hash_table/hash_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,15 @@ template <typename Key, typename Hash = DefaultHash<Key>, typename Grower = Hash
typename Allocator = HashTableAllocator>
using HashSetWithSavedHash =
HashSetTable<Key, HashSetCellWithSavedHash<Key, Hash>, Hash, Grower, Allocator>;

template <typename Key, typename Hash, size_t initial_size_degree>
using HashSetWithStackMemory =
HashSet<Key, Hash, HashTableGrower<initial_size_degree>,
HashTableAllocatorWithStackMemory<(1ULL << initial_size_degree) *
sizeof(HashTableCell<Key, Hash>)>>;

template <typename Key, typename Hash, size_t initial_size_degree>
using HashSetWithSavedHashWithStackMemory = HashSetWithSavedHash<
Key, Hash, HashTableGrower<initial_size_degree>,
HashTableAllocatorWithStackMemory<(1ULL << initial_size_degree) *
sizeof(HashSetCellWithSavedHash<Key, Hash>)>>;
Loading

0 comments on commit cd795ac

Please sign in to comment.