Skip to content

Commit

Permalink
Refactor scalar function registration (#3119)
Browse files Browse the repository at this point in the history
Refactor comparison functions

Refactor date functions

Refactor timestamp functions

Refactor interval functions

Refactor blob functions

Refactor UUID functions

Refactor struct functions

Refactor map functions

Refactor union functions
  • Loading branch information
manh9203 committed Mar 25, 2024
1 parent 3813eed commit 53ef58e
Show file tree
Hide file tree
Showing 25 changed files with 294 additions and 375 deletions.
3 changes: 2 additions & 1 deletion src/binder/bind_expression/bind_property_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "common/cast.h"
#include "common/exception/binder.h"
#include "common/string_format.h"
#include "function/struct/vector_struct_functions.h"
#include "parser/expression/parsed_property_expression.h"

using namespace kuzu::common;
Expand Down Expand Up @@ -103,7 +104,7 @@ std::shared_ptr<Expression> ExpressionBinder::bindStructPropertyExpression(
std::shared_ptr<Expression> child, const std::string& propertyName) {
auto children =
expression_vector{std::move(child), createStringLiteralExpression(propertyName)};
return bindScalarFunctionExpression(children, STRUCT_EXTRACT_FUNC_NAME);
return bindScalarFunctionExpression(children, function::StructExtractFunctions::name);
}

} // namespace binder
Expand Down
3 changes: 2 additions & 1 deletion src/binder/expression_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "binder/expression/rel_expression.h"
#include "binder/expression/subquery_expression.h"
#include "common/cast.h"
#include "function/uuid/vector_uuid_functions.h"

using namespace kuzu::common;

Expand Down Expand Up @@ -107,7 +108,7 @@ bool ExpressionVisitor::isRandom(const Expression& expression) {
return false;
}
auto& funcExpr = ku_dynamic_cast<const Expression&, const FunctionExpression&>(expression);
if (funcExpr.getFunctionName() == GEN_RANDOM_UUID_FUNC_NAME) {
if (funcExpr.getFunctionName() == function::GenRandomUUIDFunction::name) {
return true;
}
for (auto& child : ExpressionChildrenCollector::collectChildren(expression)) {
Expand Down
15 changes: 9 additions & 6 deletions src/common/expression_type.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#include "common/enums/expression_type.h"

#include "common/assert.h"
#include "function/comparison/vector_comparison_functions.h"

using namespace kuzu::function;

namespace kuzu {
namespace common {
Expand Down Expand Up @@ -54,17 +57,17 @@ std::string expressionTypeToString(ExpressionType type) {
case ExpressionType::NOT:
return "NOT";
case ExpressionType::EQUALS:
return EQUALS_FUNC_NAME;
return EqualsFunction::name;
case ExpressionType::NOT_EQUALS:
return NOT_EQUALS_FUNC_NAME;
return NotEqualsFunction::name;
case ExpressionType::GREATER_THAN:
return GREATER_THAN_FUNC_NAME;
return GreaterThanFunction::name;
case ExpressionType::GREATER_THAN_EQUALS:
return GREATER_THAN_EQUALS_FUNC_NAME;
return GreaterThanEqualsFunction::name;
case ExpressionType::LESS_THAN:
return LESS_THAN_FUNC_NAME;
return LessThanFunction::name;
case ExpressionType::LESS_THAN_EQUALS:
return LESS_THAN_EQUALS_FUNC_NAME;
return LessThanEqualsFunction::name;
case ExpressionType::IS_NULL:
return "IS_NULL";
case ExpressionType::IS_NOT_NULL:
Expand Down
131 changes: 1 addition & 130 deletions src/function/built_in_function_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,12 @@
#include "function/aggregate/count_star.h"
#include "function/aggregate_function.h"
#include "function/arithmetic/vector_arithmetic_functions.h"
#include "function/blob/vector_blob_functions.h"
#include "function/comparison/vector_comparison_functions.h"
#include "function/date/vector_date_functions.h"
#include "function/function_collection.h"
#include "function/interval/vector_interval_functions.h"
#include "function/list/vector_list_functions.h"
#include "function/map/vector_map_functions.h"
#include "function/path/vector_path_functions.h"
#include "function/rdf/vector_rdf_functions.h"
#include "function/scalar_function.h"
#include "function/schema/vector_node_rel_functions.h"
#include "function/struct/vector_struct_functions.h"
#include "function/table/call_functions.h"
#include "function/timestamp/vector_timestamp_functions.h"
#include "function/union/vector_union_functions.h"
#include "function/uuid/vector_uuid_functions.h"
#include "processor/operator/persistent/reader/csv/parallel_csv_reader.h"
#include "processor/operator/persistent/reader/csv/serial_csv_reader.h"
#include "processor/operator/persistent/reader/npy/npy_reader.h"
Expand Down Expand Up @@ -56,17 +47,8 @@ void BuiltInFunctionsUtils::createFunctions(CatalogSet* catalogSet) {
}

void BuiltInFunctionsUtils::registerScalarFunctions(CatalogSet* catalogSet) {
registerComparisonFunctions(catalogSet);
registerDateFunctions(catalogSet);
registerTimestampFunctions(catalogSet);
registerIntervalFunctions(catalogSet);
registerStructFunctions(catalogSet);
registerMapFunctions(catalogSet);
registerUnionFunctions(catalogSet);
registerNodeRelFunctions(catalogSet);
registerPathFunctions(catalogSet);
registerBlobFunctions(catalogSet);
registerUUIDFunctions(catalogSet);
registerRdfFunctions(catalogSet);
}

Expand Down Expand Up @@ -533,117 +515,6 @@ void BuiltInFunctionsUtils::validateSpecialCases(std::vector<Function*>& candida
}
}

void BuiltInFunctionsUtils::registerComparisonFunctions(CatalogSet* catalogSet) {
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
EQUALS_FUNC_NAME, EqualsFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
NOT_EQUALS_FUNC_NAME, NotEqualsFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
GREATER_THAN_FUNC_NAME, GreaterThanFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
GREATER_THAN_EQUALS_FUNC_NAME, GreaterThanEqualsFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
LESS_THAN_FUNC_NAME, LessThanFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
LESS_THAN_EQUALS_FUNC_NAME, LessThanEqualsFunction::getFunctionSet()));
}

void BuiltInFunctionsUtils::registerDateFunctions(CatalogSet* catalogSet) {
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
DATE_PART_FUNC_NAME, DatePartFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
DATEPART_FUNC_NAME, DatePartFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
DATE_TRUNC_FUNC_NAME, DateTruncFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
DATETRUNC_FUNC_NAME, DateTruncFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
DAYNAME_FUNC_NAME, DayNameFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
GREATEST_FUNC_NAME, GreatestFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
LAST_DAY_FUNC_NAME, LastDayFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
LEAST_FUNC_NAME, LeastFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
MAKE_DATE_FUNC_NAME, MakeDateFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
MONTHNAME_FUNC_NAME, MonthNameFunction::getFunctionSet()));
}

void BuiltInFunctionsUtils::registerTimestampFunctions(CatalogSet* catalogSet) {
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
CENTURY_FUNC_NAME, CenturyFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
EPOCH_MS_FUNC_NAME, EpochMsFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
TO_TIMESTAMP_FUNC_NAME, ToTimestampFunction::getFunctionSet()));
}

void BuiltInFunctionsUtils::registerIntervalFunctions(CatalogSet* catalogSet) {
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
TO_YEARS_FUNC_NAME, ToYearsFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
TO_MONTHS_FUNC_NAME, ToMonthsFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
TO_DAYS_FUNC_NAME, ToDaysFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
TO_HOURS_FUNC_NAME, ToHoursFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
TO_MINUTES_FUNC_NAME, ToMinutesFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
TO_SECONDS_FUNC_NAME, ToSecondsFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
TO_MILLISECONDS_FUNC_NAME, ToMillisecondsFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
TO_MICROSECONDS_FUNC_NAME, ToMicrosecondsFunction::getFunctionSet()));
}

void BuiltInFunctionsUtils::registerBlobFunctions(CatalogSet* catalogSet) {
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
OCTET_LENGTH_FUNC_NAME, OctetLengthFunctions::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
ENCODE_FUNC_NAME, EncodeFunctions::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
DECODE_FUNC_NAME, DecodeFunctions::getFunctionSet()));
}

void BuiltInFunctionsUtils::registerUUIDFunctions(CatalogSet* catalogSet) {
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
GEN_RANDOM_UUID_FUNC_NAME, GenRandomUUIDFunction::getFunctionSet()));
}

void BuiltInFunctionsUtils::registerStructFunctions(CatalogSet* catalogSet) {
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
STRUCT_PACK_FUNC_NAME, StructPackFunctions::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
STRUCT_EXTRACT_FUNC_NAME, StructExtractFunctions::getFunctionSet()));
}

void BuiltInFunctionsUtils::registerMapFunctions(CatalogSet* catalogSet) {
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
MAP_CREATION_FUNC_NAME, MapCreationFunctions::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
MAP_EXTRACT_FUNC_NAME, MapExtractFunctions::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
ELEMENT_AT_FUNC_NAME, MapExtractFunctions::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
CARDINALITY_FUNC_NAME, SizeFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
MAP_KEYS_FUNC_NAME, MapKeysFunctions::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
MAP_VALUES_FUNC_NAME, MapValuesFunctions::getFunctionSet()));
}

void BuiltInFunctionsUtils::registerUnionFunctions(CatalogSet* catalogSet) {
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
UNION_VALUE_FUNC_NAME, UnionValueFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
UNION_TAG_FUNC_NAME, UnionTagFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
UNION_EXTRACT_FUNC_NAME, UnionExtractFunction::getFunctionSet()));
}

void BuiltInFunctionsUtils::registerNodeRelFunctions(CatalogSet* catalogSet) {
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
OFFSET_FUNC_NAME, OffsetFunction::getFunctionSet()));
Expand Down
50 changes: 50 additions & 0 deletions src/function/function_collection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,18 @@

#include "function/arithmetic/vector_arithmetic_functions.h"
#include "function/array/vector_array_functions.h"
#include "function/blob/vector_blob_functions.h"
#include "function/cast/vector_cast_functions.h"
#include "function/comparison/vector_comparison_functions.h"
#include "function/date/vector_date_functions.h"
#include "function/interval/vector_interval_functions.h"
#include "function/list/vector_list_functions.h"
#include "function/map/vector_map_functions.h"
#include "function/string/vector_string_functions.h"
#include "function/struct/vector_struct_functions.h"
#include "function/timestamp/vector_timestamp_functions.h"
#include "function/union/vector_union_functions.h"
#include "function/uuid/vector_uuid_functions.h"

namespace kuzu {
namespace function {
Expand Down Expand Up @@ -95,6 +104,47 @@ FunctionCollection* FunctionCollection::getFunctions() {
SCALAR_FUNCTION(CastToInt128Function), SCALAR_FUNCTION(CastToBoolFunction),
SCALAR_FUNCTION(CastAnyFunction),

// Comparison functions
SCALAR_FUNCTION(EqualsFunction), SCALAR_FUNCTION(NotEqualsFunction),
SCALAR_FUNCTION(GreaterThanFunction), SCALAR_FUNCTION(GreaterThanEqualsFunction),
SCALAR_FUNCTION(LessThanFunction), SCALAR_FUNCTION(LessThanEqualsFunction),

// Date functions
SCALAR_FUNCTION(DatePartFunction), SCALAR_FUNCTION_ALIAS(DatePartFunction),
SCALAR_FUNCTION(DateTruncFunction), SCALAR_FUNCTION_ALIAS(DateTruncFunction),
SCALAR_FUNCTION(DayNameFunction), SCALAR_FUNCTION(GreatestFunction),
SCALAR_FUNCTION(LastDayFunction), SCALAR_FUNCTION(LeastFunction),
SCALAR_FUNCTION(MakeDateFunction), SCALAR_FUNCTION(MonthNameFunction),

// Timestamp functions
SCALAR_FUNCTION(CenturyFunction), SCALAR_FUNCTION(EpochMsFunction),
SCALAR_FUNCTION(ToTimestampFunction),

// Interval functions
SCALAR_FUNCTION(ToYearsFunction), SCALAR_FUNCTION(ToMonthsFunction),
SCALAR_FUNCTION(ToDaysFunction), SCALAR_FUNCTION(ToHoursFunction),
SCALAR_FUNCTION(ToMinutesFunction), SCALAR_FUNCTION(ToSecondsFunction),
SCALAR_FUNCTION(ToMillisecondsFunction), SCALAR_FUNCTION(ToMicrosecondsFunction),

// Blob functions
SCALAR_FUNCTION(OctetLengthFunctions), SCALAR_FUNCTION(EncodeFunctions),
SCALAR_FUNCTION(DecodeFunctions),

// UUID functions
SCALAR_FUNCTION(GenRandomUUIDFunction),

// Struct functions
SCALAR_FUNCTION(StructPackFunctions), SCALAR_FUNCTION(StructExtractFunctions),

// Map functions
SCALAR_FUNCTION(MapCreationFunctions), SCALAR_FUNCTION(MapExtractFunctions),
SCALAR_FUNCTION_ALIAS(MapExtractFunctions), SCALAR_FUNCTION_ALIAS(SizeFunction),
SCALAR_FUNCTION(MapKeysFunctions), SCALAR_FUNCTION(MapValuesFunctions),

// Union functions
SCALAR_FUNCTION(UnionValueFunction), SCALAR_FUNCTION(UnionTagFunction),
SCALAR_FUNCTION(UnionExtractFunction),

// End of array
FINAL_FUNCTION};

Expand Down
3 changes: 2 additions & 1 deletion src/function/pattern/id_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "common/cast.h"
#include "function/rewrite_function.h"
#include "function/schema/vector_node_rel_functions.h"
#include "function/struct/vector_struct_functions.h"

using namespace kuzu::common;
using namespace kuzu::binder;
Expand All @@ -28,7 +29,7 @@ static std::shared_ptr<binder::Expression> rewriteFunc(
auto key = Value(LogicalType::STRING(), InternalKeyword::ID);
auto keyExpr = binder->createLiteralExpression(key.copy());
auto newParams = expression_vector{params[0], keyExpr};
return binder->bindScalarFunctionExpression(newParams, STRUCT_EXTRACT_FUNC_NAME);
return binder->bindScalarFunctionExpression(newParams, StructExtractFunctions::name);
}

function_set IDFunction::getFunctionSet() {
Expand Down
12 changes: 6 additions & 6 deletions src/function/vector_blob_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@ namespace function {

function_set OctetLengthFunctions::getFunctionSet() {
function_set definitions;
definitions.push_back(make_unique<ScalarFunction>(OCTET_LENGTH_FUNC_NAME,
std::vector<LogicalTypeID>{LogicalTypeID::BLOB}, LogicalTypeID::INT64,
ScalarFunction::UnaryExecFunction<blob_t, int64_t, OctetLength>, nullptr, nullptr, nullptr,
false /* isVarLength */));
definitions.push_back(
make_unique<ScalarFunction>(name, std::vector<LogicalTypeID>{LogicalTypeID::BLOB},
LogicalTypeID::INT64, ScalarFunction::UnaryExecFunction<blob_t, int64_t, OctetLength>,
nullptr, nullptr, nullptr, false /* isVarLength */));
return definitions;
}

function_set EncodeFunctions::getFunctionSet() {
function_set definitions;
definitions.push_back(make_unique<ScalarFunction>(ENCODE_FUNC_NAME,
definitions.push_back(make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{LogicalTypeID::STRING}, LogicalTypeID::BLOB,
ScalarFunction::UnaryStringExecFunction<ku_string_t, blob_t, Encode>, nullptr,
false /* isVarLength */));
Expand All @@ -30,7 +30,7 @@ function_set EncodeFunctions::getFunctionSet() {

function_set DecodeFunctions::getFunctionSet() {
function_set definitions;
definitions.push_back(make_unique<ScalarFunction>(DECODE_FUNC_NAME,
definitions.push_back(make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{LogicalTypeID::BLOB}, LogicalTypeID::STRING,
ScalarFunction::UnaryStringExecFunction<blob_t, ku_string_t, Decode>, nullptr,
false /* isVarLength */));
Expand Down
Loading

0 comments on commit 53ef58e

Please sign in to comment.