Skip to content

Commit

Permalink
[2/2] support multi-distinct (apache#30)
Browse files Browse the repository at this point in the history
  • Loading branch information
stdpain authored and HappenLee committed Jul 1, 2021
1 parent 6174ec4 commit c841b63
Show file tree
Hide file tree
Showing 10 changed files with 157 additions and 22 deletions.
22 changes: 14 additions & 8 deletions be/src/vec/aggregate_functions/aggregate_function_count.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,33 @@
// specific language governing permissions and limitations
// under the License.

#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
#include <vec/aggregate_functions/aggregate_function_count.h>
#include <vec/aggregate_functions/factory_helpers.h>

#include "vec/aggregate_functions/aggregate_function_simple_factory.h"

namespace doris::vectorized {

AggregateFunctionPtr createAggregateFunctionCount(const std::string & name, const DataTypes & argument_types, const Array & parameters) {
AggregateFunctionPtr createAggregateFunctionCount(const std::string& name,
const DataTypes& argument_types,
const Array& parameters) {
assertNoParameters(name, parameters);
assertArityAtMost<1>(name, argument_types);

return std::make_shared<AggregateFunctionCount>(argument_types);
}

// void registerAggregateFunctionCount(AggregateFunctionFactory & factory)
// {
// factory.registerFunction("count", createAggregateFunctionCount, AggregateFunctionFactory::CaseInsensitive);
// }
AggregateFunctionPtr createAggregateFunctionCountNotNullUnary(const std::string& name,
const DataTypes& argument_types,
const Array& parameters) {
assertArityAtMost<1>(name, argument_types);

return std::make_shared<AggregateFunctionCountNotNullUnary>(argument_types);
}

void registerAggregateFunctionCount(AggregateFunctionSimpleFactory& factory) {
factory.registerFunction("count", createAggregateFunctionCount);
factory.registerFunction("count", createAggregateFunctionCountNotNullUnary, true);
}

} // namespace

} // namespace doris::vectorized
3 changes: 3 additions & 0 deletions be/src/vec/aggregate_functions/aggregate_function_count.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ class AggregateFunctionCountNotNullUnary final
ErrorCodes::LOGICAL_ERROR);
}

AggregateFunctionCountNotNullUnary(const DataTypes& argument_types_)
: IAggregateFunctionDataHelper(argument_types_, {}) {}

String getName() const override { return "count"; }

DataTypePtr getReturnType() const override { return std::make_shared<DataTypeInt64>(); }
Expand Down
5 changes: 1 addition & 4 deletions be/src/vec/aggregate_functions/aggregate_function_null.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,7 @@ void registerAggregateFunctionCombinatorNull(AggregateFunctionSimpleFactory& fac
auto nested_function = factory.get(name, transformArguments, params);
return function_combinator->transformAggregateFunction(nested_function, types, params);
};
factory.registerFunction("sum", creator, true);
factory.registerFunction("max", creator, true);
factory.registerFunction("min", creator, true);
factory.registerFunction("avg", creator, true);
factory.registerNullableFunctionCombinator(creator);
}

} // namespace doris::vectorized
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ void registerAggregateFunctionSum(AggregateFunctionSimpleFactory& factory);
void registerAggregateFunctionCombinatorNull(AggregateFunctionSimpleFactory& factory);
void registerAggregateFunctionMinMax(AggregateFunctionSimpleFactory& factory);
void registerAggregateFunctionAvg(AggregateFunctionSimpleFactory& factory);
void registerAggregateFunctionCount(AggregateFunctionSimpleFactory& factory);
void registerAggregateFunctionsUniq(AggregateFunctionSimpleFactory& factory);

using DataTypePtr = std::shared_ptr<const IDataType>;
using DataTypes = std::vector<DataTypePtr>;
Expand All @@ -53,7 +55,9 @@ class AggregateFunctionSimpleFactory {
public:
void registerNullableFunctionCombinator(Creator creator) {
for (auto entity : aggregate_functions) {
nullable_aggregate_functions[entity.first] = creator;
if (nullable_aggregate_functions[entity.first] == nullptr) {
nullable_aggregate_functions[entity.first] = creator;
}
}
}

Expand Down Expand Up @@ -92,6 +96,8 @@ class AggregateFunctionSimpleFactory {
registerAggregateFunctionSum(instance);
registerAggregateFunctionMinMax(instance);
registerAggregateFunctionAvg(instance);
registerAggregateFunctionCount(instance);
registerAggregateFunctionsUniq(instance);
registerAggregateFunctionCombinatorNull(instance);
});
return instance;
Expand Down
2 changes: 1 addition & 1 deletion be/src/vec/aggregate_functions/aggregate_function_uniq.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ void registerAggregateFunctionsUniq(AggregateFunctionSimpleFactory& factory) {
AggregateFunctionCreator creator =
createAggregateFunctionUniq<AggregateFunctionUniqExactData,
AggregateFunctionUniqExactData<String>>;
factory.registerFunction("uniqExact", creator);
factory.registerFunction("multi_distinct_count", creator);
}

} // namespace doris::vectorized
8 changes: 4 additions & 4 deletions be/src/vec/aggregate_functions/aggregate_function_uniq.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class AggregateFunctionUniq final

String getName() const override { return Data::getName(); }

DataTypePtr getReturnType() const override { return std::make_shared<DataTypeUInt64>(); }
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeInt64>(); }

void add(AggregateDataPtr place, const IColumn** columns, size_t row_num,
Arena*) const override {
Expand All @@ -104,15 +104,15 @@ class AggregateFunctionUniq final
}

void serialize(ConstAggregateDataPtr place, std::ostream& buf) const override {
// this->data(place).set.write(buf);
this->data(place).set.write(buf);
}

void deserialize(AggregateDataPtr place, std::istream& buf, Arena*) const override {
// this->data(place).set.read(buf);
this->data(place).set.read(buf);
}

void insertResultInto(ConstAggregateDataPtr place, IColumn& to) const override {
assert_cast<ColumnUInt64&>(to).getData().push_back(this->data(place).set.size());
assert_cast<ColumnInt64&>(to).getData().push_back(this->data(place).set.size());
}

const char* getHeaderFilePath() const override { return __FILE__; }
Expand Down
29 changes: 29 additions & 0 deletions be/src/vec/exec/vaggregation_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "exec/exec_node.h"
#include "runtime/mem_pool.h"
#include "runtime/row_batch.h"
#include "util/defer_op.h"
#include "vec/core/block.h"
#include "vec/data_types/data_type_nullable.h"
#include "vec/exprs/vexpr.h"
Expand Down Expand Up @@ -137,6 +138,7 @@ Status AggregationNode::prepare(RuntimeState* state) {
std::placeholders::_3);
}

_executor.close = std::bind<void>(&AggregationNode::_close_without_key, this);
} else {
_agg_data.init(AggregatedDataVariants::Type::serialized);
if (_is_merge) {
Expand All @@ -156,6 +158,7 @@ Status AggregationNode::prepare(RuntimeState* state) {
&AggregationNode::_serialize_with_serialized_key_result, this,
std::placeholders::_1, std::placeholders::_2, std::placeholders::_3);
}
_executor.close = std::bind<void>(&AggregationNode::_close_with_serialized_key, this);
}

return Status::OK();
Expand Down Expand Up @@ -201,6 +204,7 @@ Status AggregationNode::get_next(RuntimeState* state, Block* block, bool* eos) {
Status AggregationNode::close(RuntimeState* state) {
RETURN_IF_ERROR(ExecNode::close(state));
VExpr::close(_probe_expr_ctxs, state);
_executor.close();
return Status::OK();
}

Expand All @@ -211,6 +215,13 @@ Status AggregationNode::_create_agg_status(AggregateDataPtr data) {
return Status::OK();
}

Status AggregationNode::_destory_agg_status(AggregateDataPtr data) {
for (int i = 0; i < _aggregate_evaluators.size(); ++i) {
_aggregate_evaluators[i]->function()->destroy(data + _offsets_of_aggregate_states[i]);
}
return Status::OK();
}

Status AggregationNode::_get_without_key_result(RuntimeState* state, Block* block, bool* eos) {
DCHECK(_agg_data.without_key != nullptr);

Expand Down Expand Up @@ -300,11 +311,14 @@ Status AggregationNode::_merge_without_key(Block* block) {
DCHECK(_agg_data.without_key != nullptr);
std::unique_ptr<char[]> deserialize_buffer(new char[_total_size_of_aggregate_states]);
int rows = block->rows();
_create_agg_status(deserialize_buffer.get());
DeferOp defer([&]() { _destory_agg_status(deserialize_buffer.get()); });
for (int i = 0; i < _aggregate_evaluators.size(); ++i) {
auto column = block->getByPosition(i).column;
if (column->isNullable()) {
column = ((ColumnNullable*)column.get())->getNestedColumnPtr();
}

for (int j = 0; j < rows; ++j) {
std::string data_buffer;
StringRef ref = column->getDataAt(j);
Expand All @@ -323,6 +337,10 @@ Status AggregationNode::_merge_without_key(Block* block) {
return Status::OK();
}

void AggregationNode::_close_without_key() {
_destory_agg_status(_agg_data.without_key);
}

Status AggregationNode::_execute_with_serialized_key(Block* block) {
DCHECK(!_probe_expr_ctxs.empty());
// now we only support serialized key
Expand Down Expand Up @@ -567,4 +585,15 @@ Status AggregationNode::_merge_with_serialized_key(Block* block) {
return Status::OK();
}

void AggregationNode::_close_with_serialized_key() {
DCHECK(_agg_data.serialized != nullptr);

using Method = AggregationMethodSerialized<AggregatedDataWithStringKey>;
using AggState = Method::State;

auto& data = _agg_data.serialized->data;

data.forEachValue([&](const auto& key, auto& mapped) { _destory_agg_status(mapped); });
}

} // namespace doris::vectorized
5 changes: 5 additions & 0 deletions be/src/vec/exec/vaggregation_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,24 +149,29 @@ class AggregationNode : public ::doris::ExecNode {

private:
Status _create_agg_status(AggregateDataPtr data);
Status _destory_agg_status(AggregateDataPtr data);

Status _get_without_key_result(RuntimeState* state, Block* block, bool* eos);
Status _serialize_without_key(RuntimeState* state, Block* block, bool* eos);
Status _execute_without_key(Block* block);
Status _merge_without_key(Block* block);
void _close_without_key();

Status _get_with_serialized_key_result(RuntimeState* state, Block* block, bool* eos);
Status _serialize_with_serialized_key_result(RuntimeState* state, Block* block, bool* eos);
Status _execute_with_serialized_key(Block* block);
Status _merge_with_serialized_key(Block* block);
void _close_with_serialized_key();

using vectorized_execute = std::function<Status(Block* block)>;
using vectorized_get_result =
std::function<Status(RuntimeState* state, Block* block, bool* eos)>;
using vectorized_closer = std::function<void()>;

struct executor {
vectorized_execute execute;
vectorized_get_result get_result;
vectorized_closer close;
};

executor _executor;
Expand Down
7 changes: 3 additions & 4 deletions be/src/vec/exprs/vectorized_agg_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ Status AggFnEvaluator::prepare(RuntimeState* state, const RowDescriptor& desc, M
argument_types.emplace_back(_input_exprs_ctxs[i]->root()->data_type());
child_expr_name.emplace_back(_input_exprs_ctxs[i]->root()->expr_name());
}

_function = AggregateFunctionSimpleFactory::instance().get(_fn.name.function_name,
argument_types, params);
if (_function == nullptr) {
Expand Down Expand Up @@ -103,15 +102,15 @@ void AggFnEvaluator::execute_single_add(Block* block, AggregateDataPtr place, Ar
for (int i = 0; i < _input_exprs_ctxs.size(); ++i) {
int column_id = -1;
_input_exprs_ctxs[i]->execute(block, &column_id);
columns[i] =
block->getByPosition(column_id).column->convertToFullColumnIfConst();
columns[i] = block->getByPosition(column_id).column->convertToFullColumnIfConst();
}
// Because the `convertToFullColumnIfConst()` may return a temporary variable, so we need keep the reference of it
// to make sure program do not destroy it before we call `addBatchSinglePlace`.
// WARNING:
// There's danger to call `convertToFullColumnIfConst().get()` to get the `const IColumn*` directly.
std::vector<const IColumn*> column_arguments(columns.size());
std::transform(columns.cbegin(), columns.cend(), column_arguments.begin(), [](const auto& ptr) {return ptr.get();});
std::transform(columns.cbegin(), columns.cend(), column_arguments.begin(),
[](const auto& ptr) { return ptr.get(); });
_function->addBatchSinglePlace(block->rows(), place, column_arguments.data(), nullptr);
}

Expand Down
Loading

0 comments on commit c841b63

Please sign in to comment.