Skip to content

Commit

Permalink
Merge pull request #2292 from kuzudb/function-framework
Browse files Browse the repository at this point in the history
Function framework refactor
  • Loading branch information
acquamarin committed Oct 29, 2023
2 parents 38b2497 + 9121550 commit 627b5f5
Show file tree
Hide file tree
Showing 74 changed files with 2,545 additions and 2,642 deletions.
7 changes: 3 additions & 4 deletions src/binder/bind/bind_reading_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,22 +98,21 @@ std::unique_ptr<BoundReadingClause> Binder::bindUnwindClause(const ReadingClause

std::unique_ptr<BoundReadingClause> Binder::bindInQueryCall(const ReadingClause& readingClause) {
auto& call = reinterpret_cast<const InQueryCallClause&>(readingClause);
auto tableFunctionDefinition =
catalog.getBuiltInTableFunction()->mathTableFunction(call.getFuncName());
auto tableFunction = catalog.getBuiltInFunctions()->mathTableFunction(call.getFuncName());
auto inputValues = std::vector<Value>{};
for (auto& parameter : call.getParameters()) {
auto boundExpr = expressionBinder.bindLiteralExpression(*parameter);
inputValues.push_back(*reinterpret_cast<LiteralExpression*>(boundExpr.get())->getValue());
}
auto bindData = tableFunctionDefinition->bindFunc(clientContext,
auto bindData = tableFunction->bindFunc(clientContext,
function::TableFuncBindInput{std::move(inputValues)}, catalog.getReadOnlyVersion());
expression_vector outputExpressions;
for (auto i = 0u; i < bindData->returnColumnNames.size(); i++) {
outputExpressions.push_back(
createVariable(bindData->returnColumnNames[i], bindData->returnTypes[i]));
}
return std::make_unique<BoundInQueryCall>(
tableFunctionDefinition, std::move(bindData), std::move(outputExpressions));
std::move(tableFunction), std::move(bindData), std::move(outputExpressions));
}

static std::unique_ptr<LogicalType> bindFixedListType(
Expand Down
4 changes: 2 additions & 2 deletions src/binder/bind_expression/bind_comparison_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ std::shared_ptr<Expression> ExpressionBinder::bindComparisonExpression(

std::shared_ptr<Expression> ExpressionBinder::bindComparisonExpression(
ExpressionType expressionType, const expression_vector& children) {
auto builtInFunctions = binder->catalog.getBuiltInVectorFunctions();
auto builtInFunctions = binder->catalog.getBuiltInFunctions();
auto functionName = expressionTypeToString(expressionType);
std::vector<LogicalType> childrenTypes;
for (auto& child : children) {
childrenTypes.push_back(child->dataType);
}
auto function = builtInFunctions->matchVectorFunction(functionName, childrenTypes);
auto function = builtInFunctions->matchScalarFunction(functionName, childrenTypes);
expression_vector childrenAfterCast;
for (auto i = 0u; i < children.size(); ++i) {
childrenAfterCast.push_back(
Expand Down
15 changes: 8 additions & 7 deletions src/binder/bind_expression/bind_function_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ std::shared_ptr<Expression> ExpressionBinder::bindScalarFunctionExpression(

std::shared_ptr<Expression> ExpressionBinder::bindScalarFunctionExpression(
const expression_vector& children, const std::string& functionName) {
auto builtInFunctions = binder->catalog.getBuiltInVectorFunctions();
auto builtInFunctions = binder->catalog.getBuiltInFunctions();
std::vector<LogicalType> childrenTypes;
for (auto& child : children) {
childrenTypes.push_back(child->dataType);
}
auto function = builtInFunctions->matchVectorFunction(functionName, childrenTypes);
auto function = builtInFunctions->matchScalarFunction(functionName, childrenTypes);
expression_vector childrenAfterCast;
for (auto i = 0u; i < children.size(); ++i) {
auto targetType =
Expand All @@ -79,7 +79,7 @@ std::shared_ptr<Expression> ExpressionBinder::bindScalarFunctionExpression(

std::shared_ptr<Expression> ExpressionBinder::bindAggregateFunctionExpression(
const ParsedExpression& parsedExpression, const std::string& functionName, bool isDistinct) {
auto builtInFunctions = binder->catalog.getBuiltInAggregateFunction();
auto builtInFunctions = binder->catalog.getBuiltInFunctions();
std::vector<LogicalType> childrenTypes;
expression_vector children;
for (auto i = 0u; i < parsedExpression.getNumChildren(); ++i) {
Expand All @@ -92,7 +92,8 @@ std::shared_ptr<Expression> ExpressionBinder::bindAggregateFunctionExpression(
childrenTypes.push_back(child->dataType);
children.push_back(std::move(child));
}
auto function = builtInFunctions->matchFunction(functionName, childrenTypes, isDistinct);
auto function =
builtInFunctions->matchAggregateFunction(functionName, childrenTypes, isDistinct)->copy();
if (function->paramRewriteFunc) {
function->paramRewriteFunc(children);
}
Expand All @@ -103,13 +104,13 @@ std::shared_ptr<Expression> ExpressionBinder::bindAggregateFunctionExpression(
}
std::unique_ptr<function::FunctionBindData> bindData;
if (function->bindFunc) {
bindData = function->bindFunc(children, function);
bindData = function->bindFunc(children, function.get());
} else {
bindData =
std::make_unique<function::FunctionBindData>(LogicalType(function->returnTypeID));
}
return make_shared<AggregateFunctionExpression>(functionName, std::move(bindData),
std::move(children), function->aggregateFunction->clone(), uniqueExpressionName);
std::move(children), std::move(function), uniqueExpressionName);
}

std::shared_ptr<Expression> ExpressionBinder::bindMacroExpression(
Expand Down Expand Up @@ -247,7 +248,7 @@ std::shared_ptr<Expression> ExpressionBinder::bindLabelFunction(const Expression
default:
throw NotImplementedException("ExpressionBinder::bindLabelFunction");
}
auto execFunc = function::LabelVectorFunction::execFunction;
auto execFunc = function::LabelFunction::execFunction;
auto bindData =
std::make_unique<function::FunctionBindData>(LogicalType(LogicalTypeID::STRING));
auto uniqueExpressionName = ScalarFunctionExpression::getUniqueName(LABEL_FUNC_NAME, children);
Expand Down
6 changes: 3 additions & 3 deletions src/binder/expression_binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,12 @@ std::shared_ptr<Expression> ExpressionBinder::implicitCastIfNecessary(

std::shared_ptr<Expression> ExpressionBinder::implicitCast(
const std::shared_ptr<Expression>& expression, const LogicalType& targetType) {
if (VectorCastFunction::hasImplicitCast(expression->dataType, targetType)) {
auto functionName = VectorCastFunction::bindImplicitCastFuncName(targetType);
if (CastFunction::hasImplicitCast(expression->dataType, targetType)) {
auto functionName = CastFunction::bindImplicitCastFuncName(targetType);
auto children = expression_vector{expression};
auto bindData = std::make_unique<FunctionBindData>(targetType);
function::scalar_exec_func execFunc;
VectorCastFunction::bindImplicitCastFunc(
CastFunction::bindImplicitCastFunc(
expression->dataType.getLogicalTypeID(), targetType.getLogicalTypeID(), execFunc);
auto uniqueName = ScalarFunctionExpression::getUniqueName(functionName, children);
return std::make_shared<ScalarFunctionExpression>(functionName, FUNCTION,
Expand Down
5 changes: 2 additions & 3 deletions src/catalog/catalog.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,8 @@ std::unordered_set<TableSchema*> Catalog::getAllRelTableSchemasContainBoundTable
return relTableSchemas;
}

void Catalog::addVectorFunction(
std::string name, function::vector_function_definitions definitions) {
catalogContentForReadOnlyTrx->addVectorFunction(std::move(name), std::move(definitions));
void Catalog::addFunction(std::string name, function::function_set functionSet) {
catalogContentForReadOnlyTrx->addFunction(std::move(name), std::move(functionSet));
}

void Catalog::addScalarMacroFunction(
Expand Down
30 changes: 18 additions & 12 deletions src/catalog/catalog_content.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,21 +204,29 @@ void CatalogContent::readFromFile(const std::string& directory, DBFileType dbFil

ExpressionType CatalogContent::getFunctionType(const std::string& name) const {
auto upperCaseName = StringUtils::getUpper(name);
if (builtInVectorFunctions->containsFunction(upperCaseName)) {
return FUNCTION;
} else if (builtInAggregateFunctions->containsFunction(upperCaseName)) {
return AGGREGATE_FUNCTION;
} else if (macros.contains(upperCaseName)) {
if (macros.contains(upperCaseName)) {
return MACRO;
} else {
} else if (!builtInFunctions->containsFunction(name)) {
throw CatalogException(name + " function does not exist.");
} else {
// TODO(Ziyi): we should let table function use the same interface to bind.
auto funcType = builtInFunctions->getFunctionType(upperCaseName);
switch (funcType) {
case function::FunctionType::SCALAR:
return FUNCTION;
case function::FunctionType::AGGREGATE:
return AGGREGATE_FUNCTION;
// LCOV_EXCL_START
default:
throw NotImplementedException{"CatalogContent::getFunctionType"};
// LCOV_EXCL_END
}
}
}

void CatalogContent::addVectorFunction(
std::string name, function::vector_function_definitions definitions) {
void CatalogContent::addFunction(std::string name, function::function_set definitions) {
StringUtils::toUpper(name);
builtInVectorFunctions->addFunction(std::move(name), std::move(definitions));
builtInFunctions->addFunction(std::move(name), std::move(definitions));
}

void CatalogContent::addScalarMacroFunction(
Expand Down Expand Up @@ -272,9 +280,7 @@ void CatalogContent::writeMagicBytes(Serializer& serializer) {
}

void CatalogContent::registerBuiltInFunctions() {
builtInVectorFunctions = std::make_unique<function::BuiltInVectorFunctions>();
builtInAggregateFunctions = std::make_unique<function::BuiltInAggregateFunctions>();
builtInTableFunctions = std::make_unique<function::BuiltInTableFunctions>();
builtInFunctions = std::make_unique<function::BuiltInFunctions>();
}

bool CatalogContent::containsTable(const std::string& tableName, TableType tableType) const {
Expand Down
4 changes: 3 additions & 1 deletion src/common/expression_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ bool isExpressionSubquery(ExpressionType type) {
return EXISTENTIAL_SUBQUERY == type;
}

// LCOV_EXCL_START
std::string expressionTypeToString(ExpressionType type) {
switch (type) {
case OR:
Expand Down Expand Up @@ -77,7 +78,7 @@ std::string expressionTypeToString(ExpressionType type) {
case PARAMETER:
return "PARAMETER";
case FUNCTION:
return "FUNCTION";
return "SCALAR_FUNCTION";
case AGGREGATE_FUNCTION:
return "AGGREGATE_FUNCTION";
case EXISTENTIAL_SUBQUERY:
Expand All @@ -86,6 +87,7 @@ std::string expressionTypeToString(ExpressionType type) {
throw NotImplementedException("Cannot convert expression type to string");
}
}
// LCOV_EXCL_STOP

} // namespace common
} // namespace kuzu
4 changes: 2 additions & 2 deletions src/expression_evaluator/node_rel_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ void NodeRelExpressionEvaluator::evaluate() {
for (auto& child : children) {
child->evaluate();
}
StructPackVectorFunctions::execFunc(parameters, *resultVector);
StructPackFunctions::execFunc(parameters, *resultVector);
}

void NodeRelExpressionEvaluator::resolveResultVector(
Expand All @@ -26,7 +26,7 @@ void NodeRelExpressionEvaluator::resolveResultVector(
inputEvaluators.push_back(child.get());
}
resolveResultStateFromChildren(inputEvaluators);
StructPackVectorFunctions::compileFunc(nullptr, parameters, resultVector);
StructPackFunctions::compileFunc(nullptr, parameters, resultVector);
}

} // namespace evaluator
Expand Down
4 changes: 1 addition & 3 deletions src/function/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@ add_library(kuzu_function
OBJECT
aggregate_function.cpp
base_lower_upper_operation.cpp
built_in_aggregate_functions.cpp
built_in_vector_functions.cpp
built_in_table_functions.cpp
built_in_functions.cpp
cast_string_non_nested_functions.cpp
cast_string_to_functions.cpp
comparison_functions.cpp
Expand Down
Loading

0 comments on commit 627b5f5

Please sign in to comment.