Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add collect agg function #1292

Merged
merged 1 commit into from
Feb 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/binder/bind_expression/bind_function_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,14 @@ std::shared_ptr<Expression> ExpressionBinder::bindAggregateFunctionExpression(
if (children.empty()) {
uniqueExpressionName = binder->getUniqueExpressionName(uniqueExpressionName);
}
return make_shared<AggregateFunctionExpression>(DataType(function->returnTypeID),
std::move(children), function->aggregateFunction->clone(), uniqueExpressionName);
DataType returnType;
if (function->bindFunc) {
function->bindFunc(childrenTypes, function, returnType);
} else {
returnType = DataType(function->returnTypeID);
}
return make_shared<AggregateFunctionExpression>(returnType, std::move(children),
function->aggregateFunction->clone(), uniqueExpressionName);
}

std::shared_ptr<Expression> ExpressionBinder::staticEvaluate(const std::string& functionName,
Expand Down
8 changes: 8 additions & 0 deletions src/function/aggregate_function.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "function/aggregate/aggregate_function.h"

#include "function/aggregate/avg.h"
#include "function/aggregate/collect.h"
#include "function/aggregate/count.h"
#include "function/aggregate/count_star.h"
#include "function/aggregate/min_max.h"
Expand Down Expand Up @@ -69,6 +70,13 @@ std::unique_ptr<AggregateFunction> AggregateFunctionUtil::getMaxFunction(
return getMinMaxFunction<operation::GreaterThan>(inputType, isDistinct);
}

std::unique_ptr<AggregateFunction> AggregateFunctionUtil::getCollectFunction(
const common::DataType& inputType, bool isDistinct) {
return std::make_unique<AggregateFunction>(CollectFunction::initialize,
CollectFunction::updateAll, CollectFunction::updatePos, CollectFunction::combine,
CollectFunction::finalize, inputType, isDistinct);
}

template<typename FUNC>
std::unique_ptr<AggregateFunction> AggregateFunctionUtil::getMinMaxFunction(
const DataType& inputType, bool isDistinct) {
Expand Down
32 changes: 23 additions & 9 deletions src/function/built_in_aggregate_functions.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "function/aggregate/built_in_aggregate_functions.h"

#include "function/aggregate/collect.h"

using namespace kuzu::common;

namespace kuzu {
Expand Down Expand Up @@ -35,10 +37,10 @@ uint32_t BuiltInAggregateFunctions::getFunctionCost(const std::vector<DataType>&
isDistinct != function->isDistinct) {
return UINT32_MAX;
}
// Currently all aggregate functions takes either 0 or 1 parameter. Therefore we do not allow
// any implicit cast and require a perfect match.
for (auto i = 0u; i < inputTypes.size(); ++i) {
if (inputTypes[i].typeID != function->parameterTypeIDs[i]) {
if (function->parameterTypeIDs[i] == ANY) {
acquamarin marked this conversation as resolved.
Show resolved Hide resolved
continue;
} else if (inputTypes[i].typeID != function->parameterTypeIDs[i]) {
return UINT32_MAX;
}
}
Expand Down Expand Up @@ -70,13 +72,14 @@ void BuiltInAggregateFunctions::registerAggregateFunctions() {
registerAvg();
registerMin();
registerMax();
registerCollect();
}

void BuiltInAggregateFunctions::registerCountStar() {
std::vector<std::unique_ptr<AggregateFunctionDefinition>> definitions;
definitions.push_back(std::make_unique<AggregateFunctionDefinition>(COUNT_STAR_FUNC_NAME,
std::vector<DataTypeID>{}, INT64, AggregateFunctionUtil::getCountStarFunction(), false));
aggregateFunctions.insert({COUNT_STAR_FUNC_NAME, move(definitions)});
aggregateFunctions.insert({COUNT_STAR_FUNC_NAME, std::move(definitions)});
}

void BuiltInAggregateFunctions::registerCount() {
Expand All @@ -90,7 +93,7 @@ void BuiltInAggregateFunctions::registerCount() {
AggregateFunctionUtil::getCountFunction(inputType, isDistinct), isDistinct));
}
}
aggregateFunctions.insert({COUNT_FUNC_NAME, move(definitions)});
aggregateFunctions.insert({COUNT_FUNC_NAME, std::move(definitions)});
}

void BuiltInAggregateFunctions::registerSum() {
Expand All @@ -102,7 +105,7 @@ void BuiltInAggregateFunctions::registerSum() {
AggregateFunctionUtil::getSumFunction(DataType(typeID), isDistinct), isDistinct));
}
}
aggregateFunctions.insert({SUM_FUNC_NAME, move(definitions)});
aggregateFunctions.insert({SUM_FUNC_NAME, std::move(definitions)});
}

void BuiltInAggregateFunctions::registerAvg() {
Expand All @@ -114,7 +117,7 @@ void BuiltInAggregateFunctions::registerAvg() {
AggregateFunctionUtil::getAvgFunction(DataType(typeID), isDistinct), isDistinct));
}
}
aggregateFunctions.insert({AVG_FUNC_NAME, move(definitions)});
aggregateFunctions.insert({AVG_FUNC_NAME, std::move(definitions)});
}

void BuiltInAggregateFunctions::registerMin() {
Expand All @@ -126,7 +129,7 @@ void BuiltInAggregateFunctions::registerMin() {
AggregateFunctionUtil::getMinFunction(DataType(typeID), isDistinct), isDistinct));
}
}
aggregateFunctions.insert({MIN_FUNC_NAME, move(definitions)});
aggregateFunctions.insert({MIN_FUNC_NAME, std::move(definitions)});
}

void BuiltInAggregateFunctions::registerMax() {
Expand All @@ -138,7 +141,18 @@ void BuiltInAggregateFunctions::registerMax() {
AggregateFunctionUtil::getMaxFunction(DataType(typeID), isDistinct), isDistinct));
}
}
aggregateFunctions.insert({MAX_FUNC_NAME, move(definitions)});
aggregateFunctions.insert({MAX_FUNC_NAME, std::move(definitions)});
}

void BuiltInAggregateFunctions::registerCollect() {
std::vector<std::unique_ptr<AggregateFunctionDefinition>> definitions;
for (auto isDistinct : std::vector<bool>{true, false}) {
definitions.push_back(std::make_unique<AggregateFunctionDefinition>(COLLECT_FUNC_NAME,
std::vector<DataTypeID>{ANY}, LIST,
AggregateFunctionUtil::getCollectFunction(DataType(ANY), isDistinct), isDistinct,
CollectFunction::bindFunc));
}
aggregateFunctions.insert({COLLECT_FUNC_NAME, std::move(definitions)});
}

} // namespace function
Expand Down
67 changes: 35 additions & 32 deletions src/function/vector_list_operation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ void VectorListOperations::ListCreation(
}

void ListCreationVectorOperation::listCreationBindFunc(const std::vector<DataType>& argumentTypes,
VectorOperationDefinition* definition, DataType& actualReturnType) {
FunctionDefinition* definition, DataType& actualReturnType) {
if (argumentTypes.empty()) {
throw BinderException(
"Cannot resolve child data type for " + LIST_CREATION_FUNC_NAME + ".");
Expand Down Expand Up @@ -80,40 +80,41 @@ std::vector<std::unique_ptr<VectorOperationDefinition>> ListLenVectorOperation::
}

void ListExtractVectorOperation::listExtractBindFunc(const std::vector<DataType>& argumentTypes,
VectorOperationDefinition* definition, DataType& returnType) {
FunctionDefinition* definition, DataType& returnType) {
definition->returnTypeID = argumentTypes[0].childType->typeID;
auto vectorOperationDefinition = reinterpret_cast<VectorOperationDefinition*>(definition);
returnType = *argumentTypes[0].childType;
switch (definition->returnTypeID) {
case BOOL: {
definition->execFunc =
vectorOperationDefinition->execFunc =
BinaryListExecFunction<ku_list_t, int64_t, uint8_t, operation::ListExtract>;
} break;
case INT64: {
definition->execFunc =
vectorOperationDefinition->execFunc =
BinaryListExecFunction<ku_list_t, int64_t, int64_t, operation::ListExtract>;
} break;
case DOUBLE: {
definition->execFunc =
vectorOperationDefinition->execFunc =
BinaryListExecFunction<ku_list_t, int64_t, double_t, operation::ListExtract>;
} break;
case DATE: {
definition->execFunc =
vectorOperationDefinition->execFunc =
BinaryListExecFunction<ku_list_t, int64_t, date_t, operation::ListExtract>;
} break;
case TIMESTAMP: {
definition->execFunc =
vectorOperationDefinition->execFunc =
BinaryListExecFunction<ku_list_t, int64_t, timestamp_t, operation::ListExtract>;
} break;
case INTERVAL: {
definition->execFunc =
vectorOperationDefinition->execFunc =
BinaryListExecFunction<ku_list_t, int64_t, interval_t, operation::ListExtract>;
} break;
case STRING: {
definition->execFunc =
vectorOperationDefinition->execFunc =
BinaryListExecFunction<ku_list_t, int64_t, ku_string_t, operation::ListExtract>;
} break;
case LIST: {
definition->execFunc =
vectorOperationDefinition->execFunc =
BinaryListExecFunction<ku_list_t, int64_t, ku_list_t, operation::ListExtract>;
} break;
default: {
Expand All @@ -139,8 +140,8 @@ std::vector<std::unique_ptr<VectorOperationDefinition>>
ListConcatVectorOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> result;
auto execFunc = BinaryListExecFunction<ku_list_t, ku_list_t, ku_list_t, operation::ListConcat>;
auto bindFunc = [](const std::vector<DataType>& argumentTypes,
VectorOperationDefinition* definition, DataType& actualReturnType) {
auto bindFunc = [](const std::vector<DataType>& argumentTypes, FunctionDefinition* definition,
DataType& actualReturnType) {
if (argumentTypes[0] != argumentTypes[1]) {
throw BinderException(getListFunctionIncompatibleChildrenTypeErrorMsg(
LIST_CONCAT_FUNC_NAME, argumentTypes[0], argumentTypes[1]));
Expand All @@ -155,44 +156,45 @@ ListConcatVectorOperation::getDefinitions() {
}

void ListAppendVectorOperation::listAppendBindFunc(const std::vector<DataType>& argumentTypes,
VectorOperationDefinition* definition, DataType& returnType) {
FunctionDefinition* definition, DataType& returnType) {
if (*argumentTypes[0].childType != argumentTypes[1]) {
throw BinderException(getListFunctionIncompatibleChildrenTypeErrorMsg(
LIST_APPEND_FUNC_NAME, argumentTypes[0], argumentTypes[1]));
}
auto vectorOperationDefinition = reinterpret_cast<VectorOperationDefinition*>(definition);
definition->returnTypeID = argumentTypes[0].typeID;
returnType = argumentTypes[0];
switch (argumentTypes[1].typeID) {
case INT64: {
definition->execFunc =
vectorOperationDefinition->execFunc =
BinaryListExecFunction<ku_list_t, int64_t, ku_list_t, operation::ListAppend>;
} break;
case DOUBLE: {
definition->execFunc =
vectorOperationDefinition->execFunc =
BinaryListExecFunction<ku_list_t, double_t, ku_list_t, operation::ListAppend>;
} break;
case BOOL: {
definition->execFunc =
vectorOperationDefinition->execFunc =
BinaryListExecFunction<ku_list_t, uint8_t, ku_list_t, operation::ListAppend>;
} break;
case STRING: {
definition->execFunc =
vectorOperationDefinition->execFunc =
BinaryListExecFunction<ku_list_t, ku_string_t, ku_list_t, operation::ListAppend>;
} break;
case DATE: {
definition->execFunc =
vectorOperationDefinition->execFunc =
BinaryListExecFunction<ku_list_t, date_t, ku_list_t, operation::ListAppend>;
} break;
case TIMESTAMP: {
definition->execFunc =
vectorOperationDefinition->execFunc =
BinaryListExecFunction<ku_list_t, timestamp_t, ku_list_t, operation::ListAppend>;
} break;
case INTERVAL: {
definition->execFunc =
vectorOperationDefinition->execFunc =
BinaryListExecFunction<ku_list_t, interval_t, ku_list_t, operation::ListAppend>;
} break;
case LIST: {
definition->execFunc =
vectorOperationDefinition->execFunc =
BinaryListExecFunction<ku_list_t, ku_list_t, ku_list_t, operation::ListAppend>;
} break;
default: {
Expand All @@ -211,44 +213,45 @@ ListAppendVectorOperation::getDefinitions() {
}

void ListPrependVectorOperation::listPrependBindFunc(const std::vector<DataType>& argumentTypes,
VectorOperationDefinition* definition, DataType& returnType) {
FunctionDefinition* definition, DataType& returnType) {
if (argumentTypes[0] != *argumentTypes[1].childType) {
throw BinderException(getListFunctionIncompatibleChildrenTypeErrorMsg(
LIST_APPEND_FUNC_NAME, argumentTypes[0], argumentTypes[1]));
}
auto vectorOperationDefinition = reinterpret_cast<VectorOperationDefinition*>(definition);
definition->returnTypeID = argumentTypes[1].typeID;
returnType = argumentTypes[1];
switch (argumentTypes[0].typeID) {
case INT64: {
definition->execFunc =
vectorOperationDefinition->execFunc =
BinaryListExecFunction<int64_t, ku_list_t, ku_list_t, operation::ListPrepend>;
} break;
case DOUBLE: {
definition->execFunc =
vectorOperationDefinition->execFunc =
BinaryListExecFunction<double_t, ku_list_t, ku_list_t, operation::ListPrepend>;
} break;
case BOOL: {
definition->execFunc =
vectorOperationDefinition->execFunc =
BinaryListExecFunction<uint8_t, ku_list_t, ku_list_t, operation::ListPrepend>;
} break;
case STRING: {
definition->execFunc =
vectorOperationDefinition->execFunc =
BinaryListExecFunction<ku_string_t, ku_list_t, ku_list_t, operation::ListPrepend>;
} break;
case DATE: {
definition->execFunc =
vectorOperationDefinition->execFunc =
BinaryListExecFunction<date_t, ku_list_t, ku_list_t, operation::ListPrepend>;
} break;
case TIMESTAMP: {
definition->execFunc =
vectorOperationDefinition->execFunc =
BinaryListExecFunction<timestamp_t, ku_list_t, ku_list_t, operation::ListPrepend>;
} break;
case INTERVAL: {
definition->execFunc =
vectorOperationDefinition->execFunc =
BinaryListExecFunction<interval_t, ku_list_t, ku_list_t, operation::ListPrepend>;
} break;
case LIST: {
definition->execFunc =
vectorOperationDefinition->execFunc =
BinaryListExecFunction<ku_list_t, ku_list_t, ku_list_t, operation::ListPrepend>;
} break;
default: {
Expand Down Expand Up @@ -280,8 +283,8 @@ ListContainsVectorOperation::getDefinitions() {

std::vector<std::unique_ptr<VectorOperationDefinition>> ListSliceVectorOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> result;
auto bindFunc = [](const std::vector<DataType>& argumentTypes,
VectorOperationDefinition* definition, DataType& actualReturnType) {
auto bindFunc = [](const std::vector<DataType>& argumentTypes, FunctionDefinition* definition,
DataType& actualReturnType) {
definition->returnTypeID = argumentTypes[0].typeID;
actualReturnType = argumentTypes[0];
};
Expand Down
1 change: 1 addition & 0 deletions src/include/common/expression_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ const std::string SUM_FUNC_NAME = "SUM";
const std::string AVG_FUNC_NAME = "AVG";
const std::string MIN_FUNC_NAME = "MIN";
const std::string MAX_FUNC_NAME = "MAX";
const std::string COLLECT_FUNC_NAME = "COLLECT";

// cast
const std::string CAST_TO_DATE_FUNC_NAME = "DATE";
Expand Down
Loading