From 53ef58e8682b9988319f8b957fb6f4e2b5d3b16a Mon Sep 17 00:00:00 2001 From: Manh Dinh Date: Mon, 25 Mar 2024 12:03:54 -0400 Subject: [PATCH] Refactor scalar function registration (#3119) Refactor comparison functions Refactor date functions Refactor timestamp functions Refactor interval functions Refactor blob functions Refactor UUID functions Refactor struct functions Refactor map functions Refactor union functions --- .../bind_property_expression.cpp | 3 +- src/binder/expression_visitor.cpp | 3 +- src/common/expression_type.cpp | 15 +- src/function/built_in_function_utils.cpp | 131 +----------------- src/function/function_collection.cpp | 50 +++++++ src/function/pattern/id_function.cpp | 3 +- src/function/vector_blob_functions.cpp | 12 +- src/function/vector_date_functions.cpp | 40 +++--- src/function/vector_map_functions.cpp | 57 ++++---- src/function/vector_struct_functions.cpp | 73 +++++----- src/function/vector_timestamp_functions.cpp | 6 +- src/function/vector_union_functions.cpp | 30 ++-- src/function/vector_uuid_functions.cpp | 5 +- src/include/common/enums/expression_type.h | 60 -------- .../function/blob/vector_blob_functions.h | 6 + .../function/built_in_function_utils.h | 9 -- .../comparison/vector_comparison_functions.h | 40 ++++-- .../function/date/vector_date_functions.h | 20 +++ .../interval/vector_interval_functions.h | 52 ++++--- .../function/map/vector_map_functions.h | 18 +-- .../function/struct/vector_struct_functions.h | 11 +- .../timestamp/vector_timestamp_functions.h | 6 + .../function/union/vector_union_functions.h | 14 +- .../function/uuid/vector_uuid_functions.h | 2 + src/parser/transform/transform_expression.cpp | 3 +- 25 files changed, 294 insertions(+), 375 deletions(-) diff --git a/src/binder/bind_expression/bind_property_expression.cpp b/src/binder/bind_expression/bind_property_expression.cpp index e52303dfe1..631e345384 100644 --- a/src/binder/bind_expression/bind_property_expression.cpp +++ b/src/binder/bind_expression/bind_property_expression.cpp @@ -6,6 +6,7 @@ #include "common/cast.h" #include "common/exception/binder.h" #include "common/string_format.h" +#include "function/struct/vector_struct_functions.h" #include "parser/expression/parsed_property_expression.h" using namespace kuzu::common; @@ -103,7 +104,7 @@ std::shared_ptr ExpressionBinder::bindStructPropertyExpression( std::shared_ptr child, const std::string& propertyName) { auto children = expression_vector{std::move(child), createStringLiteralExpression(propertyName)}; - return bindScalarFunctionExpression(children, STRUCT_EXTRACT_FUNC_NAME); + return bindScalarFunctionExpression(children, function::StructExtractFunctions::name); } } // namespace binder diff --git a/src/binder/expression_visitor.cpp b/src/binder/expression_visitor.cpp index 837b83c830..1ed04359fb 100644 --- a/src/binder/expression_visitor.cpp +++ b/src/binder/expression_visitor.cpp @@ -7,6 +7,7 @@ #include "binder/expression/rel_expression.h" #include "binder/expression/subquery_expression.h" #include "common/cast.h" +#include "function/uuid/vector_uuid_functions.h" using namespace kuzu::common; @@ -107,7 +108,7 @@ bool ExpressionVisitor::isRandom(const Expression& expression) { return false; } auto& funcExpr = ku_dynamic_cast(expression); - if (funcExpr.getFunctionName() == GEN_RANDOM_UUID_FUNC_NAME) { + if (funcExpr.getFunctionName() == function::GenRandomUUIDFunction::name) { return true; } for (auto& child : ExpressionChildrenCollector::collectChildren(expression)) { diff --git a/src/common/expression_type.cpp b/src/common/expression_type.cpp index e24bb21c4e..4e33b2e157 100644 --- a/src/common/expression_type.cpp +++ b/src/common/expression_type.cpp @@ -1,6 +1,9 @@ #include "common/enums/expression_type.h" #include "common/assert.h" +#include "function/comparison/vector_comparison_functions.h" + +using namespace kuzu::function; namespace kuzu { namespace common { @@ -54,17 +57,17 @@ std::string expressionTypeToString(ExpressionType type) { case ExpressionType::NOT: return "NOT"; case ExpressionType::EQUALS: - return EQUALS_FUNC_NAME; + return EqualsFunction::name; case ExpressionType::NOT_EQUALS: - return NOT_EQUALS_FUNC_NAME; + return NotEqualsFunction::name; case ExpressionType::GREATER_THAN: - return GREATER_THAN_FUNC_NAME; + return GreaterThanFunction::name; case ExpressionType::GREATER_THAN_EQUALS: - return GREATER_THAN_EQUALS_FUNC_NAME; + return GreaterThanEqualsFunction::name; case ExpressionType::LESS_THAN: - return LESS_THAN_FUNC_NAME; + return LessThanFunction::name; case ExpressionType::LESS_THAN_EQUALS: - return LESS_THAN_EQUALS_FUNC_NAME; + return LessThanEqualsFunction::name; case ExpressionType::IS_NULL: return "IS_NULL"; case ExpressionType::IS_NOT_NULL: diff --git a/src/function/built_in_function_utils.cpp b/src/function/built_in_function_utils.cpp index cb6feee179..aed466bc1c 100644 --- a/src/function/built_in_function_utils.cpp +++ b/src/function/built_in_function_utils.cpp @@ -11,21 +11,12 @@ #include "function/aggregate/count_star.h" #include "function/aggregate_function.h" #include "function/arithmetic/vector_arithmetic_functions.h" -#include "function/blob/vector_blob_functions.h" -#include "function/comparison/vector_comparison_functions.h" -#include "function/date/vector_date_functions.h" #include "function/function_collection.h" -#include "function/interval/vector_interval_functions.h" -#include "function/list/vector_list_functions.h" -#include "function/map/vector_map_functions.h" #include "function/path/vector_path_functions.h" #include "function/rdf/vector_rdf_functions.h" +#include "function/scalar_function.h" #include "function/schema/vector_node_rel_functions.h" -#include "function/struct/vector_struct_functions.h" #include "function/table/call_functions.h" -#include "function/timestamp/vector_timestamp_functions.h" -#include "function/union/vector_union_functions.h" -#include "function/uuid/vector_uuid_functions.h" #include "processor/operator/persistent/reader/csv/parallel_csv_reader.h" #include "processor/operator/persistent/reader/csv/serial_csv_reader.h" #include "processor/operator/persistent/reader/npy/npy_reader.h" @@ -56,17 +47,8 @@ void BuiltInFunctionsUtils::createFunctions(CatalogSet* catalogSet) { } void BuiltInFunctionsUtils::registerScalarFunctions(CatalogSet* catalogSet) { - registerComparisonFunctions(catalogSet); - registerDateFunctions(catalogSet); - registerTimestampFunctions(catalogSet); - registerIntervalFunctions(catalogSet); - registerStructFunctions(catalogSet); - registerMapFunctions(catalogSet); - registerUnionFunctions(catalogSet); registerNodeRelFunctions(catalogSet); registerPathFunctions(catalogSet); - registerBlobFunctions(catalogSet); - registerUUIDFunctions(catalogSet); registerRdfFunctions(catalogSet); } @@ -533,117 +515,6 @@ void BuiltInFunctionsUtils::validateSpecialCases(std::vector& candida } } -void BuiltInFunctionsUtils::registerComparisonFunctions(CatalogSet* catalogSet) { - catalogSet->createEntry(std::make_unique( - EQUALS_FUNC_NAME, EqualsFunction::getFunctionSet())); - catalogSet->createEntry(std::make_unique( - NOT_EQUALS_FUNC_NAME, NotEqualsFunction::getFunctionSet())); - catalogSet->createEntry(std::make_unique( - GREATER_THAN_FUNC_NAME, GreaterThanFunction::getFunctionSet())); - catalogSet->createEntry(std::make_unique( - GREATER_THAN_EQUALS_FUNC_NAME, GreaterThanEqualsFunction::getFunctionSet())); - catalogSet->createEntry(std::make_unique( - LESS_THAN_FUNC_NAME, LessThanFunction::getFunctionSet())); - catalogSet->createEntry(std::make_unique( - LESS_THAN_EQUALS_FUNC_NAME, LessThanEqualsFunction::getFunctionSet())); -} - -void BuiltInFunctionsUtils::registerDateFunctions(CatalogSet* catalogSet) { - catalogSet->createEntry(std::make_unique( - DATE_PART_FUNC_NAME, DatePartFunction::getFunctionSet())); - catalogSet->createEntry(std::make_unique( - DATEPART_FUNC_NAME, DatePartFunction::getFunctionSet())); - catalogSet->createEntry(std::make_unique( - DATE_TRUNC_FUNC_NAME, DateTruncFunction::getFunctionSet())); - catalogSet->createEntry(std::make_unique( - DATETRUNC_FUNC_NAME, DateTruncFunction::getFunctionSet())); - catalogSet->createEntry(std::make_unique( - DAYNAME_FUNC_NAME, DayNameFunction::getFunctionSet())); - catalogSet->createEntry(std::make_unique( - GREATEST_FUNC_NAME, GreatestFunction::getFunctionSet())); - catalogSet->createEntry(std::make_unique( - LAST_DAY_FUNC_NAME, LastDayFunction::getFunctionSet())); - catalogSet->createEntry(std::make_unique( - LEAST_FUNC_NAME, LeastFunction::getFunctionSet())); - catalogSet->createEntry(std::make_unique( - MAKE_DATE_FUNC_NAME, MakeDateFunction::getFunctionSet())); - catalogSet->createEntry(std::make_unique( - MONTHNAME_FUNC_NAME, MonthNameFunction::getFunctionSet())); -} - -void BuiltInFunctionsUtils::registerTimestampFunctions(CatalogSet* catalogSet) { - catalogSet->createEntry(std::make_unique( - CENTURY_FUNC_NAME, CenturyFunction::getFunctionSet())); - catalogSet->createEntry(std::make_unique( - EPOCH_MS_FUNC_NAME, EpochMsFunction::getFunctionSet())); - catalogSet->createEntry(std::make_unique( - TO_TIMESTAMP_FUNC_NAME, ToTimestampFunction::getFunctionSet())); -} - -void BuiltInFunctionsUtils::registerIntervalFunctions(CatalogSet* catalogSet) { - catalogSet->createEntry(std::make_unique( - TO_YEARS_FUNC_NAME, ToYearsFunction::getFunctionSet())); - catalogSet->createEntry(std::make_unique( - TO_MONTHS_FUNC_NAME, ToMonthsFunction::getFunctionSet())); - catalogSet->createEntry(std::make_unique( - TO_DAYS_FUNC_NAME, ToDaysFunction::getFunctionSet())); - catalogSet->createEntry(std::make_unique( - TO_HOURS_FUNC_NAME, ToHoursFunction::getFunctionSet())); - catalogSet->createEntry(std::make_unique( - TO_MINUTES_FUNC_NAME, ToMinutesFunction::getFunctionSet())); - catalogSet->createEntry(std::make_unique( - TO_SECONDS_FUNC_NAME, ToSecondsFunction::getFunctionSet())); - catalogSet->createEntry(std::make_unique( - TO_MILLISECONDS_FUNC_NAME, ToMillisecondsFunction::getFunctionSet())); - catalogSet->createEntry(std::make_unique( - TO_MICROSECONDS_FUNC_NAME, ToMicrosecondsFunction::getFunctionSet())); -} - -void BuiltInFunctionsUtils::registerBlobFunctions(CatalogSet* catalogSet) { - catalogSet->createEntry(std::make_unique( - OCTET_LENGTH_FUNC_NAME, OctetLengthFunctions::getFunctionSet())); - catalogSet->createEntry(std::make_unique( - ENCODE_FUNC_NAME, EncodeFunctions::getFunctionSet())); - catalogSet->createEntry(std::make_unique( - DECODE_FUNC_NAME, DecodeFunctions::getFunctionSet())); -} - -void BuiltInFunctionsUtils::registerUUIDFunctions(CatalogSet* catalogSet) { - catalogSet->createEntry(std::make_unique( - GEN_RANDOM_UUID_FUNC_NAME, GenRandomUUIDFunction::getFunctionSet())); -} - -void BuiltInFunctionsUtils::registerStructFunctions(CatalogSet* catalogSet) { - catalogSet->createEntry(std::make_unique( - STRUCT_PACK_FUNC_NAME, StructPackFunctions::getFunctionSet())); - catalogSet->createEntry(std::make_unique( - STRUCT_EXTRACT_FUNC_NAME, StructExtractFunctions::getFunctionSet())); -} - -void BuiltInFunctionsUtils::registerMapFunctions(CatalogSet* catalogSet) { - catalogSet->createEntry(std::make_unique( - MAP_CREATION_FUNC_NAME, MapCreationFunctions::getFunctionSet())); - catalogSet->createEntry(std::make_unique( - MAP_EXTRACT_FUNC_NAME, MapExtractFunctions::getFunctionSet())); - catalogSet->createEntry(std::make_unique( - ELEMENT_AT_FUNC_NAME, MapExtractFunctions::getFunctionSet())); - catalogSet->createEntry(std::make_unique( - CARDINALITY_FUNC_NAME, SizeFunction::getFunctionSet())); - catalogSet->createEntry(std::make_unique( - MAP_KEYS_FUNC_NAME, MapKeysFunctions::getFunctionSet())); - catalogSet->createEntry(std::make_unique( - MAP_VALUES_FUNC_NAME, MapValuesFunctions::getFunctionSet())); -} - -void BuiltInFunctionsUtils::registerUnionFunctions(CatalogSet* catalogSet) { - catalogSet->createEntry(std::make_unique( - UNION_VALUE_FUNC_NAME, UnionValueFunction::getFunctionSet())); - catalogSet->createEntry(std::make_unique( - UNION_TAG_FUNC_NAME, UnionTagFunction::getFunctionSet())); - catalogSet->createEntry(std::make_unique( - UNION_EXTRACT_FUNC_NAME, UnionExtractFunction::getFunctionSet())); -} - void BuiltInFunctionsUtils::registerNodeRelFunctions(CatalogSet* catalogSet) { catalogSet->createEntry(std::make_unique( OFFSET_FUNC_NAME, OffsetFunction::getFunctionSet())); diff --git a/src/function/function_collection.cpp b/src/function/function_collection.cpp index 0f4a78bc65..4fc4b13512 100644 --- a/src/function/function_collection.cpp +++ b/src/function/function_collection.cpp @@ -2,9 +2,18 @@ #include "function/arithmetic/vector_arithmetic_functions.h" #include "function/array/vector_array_functions.h" +#include "function/blob/vector_blob_functions.h" #include "function/cast/vector_cast_functions.h" +#include "function/comparison/vector_comparison_functions.h" +#include "function/date/vector_date_functions.h" +#include "function/interval/vector_interval_functions.h" #include "function/list/vector_list_functions.h" +#include "function/map/vector_map_functions.h" #include "function/string/vector_string_functions.h" +#include "function/struct/vector_struct_functions.h" +#include "function/timestamp/vector_timestamp_functions.h" +#include "function/union/vector_union_functions.h" +#include "function/uuid/vector_uuid_functions.h" namespace kuzu { namespace function { @@ -95,6 +104,47 @@ FunctionCollection* FunctionCollection::getFunctions() { SCALAR_FUNCTION(CastToInt128Function), SCALAR_FUNCTION(CastToBoolFunction), SCALAR_FUNCTION(CastAnyFunction), + // Comparison functions + SCALAR_FUNCTION(EqualsFunction), SCALAR_FUNCTION(NotEqualsFunction), + SCALAR_FUNCTION(GreaterThanFunction), SCALAR_FUNCTION(GreaterThanEqualsFunction), + SCALAR_FUNCTION(LessThanFunction), SCALAR_FUNCTION(LessThanEqualsFunction), + + // Date functions + SCALAR_FUNCTION(DatePartFunction), SCALAR_FUNCTION_ALIAS(DatePartFunction), + SCALAR_FUNCTION(DateTruncFunction), SCALAR_FUNCTION_ALIAS(DateTruncFunction), + SCALAR_FUNCTION(DayNameFunction), SCALAR_FUNCTION(GreatestFunction), + SCALAR_FUNCTION(LastDayFunction), SCALAR_FUNCTION(LeastFunction), + SCALAR_FUNCTION(MakeDateFunction), SCALAR_FUNCTION(MonthNameFunction), + + // Timestamp functions + SCALAR_FUNCTION(CenturyFunction), SCALAR_FUNCTION(EpochMsFunction), + SCALAR_FUNCTION(ToTimestampFunction), + + // Interval functions + SCALAR_FUNCTION(ToYearsFunction), SCALAR_FUNCTION(ToMonthsFunction), + SCALAR_FUNCTION(ToDaysFunction), SCALAR_FUNCTION(ToHoursFunction), + SCALAR_FUNCTION(ToMinutesFunction), SCALAR_FUNCTION(ToSecondsFunction), + SCALAR_FUNCTION(ToMillisecondsFunction), SCALAR_FUNCTION(ToMicrosecondsFunction), + + // Blob functions + SCALAR_FUNCTION(OctetLengthFunctions), SCALAR_FUNCTION(EncodeFunctions), + SCALAR_FUNCTION(DecodeFunctions), + + // UUID functions + SCALAR_FUNCTION(GenRandomUUIDFunction), + + // Struct functions + SCALAR_FUNCTION(StructPackFunctions), SCALAR_FUNCTION(StructExtractFunctions), + + // Map functions + SCALAR_FUNCTION(MapCreationFunctions), SCALAR_FUNCTION(MapExtractFunctions), + SCALAR_FUNCTION_ALIAS(MapExtractFunctions), SCALAR_FUNCTION_ALIAS(SizeFunction), + SCALAR_FUNCTION(MapKeysFunctions), SCALAR_FUNCTION(MapValuesFunctions), + + // Union functions + SCALAR_FUNCTION(UnionValueFunction), SCALAR_FUNCTION(UnionTagFunction), + SCALAR_FUNCTION(UnionExtractFunction), + // End of array FINAL_FUNCTION}; diff --git a/src/function/pattern/id_function.cpp b/src/function/pattern/id_function.cpp index 2a8dc679f1..a051313435 100644 --- a/src/function/pattern/id_function.cpp +++ b/src/function/pattern/id_function.cpp @@ -5,6 +5,7 @@ #include "common/cast.h" #include "function/rewrite_function.h" #include "function/schema/vector_node_rel_functions.h" +#include "function/struct/vector_struct_functions.h" using namespace kuzu::common; using namespace kuzu::binder; @@ -28,7 +29,7 @@ static std::shared_ptr rewriteFunc( auto key = Value(LogicalType::STRING(), InternalKeyword::ID); auto keyExpr = binder->createLiteralExpression(key.copy()); auto newParams = expression_vector{params[0], keyExpr}; - return binder->bindScalarFunctionExpression(newParams, STRUCT_EXTRACT_FUNC_NAME); + return binder->bindScalarFunctionExpression(newParams, StructExtractFunctions::name); } function_set IDFunction::getFunctionSet() { diff --git a/src/function/vector_blob_functions.cpp b/src/function/vector_blob_functions.cpp index b86dd04df9..ecc64a5d79 100644 --- a/src/function/vector_blob_functions.cpp +++ b/src/function/vector_blob_functions.cpp @@ -12,16 +12,16 @@ namespace function { function_set OctetLengthFunctions::getFunctionSet() { function_set definitions; - definitions.push_back(make_unique(OCTET_LENGTH_FUNC_NAME, - std::vector{LogicalTypeID::BLOB}, LogicalTypeID::INT64, - ScalarFunction::UnaryExecFunction, nullptr, nullptr, nullptr, - false /* isVarLength */)); + definitions.push_back( + make_unique(name, std::vector{LogicalTypeID::BLOB}, + LogicalTypeID::INT64, ScalarFunction::UnaryExecFunction, + nullptr, nullptr, nullptr, false /* isVarLength */)); return definitions; } function_set EncodeFunctions::getFunctionSet() { function_set definitions; - definitions.push_back(make_unique(ENCODE_FUNC_NAME, + definitions.push_back(make_unique(name, std::vector{LogicalTypeID::STRING}, LogicalTypeID::BLOB, ScalarFunction::UnaryStringExecFunction, nullptr, false /* isVarLength */)); @@ -30,7 +30,7 @@ function_set EncodeFunctions::getFunctionSet() { function_set DecodeFunctions::getFunctionSet() { function_set definitions; - definitions.push_back(make_unique(DECODE_FUNC_NAME, + definitions.push_back(make_unique(name, std::vector{LogicalTypeID::BLOB}, LogicalTypeID::STRING, ScalarFunction::UnaryStringExecFunction, nullptr, false /* isVarLength */)); diff --git a/src/function/vector_date_functions.cpp b/src/function/vector_date_functions.cpp index c3fa3c12b1..1b8397521c 100644 --- a/src/function/vector_date_functions.cpp +++ b/src/function/vector_date_functions.cpp @@ -10,15 +10,15 @@ namespace function { function_set DatePartFunction::getFunctionSet() { function_set result; - result.push_back(make_unique(DATE_PART_FUNC_NAME, + result.push_back(make_unique(name, std::vector{LogicalTypeID::STRING, LogicalTypeID::DATE}, LogicalTypeID::INT64, ScalarFunction::BinaryExecFunction)); - result.push_back(make_unique(DATE_PART_FUNC_NAME, + result.push_back(make_unique(name, std::vector{LogicalTypeID::STRING, LogicalTypeID::TIMESTAMP}, LogicalTypeID::INT64, ScalarFunction::BinaryExecFunction)); - result.push_back(make_unique(DATE_PART_FUNC_NAME, + result.push_back(make_unique(name, std::vector{LogicalTypeID::STRING, LogicalTypeID::INTERVAL}, LogicalTypeID::INT64, ScalarFunction::BinaryExecFunction)); @@ -27,10 +27,10 @@ function_set DatePartFunction::getFunctionSet() { function_set DateTruncFunction::getFunctionSet() { function_set result; - result.push_back(make_unique(DATE_TRUNC_FUNC_NAME, + result.push_back(make_unique(name, std::vector{LogicalTypeID::STRING, LogicalTypeID::DATE}, LogicalTypeID::DATE, ScalarFunction::BinaryExecFunction)); - result.push_back(make_unique(DATE_TRUNC_FUNC_NAME, + result.push_back(make_unique(name, std::vector{LogicalTypeID::STRING, LogicalTypeID::TIMESTAMP}, LogicalTypeID::TIMESTAMP, ScalarFunction::BinaryExecFunction)); @@ -39,10 +39,10 @@ function_set DateTruncFunction::getFunctionSet() { function_set DayNameFunction::getFunctionSet() { function_set result; - result.push_back(make_unique(DAYNAME_FUNC_NAME, + result.push_back(make_unique(name, std::vector{LogicalTypeID::DATE}, LogicalTypeID::STRING, ScalarFunction::UnaryExecFunction)); - result.push_back(make_unique(DAYNAME_FUNC_NAME, + result.push_back(make_unique(name, std::vector{LogicalTypeID::TIMESTAMP}, LogicalTypeID::STRING, ScalarFunction::UnaryExecFunction)); return result; @@ -50,10 +50,10 @@ function_set DayNameFunction::getFunctionSet() { function_set GreatestFunction::getFunctionSet() { function_set result; - result.push_back(make_unique(GREATEST_FUNC_NAME, + result.push_back(make_unique(name, std::vector{LogicalTypeID::DATE, LogicalTypeID::DATE}, LogicalTypeID::DATE, ScalarFunction::BinaryExecFunction)); - result.push_back(make_unique(GREATEST_FUNC_NAME, + result.push_back(make_unique(name, std::vector{LogicalTypeID::TIMESTAMP, LogicalTypeID::TIMESTAMP}, LogicalTypeID::TIMESTAMP, ScalarFunction::BinaryExecFunction)); @@ -62,21 +62,21 @@ function_set GreatestFunction::getFunctionSet() { function_set LastDayFunction::getFunctionSet() { function_set result; - result.push_back(make_unique(LAST_DAY_FUNC_NAME, - std::vector{LogicalTypeID::DATE}, LogicalTypeID::DATE, - ScalarFunction::UnaryExecFunction)); - result.push_back(make_unique(LAST_DAY_FUNC_NAME, - std::vector{LogicalTypeID::TIMESTAMP}, LogicalTypeID::DATE, - ScalarFunction::UnaryExecFunction)); + result.push_back( + make_unique(name, std::vector{LogicalTypeID::DATE}, + LogicalTypeID::DATE, ScalarFunction::UnaryExecFunction)); + result.push_back( + make_unique(name, std::vector{LogicalTypeID::TIMESTAMP}, + LogicalTypeID::DATE, ScalarFunction::UnaryExecFunction)); return result; } function_set LeastFunction::getFunctionSet() { function_set result; - result.push_back(make_unique(LEAST_FUNC_NAME, + result.push_back(make_unique(name, std::vector{LogicalTypeID::DATE, LogicalTypeID::DATE}, LogicalTypeID::DATE, ScalarFunction::BinaryExecFunction)); - result.push_back(make_unique(LEAST_FUNC_NAME, + result.push_back(make_unique(name, std::vector{LogicalTypeID::TIMESTAMP, LogicalTypeID::TIMESTAMP}, LogicalTypeID::TIMESTAMP, ScalarFunction::BinaryExecFunction)); @@ -85,7 +85,7 @@ function_set LeastFunction::getFunctionSet() { function_set MakeDateFunction::getFunctionSet() { function_set result; - result.push_back(make_unique(MAKE_DATE_FUNC_NAME, + result.push_back(make_unique(name, std::vector{ LogicalTypeID::INT64, LogicalTypeID::INT64, LogicalTypeID::INT64}, LogicalTypeID::DATE, @@ -95,10 +95,10 @@ function_set MakeDateFunction::getFunctionSet() { function_set MonthNameFunction::getFunctionSet() { function_set result; - result.push_back(make_unique(MONTHNAME_FUNC_NAME, + result.push_back(make_unique(name, std::vector{LogicalTypeID::DATE}, LogicalTypeID::STRING, ScalarFunction::UnaryExecFunction)); - result.push_back(make_unique(MONTHNAME_FUNC_NAME, + result.push_back(make_unique(name, std::vector{LogicalTypeID::TIMESTAMP}, LogicalTypeID::STRING, ScalarFunction::UnaryExecFunction)); return result; diff --git a/src/function/vector_map_functions.cpp b/src/function/vector_map_functions.cpp index fc3f725995..f1dbf0bceb 100644 --- a/src/function/vector_map_functions.cpp +++ b/src/function/vector_map_functions.cpp @@ -11,17 +11,7 @@ using namespace kuzu::common; namespace kuzu { namespace function { -function_set MapCreationFunctions::getFunctionSet() { - auto execFunc = ScalarFunction::BinaryExecListStructFunction; - function_set functionSet; - functionSet.push_back(make_unique(MAP_CREATION_FUNC_NAME, - std::vector{LogicalTypeID::VAR_LIST, LogicalTypeID::VAR_LIST}, - LogicalTypeID::MAP, execFunc, nullptr, bindFunc, false /* isVarLength */)); - return functionSet; -} - -std::unique_ptr MapCreationFunctions::bindFunc( +static std::unique_ptr MapCreationFunctionsBindFunc( const binder::expression_vector& arguments, kuzu::function::Function* /*function*/) { auto keyType = VarListType::getChildType(&arguments[0]->dataType); auto valueType = VarListType::getChildType(&arguments[1]->dataType); @@ -29,11 +19,14 @@ std::unique_ptr MapCreationFunctions::bindFunc( return std::make_unique(std::move(resultType)); } -function_set MapExtractFunctions::getFunctionSet() { +function_set MapCreationFunctions::getFunctionSet() { + auto execFunc = ScalarFunction::BinaryExecListStructFunction; function_set functionSet; - functionSet.push_back(make_unique(MAP_EXTRACT_FUNC_NAME, - std::vector{LogicalTypeID::MAP, LogicalTypeID::ANY}, LogicalTypeID::VAR_LIST, - nullptr, nullptr, bindFunc, false /* isVarLength */)); + functionSet.push_back(make_unique(name, + std::vector{LogicalTypeID::VAR_LIST, LogicalTypeID::VAR_LIST}, + LogicalTypeID::MAP, execFunc, nullptr, MapCreationFunctionsBindFunc, + false /* isVarLength */)); return functionSet; } @@ -45,7 +38,7 @@ static void validateKeyType(const std::shared_ptr& mapExpres } } -std::unique_ptr MapExtractFunctions::bindFunc( +static std::unique_ptr MapExtractFunctionsBindFunc( const binder::expression_vector& arguments, kuzu::function::Function* function) { validateKeyType(arguments[0], arguments[1]); auto scalarFunction = ku_dynamic_cast(function); @@ -126,37 +119,45 @@ std::unique_ptr MapExtractFunctions::bindFunc( std::make_unique(*MapType::getValueType(&arguments[0]->dataType)))); } -function_set MapKeysFunctions::getFunctionSet() { - auto execFunc = - ScalarFunction::UnaryExecNestedTypeFunction; +function_set MapExtractFunctions::getFunctionSet() { function_set functionSet; - functionSet.push_back(make_unique(MAP_KEYS_FUNC_NAME, - std::vector{LogicalTypeID::MAP}, LogicalTypeID::VAR_LIST, execFunc, nullptr, - bindFunc, false /* isVarLength */)); + functionSet.push_back(make_unique(name, + std::vector{LogicalTypeID::MAP, LogicalTypeID::ANY}, LogicalTypeID::VAR_LIST, + nullptr, nullptr, MapExtractFunctionsBindFunc, false /* isVarLength */)); return functionSet; } -std::unique_ptr MapKeysFunctions::bindFunc( +static std::unique_ptr MapKeysFunctionsBindFunc( const binder::expression_vector& arguments, kuzu::function::Function* /*function*/) { return std::make_unique(LogicalType::VAR_LIST( std::make_unique(*MapType::getKeyType(&arguments[0]->dataType)))); } -function_set MapValuesFunctions::getFunctionSet() { +function_set MapKeysFunctions::getFunctionSet() { auto execFunc = - ScalarFunction::UnaryExecNestedTypeFunction; + ScalarFunction::UnaryExecNestedTypeFunction; function_set functionSet; - functionSet.push_back(make_unique(MAP_VALUES_FUNC_NAME, + functionSet.push_back(make_unique(name, std::vector{LogicalTypeID::MAP}, LogicalTypeID::VAR_LIST, execFunc, nullptr, - bindFunc, false /* isVarLength */)); + MapKeysFunctionsBindFunc, false /* isVarLength */)); return functionSet; } -std::unique_ptr MapValuesFunctions::bindFunc( +static std::unique_ptr MapValuesFunctionsBindFunc( const binder::expression_vector& arguments, kuzu::function::Function* /*function*/) { return std::make_unique(LogicalType::VAR_LIST( std::make_unique(*MapType::getValueType(&arguments[0]->dataType)))); } +function_set MapValuesFunctions::getFunctionSet() { + auto execFunc = + ScalarFunction::UnaryExecNestedTypeFunction; + function_set functionSet; + functionSet.push_back(make_unique(name, + std::vector{LogicalTypeID::MAP}, LogicalTypeID::VAR_LIST, execFunc, nullptr, + MapValuesFunctionsBindFunc, false /* isVarLength */)); + return functionSet; +} + } // namespace function } // namespace kuzu diff --git a/src/function/vector_struct_functions.cpp b/src/function/vector_struct_functions.cpp index 81052f3065..6fc3de61b7 100644 --- a/src/function/vector_struct_functions.cpp +++ b/src/function/vector_struct_functions.cpp @@ -4,22 +4,14 @@ #include "binder/expression_binder.h" #include "common/exception/binder.h" #include "common/string_format.h" -#include "function/function.h" +#include "function/scalar_function.h" using namespace kuzu::common; namespace kuzu { namespace function { -function_set StructPackFunctions::getFunctionSet() { - function_set functions; - functions.push_back(make_unique(STRUCT_PACK_FUNC_NAME, - std::vector{LogicalTypeID::ANY}, LogicalTypeID::STRUCT, execFunc, nullptr, - compileFunc, bindFunc, true /* isVarLength */)); - return functions; -} - -std::unique_ptr StructPackFunctions::bindFunc( +static std::unique_ptr StructPackFunctionsBindFunc( const binder::expression_vector& arguments, Function* /*function*/) { std::vector fields; for (auto& argument : arguments) { @@ -32,20 +24,6 @@ std::unique_ptr StructPackFunctions::bindFunc( return std::make_unique(std::move(resultType)); } -void StructPackFunctions::execFunc(const std::vector>& parameters, - ValueVector& result, void* /*dataPtr*/) { - for (auto i = 0u; i < parameters.size(); i++) { - auto& parameter = parameters[i]; - if (parameter->state == result.state) { - continue; - } - // If the parameter's state is inconsistent with the result's state, we need to copy the - // parameter's value to the corresponding child vector. - copyParameterValueToStructFieldVector( - parameter.get(), StructVector::getFieldVector(&result, i).get(), result.state.get()); - } -} - void StructPackFunctions::compileFunc(FunctionBindData* /*bindData*/, const std::vector>& parameters, std::shared_ptr& result) { @@ -60,7 +38,7 @@ void StructPackFunctions::compileFunc(FunctionBindData* /*bindData*/, } } -void StructPackFunctions::copyParameterValueToStructFieldVector( +static void copyParameterValueToStructFieldVector( const ValueVector* parameter, ValueVector* structField, DataChunkState* structVectorState) { // If the parameter is unFlat, then its state must be consistent with the result's state. // Thus, we don't need to copy values to structFieldVector. @@ -77,20 +55,26 @@ void StructPackFunctions::copyParameterValueToStructFieldVector( } } -function_set StructExtractFunctions::getFunctionSet() { - function_set functions; - auto inputTypeIDs = - std::vector{LogicalTypeID::STRUCT, LogicalTypeID::NODE, LogicalTypeID::REL}; - for (auto inputTypeID : inputTypeIDs) { - functions.push_back(getFunction(inputTypeID)); +void StructPackFunctions::execFunc(const std::vector>& parameters, + ValueVector& result, void* /*dataPtr*/) { + for (auto i = 0u; i < parameters.size(); i++) { + auto& parameter = parameters[i]; + if (parameter->state == result.state) { + continue; + } + // If the parameter's state is inconsistent with the result's state, we need to copy the + // parameter's value to the corresponding child vector. + copyParameterValueToStructFieldVector( + parameter.get(), StructVector::getFieldVector(&result, i).get(), result.state.get()); } - return functions; } -std::unique_ptr StructExtractFunctions::getFunction(LogicalTypeID logicalTypeID) { - return std::make_unique(STRUCT_EXTRACT_FUNC_NAME, - std::vector{logicalTypeID, LogicalTypeID::STRING}, LogicalTypeID::ANY, - nullptr, nullptr, compileFunc, bindFunc, false /* isVarLength */); +function_set StructPackFunctions::getFunctionSet() { + function_set functions; + functions.push_back(make_unique(name, + std::vector{LogicalTypeID::ANY}, LogicalTypeID::STRUCT, execFunc, nullptr, + compileFunc, StructPackFunctionsBindFunc, true /* isVarLength */)); + return functions; } std::unique_ptr StructExtractFunctions::bindFunc( @@ -117,5 +101,22 @@ void StructExtractFunctions::compileFunc(FunctionBindData* bindData, result->state = parameters[0]->state; } +static std::unique_ptr getStructExtractFunction(LogicalTypeID logicalTypeID) { + return std::make_unique(StructExtractFunctions::name, + std::vector{logicalTypeID, LogicalTypeID::STRING}, LogicalTypeID::ANY, + nullptr, nullptr, StructExtractFunctions::compileFunc, StructExtractFunctions::bindFunc, + false /* isVarLength */); +} + +function_set StructExtractFunctions::getFunctionSet() { + function_set functions; + auto inputTypeIDs = + std::vector{LogicalTypeID::STRUCT, LogicalTypeID::NODE, LogicalTypeID::REL}; + for (auto inputTypeID : inputTypeIDs) { + functions.push_back(getStructExtractFunction(inputTypeID)); + } + return functions; +} + } // namespace function } // namespace kuzu diff --git a/src/function/vector_timestamp_functions.cpp b/src/function/vector_timestamp_functions.cpp index 601b44053a..5a135aa28c 100644 --- a/src/function/vector_timestamp_functions.cpp +++ b/src/function/vector_timestamp_functions.cpp @@ -10,7 +10,7 @@ namespace function { function_set CenturyFunction::getFunctionSet() { function_set result; - result.push_back(make_unique(CENTURY_FUNC_NAME, + result.push_back(make_unique(name, std::vector{LogicalTypeID::TIMESTAMP}, LogicalTypeID::INT64, ScalarFunction::UnaryExecFunction)); return result; @@ -18,7 +18,7 @@ function_set CenturyFunction::getFunctionSet() { function_set EpochMsFunction::getFunctionSet() { function_set result; - result.push_back(make_unique(EPOCH_MS_FUNC_NAME, + result.push_back(make_unique(name, std::vector{LogicalTypeID::INT64}, LogicalTypeID::TIMESTAMP, ScalarFunction::UnaryExecFunction)); return result; @@ -26,7 +26,7 @@ function_set EpochMsFunction::getFunctionSet() { function_set ToTimestampFunction::getFunctionSet() { function_set result; - result.push_back(make_unique(TO_TIMESTAMP_FUNC_NAME, + result.push_back(make_unique(name, std::vector{LogicalTypeID::DOUBLE}, LogicalTypeID::TIMESTAMP, ScalarFunction::UnaryExecFunction)); return result; diff --git a/src/function/vector_union_functions.cpp b/src/function/vector_union_functions.cpp index 33b6eefabb..202bc9aa19 100644 --- a/src/function/vector_union_functions.cpp +++ b/src/function/vector_union_functions.cpp @@ -9,15 +9,7 @@ using namespace kuzu::common; namespace kuzu { namespace function { -function_set UnionValueFunction::getFunctionSet() { - function_set functionSet; - functionSet.push_back(make_unique(UNION_VALUE_FUNC_NAME, - std::vector{LogicalTypeID::ANY}, LogicalTypeID::UNION, execFunc, nullptr, - compileFunc, bindFunc, false /* isVarLength */)); - return functionSet; -} - -std::unique_ptr UnionValueFunction::bindFunc( +static std::unique_ptr UnionValueFunctionBindFunc( const binder::expression_vector& arguments, kuzu::function::Function* /*function*/) { KU_ASSERT(arguments.size() == 1); std::vector fields; @@ -32,12 +24,13 @@ std::unique_ptr UnionValueFunction::bindFunc( return std::make_unique(std::move(resultType)); } -void UnionValueFunction::execFunc(const std::vector>& /*parameters*/, - ValueVector& result, void* /*dataPtr*/) { +static void UnionValueFunctionExecFunc( + const std::vector>& /*parameters*/, ValueVector& result, + void* /*dataPtr*/) { UnionVector::setTagField(&result, UnionType::TAG_FIELD_IDX); } -void UnionValueFunction::compileFunc(FunctionBindData* /*bindData*/, +static void UnionValueFunctionCompileFunc(FunctionBindData* /*bindData*/, const std::vector>& parameters, std::shared_ptr& result) { KU_ASSERT(parameters.size() == 1); @@ -46,9 +39,18 @@ void UnionValueFunction::compileFunc(FunctionBindData* /*bindData*/, UnionVector::referenceVector(result.get(), UnionType::TAG_FIELD_IDX, parameters[0]); } +function_set UnionValueFunction::getFunctionSet() { + function_set functionSet; + functionSet.push_back( + make_unique(name, std::vector{LogicalTypeID::ANY}, + LogicalTypeID::UNION, UnionValueFunctionExecFunc, nullptr, + UnionValueFunctionCompileFunc, UnionValueFunctionBindFunc, false /* isVarLength */)); + return functionSet; +} + function_set UnionTagFunction::getFunctionSet() { function_set functionSet; - functionSet.push_back(make_unique(UNION_TAG_FUNC_NAME, + functionSet.push_back(make_unique(name, std::vector{LogicalTypeID::UNION}, LogicalTypeID::STRING, ScalarFunction::UnaryExecNestedTypeFunction, nullptr, nullptr, false /* isVarLength */)); @@ -57,7 +59,7 @@ function_set UnionTagFunction::getFunctionSet() { function_set UnionExtractFunction::getFunctionSet() { function_set functionSet; - functionSet.push_back(make_unique(UNION_EXTRACT_FUNC_NAME, + functionSet.push_back(make_unique(name, std::vector{LogicalTypeID::UNION, LogicalTypeID::STRING}, LogicalTypeID::ANY, nullptr, nullptr, StructExtractFunctions::compileFunc, StructExtractFunctions::bindFunc, false /* isVarLength */)); diff --git a/src/function/vector_uuid_functions.cpp b/src/function/vector_uuid_functions.cpp index f441d8310f..e4bb0a22cb 100644 --- a/src/function/vector_uuid_functions.cpp +++ b/src/function/vector_uuid_functions.cpp @@ -10,9 +10,8 @@ namespace function { function_set GenRandomUUIDFunction::getFunctionSet() { function_set definitions; - definitions.push_back( - make_unique(GEN_RANDOM_UUID_FUNC_NAME, std::vector{}, - LogicalTypeID::UUID, ScalarFunction::PoniterExecFunction)); + definitions.push_back(make_unique(name, std::vector{}, + LogicalTypeID::UUID, ScalarFunction::PoniterExecFunction)); return definitions; } diff --git a/src/include/common/enums/expression_type.h b/src/include/common/enums/expression_type.h index c81d0181a5..2ea8b0fe9b 100644 --- a/src/include/common/enums/expression_type.h +++ b/src/include/common/enums/expression_type.h @@ -19,61 +19,9 @@ const char* const MIN_FUNC_NAME = "MIN"; const char* const MAX_FUNC_NAME = "MAX"; const char* const COLLECT_FUNC_NAME = "COLLECT"; -// struct -const char* const STRUCT_PACK_FUNC_NAME = "STRUCT_PACK"; -const char* const STRUCT_EXTRACT_FUNC_NAME = "STRUCT_EXTRACT"; - -// map -const char* const MAP_CREATION_FUNC_NAME = "MAP"; -const char* const MAP_EXTRACT_FUNC_NAME = "MAP_EXTRACT"; -const char* const ELEMENT_AT_FUNC_NAME = "ELEMENT_AT"; // alias of MAP_EXTRACT -const char* const CARDINALITY_FUNC_NAME = "CARDINALITY"; -const char* const MAP_KEYS_FUNC_NAME = "MAP_KEYS"; -const char* const MAP_VALUES_FUNC_NAME = "MAP_VALUES"; - -// union -const char* const UNION_VALUE_FUNC_NAME = "UNION_VALUE"; -const char* const UNION_TAG_FUNC_NAME = "UNION_TAG"; -const char* const UNION_EXTRACT_FUNC_NAME = "UNION_EXTRACT"; - -// comparison -const char* const EQUALS_FUNC_NAME = "EQUALS"; -const char* const NOT_EQUALS_FUNC_NAME = "NOT_EQUALS"; -const char* const GREATER_THAN_FUNC_NAME = "GREATER_THAN"; -const char* const GREATER_THAN_EQUALS_FUNC_NAME = "GREATER_THAN_EQUALS"; -const char* const LESS_THAN_FUNC_NAME = "LESS_THAN"; -const char* const LESS_THAN_EQUALS_FUNC_NAME = "LESS_THAN_EQUALS"; - // string const char* const LENGTH_FUNC_NAME = "LENGTH"; -// Date functions. -const char* const DATE_PART_FUNC_NAME = "DATE_PART"; -const char* const DATEPART_FUNC_NAME = "DATEPART"; -const char* const DATE_TRUNC_FUNC_NAME = "DATE_TRUNC"; -const char* const DATETRUNC_FUNC_NAME = "DATETRUNC"; -const char* const DAYNAME_FUNC_NAME = "DAYNAME"; -const char* const GREATEST_FUNC_NAME = "GREATEST"; -const char* const LAST_DAY_FUNC_NAME = "LAST_DAY"; -const char* const LEAST_FUNC_NAME = "LEAST"; -const char* const MAKE_DATE_FUNC_NAME = "MAKE_DATE"; -const char* const MONTHNAME_FUNC_NAME = "MONTHNAME"; - -// Timestamp functions. -const char* const CENTURY_FUNC_NAME = "CENTURY"; -const char* const EPOCH_MS_FUNC_NAME = "EPOCH_MS"; -const char* const TO_TIMESTAMP_FUNC_NAME = "TO_TIMESTAMP"; - -// Interval functions. -const char* const TO_YEARS_FUNC_NAME = "TO_YEARS"; -const char* const TO_MONTHS_FUNC_NAME = "TO_MONTHS"; -const char* const TO_DAYS_FUNC_NAME = "TO_DAYS"; -const char* const TO_HOURS_FUNC_NAME = "TO_HOURS"; -const char* const TO_MINUTES_FUNC_NAME = "TO_MINUTES"; -const char* const TO_SECONDS_FUNC_NAME = "TO_SECONDS"; -const char* const TO_MILLISECONDS_FUNC_NAME = "TO_MILLISECONDS"; -const char* const TO_MICROSECONDS_FUNC_NAME = "TO_MICROSECONDS"; - // Node/Rel functions. const char* const ID_FUNC_NAME = "ID"; const char* const LABEL_FUNC_NAME = "LABEL"; @@ -88,14 +36,6 @@ const char* const PROPERTIES_FUNC_NAME = "PROPERTIES"; const char* const IS_TRAIL_FUNC_NAME = "IS_TRAIL"; const char* const IS_ACYCLIC_FUNC_NAME = "IS_ACYCLIC"; -// Blob functions -const char* const OCTET_LENGTH_FUNC_NAME = "OCTET_LENGTH"; -const char* const ENCODE_FUNC_NAME = "ENCODE"; -const char* const DECODE_FUNC_NAME = "DECODE"; - -// UUID functions -const char* const GEN_RANDOM_UUID_FUNC_NAME = "GEN_RANDOM_UUID"; - // RDF functions const char* const TYPE_FUNC_NAME = "TYPE"; const char* const VALIDATE_PREDICATE_FUNC_NAME = "VALIDATE_PREDICATE"; diff --git a/src/include/function/blob/vector_blob_functions.h b/src/include/function/blob/vector_blob_functions.h index af72f6039e..6c252b476e 100644 --- a/src/include/function/blob/vector_blob_functions.h +++ b/src/include/function/blob/vector_blob_functions.h @@ -6,14 +6,20 @@ namespace kuzu { namespace function { struct OctetLengthFunctions { + static constexpr const char* name = "OCTET_LENGTH"; + static function_set getFunctionSet(); }; struct EncodeFunctions { + static constexpr const char* name = "ENCODE"; + static function_set getFunctionSet(); }; struct DecodeFunctions { + static constexpr const char* name = "DECODE"; + static function_set getFunctionSet(); }; diff --git a/src/include/function/built_in_function_utils.h b/src/include/function/built_in_function_utils.h index 6611245ac7..4959de4368 100644 --- a/src/include/function/built_in_function_utils.h +++ b/src/include/function/built_in_function_utils.h @@ -82,15 +82,6 @@ class BuiltInFunctionsUtils { // Scalar functions. static void registerScalarFunctions(catalog::CatalogSet* catalogSet); - static void registerComparisonFunctions(catalog::CatalogSet* catalogSet); - static void registerDateFunctions(catalog::CatalogSet* catalogSet); - static void registerTimestampFunctions(catalog::CatalogSet* catalogSet); - static void registerIntervalFunctions(catalog::CatalogSet* catalogSet); - static void registerBlobFunctions(catalog::CatalogSet* catalogSet); - static void registerUUIDFunctions(catalog::CatalogSet* catalogSet); - static void registerStructFunctions(catalog::CatalogSet* catalogSet); - static void registerMapFunctions(catalog::CatalogSet* catalogSet); - static void registerUnionFunctions(catalog::CatalogSet* catalogSet); static void registerNodeRelFunctions(catalog::CatalogSet* catalogSet); static void registerPathFunctions(catalog::CatalogSet* catalogSet); static void registerRdfFunctions(catalog::CatalogSet* catalogSet); diff --git a/src/include/function/comparison/vector_comparison_functions.h b/src/include/function/comparison/vector_comparison_functions.h index 5b39b93918..3d05682136 100644 --- a/src/include/function/comparison/vector_comparison_functions.h +++ b/src/include/function/comparison/vector_comparison_functions.h @@ -46,7 +46,7 @@ struct ComparisonFunction { } template - static inline std::unique_ptr getFunction( + static std::unique_ptr getFunction( const std::string& name, common::LogicalTypeID leftType, common::LogicalTypeID rightType) { auto leftPhysical = common::LogicalType::getPhysicalType(leftType); auto rightPhysical = common::LogicalType::getPhysicalType(rightType); @@ -196,40 +196,50 @@ struct ComparisonFunction { }; struct EqualsFunction { - static inline function_set getFunctionSet() { - return ComparisonFunction::getFunctionSet(common::EQUALS_FUNC_NAME); + static constexpr const char* name = "EQUALS"; + + static function_set getFunctionSet() { + return ComparisonFunction::getFunctionSet(name); } }; struct NotEqualsFunction { - static inline function_set getFunctionSet() { - return ComparisonFunction::getFunctionSet(common::NOT_EQUALS_FUNC_NAME); + static constexpr const char* name = "NOT_EQUALS"; + + static function_set getFunctionSet() { + return ComparisonFunction::getFunctionSet(name); } }; struct GreaterThanFunction { - static inline function_set getFunctionSet() { - return ComparisonFunction::getFunctionSet(common::GREATER_THAN_FUNC_NAME); + static constexpr const char* name = "GREATER_THAN"; + + static function_set getFunctionSet() { + return ComparisonFunction::getFunctionSet(name); } }; struct GreaterThanEqualsFunction { - static inline function_set getFunctionSet() { - return ComparisonFunction::getFunctionSet( - common::GREATER_THAN_EQUALS_FUNC_NAME); + static constexpr const char* name = "GREATER_THAN_EQUALS"; + + static function_set getFunctionSet() { + return ComparisonFunction::getFunctionSet(name); } }; struct LessThanFunction { - static inline function_set getFunctionSet() { - return ComparisonFunction::getFunctionSet(common::LESS_THAN_FUNC_NAME); + static constexpr const char* name = "LESS_THAN"; + + static function_set getFunctionSet() { + return ComparisonFunction::getFunctionSet(name); } }; struct LessThanEqualsFunction { - static inline function_set getFunctionSet() { - return ComparisonFunction::getFunctionSet( - common::LESS_THAN_EQUALS_FUNC_NAME); + static constexpr const char* name = "LESS_THAN_EQUALS"; + + static function_set getFunctionSet() { + return ComparisonFunction::getFunctionSet(name); } }; diff --git a/src/include/function/date/vector_date_functions.h b/src/include/function/date/vector_date_functions.h index 5058bd31c6..2c95881e6c 100644 --- a/src/include/function/date/vector_date_functions.h +++ b/src/include/function/date/vector_date_functions.h @@ -6,34 +6,54 @@ namespace kuzu { namespace function { struct DatePartFunction { + static constexpr const char* name = "DATE_PART"; + + static constexpr const char* alias = "DATEPART"; + static function_set getFunctionSet(); }; struct DateTruncFunction { + static constexpr const char* name = "DATE_TRUNC"; + + static constexpr const char* alias = "DATETRUNC"; + static function_set getFunctionSet(); }; struct DayNameFunction { + static constexpr const char* name = "DAYNAME"; + static function_set getFunctionSet(); }; struct GreatestFunction { + static constexpr const char* name = "GREATEST"; + static function_set getFunctionSet(); }; struct LastDayFunction { + static constexpr const char* name = "LAST_DAY"; + static function_set getFunctionSet(); }; struct LeastFunction { + static constexpr const char* name = "LEAST"; + static function_set getFunctionSet(); }; struct MakeDateFunction { + static constexpr const char* name = "MAKE_DATE"; + static function_set getFunctionSet(); }; struct MonthNameFunction { + static constexpr const char* name = "MONTHNAME"; + static function_set getFunctionSet(); }; diff --git a/src/include/function/interval/vector_interval_functions.h b/src/include/function/interval/vector_interval_functions.h index f107f87f5b..fdeef159f7 100644 --- a/src/include/function/interval/vector_interval_functions.h +++ b/src/include/function/interval/vector_interval_functions.h @@ -9,7 +9,7 @@ namespace function { struct IntervalFunction { public: template - static inline function_set getUnaryIntervalFunction(std::string funcName) { + static function_set getUnaryIntervalFunction(std::string funcName) { function_set result; result.push_back(std::make_unique(funcName, std::vector{common::LogicalTypeID::INT64}, @@ -20,52 +20,66 @@ struct IntervalFunction { }; struct ToYearsFunction { - static inline function_set getFunctionSet() { - return IntervalFunction::getUnaryIntervalFunction(common::TO_YEARS_FUNC_NAME); + static constexpr const char* name = "TO_YEARS"; + + static function_set getFunctionSet() { + return IntervalFunction::getUnaryIntervalFunction(name); } }; struct ToMonthsFunction { - static inline function_set getFunctionSet() { - return IntervalFunction::getUnaryIntervalFunction(common::TO_MONTHS_FUNC_NAME); + static constexpr const char* name = "TO_MONTHS"; + + static function_set getFunctionSet() { + return IntervalFunction::getUnaryIntervalFunction(name); } }; struct ToDaysFunction { - static inline function_set getFunctionSet() { - return IntervalFunction::getUnaryIntervalFunction(common::TO_DAYS_FUNC_NAME); + static constexpr const char* name = "TO_DAYS"; + + static function_set getFunctionSet() { + return IntervalFunction::getUnaryIntervalFunction(name); } }; struct ToHoursFunction { - static inline function_set getFunctionSet() { - return IntervalFunction::getUnaryIntervalFunction(common::TO_HOURS_FUNC_NAME); + static constexpr const char* name = "TO_HOURS"; + + static function_set getFunctionSet() { + return IntervalFunction::getUnaryIntervalFunction(name); } }; struct ToMinutesFunction { - static inline function_set getFunctionSet() { - return IntervalFunction::getUnaryIntervalFunction(common::TO_MINUTES_FUNC_NAME); + static constexpr const char* name = "TO_MINUTES"; + + static function_set getFunctionSet() { + return IntervalFunction::getUnaryIntervalFunction(name); } }; struct ToSecondsFunction { - static inline function_set getFunctionSet() { - return IntervalFunction::getUnaryIntervalFunction(common::TO_SECONDS_FUNC_NAME); + static constexpr const char* name = "TO_SECONDS"; + + static function_set getFunctionSet() { + return IntervalFunction::getUnaryIntervalFunction(name); } }; struct ToMillisecondsFunction { - static inline function_set getFunctionSet() { - return IntervalFunction::getUnaryIntervalFunction( - common::TO_MILLISECONDS_FUNC_NAME); + static constexpr const char* name = "TO_MILLISECONDS"; + + static function_set getFunctionSet() { + return IntervalFunction::getUnaryIntervalFunction(name); } }; struct ToMicrosecondsFunction { - static inline function_set getFunctionSet() { - return IntervalFunction::getUnaryIntervalFunction( - common::TO_MICROSECONDS_FUNC_NAME); + static constexpr const char* name = "TO_MICROSECONDS"; + + static function_set getFunctionSet() { + return IntervalFunction::getUnaryIntervalFunction(name); } }; diff --git a/src/include/function/map/vector_map_functions.h b/src/include/function/map/vector_map_functions.h index fe0484f910..da99678651 100644 --- a/src/include/function/map/vector_map_functions.h +++ b/src/include/function/map/vector_map_functions.h @@ -6,27 +6,29 @@ namespace kuzu { namespace function { struct MapCreationFunctions { + static constexpr const char* name = "MAP"; + static function_set getFunctionSet(); - static std::unique_ptr bindFunc( - const binder::expression_vector& arguments, Function* function); }; struct MapExtractFunctions { + static constexpr const char* name = "MAP_EXTRACT"; + + static constexpr const char* alias = "ELEMENT_AT"; + static function_set getFunctionSet(); - static std::unique_ptr bindFunc( - const binder::expression_vector& arguments, Function* function); }; struct MapKeysFunctions { + static constexpr const char* name = "MAP_KEYS"; + static function_set getFunctionSet(); - static std::unique_ptr bindFunc( - const binder::expression_vector& arguments, Function* function); }; struct MapValuesFunctions { + static constexpr const char* name = "MAP_VALUES"; + static function_set getFunctionSet(); - static std::unique_ptr bindFunc( - const binder::expression_vector& arguments, Function* function); }; } // namespace function diff --git a/src/include/function/struct/vector_struct_functions.h b/src/include/function/struct/vector_struct_functions.h index 4b41287d38..39f8a23a82 100644 --- a/src/include/function/struct/vector_struct_functions.h +++ b/src/include/function/struct/vector_struct_functions.h @@ -1,23 +1,21 @@ #pragma once #include "common/vector/value_vector.h" -#include "function/scalar_function.h" +#include "function/function.h" namespace kuzu { namespace function { struct StructPackFunctions { + static constexpr const char* name = "STRUCT_PACK"; + static function_set getFunctionSet(); - static std::unique_ptr bindFunc( - const binder::expression_vector& arguments, Function* function); static void execFunc(const std::vector>& parameters, common::ValueVector& result, void* /*dataPtr*/ = nullptr); static void compileFunc(FunctionBindData* bindData, const std::vector>& parameters, std::shared_ptr& result); - static void copyParameterValueToStructFieldVector(const common::ValueVector* parameter, - common::ValueVector* structField, common::DataChunkState* structVectorState); }; struct StructExtractBindData : public FunctionBindData { @@ -29,8 +27,9 @@ struct StructExtractBindData : public FunctionBindData { }; struct StructExtractFunctions { + static constexpr const char* name = "STRUCT_EXTRACT"; + static function_set getFunctionSet(); - static std::unique_ptr getFunction(common::LogicalTypeID logicalTypeID); static std::unique_ptr bindFunc( const binder::expression_vector& arguments, Function* function); diff --git a/src/include/function/timestamp/vector_timestamp_functions.h b/src/include/function/timestamp/vector_timestamp_functions.h index 5b2c0c0902..4073200308 100644 --- a/src/include/function/timestamp/vector_timestamp_functions.h +++ b/src/include/function/timestamp/vector_timestamp_functions.h @@ -6,14 +6,20 @@ namespace kuzu { namespace function { struct CenturyFunction { + static constexpr const char* name = "CENTURY"; + static function_set getFunctionSet(); }; struct EpochMsFunction { + static constexpr const char* name = "EPOCH_MS"; + static function_set getFunctionSet(); }; struct ToTimestampFunction { + static constexpr const char* name = "TO_TIMESTAMP"; + static function_set getFunctionSet(); }; diff --git a/src/include/function/union/vector_union_functions.h b/src/include/function/union/vector_union_functions.h index 9f33301953..ac1fb6c7c8 100644 --- a/src/include/function/union/vector_union_functions.h +++ b/src/include/function/union/vector_union_functions.h @@ -1,27 +1,25 @@ #pragma once -#include "common/vector/value_vector.h" #include "function/function.h" namespace kuzu { namespace function { struct UnionValueFunction { + static constexpr const char* name = "UNION_VALUE"; + static function_set getFunctionSet(); - static std::unique_ptr bindFunc( - const binder::expression_vector& arguments, Function* function); - static void execFunc(const std::vector>& parameters, - common::ValueVector& result, void* /*dataPtr*/ = nullptr); - static void compileFunc(FunctionBindData* bindData, - const std::vector>& parameters, - std::shared_ptr& result); }; struct UnionTagFunction { + static constexpr const char* name = "UNION_TAG"; + static function_set getFunctionSet(); }; struct UnionExtractFunction { + static constexpr const char* name = "UNION_EXTRACT"; + static function_set getFunctionSet(); }; diff --git a/src/include/function/uuid/vector_uuid_functions.h b/src/include/function/uuid/vector_uuid_functions.h index ce7ac3d6e3..5cd7587dfc 100644 --- a/src/include/function/uuid/vector_uuid_functions.h +++ b/src/include/function/uuid/vector_uuid_functions.h @@ -6,6 +6,8 @@ namespace kuzu { namespace function { struct GenRandomUUIDFunction { + static constexpr const char* name = "GEN_RANDOM_UUID"; + static function_set getFunctionSet(); }; diff --git a/src/parser/transform/transform_expression.cpp b/src/parser/transform/transform_expression.cpp index 7f8da14c1e..929814e462 100644 --- a/src/parser/transform/transform_expression.cpp +++ b/src/parser/transform/transform_expression.cpp @@ -2,6 +2,7 @@ #include "function/cast/functions/cast_from_string_functions.h" #include "function/list/vector_list_functions.h" #include "function/string/vector_string_functions.h" +#include "function/struct/vector_struct_functions.h" #include "parser/expression/parsed_case_expression.h" #include "parser/expression/parsed_function_expression.h" #include "parser/expression/parsed_literal_expression.h" @@ -439,7 +440,7 @@ std::unique_ptr Transformer::transformListLiteral( std::unique_ptr Transformer::transformStructLiteral( CypherParser::KU_StructLiteralContext& ctx) { auto structPack = - std::make_unique(STRUCT_PACK_FUNC_NAME, ctx.getText()); + std::make_unique(StructPackFunctions::name, ctx.getText()); for (auto& structField : ctx.kU_StructField()) { auto structExpr = transformExpression(*structField->oC_Expression()); std::string alias;