Skip to content

Commit

Permalink
Add collect agg function
Browse files Browse the repository at this point in the history
  • Loading branch information
acquamarin committed Feb 17, 2023
1 parent a879779 commit 6749ec0
Show file tree
Hide file tree
Showing 25 changed files with 422 additions and 107 deletions.
10 changes: 8 additions & 2 deletions src/binder/bind_expression/bind_function_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,14 @@ std::shared_ptr<Expression> ExpressionBinder::bindAggregateFunctionExpression(
if (children.empty()) {
uniqueExpressionName = binder->getUniqueExpressionName(uniqueExpressionName);
}
return make_shared<AggregateFunctionExpression>(DataType(function->returnTypeID),
std::move(children), function->aggregateFunction->clone(), uniqueExpressionName);
DataType returnType;
if (function->bindFunc) {
function->bindFunc(childrenTypes, function, returnType);
} else {
returnType = DataType(function->returnTypeID);
}
return make_shared<AggregateFunctionExpression>(returnType, std::move(children),
function->aggregateFunction->clone(), uniqueExpressionName);
}

std::shared_ptr<Expression> ExpressionBinder::staticEvaluate(const std::string& functionName,
Expand Down
8 changes: 8 additions & 0 deletions src/function/aggregate_function.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "function/aggregate/aggregate_function.h"

#include "function/aggregate/avg.h"
#include "function/aggregate/collect.h"
#include "function/aggregate/count.h"
#include "function/aggregate/count_star.h"
#include "function/aggregate/min_max.h"
Expand Down Expand Up @@ -69,6 +70,13 @@ std::unique_ptr<AggregateFunction> AggregateFunctionUtil::getMaxFunction(
return getMinMaxFunction<operation::GreaterThan>(inputType, isDistinct);
}

std::unique_ptr<AggregateFunction> AggregateFunctionUtil::getCollectFunction(
const common::DataType& inputType, bool isDistinct) {
return std::make_unique<AggregateFunction>(CollectFunction::initialize,
CollectFunction::updateAll, CollectFunction::updatePos, CollectFunction::combine,
CollectFunction::finalize, inputType, isDistinct);
}

template<typename FUNC>
std::unique_ptr<AggregateFunction> AggregateFunctionUtil::getMinMaxFunction(
const DataType& inputType, bool isDistinct) {
Expand Down
32 changes: 23 additions & 9 deletions src/function/built_in_aggregate_functions.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "function/aggregate/built_in_aggregate_functions.h"

#include "function/aggregate/collect.h"

using namespace kuzu::common;

namespace kuzu {
Expand Down Expand Up @@ -35,10 +37,10 @@ uint32_t BuiltInAggregateFunctions::getFunctionCost(const std::vector<DataType>&
isDistinct != function->isDistinct) {
return UINT32_MAX;
}
// Currently all aggregate functions takes either 0 or 1 parameter. Therefore we do not allow
// any implicit cast and require a perfect match.
for (auto i = 0u; i < inputTypes.size(); ++i) {
if (inputTypes[i].typeID != function->parameterTypeIDs[i]) {
if (function->parameterTypeIDs[i] == ANY) {
continue;
} else if (inputTypes[i].typeID != function->parameterTypeIDs[i]) {
return UINT32_MAX;
}
}
Expand Down Expand Up @@ -70,13 +72,14 @@ void BuiltInAggregateFunctions::registerAggregateFunctions() {
registerAvg();
registerMin();
registerMax();
registerCollect();
}

void BuiltInAggregateFunctions::registerCountStar() {
std::vector<std::unique_ptr<AggregateFunctionDefinition>> definitions;
definitions.push_back(std::make_unique<AggregateFunctionDefinition>(COUNT_STAR_FUNC_NAME,
std::vector<DataTypeID>{}, INT64, AggregateFunctionUtil::getCountStarFunction(), false));
aggregateFunctions.insert({COUNT_STAR_FUNC_NAME, move(definitions)});
aggregateFunctions.insert({COUNT_STAR_FUNC_NAME, std::move(definitions)});
}

void BuiltInAggregateFunctions::registerCount() {
Expand All @@ -90,7 +93,7 @@ void BuiltInAggregateFunctions::registerCount() {
AggregateFunctionUtil::getCountFunction(inputType, isDistinct), isDistinct));
}
}
aggregateFunctions.insert({COUNT_FUNC_NAME, move(definitions)});
aggregateFunctions.insert({COUNT_FUNC_NAME, std::move(definitions)});
}

void BuiltInAggregateFunctions::registerSum() {
Expand All @@ -102,7 +105,7 @@ void BuiltInAggregateFunctions::registerSum() {
AggregateFunctionUtil::getSumFunction(DataType(typeID), isDistinct), isDistinct));
}
}
aggregateFunctions.insert({SUM_FUNC_NAME, move(definitions)});
aggregateFunctions.insert({SUM_FUNC_NAME, std::move(definitions)});
}

void BuiltInAggregateFunctions::registerAvg() {
Expand All @@ -114,7 +117,7 @@ void BuiltInAggregateFunctions::registerAvg() {
AggregateFunctionUtil::getAvgFunction(DataType(typeID), isDistinct), isDistinct));
}
}
aggregateFunctions.insert({AVG_FUNC_NAME, move(definitions)});
aggregateFunctions.insert({AVG_FUNC_NAME, std::move(definitions)});
}

void BuiltInAggregateFunctions::registerMin() {
Expand All @@ -126,7 +129,7 @@ void BuiltInAggregateFunctions::registerMin() {
AggregateFunctionUtil::getMinFunction(DataType(typeID), isDistinct), isDistinct));
}
}
aggregateFunctions.insert({MIN_FUNC_NAME, move(definitions)});
aggregateFunctions.insert({MIN_FUNC_NAME, std::move(definitions)});
}

void BuiltInAggregateFunctions::registerMax() {
Expand All @@ -138,7 +141,18 @@ void BuiltInAggregateFunctions::registerMax() {
AggregateFunctionUtil::getMaxFunction(DataType(typeID), isDistinct), isDistinct));
}
}
aggregateFunctions.insert({MAX_FUNC_NAME, move(definitions)});
aggregateFunctions.insert({MAX_FUNC_NAME, std::move(definitions)});
}

void BuiltInAggregateFunctions::registerCollect() {
std::vector<std::unique_ptr<AggregateFunctionDefinition>> definitions;
for (auto isDistinct : std::vector<bool>{true, false}) {
definitions.push_back(std::make_unique<AggregateFunctionDefinition>(COLLECT_FUNC_NAME,
std::vector<DataTypeID>{ANY}, LIST,
AggregateFunctionUtil::getCollectFunction(DataType(ANY), isDistinct), isDistinct,
CollectFunction::bindFunc));
}
aggregateFunctions.insert({COLLECT_FUNC_NAME, std::move(definitions)});
}

} // namespace function
Expand Down
67 changes: 35 additions & 32 deletions src/function/vector_list_operation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ void VectorListOperations::ListCreation(
}

void ListCreationVectorOperation::listCreationBindFunc(const std::vector<DataType>& argumentTypes,
VectorOperationDefinition* definition, DataType& actualReturnType) {
FunctionDefinition* definition, DataType& actualReturnType) {
if (argumentTypes.empty()) {
throw BinderException(
"Cannot resolve child data type for " + LIST_CREATION_FUNC_NAME + ".");
Expand Down Expand Up @@ -80,40 +80,41 @@ std::vector<std::unique_ptr<VectorOperationDefinition>> ListLenVectorOperation::
}

void ListExtractVectorOperation::listExtractBindFunc(const std::vector<DataType>& argumentTypes,
VectorOperationDefinition* definition, DataType& returnType) {
FunctionDefinition* definition, DataType& returnType) {
definition->returnTypeID = argumentTypes[0].childType->typeID;
auto vectorOperationDefinition = reinterpret_cast<VectorOperationDefinition*>(definition);
returnType = *argumentTypes[0].childType;
switch (definition->returnTypeID) {
case BOOL: {
definition->execFunc =
vectorOperationDefinition->execFunc =
BinaryListExecFunction<ku_list_t, int64_t, uint8_t, operation::ListExtract>;
} break;
case INT64: {
definition->execFunc =
vectorOperationDefinition->execFunc =
BinaryListExecFunction<ku_list_t, int64_t, int64_t, operation::ListExtract>;
} break;
case DOUBLE: {
definition->execFunc =
vectorOperationDefinition->execFunc =
BinaryListExecFunction<ku_list_t, int64_t, double_t, operation::ListExtract>;
} break;
case DATE: {
definition->execFunc =
vectorOperationDefinition->execFunc =
BinaryListExecFunction<ku_list_t, int64_t, date_t, operation::ListExtract>;
} break;
case TIMESTAMP: {
definition->execFunc =
vectorOperationDefinition->execFunc =
BinaryListExecFunction<ku_list_t, int64_t, timestamp_t, operation::ListExtract>;
} break;
case INTERVAL: {
definition->execFunc =
vectorOperationDefinition->execFunc =
BinaryListExecFunction<ku_list_t, int64_t, interval_t, operation::ListExtract>;
} break;
case STRING: {
definition->execFunc =
vectorOperationDefinition->execFunc =
BinaryListExecFunction<ku_list_t, int64_t, ku_string_t, operation::ListExtract>;
} break;
case LIST: {
definition->execFunc =
vectorOperationDefinition->execFunc =
BinaryListExecFunction<ku_list_t, int64_t, ku_list_t, operation::ListExtract>;
} break;
default: {
Expand All @@ -139,8 +140,8 @@ std::vector<std::unique_ptr<VectorOperationDefinition>>
ListConcatVectorOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> result;
auto execFunc = BinaryListExecFunction<ku_list_t, ku_list_t, ku_list_t, operation::ListConcat>;
auto bindFunc = [](const std::vector<DataType>& argumentTypes,
VectorOperationDefinition* definition, DataType& actualReturnType) {
auto bindFunc = [](const std::vector<DataType>& argumentTypes, FunctionDefinition* definition,
DataType& actualReturnType) {
if (argumentTypes[0] != argumentTypes[1]) {
throw BinderException(getListFunctionIncompatibleChildrenTypeErrorMsg(
LIST_CONCAT_FUNC_NAME, argumentTypes[0], argumentTypes[1]));
Expand All @@ -155,44 +156,45 @@ ListConcatVectorOperation::getDefinitions() {
}

void ListAppendVectorOperation::listAppendBindFunc(const std::vector<DataType>& argumentTypes,
VectorOperationDefinition* definition, DataType& returnType) {
FunctionDefinition* definition, DataType& returnType) {
if (*argumentTypes[0].childType != argumentTypes[1]) {
throw BinderException(getListFunctionIncompatibleChildrenTypeErrorMsg(
LIST_APPEND_FUNC_NAME, argumentTypes[0], argumentTypes[1]));
}
auto vectorOperationDefinition = reinterpret_cast<VectorOperationDefinition*>(definition);
definition->returnTypeID = argumentTypes[0].typeID;
returnType = argumentTypes[0];
switch (argumentTypes[1].typeID) {
case INT64: {
definition->execFunc =
vectorOperationDefinition->execFunc =
BinaryListExecFunction<ku_list_t, int64_t, ku_list_t, operation::ListAppend>;
} break;
case DOUBLE: {
definition->execFunc =
vectorOperationDefinition->execFunc =
BinaryListExecFunction<ku_list_t, double_t, ku_list_t, operation::ListAppend>;
} break;
case BOOL: {
definition->execFunc =
vectorOperationDefinition->execFunc =
BinaryListExecFunction<ku_list_t, uint8_t, ku_list_t, operation::ListAppend>;
} break;
case STRING: {
definition->execFunc =
vectorOperationDefinition->execFunc =
BinaryListExecFunction<ku_list_t, ku_string_t, ku_list_t, operation::ListAppend>;
} break;
case DATE: {
definition->execFunc =
vectorOperationDefinition->execFunc =
BinaryListExecFunction<ku_list_t, date_t, ku_list_t, operation::ListAppend>;
} break;
case TIMESTAMP: {
definition->execFunc =
vectorOperationDefinition->execFunc =
BinaryListExecFunction<ku_list_t, timestamp_t, ku_list_t, operation::ListAppend>;
} break;
case INTERVAL: {
definition->execFunc =
vectorOperationDefinition->execFunc =
BinaryListExecFunction<ku_list_t, interval_t, ku_list_t, operation::ListAppend>;
} break;
case LIST: {
definition->execFunc =
vectorOperationDefinition->execFunc =
BinaryListExecFunction<ku_list_t, ku_list_t, ku_list_t, operation::ListAppend>;
} break;
default: {
Expand All @@ -211,44 +213,45 @@ ListAppendVectorOperation::getDefinitions() {
}

void ListPrependVectorOperation::listPrependBindFunc(const std::vector<DataType>& argumentTypes,
VectorOperationDefinition* definition, DataType& returnType) {
FunctionDefinition* definition, DataType& returnType) {
if (argumentTypes[0] != *argumentTypes[1].childType) {
throw BinderException(getListFunctionIncompatibleChildrenTypeErrorMsg(
LIST_APPEND_FUNC_NAME, argumentTypes[0], argumentTypes[1]));
}
auto vectorOperationDefinition = reinterpret_cast<VectorOperationDefinition*>(definition);
definition->returnTypeID = argumentTypes[1].typeID;
returnType = argumentTypes[1];
switch (argumentTypes[0].typeID) {
case INT64: {
definition->execFunc =
vectorOperationDefinition->execFunc =
BinaryListExecFunction<int64_t, ku_list_t, ku_list_t, operation::ListPrepend>;
} break;
case DOUBLE: {
definition->execFunc =
vectorOperationDefinition->execFunc =
BinaryListExecFunction<double_t, ku_list_t, ku_list_t, operation::ListPrepend>;
} break;
case BOOL: {
definition->execFunc =
vectorOperationDefinition->execFunc =
BinaryListExecFunction<uint8_t, ku_list_t, ku_list_t, operation::ListPrepend>;
} break;
case STRING: {
definition->execFunc =
vectorOperationDefinition->execFunc =
BinaryListExecFunction<ku_string_t, ku_list_t, ku_list_t, operation::ListPrepend>;
} break;
case DATE: {
definition->execFunc =
vectorOperationDefinition->execFunc =
BinaryListExecFunction<date_t, ku_list_t, ku_list_t, operation::ListPrepend>;
} break;
case TIMESTAMP: {
definition->execFunc =
vectorOperationDefinition->execFunc =
BinaryListExecFunction<timestamp_t, ku_list_t, ku_list_t, operation::ListPrepend>;
} break;
case INTERVAL: {
definition->execFunc =
vectorOperationDefinition->execFunc =
BinaryListExecFunction<interval_t, ku_list_t, ku_list_t, operation::ListPrepend>;
} break;
case LIST: {
definition->execFunc =
vectorOperationDefinition->execFunc =
BinaryListExecFunction<ku_list_t, ku_list_t, ku_list_t, operation::ListPrepend>;
} break;
default: {
Expand Down Expand Up @@ -280,8 +283,8 @@ ListContainsVectorOperation::getDefinitions() {

std::vector<std::unique_ptr<VectorOperationDefinition>> ListSliceVectorOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> result;
auto bindFunc = [](const std::vector<DataType>& argumentTypes,
VectorOperationDefinition* definition, DataType& actualReturnType) {
auto bindFunc = [](const std::vector<DataType>& argumentTypes, FunctionDefinition* definition,
DataType& actualReturnType) {
definition->returnTypeID = argumentTypes[0].typeID;
actualReturnType = argumentTypes[0];
};
Expand Down
1 change: 1 addition & 0 deletions src/include/common/expression_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ const std::string SUM_FUNC_NAME = "SUM";
const std::string AVG_FUNC_NAME = "AVG";
const std::string MIN_FUNC_NAME = "MIN";
const std::string MAX_FUNC_NAME = "MAX";
const std::string COLLECT_FUNC_NAME = "COLLECT";

// cast
const std::string CAST_TO_DATE_FUNC_NAME = "DATE";
Expand Down
Loading

0 comments on commit 6749ec0

Please sign in to comment.