Skip to content

Commit

Permalink
Merge pull request #3079 from kuzudb/arithmetic-functions-refactor
Browse files Browse the repository at this point in the history
Refactor arithmetic functions
  • Loading branch information
manh9203 committed Mar 19, 2024
2 parents 8b2c768 + efdc1e4 commit 8fa40d6
Show file tree
Hide file tree
Showing 45 changed files with 428 additions and 378 deletions.
5 changes: 3 additions & 2 deletions src/binder/bind_expression/bind_function_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "binder/expression/property_expression.h"
#include "binder/expression_binder.h"
#include "common/exception/binder.h"
#include "function/arithmetic/vector_arithmetic_functions.h"
#include "function/rewrite_function.h"
#include "function/schema/vector_label_functions.h"
#include "main/client_context.h"
Expand Down Expand Up @@ -323,12 +324,12 @@ std::shared_ptr<Expression> ExpressionBinder::bindRecursiveJoinLengthFunction(
children.push_back(std::move(numRelsExpression));
children.push_back(
ku_dynamic_cast<Expression&, RelExpression&>(*recursiveRels[0]).getLengthExpression());
auto result = bindScalarFunctionExpression(children, ADD_FUNC_NAME);
auto result = bindScalarFunctionExpression(children, AddFunction::name);
for (auto i = 1u; i < recursiveRels.size(); ++i) {
children[0] = std::move(result);
children[1] = ku_dynamic_cast<Expression&, RelExpression&>(*recursiveRels[i])
.getLengthExpression();
result = bindScalarFunctionExpression(children, ADD_FUNC_NAME);
result = bindScalarFunctionExpression(children, AddFunction::name);
}
return result;
} else if (ExpressionUtil::isRecursiveRelPattern(expression)) {
Expand Down
2 changes: 2 additions & 0 deletions src/catalog/catalog_entry/rel_table_catalog_entry.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "catalog/catalog_entry/rel_table_catalog_entry.h"

#include <sstream>

#include "catalog/catalog.h"

using namespace kuzu::common;
Expand Down
4 changes: 4 additions & 0 deletions src/catalog/catalog_entry/scalar_function_catalog_entry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
namespace kuzu {
namespace catalog {

ScalarFunctionCatalogEntry::ScalarFunctionCatalogEntry(
const char* name, function::function_set functionSet)
: ScalarFunctionCatalogEntry{std::string{name}, std::move(functionSet)} {}

ScalarFunctionCatalogEntry::ScalarFunctionCatalogEntry(
std::string name, function::function_set functionSet)
: FunctionCatalogEntry{
Expand Down
1 change: 1 addition & 0 deletions src/function/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ add_library(kuzu_function
cast_from_string_functions.cpp
comparison_functions.cpp
find_function.cpp
function_collection.cpp
scalar_macro_function.cpp
vector_arithmetic_functions.cpp
vector_boolean_functions.cpp
Expand Down
98 changes: 13 additions & 85 deletions src/function/built_in_function_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "function/cast/vector_cast_functions.h"
#include "function/comparison/vector_comparison_functions.h"
#include "function/date/vector_date_functions.h"
#include "function/function_collection.h"
#include "function/interval/vector_interval_functions.h"
#include "function/list/vector_list_functions.h"
#include "function/map/vector_map_functions.h"
Expand Down Expand Up @@ -52,11 +53,12 @@ void BuiltInFunctionsUtils::createFunctions(CatalogSet* catalogSet) {
registerScalarFunctions(catalogSet);
registerAggregateFunctions(catalogSet);
registerTableFunctions(catalogSet);

registerFunctions(catalogSet);
}

void BuiltInFunctionsUtils::registerScalarFunctions(CatalogSet* catalogSet) {
registerComparisonFunctions(catalogSet);
registerArithmeticFunctions(catalogSet);
registerDateFunctions(catalogSet);
registerTimestampFunctions(catalogSet);
registerIntervalFunctions(catalogSet);
Expand Down Expand Up @@ -514,7 +516,7 @@ void BuiltInFunctionsUtils::validateSpecialCases(std::vector<Function*>& candida
const std::string& name, const std::vector<LogicalType>& inputTypes,
function::function_set& set) {
// special case for add func
if (name == ADD_FUNC_NAME) {
if (name == AddFunction::name) {
auto targetType0 = candidateFunctions[0]->parameterTypeIDs[0];
auto targetType1 = candidateFunctions[0]->parameterTypeIDs[1];
auto inputType0 = inputTypes[0].getLogicalTypeID();
Expand Down Expand Up @@ -551,89 +553,6 @@ void BuiltInFunctionsUtils::registerComparisonFunctions(CatalogSet* catalogSet)
LESS_THAN_EQUALS_FUNC_NAME, LessThanEqualsFunction::getFunctionSet()));
}

void BuiltInFunctionsUtils::registerArithmeticFunctions(CatalogSet* catalogSet) {
catalogSet->createEntry(
std::make_unique<ScalarFunctionCatalogEntry>(ADD_FUNC_NAME, AddFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
SUBTRACT_FUNC_NAME, SubtractFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
MULTIPLY_FUNC_NAME, MultiplyFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
DIVIDE_FUNC_NAME, DivideFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
MODULO_FUNC_NAME, ModuloFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
POWER_FUNC_NAME, PowerFunction::getFunctionSet()));
catalogSet->createEntry(
std::make_unique<ScalarFunctionCatalogEntry>(ABS_FUNC_NAME, AbsFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
ACOS_FUNC_NAME, AcosFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
ASIN_FUNC_NAME, AsinFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
ATAN_FUNC_NAME, AtanFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
ATAN2_FUNC_NAME, Atan2Function::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
BITWISE_XOR_FUNC_NAME, BitwiseXorFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
BITWISE_AND_FUNC_NAME, BitwiseAndFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
BITWISE_OR_FUNC_NAME, BitwiseOrFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
BITSHIFT_LEFT_FUNC_NAME, BitShiftLeftFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
BITSHIFT_RIGHT_FUNC_NAME, BitShiftRightFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
CBRT_FUNC_NAME, CbrtFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
CEIL_FUNC_NAME, CeilFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
CEILING_FUNC_NAME, CeilFunction::getFunctionSet()));
catalogSet->createEntry(
std::make_unique<ScalarFunctionCatalogEntry>(COS_FUNC_NAME, CosFunction::getFunctionSet()));
catalogSet->createEntry(
std::make_unique<ScalarFunctionCatalogEntry>(COT_FUNC_NAME, CotFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
DEGREES_FUNC_NAME, DegreesFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
EVEN_FUNC_NAME, EvenFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
FACTORIAL_FUNC_NAME, FactorialFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
FLOOR_FUNC_NAME, FloorFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
GAMMA_FUNC_NAME, GammaFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
LGAMMA_FUNC_NAME, LgammaFunction::getFunctionSet()));
catalogSet->createEntry(
std::make_unique<ScalarFunctionCatalogEntry>(LN_FUNC_NAME, LnFunction::getFunctionSet()));
catalogSet->createEntry(
std::make_unique<ScalarFunctionCatalogEntry>(LOG_FUNC_NAME, LogFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
LOG2_FUNC_NAME, Log2Function::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
LOG10_FUNC_NAME, LogFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
NEGATE_FUNC_NAME, NegateFunction::getFunctionSet()));
catalogSet->createEntry(
std::make_unique<ScalarFunctionCatalogEntry>(PI_FUNC_NAME, PiFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
POW_FUNC_NAME, PowerFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
RADIANS_FUNC_NAME, RadiansFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
ROUND_FUNC_NAME, RoundFunction::getFunctionSet()));
catalogSet->createEntry(
std::make_unique<ScalarFunctionCatalogEntry>(SIN_FUNC_NAME, SinFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
SIGN_FUNC_NAME, SignFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
SQRT_FUNC_NAME, SqrtFunction::getFunctionSet()));
catalogSet->createEntry(
std::make_unique<ScalarFunctionCatalogEntry>(TAN_FUNC_NAME, TanFunction::getFunctionSet()));
}

void BuiltInFunctionsUtils::registerDateFunctions(CatalogSet* catalogSet) {
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
DATE_PART_FUNC_NAME, DatePartFunction::getFunctionSet()));
Expand Down Expand Up @@ -1064,6 +983,15 @@ void BuiltInFunctionsUtils::registerTableFunctions(CatalogSet* catalogSet) {
READ_FTABLE_FUNC_NAME, FTableScan::getFunctionSet()));
}

void BuiltInFunctionsUtils::registerFunctions(catalog::CatalogSet* catalogSet) {
auto functions = FunctionCollection::getFunctions();
for (auto i = 0u; functions[i].name != nullptr; ++i) {
auto functionSet = functions[i].getFunctionSetFunc();
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
functions[i].name, std::move(functionSet)));
}
}

static std::string getFunctionMatchFailureMsg(const std::string name,
const std::vector<LogicalType>& inputTypes, const std::string& supportedInputs,
bool isDistinct = false) {
Expand Down
46 changes: 46 additions & 0 deletions src/function/function_collection.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#include "function/function_collection.h"

#include "function/arithmetic/vector_arithmetic_functions.h"

namespace kuzu {
namespace function {

#define SCALAR_FUNCTION(_PARAM) \
{ _PARAM::name, _PARAM::getFunctionSet }
#define SCALAR_FUNCTION_ALIAS(_PARAM) \
{ _PARAM::alias, _PARAM::getFunctionSet }
#define FINAL_FUNCTION \
{ nullptr, nullptr }

FunctionCollection* FunctionCollection::getFunctions() {
static FunctionCollection functions[] = {

// Arithmetic Functions
SCALAR_FUNCTION(AddFunction), SCALAR_FUNCTION(SubtractFunction),
SCALAR_FUNCTION(MultiplyFunction), SCALAR_FUNCTION(DivideFunction),
SCALAR_FUNCTION(ModuloFunction), SCALAR_FUNCTION(PowerFunction),
SCALAR_FUNCTION(AbsFunction), SCALAR_FUNCTION(AcosFunction), SCALAR_FUNCTION(AsinFunction),
SCALAR_FUNCTION(AtanFunction), SCALAR_FUNCTION(Atan2Function),
SCALAR_FUNCTION(BitwiseXorFunction), SCALAR_FUNCTION(BitwiseAndFunction),
SCALAR_FUNCTION(BitwiseOrFunction), SCALAR_FUNCTION(BitShiftLeftFunction),
SCALAR_FUNCTION(BitShiftRightFunction), SCALAR_FUNCTION(CbrtFunction),
SCALAR_FUNCTION(CeilFunction), SCALAR_FUNCTION_ALIAS(CeilFunction),
SCALAR_FUNCTION(CosFunction), SCALAR_FUNCTION(CotFunction),
SCALAR_FUNCTION(DegreesFunction), SCALAR_FUNCTION(EvenFunction),
SCALAR_FUNCTION(FactorialFunction), SCALAR_FUNCTION(FloorFunction),
SCALAR_FUNCTION(GammaFunction), SCALAR_FUNCTION(LgammaFunction),
SCALAR_FUNCTION(LnFunction), SCALAR_FUNCTION(LogFunction),
SCALAR_FUNCTION_ALIAS(LogFunction), SCALAR_FUNCTION(Log2Function),
SCALAR_FUNCTION(NegateFunction), SCALAR_FUNCTION(PiFunction),
SCALAR_FUNCTION_ALIAS(PowerFunction), SCALAR_FUNCTION(RadiansFunction),
SCALAR_FUNCTION(RoundFunction), SCALAR_FUNCTION(SinFunction), SCALAR_FUNCTION(SignFunction),
SCALAR_FUNCTION(SqrtFunction), SCALAR_FUNCTION(TanFunction),

// End of array
FINAL_FUNCTION};

return functions;
}

} // namespace function
} // namespace kuzu
1 change: 1 addition & 0 deletions src/function/pattern/id_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "binder/expression/node_expression.h"
#include "binder/expression/rel_expression.h"
#include "binder/expression_binder.h"
#include "common/cast.h"
#include "function/rewrite_function.h"
#include "function/schema/vector_node_rel_functions.h"

Expand Down
Loading

0 comments on commit 8fa40d6

Please sign in to comment.