From e924547ef21d43e221937922836c67df3df443f0 Mon Sep 17 00:00:00 2001 From: ziyi chen Date: Tue, 11 Jul 2023 00:18:18 -0400 Subject: [PATCH] Add UDF support to c++ API --- src/catalog/catalog.cpp | 6 + src/catalog/catalog_structs.cpp | 4 + src/common/types/value.cpp | 1 + src/common/vector/auxiliary_buffer.cpp | 1 + src/expression_evaluator/path_evaluator.cpp | 1 + src/function/built_in_vector_functions.cpp | 324 +++++++++--------- src/function/vector_path_functions.cpp | 1 + src/include/catalog/catalog.h | 2 + src/include/catalog/catalog_structs.h | 5 +- src/include/common/vector/auxiliary_buffer.h | 5 +- .../function/binary_function_executor.h | 92 +++-- .../function/built_in_vector_functions.h | 6 +- .../function/ternary_function_executor.h | 136 ++++---- src/include/function/udf_function.h | 159 +++++++++ .../function/unary_function_executor.h | 69 ++-- src/include/main/connection.h | 11 + .../storage/buffer_manager/memory_manager.h | 4 +- src/main/connection.cpp | 5 + src/storage/buffer_manager/memory_manager.cpp | 1 + test/main/CMakeLists.txt | 3 +- test/main/udf_test.cpp | 106 ++++++ tools/rust_api/build.rs | 1 + 22 files changed, 655 insertions(+), 288 deletions(-) create mode 100644 src/include/function/udf_function.h create mode 100644 test/main/udf_test.cpp diff --git a/src/catalog/catalog.cpp b/src/catalog/catalog.cpp index 86d32ff103..7d093b87f2 100644 --- a/src/catalog/catalog.cpp +++ b/src/catalog/catalog.cpp @@ -463,5 +463,11 @@ std::unordered_set Catalog::getAllRelTableSchemasContainBoundTa return relTableSchemas; } +void Catalog::addVectorFunction( + std::string name, function::vector_function_definitions definitions) { + common::StringUtils::toUpper(name); + builtInVectorFunctions->addFunction(std::move(name), std::move(definitions)); +} + } // namespace catalog } // namespace kuzu diff --git a/src/catalog/catalog_structs.cpp b/src/catalog/catalog_structs.cpp index fdb7aa80c4..a5a8841fda 100644 --- a/src/catalog/catalog_structs.cpp +++ b/src/catalog/catalog_structs.cpp @@ -41,6 +41,10 @@ std::string getRelMultiplicityAsString(RelMultiplicity relMultiplicity) { } } +bool TableSchema::isReservedPropertyName(const std::string& propertyName) { + return common::StringUtils::getUpper(propertyName) == common::InternalKeyword::ID; +} + std::string TableSchema::getPropertyName(property_id_t propertyID) const { for (auto& property : properties) { if (property.propertyID == propertyID) { diff --git a/src/common/types/value.cpp b/src/common/types/value.cpp index f341b382c2..27b9431940 100644 --- a/src/common/types/value.cpp +++ b/src/common/types/value.cpp @@ -1,6 +1,7 @@ #include "common/types/value.h" #include "common/null_buffer.h" +#include "common/string_utils.h" #include "storage/storage_utils.h" namespace kuzu { diff --git a/src/common/vector/auxiliary_buffer.cpp b/src/common/vector/auxiliary_buffer.cpp index 3ef6171cf3..3ac9ce9f83 100644 --- a/src/common/vector/auxiliary_buffer.cpp +++ b/src/common/vector/auxiliary_buffer.cpp @@ -1,5 +1,6 @@ #include "common/vector/auxiliary_buffer.h" +#include "arrow/array.h" #include "common/vector/value_vector.h" namespace kuzu { diff --git a/src/expression_evaluator/path_evaluator.cpp b/src/expression_evaluator/path_evaluator.cpp index 37e44563b5..d3ebb37f4d 100644 --- a/src/expression_evaluator/path_evaluator.cpp +++ b/src/expression_evaluator/path_evaluator.cpp @@ -1,6 +1,7 @@ #include "expression_evaluator/path_evaluator.h" #include "binder/expression/path_expression.h" +#include "common/string_utils.h" using namespace kuzu::common; using namespace kuzu::binder; diff --git a/src/function/built_in_vector_functions.cpp b/src/function/built_in_vector_functions.cpp index 6760f33830..79bb36e66c 100644 --- a/src/function/built_in_vector_functions.cpp +++ b/src/function/built_in_vector_functions.cpp @@ -1,5 +1,6 @@ #include "function/built_in_vector_functions.h" +#include "common/string_utils.h" #include "function/arithmetic/vector_arithmetic_functions.h" #include "function/blob/vector_blob_functions.h" #include "function/cast/vector_cast_functions.h" @@ -50,7 +51,7 @@ bool BuiltInVectorFunctions::canApplyStaticEvaluation( VectorFunctionDefinition* BuiltInVectorFunctions::matchVectorFunction( const std::string& name, const std::vector& inputTypes) { - auto& functionDefinitions = VectorFunctions.at(name); + auto& functionDefinitions = vectorFunctions.at(name); bool isOverload = functionDefinitions.size() > 1; std::vector candidateFunctions; uint32_t minCost = UINT32_MAX; @@ -264,7 +265,7 @@ void BuiltInVectorFunctions::validateNonEmptyCandidateFunctions( const std::vector& inputTypes) { if (candidateFunctions.empty()) { std::string supportedInputsString; - for (auto& functionDefinition : VectorFunctions.at(name)) { + for (auto& functionDefinition : vectorFunctions.at(name)) { supportedInputsString += functionDefinition->signatureToString() + "\n"; } throw BinderException("Cannot match a built-in function for given function " + name + @@ -274,221 +275,230 @@ void BuiltInVectorFunctions::validateNonEmptyCandidateFunctions( } void BuiltInVectorFunctions::registerComparisonFunctions() { - VectorFunctions.insert({EQUALS_FUNC_NAME, EqualsVectorFunction::getDefinitions()}); - VectorFunctions.insert({NOT_EQUALS_FUNC_NAME, NotEqualsVectorFunction::getDefinitions()}); - VectorFunctions.insert({GREATER_THAN_FUNC_NAME, GreaterThanVectorFunction::getDefinitions()}); - VectorFunctions.insert( + vectorFunctions.insert({EQUALS_FUNC_NAME, EqualsVectorFunction::getDefinitions()}); + vectorFunctions.insert({NOT_EQUALS_FUNC_NAME, NotEqualsVectorFunction::getDefinitions()}); + vectorFunctions.insert({GREATER_THAN_FUNC_NAME, GreaterThanVectorFunction::getDefinitions()}); + vectorFunctions.insert( {GREATER_THAN_EQUALS_FUNC_NAME, GreaterThanEqualsVectorFunction::getDefinitions()}); - VectorFunctions.insert({LESS_THAN_FUNC_NAME, LessThanVectorFunction::getDefinitions()}); - VectorFunctions.insert( + vectorFunctions.insert({LESS_THAN_FUNC_NAME, LessThanVectorFunction::getDefinitions()}); + vectorFunctions.insert( {LESS_THAN_EQUALS_FUNC_NAME, LessThanEqualsVectorFunction::getDefinitions()}); } void BuiltInVectorFunctions::registerArithmeticFunctions() { - VectorFunctions.insert({ADD_FUNC_NAME, AddVectorFunction::getDefinitions()}); - VectorFunctions.insert({SUBTRACT_FUNC_NAME, SubtractVectorFunction::getDefinitions()}); - VectorFunctions.insert({MULTIPLY_FUNC_NAME, MultiplyVectorFunction::getDefinitions()}); - VectorFunctions.insert({DIVIDE_FUNC_NAME, DivideVectorFunction::getDefinitions()}); - VectorFunctions.insert({MODULO_FUNC_NAME, ModuloVectorFunction::getDefinitions()}); - VectorFunctions.insert({POWER_FUNC_NAME, PowerVectorFunction::getDefinitions()}); - - VectorFunctions.insert({ABS_FUNC_NAME, AbsVectorFunction::getDefinitions()}); - VectorFunctions.insert({ACOS_FUNC_NAME, AcosVectorFunction::getDefinitions()}); - VectorFunctions.insert({ASIN_FUNC_NAME, AsinVectorFunction::getDefinitions()}); - VectorFunctions.insert({ATAN_FUNC_NAME, AtanVectorFunction::getDefinitions()}); - VectorFunctions.insert({ATAN2_FUNC_NAME, Atan2VectorFunction::getDefinitions()}); - VectorFunctions.insert({BITWISE_XOR_FUNC_NAME, BitwiseXorVectorFunction::getDefinitions()}); - VectorFunctions.insert({BITWISE_AND_FUNC_NAME, BitwiseAndVectorFunction::getDefinitions()}); - VectorFunctions.insert({BITWISE_OR_FUNC_NAME, BitwiseOrVectorFunction::getDefinitions()}); - VectorFunctions.insert({BITSHIFT_LEFT_FUNC_NAME, BitShiftLeftVectorFunction::getDefinitions()}); - VectorFunctions.insert( + vectorFunctions.insert({ADD_FUNC_NAME, AddVectorFunction::getDefinitions()}); + vectorFunctions.insert({SUBTRACT_FUNC_NAME, SubtractVectorFunction::getDefinitions()}); + vectorFunctions.insert({MULTIPLY_FUNC_NAME, MultiplyVectorFunction::getDefinitions()}); + vectorFunctions.insert({DIVIDE_FUNC_NAME, DivideVectorFunction::getDefinitions()}); + vectorFunctions.insert({MODULO_FUNC_NAME, ModuloVectorFunction::getDefinitions()}); + vectorFunctions.insert({POWER_FUNC_NAME, PowerVectorFunction::getDefinitions()}); + + vectorFunctions.insert({ABS_FUNC_NAME, AbsVectorFunction::getDefinitions()}); + vectorFunctions.insert({ACOS_FUNC_NAME, AcosVectorFunction::getDefinitions()}); + vectorFunctions.insert({ASIN_FUNC_NAME, AsinVectorFunction::getDefinitions()}); + vectorFunctions.insert({ATAN_FUNC_NAME, AtanVectorFunction::getDefinitions()}); + vectorFunctions.insert({ATAN2_FUNC_NAME, Atan2VectorFunction::getDefinitions()}); + vectorFunctions.insert({BITWISE_XOR_FUNC_NAME, BitwiseXorVectorFunction::getDefinitions()}); + vectorFunctions.insert({BITWISE_AND_FUNC_NAME, BitwiseAndVectorFunction::getDefinitions()}); + vectorFunctions.insert({BITWISE_OR_FUNC_NAME, BitwiseOrVectorFunction::getDefinitions()}); + vectorFunctions.insert({BITSHIFT_LEFT_FUNC_NAME, BitShiftLeftVectorFunction::getDefinitions()}); + vectorFunctions.insert( {BITSHIFT_RIGHT_FUNC_NAME, BitShiftRightVectorFunction::getDefinitions()}); - VectorFunctions.insert({CBRT_FUNC_NAME, CbrtVectorFunction::getDefinitions()}); - VectorFunctions.insert({CEIL_FUNC_NAME, CeilVectorFunction::getDefinitions()}); - VectorFunctions.insert({CEILING_FUNC_NAME, CeilVectorFunction::getDefinitions()}); - VectorFunctions.insert({COS_FUNC_NAME, CosVectorFunction::getDefinitions()}); - VectorFunctions.insert({COT_FUNC_NAME, CotVectorFunction::getDefinitions()}); - VectorFunctions.insert({DEGREES_FUNC_NAME, DegreesVectorFunction::getDefinitions()}); - VectorFunctions.insert({EVEN_FUNC_NAME, EvenVectorFunction::getDefinitions()}); - VectorFunctions.insert({FACTORIAL_FUNC_NAME, FactorialVectorFunction::getDefinitions()}); - VectorFunctions.insert({FLOOR_FUNC_NAME, FloorVectorFunction::getDefinitions()}); - VectorFunctions.insert({GAMMA_FUNC_NAME, GammaVectorFunction::getDefinitions()}); - VectorFunctions.insert({LGAMMA_FUNC_NAME, LgammaVectorFunction::getDefinitions()}); - VectorFunctions.insert({LN_FUNC_NAME, LnVectorFunction::getDefinitions()}); - VectorFunctions.insert({LOG_FUNC_NAME, LogVectorFunction::getDefinitions()}); - VectorFunctions.insert({LOG2_FUNC_NAME, Log2VectorFunction::getDefinitions()}); - VectorFunctions.insert({LOG10_FUNC_NAME, LogVectorFunction::getDefinitions()}); - VectorFunctions.insert({NEGATE_FUNC_NAME, NegateVectorFunction::getDefinitions()}); - VectorFunctions.insert({PI_FUNC_NAME, PiVectorFunction::getDefinitions()}); - VectorFunctions.insert({POW_FUNC_NAME, PowerVectorFunction::getDefinitions()}); - VectorFunctions.insert({RADIANS_FUNC_NAME, RadiansVectorFunction::getDefinitions()}); - VectorFunctions.insert({ROUND_FUNC_NAME, RoundVectorFunction::getDefinitions()}); - VectorFunctions.insert({SIN_FUNC_NAME, SinVectorFunction::getDefinitions()}); - VectorFunctions.insert({SIGN_FUNC_NAME, SignVectorFunction::getDefinitions()}); - VectorFunctions.insert({SQRT_FUNC_NAME, SqrtVectorFunction::getDefinitions()}); - VectorFunctions.insert({TAN_FUNC_NAME, TanVectorFunction::getDefinitions()}); + vectorFunctions.insert({CBRT_FUNC_NAME, CbrtVectorFunction::getDefinitions()}); + vectorFunctions.insert({CEIL_FUNC_NAME, CeilVectorFunction::getDefinitions()}); + vectorFunctions.insert({CEILING_FUNC_NAME, CeilVectorFunction::getDefinitions()}); + vectorFunctions.insert({COS_FUNC_NAME, CosVectorFunction::getDefinitions()}); + vectorFunctions.insert({COT_FUNC_NAME, CotVectorFunction::getDefinitions()}); + vectorFunctions.insert({DEGREES_FUNC_NAME, DegreesVectorFunction::getDefinitions()}); + vectorFunctions.insert({EVEN_FUNC_NAME, EvenVectorFunction::getDefinitions()}); + vectorFunctions.insert({FACTORIAL_FUNC_NAME, FactorialVectorFunction::getDefinitions()}); + vectorFunctions.insert({FLOOR_FUNC_NAME, FloorVectorFunction::getDefinitions()}); + vectorFunctions.insert({GAMMA_FUNC_NAME, GammaVectorFunction::getDefinitions()}); + vectorFunctions.insert({LGAMMA_FUNC_NAME, LgammaVectorFunction::getDefinitions()}); + vectorFunctions.insert({LN_FUNC_NAME, LnVectorFunction::getDefinitions()}); + vectorFunctions.insert({LOG_FUNC_NAME, LogVectorFunction::getDefinitions()}); + vectorFunctions.insert({LOG2_FUNC_NAME, Log2VectorFunction::getDefinitions()}); + vectorFunctions.insert({LOG10_FUNC_NAME, LogVectorFunction::getDefinitions()}); + vectorFunctions.insert({NEGATE_FUNC_NAME, NegateVectorFunction::getDefinitions()}); + vectorFunctions.insert({PI_FUNC_NAME, PiVectorFunction::getDefinitions()}); + vectorFunctions.insert({POW_FUNC_NAME, PowerVectorFunction::getDefinitions()}); + vectorFunctions.insert({RADIANS_FUNC_NAME, RadiansVectorFunction::getDefinitions()}); + vectorFunctions.insert({ROUND_FUNC_NAME, RoundVectorFunction::getDefinitions()}); + vectorFunctions.insert({SIN_FUNC_NAME, SinVectorFunction::getDefinitions()}); + vectorFunctions.insert({SIGN_FUNC_NAME, SignVectorFunction::getDefinitions()}); + vectorFunctions.insert({SQRT_FUNC_NAME, SqrtVectorFunction::getDefinitions()}); + vectorFunctions.insert({TAN_FUNC_NAME, TanVectorFunction::getDefinitions()}); } void BuiltInVectorFunctions::registerDateFunctions() { - VectorFunctions.insert({DATE_PART_FUNC_NAME, DatePartVectorFunction::getDefinitions()}); - VectorFunctions.insert({DATEPART_FUNC_NAME, DatePartVectorFunction::getDefinitions()}); - VectorFunctions.insert({DATE_TRUNC_FUNC_NAME, DateTruncVectorFunction::getDefinitions()}); - VectorFunctions.insert({DATETRUNC_FUNC_NAME, DateTruncVectorFunction::getDefinitions()}); - VectorFunctions.insert({DAYNAME_FUNC_NAME, DayNameVectorFunction::getDefinitions()}); - VectorFunctions.insert({GREATEST_FUNC_NAME, GreatestVectorFunction::getDefinitions()}); - VectorFunctions.insert({LAST_DAY_FUNC_NAME, LastDayVectorFunction::getDefinitions()}); - VectorFunctions.insert({LEAST_FUNC_NAME, LeastVectorFunction::getDefinitions()}); - VectorFunctions.insert({MAKE_DATE_FUNC_NAME, MakeDateVectorFunction::getDefinitions()}); - VectorFunctions.insert({MONTHNAME_FUNC_NAME, MonthNameVectorFunction::getDefinitions()}); + vectorFunctions.insert({DATE_PART_FUNC_NAME, DatePartVectorFunction::getDefinitions()}); + vectorFunctions.insert({DATEPART_FUNC_NAME, DatePartVectorFunction::getDefinitions()}); + vectorFunctions.insert({DATE_TRUNC_FUNC_NAME, DateTruncVectorFunction::getDefinitions()}); + vectorFunctions.insert({DATETRUNC_FUNC_NAME, DateTruncVectorFunction::getDefinitions()}); + vectorFunctions.insert({DAYNAME_FUNC_NAME, DayNameVectorFunction::getDefinitions()}); + vectorFunctions.insert({GREATEST_FUNC_NAME, GreatestVectorFunction::getDefinitions()}); + vectorFunctions.insert({LAST_DAY_FUNC_NAME, LastDayVectorFunction::getDefinitions()}); + vectorFunctions.insert({LEAST_FUNC_NAME, LeastVectorFunction::getDefinitions()}); + vectorFunctions.insert({MAKE_DATE_FUNC_NAME, MakeDateVectorFunction::getDefinitions()}); + vectorFunctions.insert({MONTHNAME_FUNC_NAME, MonthNameVectorFunction::getDefinitions()}); } void BuiltInVectorFunctions::registerTimestampFunctions() { - VectorFunctions.insert({CENTURY_FUNC_NAME, CenturyVectorFunction::getDefinitions()}); - VectorFunctions.insert({EPOCH_MS_FUNC_NAME, EpochMsVectorFunction::getDefinitions()}); - VectorFunctions.insert({TO_TIMESTAMP_FUNC_NAME, ToTimestampVectorFunction::getDefinitions()}); + vectorFunctions.insert({CENTURY_FUNC_NAME, CenturyVectorFunction::getDefinitions()}); + vectorFunctions.insert({EPOCH_MS_FUNC_NAME, EpochMsVectorFunction::getDefinitions()}); + vectorFunctions.insert({TO_TIMESTAMP_FUNC_NAME, ToTimestampVectorFunction::getDefinitions()}); } void BuiltInVectorFunctions::registerIntervalFunctions() { - VectorFunctions.insert({TO_YEARS_FUNC_NAME, ToYearsVectorFunction::getDefinitions()}); - VectorFunctions.insert({TO_MONTHS_FUNC_NAME, ToMonthsVectorFunction::getDefinitions()}); - VectorFunctions.insert({TO_DAYS_FUNC_NAME, ToDaysVectorFunction::getDefinitions()}); - VectorFunctions.insert({TO_HOURS_FUNC_NAME, ToHoursVectorFunction::getDefinitions()}); - VectorFunctions.insert({TO_MINUTES_FUNC_NAME, ToMinutesVectorFunction::getDefinitions()}); - VectorFunctions.insert({TO_SECONDS_FUNC_NAME, ToSecondsVectorFunction::getDefinitions()}); - VectorFunctions.insert( + vectorFunctions.insert({TO_YEARS_FUNC_NAME, ToYearsVectorFunction::getDefinitions()}); + vectorFunctions.insert({TO_MONTHS_FUNC_NAME, ToMonthsVectorFunction::getDefinitions()}); + vectorFunctions.insert({TO_DAYS_FUNC_NAME, ToDaysVectorFunction::getDefinitions()}); + vectorFunctions.insert({TO_HOURS_FUNC_NAME, ToHoursVectorFunction::getDefinitions()}); + vectorFunctions.insert({TO_MINUTES_FUNC_NAME, ToMinutesVectorFunction::getDefinitions()}); + vectorFunctions.insert({TO_SECONDS_FUNC_NAME, ToSecondsVectorFunction::getDefinitions()}); + vectorFunctions.insert( {TO_MILLISECONDS_FUNC_NAME, ToMillisecondsVectorFunction::getDefinitions()}); - VectorFunctions.insert( + vectorFunctions.insert( {TO_MICROSECONDS_FUNC_NAME, ToMicrosecondsVectorFunction::getDefinitions()}); } void BuiltInVectorFunctions::registerBlobFunctions() { - VectorFunctions.insert({OCTET_LENGTH_FUNC_NAME, OctetLengthVectorFunctions::getDefinitions()}); - VectorFunctions.insert({ENCODE_FUNC_NAME, EncodeVectorFunctions::getDefinitions()}); - VectorFunctions.insert({DECODE_FUNC_NAME, DecodeVectorFunctions::getDefinitions()}); + vectorFunctions.insert({OCTET_LENGTH_FUNC_NAME, OctetLengthVectorFunctions::getDefinitions()}); + vectorFunctions.insert({ENCODE_FUNC_NAME, EncodeVectorFunctions::getDefinitions()}); + vectorFunctions.insert({DECODE_FUNC_NAME, DecodeVectorFunctions::getDefinitions()}); } void BuiltInVectorFunctions::registerStringFunctions() { - VectorFunctions.insert({ARRAY_EXTRACT_FUNC_NAME, ArrayExtractVectorFunction::getDefinitions()}); - VectorFunctions.insert({CONCAT_FUNC_NAME, ConcatVectorFunction::getDefinitions()}); - VectorFunctions.insert({CONTAINS_FUNC_NAME, ContainsVectorFunction::getDefinitions()}); - VectorFunctions.insert({ENDS_WITH_FUNC_NAME, EndsWithVectorFunction::getDefinitions()}); - VectorFunctions.insert({LCASE_FUNC_NAME, LowerVectorFunction::getDefinitions()}); - VectorFunctions.insert({LEFT_FUNC_NAME, LeftVectorFunction::getDefinitions()}); - VectorFunctions.insert({LENGTH_FUNC_NAME, LengthVectorFunction::getDefinitions()}); - VectorFunctions.insert({LOWER_FUNC_NAME, LowerVectorFunction::getDefinitions()}); - VectorFunctions.insert({LPAD_FUNC_NAME, LpadVectorFunction::getDefinitions()}); - VectorFunctions.insert({LTRIM_FUNC_NAME, LtrimVectorFunction::getDefinitions()}); - VectorFunctions.insert({PREFIX_FUNC_NAME, StartsWithVectorFunction::getDefinitions()}); - VectorFunctions.insert({REPEAT_FUNC_NAME, RepeatVectorFunction::getDefinitions()}); - VectorFunctions.insert({REVERSE_FUNC_NAME, ReverseVectorFunction::getDefinitions()}); - VectorFunctions.insert({RIGHT_FUNC_NAME, RightVectorFunction::getDefinitions()}); - VectorFunctions.insert({RPAD_FUNC_NAME, RpadVectorFunction::getDefinitions()}); - VectorFunctions.insert({RTRIM_FUNC_NAME, RtrimVectorFunction::getDefinitions()}); - VectorFunctions.insert({STARTS_WITH_FUNC_NAME, StartsWithVectorFunction::getDefinitions()}); - VectorFunctions.insert({SUBSTR_FUNC_NAME, SubStrVectorFunction::getDefinitions()}); - VectorFunctions.insert({SUBSTRING_FUNC_NAME, SubStrVectorFunction::getDefinitions()}); - VectorFunctions.insert({SUFFIX_FUNC_NAME, EndsWithVectorFunction::getDefinitions()}); - VectorFunctions.insert({TRIM_FUNC_NAME, TrimVectorFunction::getDefinitions()}); - VectorFunctions.insert({UCASE_FUNC_NAME, UpperVectorFunction::getDefinitions()}); - VectorFunctions.insert({UPPER_FUNC_NAME, UpperVectorFunction::getDefinitions()}); - VectorFunctions.insert( + vectorFunctions.insert({ARRAY_EXTRACT_FUNC_NAME, ArrayExtractVectorFunction::getDefinitions()}); + vectorFunctions.insert({CONCAT_FUNC_NAME, ConcatVectorFunction::getDefinitions()}); + vectorFunctions.insert({CONTAINS_FUNC_NAME, ContainsVectorFunction::getDefinitions()}); + vectorFunctions.insert({ENDS_WITH_FUNC_NAME, EndsWithVectorFunction::getDefinitions()}); + vectorFunctions.insert({LCASE_FUNC_NAME, LowerVectorFunction::getDefinitions()}); + vectorFunctions.insert({LEFT_FUNC_NAME, LeftVectorFunction::getDefinitions()}); + vectorFunctions.insert({LENGTH_FUNC_NAME, LengthVectorFunction::getDefinitions()}); + vectorFunctions.insert({LOWER_FUNC_NAME, LowerVectorFunction::getDefinitions()}); + vectorFunctions.insert({LPAD_FUNC_NAME, LpadVectorFunction::getDefinitions()}); + vectorFunctions.insert({LTRIM_FUNC_NAME, LtrimVectorFunction::getDefinitions()}); + vectorFunctions.insert({PREFIX_FUNC_NAME, StartsWithVectorFunction::getDefinitions()}); + vectorFunctions.insert({REPEAT_FUNC_NAME, RepeatVectorFunction::getDefinitions()}); + vectorFunctions.insert({REVERSE_FUNC_NAME, ReverseVectorFunction::getDefinitions()}); + vectorFunctions.insert({RIGHT_FUNC_NAME, RightVectorFunction::getDefinitions()}); + vectorFunctions.insert({RPAD_FUNC_NAME, RpadVectorFunction::getDefinitions()}); + vectorFunctions.insert({RTRIM_FUNC_NAME, RtrimVectorFunction::getDefinitions()}); + vectorFunctions.insert({STARTS_WITH_FUNC_NAME, StartsWithVectorFunction::getDefinitions()}); + vectorFunctions.insert({SUBSTR_FUNC_NAME, SubStrVectorFunction::getDefinitions()}); + vectorFunctions.insert({SUBSTRING_FUNC_NAME, SubStrVectorFunction::getDefinitions()}); + vectorFunctions.insert({SUFFIX_FUNC_NAME, EndsWithVectorFunction::getDefinitions()}); + vectorFunctions.insert({TRIM_FUNC_NAME, TrimVectorFunction::getDefinitions()}); + vectorFunctions.insert({UCASE_FUNC_NAME, UpperVectorFunction::getDefinitions()}); + vectorFunctions.insert({UPPER_FUNC_NAME, UpperVectorFunction::getDefinitions()}); + vectorFunctions.insert( {REGEXP_FULL_MATCH_FUNC_NAME, RegexpFullMatchVectorFunction::getDefinitions()}); - VectorFunctions.insert( + vectorFunctions.insert( {REGEXP_MATCHES_FUNC_NAME, RegexpMatchesVectorFunction::getDefinitions()}); - VectorFunctions.insert( + vectorFunctions.insert( {REGEXP_REPLACE_FUNC_NAME, RegexpReplaceVectorFunction::getDefinitions()}); - VectorFunctions.insert( + vectorFunctions.insert( {REGEXP_EXTRACT_FUNC_NAME, RegexpExtractVectorFunction::getDefinitions()}); - VectorFunctions.insert( + vectorFunctions.insert( {REGEXP_EXTRACT_ALL_FUNC_NAME, RegexpExtractAllVectorFunction::getDefinitions()}); } void BuiltInVectorFunctions::registerCastFunctions() { - VectorFunctions.insert({CAST_TO_DATE_FUNC_NAME, CastToDateVectorFunction::getDefinitions()}); - VectorFunctions.insert( + vectorFunctions.insert({CAST_TO_DATE_FUNC_NAME, CastToDateVectorFunction::getDefinitions()}); + vectorFunctions.insert( {CAST_TO_TIMESTAMP_FUNC_NAME, CastToTimestampVectorFunction::getDefinitions()}); - VectorFunctions.insert( + vectorFunctions.insert( {CAST_TO_INTERVAL_FUNC_NAME, CastToIntervalVectorFunction::getDefinitions()}); - VectorFunctions.insert( + vectorFunctions.insert( {CAST_TO_STRING_FUNC_NAME, CastToStringVectorFunction::getDefinitions()}); - VectorFunctions.insert({CAST_TO_BLOB_FUNC_NAME, CastToBlobVectorFunction::getDefinitions()}); - VectorFunctions.insert( + vectorFunctions.insert({CAST_TO_BLOB_FUNC_NAME, CastToBlobVectorFunction::getDefinitions()}); + vectorFunctions.insert( {CAST_TO_DOUBLE_FUNC_NAME, CastToDoubleVectorFunction::getDefinitions()}); - VectorFunctions.insert({CAST_TO_FLOAT_FUNC_NAME, CastToFloatVectorFunction::getDefinitions()}); - VectorFunctions.insert( + vectorFunctions.insert({CAST_TO_FLOAT_FUNC_NAME, CastToFloatVectorFunction::getDefinitions()}); + vectorFunctions.insert( {CAST_TO_SERIAL_FUNC_NAME, CastToSerialVectorFunction::getDefinitions()}); - VectorFunctions.insert({CAST_TO_INT64_FUNC_NAME, CastToInt64VectorFunction::getDefinitions()}); - VectorFunctions.insert({CAST_TO_INT32_FUNC_NAME, CastToInt32VectorFunction::getDefinitions()}); - VectorFunctions.insert({CAST_TO_INT16_FUNC_NAME, CastToInt16VectorFunction::getDefinitions()}); + vectorFunctions.insert({CAST_TO_INT64_FUNC_NAME, CastToInt64VectorFunction::getDefinitions()}); + vectorFunctions.insert({CAST_TO_INT32_FUNC_NAME, CastToInt32VectorFunction::getDefinitions()}); + vectorFunctions.insert({CAST_TO_INT16_FUNC_NAME, CastToInt16VectorFunction::getDefinitions()}); } void BuiltInVectorFunctions::registerListFunctions() { - VectorFunctions.insert({LIST_CREATION_FUNC_NAME, ListCreationVectorFunction::getDefinitions()}); - VectorFunctions.insert({LIST_LEN_FUNC_NAME, ListLenVectorFunction::getDefinitions()}); - VectorFunctions.insert({LIST_EXTRACT_FUNC_NAME, ListExtractVectorFunction::getDefinitions()}); - VectorFunctions.insert({LIST_ELEMENT_FUNC_NAME, ListExtractVectorFunction::getDefinitions()}); - VectorFunctions.insert({LIST_CONCAT_FUNC_NAME, ListConcatVectorFunction::getDefinitions()}); - VectorFunctions.insert({LIST_CAT_FUNC_NAME, ListConcatVectorFunction::getDefinitions()}); - VectorFunctions.insert({ARRAY_CONCAT_FUNC_NAME, ListConcatVectorFunction::getDefinitions()}); - VectorFunctions.insert({ARRAY_CAT_FUNC_NAME, ListConcatVectorFunction::getDefinitions()}); - VectorFunctions.insert({LIST_APPEND_FUNC_NAME, ListAppendVectorFunction::getDefinitions()}); - VectorFunctions.insert({ARRAY_APPEND_FUNC_NAME, ListAppendVectorFunction::getDefinitions()}); - VectorFunctions.insert({ARRAY_PUSH_BACK_FUNC_NAME, ListAppendVectorFunction::getDefinitions()}); - VectorFunctions.insert({LIST_PREPEND_FUNC_NAME, ListPrependVectorFunction::getDefinitions()}); - VectorFunctions.insert({ARRAY_PREPEND_FUNC_NAME, ListPrependVectorFunction::getDefinitions()}); - VectorFunctions.insert( + vectorFunctions.insert({LIST_CREATION_FUNC_NAME, ListCreationVectorFunction::getDefinitions()}); + vectorFunctions.insert({LIST_LEN_FUNC_NAME, ListLenVectorFunction::getDefinitions()}); + vectorFunctions.insert({LIST_EXTRACT_FUNC_NAME, ListExtractVectorFunction::getDefinitions()}); + vectorFunctions.insert({LIST_ELEMENT_FUNC_NAME, ListExtractVectorFunction::getDefinitions()}); + vectorFunctions.insert({LIST_CONCAT_FUNC_NAME, ListConcatVectorFunction::getDefinitions()}); + vectorFunctions.insert({LIST_CAT_FUNC_NAME, ListConcatVectorFunction::getDefinitions()}); + vectorFunctions.insert({ARRAY_CONCAT_FUNC_NAME, ListConcatVectorFunction::getDefinitions()}); + vectorFunctions.insert({ARRAY_CAT_FUNC_NAME, ListConcatVectorFunction::getDefinitions()}); + vectorFunctions.insert({LIST_APPEND_FUNC_NAME, ListAppendVectorFunction::getDefinitions()}); + vectorFunctions.insert({ARRAY_APPEND_FUNC_NAME, ListAppendVectorFunction::getDefinitions()}); + vectorFunctions.insert({ARRAY_PUSH_BACK_FUNC_NAME, ListAppendVectorFunction::getDefinitions()}); + vectorFunctions.insert({LIST_PREPEND_FUNC_NAME, ListPrependVectorFunction::getDefinitions()}); + vectorFunctions.insert({ARRAY_PREPEND_FUNC_NAME, ListPrependVectorFunction::getDefinitions()}); + vectorFunctions.insert( {ARRAY_PUSH_FRONT_FUNC_NAME, ListPrependVectorFunction::getDefinitions()}); - VectorFunctions.insert({LIST_POSITION_FUNC_NAME, ListPositionVectorFunction::getDefinitions()}); - VectorFunctions.insert( + vectorFunctions.insert({LIST_POSITION_FUNC_NAME, ListPositionVectorFunction::getDefinitions()}); + vectorFunctions.insert( {ARRAY_POSITION_FUNC_NAME, ListPositionVectorFunction::getDefinitions()}); - VectorFunctions.insert({LIST_INDEXOF_FUNC_NAME, ListPositionVectorFunction::getDefinitions()}); - VectorFunctions.insert({ARRAY_INDEXOF_FUNC_NAME, ListPositionVectorFunction::getDefinitions()}); - VectorFunctions.insert({LIST_CONTAINS_FUNC_NAME, ListContainsVectorFunction::getDefinitions()}); - VectorFunctions.insert({LIST_HAS_FUNC_NAME, ListContainsVectorFunction::getDefinitions()}); - VectorFunctions.insert( + vectorFunctions.insert({LIST_INDEXOF_FUNC_NAME, ListPositionVectorFunction::getDefinitions()}); + vectorFunctions.insert({ARRAY_INDEXOF_FUNC_NAME, ListPositionVectorFunction::getDefinitions()}); + vectorFunctions.insert({LIST_CONTAINS_FUNC_NAME, ListContainsVectorFunction::getDefinitions()}); + vectorFunctions.insert({LIST_HAS_FUNC_NAME, ListContainsVectorFunction::getDefinitions()}); + vectorFunctions.insert( {ARRAY_CONTAINS_FUNC_NAME, ListContainsVectorFunction::getDefinitions()}); - VectorFunctions.insert({ARRAY_HAS_FUNC_NAME, ListContainsVectorFunction::getDefinitions()}); - VectorFunctions.insert({LIST_SLICE_FUNC_NAME, ListSliceVectorFunction::getDefinitions()}); - VectorFunctions.insert({ARRAY_SLICE_FUNC_NAME, ListSliceVectorFunction::getDefinitions()}); - VectorFunctions.insert({LIST_SORT_FUNC_NAME, ListSortVectorFunction::getDefinitions()}); - VectorFunctions.insert( + vectorFunctions.insert({ARRAY_HAS_FUNC_NAME, ListContainsVectorFunction::getDefinitions()}); + vectorFunctions.insert({LIST_SLICE_FUNC_NAME, ListSliceVectorFunction::getDefinitions()}); + vectorFunctions.insert({ARRAY_SLICE_FUNC_NAME, ListSliceVectorFunction::getDefinitions()}); + vectorFunctions.insert({LIST_SORT_FUNC_NAME, ListSortVectorFunction::getDefinitions()}); + vectorFunctions.insert( {LIST_REVERSE_SORT_FUNC_NAME, ListReverseSortVectorFunction::getDefinitions()}); - VectorFunctions.insert({LIST_SUM_FUNC_NAME, ListSumVectorFunction::getDefinitions()}); - VectorFunctions.insert({LIST_DISTINCT_FUNC_NAME, ListDistinctVectorFunction::getDefinitions()}); - VectorFunctions.insert({LIST_UNIQUE_FUNC_NAME, ListUniqueVectorFunction::getDefinitions()}); - VectorFunctions.insert( + vectorFunctions.insert({LIST_SUM_FUNC_NAME, ListSumVectorFunction::getDefinitions()}); + vectorFunctions.insert({LIST_DISTINCT_FUNC_NAME, ListDistinctVectorFunction::getDefinitions()}); + vectorFunctions.insert({LIST_UNIQUE_FUNC_NAME, ListUniqueVectorFunction::getDefinitions()}); + vectorFunctions.insert( {LIST_ANY_VALUE_FUNC_NAME, ListAnyValueVectorFunction::getDefinitions()}); } void BuiltInVectorFunctions::registerStructFunctions() { - VectorFunctions.insert({STRUCT_PACK_FUNC_NAME, StructPackVectorFunctions::getDefinitions()}); - VectorFunctions.insert( + vectorFunctions.insert({STRUCT_PACK_FUNC_NAME, StructPackVectorFunctions::getDefinitions()}); + vectorFunctions.insert( {STRUCT_EXTRACT_FUNC_NAME, StructExtractVectorFunctions::getDefinitions()}); } void BuiltInVectorFunctions::registerMapFunctions() { - VectorFunctions.insert({MAP_CREATION_FUNC_NAME, MapCreationVectorFunctions::getDefinitions()}); - VectorFunctions.insert({MAP_EXTRACT_FUNC_NAME, MapExtractVectorFunctions::getDefinitions()}); - VectorFunctions.insert({ELEMENT_AT_FUNC_NAME, MapExtractVectorFunctions::getDefinitions()}); - VectorFunctions.insert({CARDINALITY_FUNC_NAME, ListLenVectorFunction::getDefinitions()}); - VectorFunctions.insert({MAP_KEYS_FUNC_NAME, MapKeysVectorFunctions::getDefinitions()}); - VectorFunctions.insert({MAP_VALUES_FUNC_NAME, MapValuesVectorFunctions::getDefinitions()}); + vectorFunctions.insert({MAP_CREATION_FUNC_NAME, MapCreationVectorFunctions::getDefinitions()}); + vectorFunctions.insert({MAP_EXTRACT_FUNC_NAME, MapExtractVectorFunctions::getDefinitions()}); + vectorFunctions.insert({ELEMENT_AT_FUNC_NAME, MapExtractVectorFunctions::getDefinitions()}); + vectorFunctions.insert({CARDINALITY_FUNC_NAME, ListLenVectorFunction::getDefinitions()}); + vectorFunctions.insert({MAP_KEYS_FUNC_NAME, MapKeysVectorFunctions::getDefinitions()}); + vectorFunctions.insert({MAP_VALUES_FUNC_NAME, MapValuesVectorFunctions::getDefinitions()}); } void BuiltInVectorFunctions::registerUnionFunctions() { - VectorFunctions.insert({UNION_VALUE_FUNC_NAME, UnionValueVectorFunction::getDefinitions()}); - VectorFunctions.insert({UNION_TAG_FUNC_NAME, UnionTagVectorFunction::getDefinitions()}); - VectorFunctions.insert({UNION_EXTRACT_FUNC_NAME, UnionExtractVectorFunction::getDefinitions()}); + vectorFunctions.insert({UNION_VALUE_FUNC_NAME, UnionValueVectorFunction::getDefinitions()}); + vectorFunctions.insert({UNION_TAG_FUNC_NAME, UnionTagVectorFunction::getDefinitions()}); + vectorFunctions.insert({UNION_EXTRACT_FUNC_NAME, UnionExtractVectorFunction::getDefinitions()}); } void BuiltInVectorFunctions::registerNodeRelFunctions() { - VectorFunctions.insert({OFFSET_FUNC_NAME, OffsetVectorFunction::getDefinitions()}); + vectorFunctions.insert({OFFSET_FUNC_NAME, OffsetVectorFunction::getDefinitions()}); } void BuiltInVectorFunctions::registerPathFunctions() { - VectorFunctions.insert({NODES_FUNC_NAME, NodesVectorFunction::getDefinitions()}); - VectorFunctions.insert({RELS_FUNC_NAME, RelsVectorFunction::getDefinitions()}); - VectorFunctions.insert({PROPERTIES_FUNC_NAME, PropertiesVectorFunction::getDefinitions()}); + vectorFunctions.insert({NODES_FUNC_NAME, NodesVectorFunction::getDefinitions()}); + vectorFunctions.insert({RELS_FUNC_NAME, RelsVectorFunction::getDefinitions()}); + vectorFunctions.insert({PROPERTIES_FUNC_NAME, PropertiesVectorFunction::getDefinitions()}); +} + +void BuiltInVectorFunctions::addFunction( + std::string name, function::vector_function_definitions definitions) { + if (vectorFunctions.contains(name)) { + throw common::CatalogException{ + common::StringUtils::string_format("function {} already exists.", name)}; + } + vectorFunctions.emplace(std::move(name), std::move(definitions)); } } // namespace function diff --git a/src/function/vector_path_functions.cpp b/src/function/vector_path_functions.cpp index 176cdbb214..7991fb7464 100644 --- a/src/function/vector_path_functions.cpp +++ b/src/function/vector_path_functions.cpp @@ -1,6 +1,7 @@ #include "function/path/vector_path_functions.h" #include "binder/expression/literal_expression.h" +#include "common/string_utils.h" #include "function/struct/vector_struct_functions.h" namespace kuzu { diff --git a/src/include/catalog/catalog.h b/src/include/catalog/catalog.h index 5a836abbee..efb1676ae1 100644 --- a/src/include/catalog/catalog.h +++ b/src/include/catalog/catalog.h @@ -215,6 +215,8 @@ class Catalog { std::unordered_set getAllRelTableSchemasContainBoundTable( common::table_id_t boundTableID) const; + void addVectorFunction(std::string name, function::vector_function_definitions definitions); + private: inline bool hasUpdates() { return catalogContentForWriteTrx != nullptr; } diff --git a/src/include/catalog/catalog_structs.h b/src/include/catalog/catalog_structs.h index 4d1beeb019..802c1f1aff 100644 --- a/src/include/catalog/catalog_structs.h +++ b/src/include/catalog/catalog_structs.h @@ -7,7 +7,6 @@ #include "common/constants.h" #include "common/exception.h" #include "common/rel_direction.h" -#include "common/string_utils.h" #include "common/types/types_include.h" namespace kuzu { @@ -48,9 +47,7 @@ struct TableSchema { virtual ~TableSchema() = default; - static inline bool isReservedPropertyName(const std::string& propertyName) { - return common::StringUtils::getUpper(propertyName) == common::InternalKeyword::ID; - } + static bool isReservedPropertyName(const std::string& propertyName); inline uint32_t getNumProperties() const { return properties.size(); } diff --git a/src/include/common/vector/auxiliary_buffer.h b/src/include/common/vector/auxiliary_buffer.h index c62e440c16..1254a228e7 100644 --- a/src/include/common/vector/auxiliary_buffer.h +++ b/src/include/common/vector/auxiliary_buffer.h @@ -1,8 +1,11 @@ #pragma once -#include "arrow/array.h" #include "common/in_mem_overflow_buffer.h" +namespace arrow { +class Array; +} + namespace kuzu { namespace common { diff --git a/src/include/function/binary_function_executor.h b/src/include/function/binary_function_executor.h index 968a18f7d1..86aab76301 100644 --- a/src/include/function/binary_function_executor.h +++ b/src/include/function/binary_function_executor.h @@ -18,7 +18,7 @@ struct BinaryFunctionWrapper { template static inline void operation(LEFT_TYPE& left, RIGHT_TYPE& right, RESULT_TYPE& result, common::ValueVector* leftValueVector, common::ValueVector* rightValueVector, - common::ValueVector* resultValueVector) { + common::ValueVector* resultValueVector, void* dataPtr) { OP::operation(left, right, result); } }; @@ -27,7 +27,7 @@ struct BinaryListStructFunctionWrapper { template static inline void operation(LEFT_TYPE& left, RIGHT_TYPE& right, RESULT_TYPE& result, common::ValueVector* leftValueVector, common::ValueVector* rightValueVector, - common::ValueVector* resultValueVector) { + common::ValueVector* resultValueVector, void* dataPtr) { OP::operation(left, right, result, *leftValueVector, *rightValueVector, *resultValueVector); } }; @@ -36,7 +36,7 @@ struct BinaryStringFunctionWrapper { template static inline void operation(LEFT_TYPE& left, RIGHT_TYPE& right, RESULT_TYPE& result, common::ValueVector* leftValueVector, common::ValueVector* rightValueVector, - common::ValueVector* resultValueVector) { + common::ValueVector* resultValueVector, void* dataPtr) { OP::operation(left, right, result, *resultValueVector); } }; @@ -45,39 +45,50 @@ struct BinaryComparisonFunctionWrapper { template static inline void operation(LEFT_TYPE& left, RIGHT_TYPE& right, RESULT_TYPE& result, common::ValueVector* leftValueVector, common::ValueVector* rightValueVector, - common::ValueVector* resultValueVector) { + common::ValueVector* resultValueVector, void* dataPtr) { OP::operation(left, right, result, leftValueVector, rightValueVector); } }; +struct BinaryUDFFunctionWrapper { + template + static inline void operation(LEFT_TYPE& left, RIGHT_TYPE& right, RESULT_TYPE& result, + common::ValueVector* leftValueVector, common::ValueVector* rightValueVector, + common::ValueVector* resultValueVector, void* dataPtr) { + OP::operation(left, right, result, dataPtr); + } +}; + struct BinaryFunctionExecutor { template static inline void executeOnValue(common::ValueVector& left, common::ValueVector& right, - common::ValueVector& resultValueVector, uint64_t lPos, uint64_t rPos, uint64_t resPos) { + common::ValueVector& resultValueVector, uint64_t lPos, uint64_t rPos, uint64_t resPos, + void* dataPtr) { OP_WRAPPER::template operation( ((LEFT_TYPE*)left.getData())[lPos], ((RIGHT_TYPE*)right.getData())[rPos], - ((RESULT_TYPE*)resultValueVector.getData())[resPos], &left, &right, &resultValueVector); + ((RESULT_TYPE*)resultValueVector.getData())[resPos], &left, &right, &resultValueVector, + dataPtr); } template - static void executeBothFlat( - common::ValueVector& left, common::ValueVector& right, common::ValueVector& result) { + static void executeBothFlat(common::ValueVector& left, common::ValueVector& right, + common::ValueVector& result, void* dataPtr) { auto lPos = left.state->selVector->selectedPositions[0]; auto rPos = right.state->selVector->selectedPositions[0]; auto resPos = result.state->selVector->selectedPositions[0]; result.setNull(resPos, left.isNull(lPos) || right.isNull(rPos)); if (!result.isNull(resPos)) { executeOnValue( - left, right, result, lPos, rPos, resPos); + left, right, result, lPos, rPos, resPos, dataPtr); } } template - static void executeFlatUnFlat( - common::ValueVector& left, common::ValueVector& right, common::ValueVector& result) { + static void executeFlatUnFlat(common::ValueVector& left, common::ValueVector& right, + common::ValueVector& result, void* dataPtr) { auto lPos = left.state->selVector->selectedPositions[0]; if (left.isNull(lPos)) { result.setAllNull(); @@ -85,13 +96,13 @@ struct BinaryFunctionExecutor { if (right.state->selVector->isUnfiltered()) { for (auto i = 0u; i < right.state->selVector->selectedSize; ++i) { executeOnValue( - left, right, result, lPos, i, i); + left, right, result, lPos, i, i, dataPtr); } } else { for (auto i = 0u; i < right.state->selVector->selectedSize; ++i) { auto rPos = right.state->selVector->selectedPositions[i]; executeOnValue( - left, right, result, lPos, rPos, rPos); + left, right, result, lPos, rPos, rPos, dataPtr); } } } else { @@ -100,7 +111,7 @@ struct BinaryFunctionExecutor { result.setNull(i, right.isNull(i)); // left is always not null if (!result.isNull(i)) { executeOnValue( - left, right, result, lPos, i, i); + left, right, result, lPos, i, i, dataPtr); } } } else { @@ -109,7 +120,7 @@ struct BinaryFunctionExecutor { result.setNull(rPos, right.isNull(rPos)); // left is always not null if (!result.isNull(rPos)) { executeOnValue( - left, right, result, lPos, rPos, rPos); + left, right, result, lPos, rPos, rPos, dataPtr); } } } @@ -118,8 +129,8 @@ struct BinaryFunctionExecutor { template - static void executeUnFlatFlat( - common::ValueVector& left, common::ValueVector& right, common::ValueVector& result) { + static void executeUnFlatFlat(common::ValueVector& left, common::ValueVector& right, + common::ValueVector& result, void* dataPtr) { auto rPos = right.state->selVector->selectedPositions[0]; if (right.isNull(rPos)) { result.setAllNull(); @@ -127,13 +138,13 @@ struct BinaryFunctionExecutor { if (left.state->selVector->isUnfiltered()) { for (auto i = 0u; i < left.state->selVector->selectedSize; ++i) { executeOnValue( - left, right, result, i, rPos, i); + left, right, result, i, rPos, i, dataPtr); } } else { for (auto i = 0u; i < left.state->selVector->selectedSize; ++i) { auto lPos = left.state->selVector->selectedPositions[i]; executeOnValue( - left, right, result, lPos, rPos, lPos); + left, right, result, lPos, rPos, lPos, dataPtr); } } } else { @@ -142,7 +153,7 @@ struct BinaryFunctionExecutor { result.setNull(i, left.isNull(i)); // right is always not null if (!result.isNull(i)) { executeOnValue( - left, right, result, i, rPos, i); + left, right, result, i, rPos, i, dataPtr); } } } else { @@ -151,7 +162,7 @@ struct BinaryFunctionExecutor { result.setNull(lPos, left.isNull(lPos)); // right is always not null if (!result.isNull(lPos)) { executeOnValue( - left, right, result, lPos, rPos, lPos); + left, right, result, lPos, rPos, lPos, dataPtr); } } } @@ -160,20 +171,20 @@ struct BinaryFunctionExecutor { template - static void executeBothUnFlat( - common::ValueVector& left, common::ValueVector& right, common::ValueVector& result) { + static void executeBothUnFlat(common::ValueVector& left, common::ValueVector& right, + common::ValueVector& result, void* dataPtr) { assert(left.state == right.state); if (left.hasNoNullsGuarantee() && right.hasNoNullsGuarantee()) { if (result.state->selVector->isUnfiltered()) { for (uint64_t i = 0; i < result.state->selVector->selectedSize; i++) { executeOnValue( - left, right, result, i, i, i); + left, right, result, i, i, i, dataPtr); } } else { for (uint64_t i = 0; i < result.state->selVector->selectedSize; i++) { auto pos = result.state->selVector->selectedPositions[i]; executeOnValue( - left, right, result, pos, pos, pos); + left, right, result, pos, pos, pos, dataPtr); } } } else { @@ -182,7 +193,7 @@ struct BinaryFunctionExecutor { result.setNull(i, left.isNull(i) || right.isNull(i)); if (!result.isNull(i)) { executeOnValue( - left, right, result, i, i, i); + left, right, result, i, i, i, dataPtr); } } } else { @@ -191,7 +202,7 @@ struct BinaryFunctionExecutor { result.setNull(pos, left.isNull(pos) || right.isNull(pos)); if (!result.isNull(pos)) { executeOnValue( - left, right, result, pos, pos, pos); + left, right, result, pos, pos, pos, dataPtr); } } } @@ -200,21 +211,21 @@ struct BinaryFunctionExecutor { template - static void executeSwitch( - common::ValueVector& left, common::ValueVector& right, common::ValueVector& result) { + static void executeSwitch(common::ValueVector& left, common::ValueVector& right, + common::ValueVector& result, void* dataPtr) { result.resetAuxiliaryBuffer(); if (left.state->isFlat() && right.state->isFlat()) { executeBothFlat( - left, right, result); + left, right, result, dataPtr); } else if (left.state->isFlat() && !right.state->isFlat()) { executeFlatUnFlat( - left, right, result); + left, right, result, dataPtr); } else if (!left.state->isFlat() && right.state->isFlat()) { executeUnFlatFlat( - left, right, result); + left, right, result, dataPtr); } else if (!left.state->isFlat() && !right.state->isFlat()) { executeBothUnFlat( - left, right, result); + left, right, result, dataPtr); } else { assert(false); } @@ -224,28 +235,35 @@ struct BinaryFunctionExecutor { static void execute( common::ValueVector& left, common::ValueVector& right, common::ValueVector& result) { executeSwitch( - left, right, result); + left, right, result, nullptr /* dataPtr */); } template static void executeString( common::ValueVector& left, common::ValueVector& right, common::ValueVector& result) { executeSwitch( - left, right, result); + left, right, result, nullptr /* dataPtr */); } template static void executeListStruct( common::ValueVector& left, common::ValueVector& right, common::ValueVector& result) { executeSwitch( - left, right, result); + left, right, result, nullptr /* dataPtr */); } template static void executeComparison( common::ValueVector& left, common::ValueVector& right, common::ValueVector& result) { executeSwitch( - left, right, result); + left, right, result, nullptr /* dataPtr */); + } + + template + static void executeUDF(common::ValueVector& left, common::ValueVector& right, + common::ValueVector& result, void* dataPtr) { + executeSwitch( + left, right, result, dataPtr); } struct BinarySelectWrapper { diff --git a/src/include/function/built_in_vector_functions.h b/src/include/function/built_in_vector_functions.h index 7d144eef8c..a9947c939d 100644 --- a/src/include/function/built_in_vector_functions.h +++ b/src/include/function/built_in_vector_functions.h @@ -11,7 +11,7 @@ class BuiltInVectorFunctions { BuiltInVectorFunctions() { registerVectorFunctions(); } inline bool containsFunction(const std::string& functionName) { - return VectorFunctions.contains(functionName); + return vectorFunctions.contains(functionName); } /** @@ -27,6 +27,8 @@ class BuiltInVectorFunctions { static uint32_t getCastCost( common::LogicalTypeID inputTypeID, common::LogicalTypeID targetTypeID); + void addFunction(std::string name, function::vector_function_definitions definitions); + private: static uint32_t getTargetTypeCost(common::LogicalTypeID typeID); @@ -75,7 +77,7 @@ class BuiltInVectorFunctions { private: // TODO(Ziyi): Refactor VectorFunction/tableOperation to inherit from the same base class. - std::unordered_map VectorFunctions; + std::unordered_map vectorFunctions; }; } // namespace function diff --git a/src/include/function/ternary_function_executor.h b/src/include/function/ternary_function_executor.h index c83e86b179..7a85404f92 100644 --- a/src/include/function/ternary_function_executor.h +++ b/src/include/function/ternary_function_executor.h @@ -8,47 +8,55 @@ namespace kuzu { namespace function { -struct TernaryOperationWrapper { +struct TernaryFunctionWrapper { template static inline void operation(A_TYPE& a, B_TYPE& b, C_TYPE& c, RESULT_TYPE& result, - void* aValueVector, void* resultValueVector) { + void* aValueVector, void* resultValueVector, void* dataPtr) { OP::operation(a, b, c, result); } }; -struct TernaryStringOperationWrapper { +struct TernaryStringFunctionWrapper { template static inline void operation(A_TYPE& a, B_TYPE& b, C_TYPE& c, RESULT_TYPE& result, - void* aValueVector, void* resultValueVector) { + void* aValueVector, void* resultValueVector, void* dataPtr) { OP::operation(a, b, c, result, *(common::ValueVector*)resultValueVector); } }; -struct TernaryListOperationWrapper { +struct TernaryListFunctionWrapper { template static inline void operation(A_TYPE& a, B_TYPE& b, C_TYPE& c, RESULT_TYPE& result, - void* aValueVector, void* resultValueVector) { + void* aValueVector, void* resultValueVector, void* dataPtr) { OP::operation(a, b, c, result, *(common::ValueVector*)aValueVector, *(common::ValueVector*)resultValueVector); } }; +struct TernaryUDFFunctionWrapper { + template + static inline void operation(A_TYPE& a, B_TYPE& b, C_TYPE& c, RESULT_TYPE& result, + void* aValueVector, void* resultValueVector, void* dataPtr) { + OP::operation(a, b, c, result, dataPtr); + } +}; + struct TernaryFunctionExecutor { template static void executeOnValue(common::ValueVector& a, common::ValueVector& b, common::ValueVector& c, common::ValueVector& result, uint64_t aPos, uint64_t bPos, - uint64_t cPos, uint64_t resPos) { + uint64_t cPos, uint64_t resPos, void* dataPtr) { auto resValues = (RESULT_TYPE*)result.getData(); OP_WRAPPER::template operation( ((A_TYPE*)a.getData())[aPos], ((B_TYPE*)b.getData())[bPos], - ((C_TYPE*)c.getData())[cPos], resValues[resPos], (void*)&a, (void*)&result); + ((C_TYPE*)c.getData())[cPos], resValues[resPos], (void*)&a, (void*)&result, dataPtr); } template static void executeAllFlat(common::ValueVector& a, common::ValueVector& b, - common::ValueVector& c, common::ValueVector& result) { + common::ValueVector& c, common::ValueVector& result, void* dataPtr) { auto aPos = a.state->selVector->selectedPositions[0]; auto bPos = b.state->selVector->selectedPositions[0]; auto cPos = c.state->selVector->selectedPositions[0]; @@ -56,14 +64,14 @@ struct TernaryFunctionExecutor { result.setNull(resPos, a.isNull(aPos) || b.isNull(bPos) || c.isNull(cPos)); if (!result.isNull(resPos)) { executeOnValue( - a, b, c, result, aPos, bPos, cPos, resPos); + a, b, c, result, aPos, bPos, cPos, resPos, dataPtr); } } template static void executeFlatFlatUnflat(common::ValueVector& a, common::ValueVector& b, - common::ValueVector& c, common::ValueVector& result) { + common::ValueVector& c, common::ValueVector& result, void* dataPtr) { auto aPos = a.state->selVector->selectedPositions[0]; auto bPos = b.state->selVector->selectedPositions[0]; if (a.isNull(aPos) || b.isNull(bPos)) { @@ -72,13 +80,13 @@ struct TernaryFunctionExecutor { if (c.state->selVector->isUnfiltered()) { for (auto i = 0u; i < c.state->selVector->selectedSize; ++i) { executeOnValue( - a, b, c, result, aPos, bPos, i, i); + a, b, c, result, aPos, bPos, i, i, dataPtr); } } else { for (auto i = 0u; i < c.state->selVector->selectedSize; ++i) { auto pos = c.state->selVector->selectedPositions[i]; executeOnValue( - a, b, c, result, aPos, bPos, pos, pos); + a, b, c, result, aPos, bPos, pos, pos, dataPtr); } } } else { @@ -87,7 +95,7 @@ struct TernaryFunctionExecutor { result.setNull(i, c.isNull(i)); if (!result.isNull(i)) { executeOnValue( - a, b, c, result, aPos, bPos, i, i); + a, b, c, result, aPos, bPos, i, i, dataPtr); } } } else { @@ -96,7 +104,7 @@ struct TernaryFunctionExecutor { result.setNull(pos, c.isNull(pos)); if (!result.isNull(pos)) { executeOnValue( - a, b, c, result, aPos, bPos, pos, pos); + a, b, c, result, aPos, bPos, pos, pos, dataPtr); } } } @@ -106,7 +114,7 @@ struct TernaryFunctionExecutor { template static void executeFlatUnflatUnflat(common::ValueVector& a, common::ValueVector& b, - common::ValueVector& c, common::ValueVector& result) { + common::ValueVector& c, common::ValueVector& result, void* dataPtr) { assert(b.state == c.state); auto aPos = a.state->selVector->selectedPositions[0]; if (a.isNull(aPos)) { @@ -115,13 +123,13 @@ struct TernaryFunctionExecutor { if (b.state->selVector->isUnfiltered()) { for (auto i = 0u; i < b.state->selVector->selectedSize; ++i) { executeOnValue( - a, b, c, result, aPos, i, i, i); + a, b, c, result, aPos, i, i, i, dataPtr); } } else { for (auto i = 0u; i < b.state->selVector->selectedSize; ++i) { auto pos = b.state->selVector->selectedPositions[i]; executeOnValue( - a, b, c, result, aPos, pos, pos, pos); + a, b, c, result, aPos, pos, pos, pos, dataPtr); } } } else { @@ -130,7 +138,7 @@ struct TernaryFunctionExecutor { result.setNull(i, b.isNull(i) || c.isNull(i)); if (!result.isNull(i)) { executeOnValue( - a, b, c, result, aPos, i, i, i); + a, b, c, result, aPos, i, i, i, dataPtr); } } } else { @@ -139,7 +147,7 @@ struct TernaryFunctionExecutor { result.setNull(pos, b.isNull(pos) || c.isNull(pos)); if (!result.isNull(pos)) { executeOnValue( - a, b, c, result, aPos, pos, pos, pos); + a, b, c, result, aPos, pos, pos, pos, dataPtr); } } } @@ -149,7 +157,7 @@ struct TernaryFunctionExecutor { template static void executeFlatUnflatFlat(common::ValueVector& a, common::ValueVector& b, - common::ValueVector& c, common::ValueVector& result) { + common::ValueVector& c, common::ValueVector& result, void* dataPtr) { auto aPos = a.state->selVector->selectedPositions[0]; auto cPos = c.state->selVector->selectedPositions[0]; if (a.isNull(aPos) || c.isNull(cPos)) { @@ -158,13 +166,13 @@ struct TernaryFunctionExecutor { if (b.state->selVector->isUnfiltered()) { for (auto i = 0u; i < b.state->selVector->selectedSize; ++i) { executeOnValue( - a, b, c, result, aPos, i, cPos, i); + a, b, c, result, aPos, i, cPos, i, dataPtr); } } else { for (auto i = 0u; i < b.state->selVector->selectedSize; ++i) { auto pos = b.state->selVector->selectedPositions[i]; executeOnValue( - a, b, c, result, aPos, pos, cPos, pos); + a, b, c, result, aPos, pos, cPos, pos, dataPtr); } } } else { @@ -173,7 +181,7 @@ struct TernaryFunctionExecutor { result.setNull(i, b.isNull(i)); if (!result.isNull(i)) { executeOnValue( - a, b, c, result, aPos, i, cPos, i); + a, b, c, result, aPos, i, cPos, i, dataPtr); } } } else { @@ -182,7 +190,7 @@ struct TernaryFunctionExecutor { result.setNull(pos, b.isNull(pos)); if (!result.isNull(pos)) { executeOnValue( - a, b, c, result, aPos, pos, cPos, pos); + a, b, c, result, aPos, pos, cPos, pos, dataPtr); } } } @@ -192,19 +200,19 @@ struct TernaryFunctionExecutor { template static void executeAllUnFlat(common::ValueVector& a, common::ValueVector& b, - common::ValueVector& c, common::ValueVector& result) { + common::ValueVector& c, common::ValueVector& result, void* dataPtr) { assert(a.state == b.state && b.state == c.state); if (a.hasNoNullsGuarantee() && b.hasNoNullsGuarantee() && c.hasNoNullsGuarantee()) { if (a.state->selVector->isUnfiltered()) { for (uint64_t i = 0; i < a.state->selVector->selectedSize; i++) { executeOnValue( - a, b, c, result, i, i, i, i); + a, b, c, result, i, i, i, i, dataPtr); } } else { for (uint64_t i = 0; i < a.state->selVector->selectedSize; i++) { auto pos = a.state->selVector->selectedPositions[i]; executeOnValue( - a, b, c, result, pos, pos, pos, pos); + a, b, c, result, pos, pos, pos, pos, dataPtr); } } } else { @@ -213,7 +221,7 @@ struct TernaryFunctionExecutor { result.setNull(i, a.isNull(i) || b.isNull(i) || c.isNull(i)); if (!result.isNull(i)) { executeOnValue( - a, b, c, result, i, i, i, i); + a, b, c, result, i, i, i, i, dataPtr); } } } else { @@ -222,7 +230,7 @@ struct TernaryFunctionExecutor { result.setNull(pos, a.isNull(pos) || b.isNull(pos) || c.isNull(pos)); if (!result.isNull(pos)) { executeOnValue( - a, b, c, result, pos, pos, pos, pos); + a, b, c, result, pos, pos, pos, pos, dataPtr); } } } @@ -232,7 +240,7 @@ struct TernaryFunctionExecutor { template static void executeUnflatFlatFlat(common::ValueVector& a, common::ValueVector& b, - common::ValueVector& c, common::ValueVector& result) { + common::ValueVector& c, common::ValueVector& result, void* dataPtr) { auto bPos = b.state->selVector->selectedPositions[0]; auto cPos = c.state->selVector->selectedPositions[0]; if (b.isNull(bPos) || c.isNull(cPos)) { @@ -241,13 +249,13 @@ struct TernaryFunctionExecutor { if (a.state->selVector->isUnfiltered()) { for (auto i = 0u; i < a.state->selVector->selectedSize; ++i) { executeOnValue( - a, b, c, result, i, bPos, cPos, i); + a, b, c, result, i, bPos, cPos, i, dataPtr); } } else { for (auto i = 0u; i < a.state->selVector->selectedSize; ++i) { auto pos = a.state->selVector->selectedPositions[i]; executeOnValue( - a, b, c, result, pos, bPos, cPos, pos); + a, b, c, result, pos, bPos, cPos, pos, dataPtr); } } } else { @@ -256,7 +264,7 @@ struct TernaryFunctionExecutor { result.setNull(i, a.isNull(i)); if (!result.isNull(i)) { executeOnValue( - a, b, c, result, i, bPos, cPos, i); + a, b, c, result, i, bPos, cPos, i, dataPtr); } } } else { @@ -265,7 +273,7 @@ struct TernaryFunctionExecutor { result.setNull(pos, a.isNull(pos)); if (!result.isNull(pos)) { executeOnValue( - a, b, c, result, pos, bPos, cPos, pos); + a, b, c, result, pos, bPos, cPos, pos, dataPtr); } } } @@ -275,7 +283,7 @@ struct TernaryFunctionExecutor { template static void executeUnflatFlatUnflat(common::ValueVector& a, common::ValueVector& b, - common::ValueVector& c, common::ValueVector& result) { + common::ValueVector& c, common::ValueVector& result, void* dataPtr) { assert(a.state == c.state); auto bPos = b.state->selVector->selectedPositions[0]; if (b.isNull(bPos)) { @@ -284,13 +292,13 @@ struct TernaryFunctionExecutor { if (a.state->selVector->isUnfiltered()) { for (auto i = 0u; i < a.state->selVector->selectedSize; ++i) { executeOnValue( - a, b, c, result, i, bPos, i, i); + a, b, c, result, i, bPos, i, i, dataPtr); } } else { for (auto i = 0u; i < a.state->selVector->selectedSize; ++i) { auto pos = a.state->selVector->selectedPositions[i]; executeOnValue( - a, b, c, result, pos, bPos, pos, pos); + a, b, c, result, pos, bPos, pos, pos, dataPtr); } } } else { @@ -299,7 +307,7 @@ struct TernaryFunctionExecutor { result.setNull(i, a.isNull(i) || c.isNull(i)); if (!result.isNull(i)) { executeOnValue( - a, b, c, result, i, bPos, i, i); + a, b, c, result, i, bPos, i, i, dataPtr); } } } else { @@ -308,7 +316,7 @@ struct TernaryFunctionExecutor { result.setNull(pos, a.isNull(pos) || c.isNull(pos)); if (!result.isNull(pos)) { executeOnValue( - a, b, c, result, pos, bPos, pos, pos); + a, b, c, result, pos, bPos, pos, pos, dataPtr); } } } @@ -318,7 +326,7 @@ struct TernaryFunctionExecutor { template static void executeUnflatUnFlatFlat(common::ValueVector& a, common::ValueVector& b, - common::ValueVector& c, common::ValueVector& result) { + common::ValueVector& c, common::ValueVector& result, void* dataPtr) { assert(a.state == b.state); auto cPos = c.state->selVector->selectedPositions[0]; if (c.isNull(cPos)) { @@ -327,13 +335,13 @@ struct TernaryFunctionExecutor { if (a.state->selVector->isUnfiltered()) { for (auto i = 0u; i < a.state->selVector->selectedSize; ++i) { executeOnValue( - a, b, c, result, i, i, cPos, i); + a, b, c, result, i, i, cPos, i, dataPtr); } } else { for (auto i = 0u; i < a.state->selVector->selectedSize; ++i) { auto pos = a.state->selVector->selectedPositions[i]; executeOnValue( - a, b, c, result, pos, pos, cPos, pos); + a, b, c, result, pos, pos, cPos, pos, dataPtr); } } } else { @@ -342,7 +350,7 @@ struct TernaryFunctionExecutor { result.setNull(i, a.isNull(i) || b.isNull(i)); if (!result.isNull(i)) { executeOnValue( - a, b, c, result, i, i, cPos, i); + a, b, c, result, i, i, cPos, i, dataPtr); } } } else { @@ -351,7 +359,7 @@ struct TernaryFunctionExecutor { result.setNull(pos, a.isNull(pos) || b.isNull(pos)); if (!result.isNull(pos)) { executeOnValue( - a, b, c, result, pos, pos, cPos, pos); + a, b, c, result, pos, pos, cPos, pos, dataPtr); } } } @@ -361,31 +369,32 @@ struct TernaryFunctionExecutor { template static void executeSwitch(common::ValueVector& a, common::ValueVector& b, - common::ValueVector& c, common::ValueVector& result) { + common::ValueVector& c, common::ValueVector& result, void* dataPtr) { result.resetAuxiliaryBuffer(); if (a.state->isFlat() && b.state->isFlat() && c.state->isFlat()) { - executeAllFlat(a, b, c, result); + executeAllFlat( + a, b, c, result, dataPtr); } else if (a.state->isFlat() && b.state->isFlat() && !c.state->isFlat()) { executeFlatFlatUnflat( - a, b, c, result); + a, b, c, result, dataPtr); } else if (a.state->isFlat() && !b.state->isFlat() && !c.state->isFlat()) { executeFlatUnflatUnflat( - a, b, c, result); + a, b, c, result, dataPtr); } else if (a.state->isFlat() && !b.state->isFlat() && c.state->isFlat()) { executeFlatUnflatFlat( - a, b, c, result); + a, b, c, result, dataPtr); } else if (!a.state->isFlat() && !b.state->isFlat() && !c.state->isFlat()) { executeAllUnFlat( - a, b, c, result); + a, b, c, result, dataPtr); } else if (!a.state->isFlat() && !b.state->isFlat() && c.state->isFlat()) { executeUnflatUnFlatFlat( - a, b, c, result); + a, b, c, result, dataPtr); } else if (!a.state->isFlat() && b.state->isFlat() && c.state->isFlat()) { executeUnflatFlatFlat( - a, b, c, result); + a, b, c, result, dataPtr); } else if (!a.state->isFlat() && b.state->isFlat() && !c.state->isFlat()) { executeUnflatFlatUnflat( - a, b, c, result); + a, b, c, result, dataPtr); } else { assert(false); } @@ -394,22 +403,29 @@ struct TernaryFunctionExecutor { template static void execute(common::ValueVector& a, common::ValueVector& b, common::ValueVector& c, common::ValueVector& result) { - executeSwitch( - a, b, c, result); + executeSwitch( + a, b, c, result, nullptr /* dataPtr */); } template static void executeString(common::ValueVector& a, common::ValueVector& b, common::ValueVector& c, common::ValueVector& result) { - executeSwitch( - a, b, c, result); + executeSwitch( + a, b, c, result, nullptr /* dataPtr */); } template static void executeListStruct(common::ValueVector& a, common::ValueVector& b, common::ValueVector& c, common::ValueVector& result) { - executeSwitch( - a, b, c, result); + executeSwitch( + a, b, c, result, nullptr /* dataPtr */); + } + + template + static void executeUDF(common::ValueVector& a, common::ValueVector& b, common::ValueVector& c, + common::ValueVector& result, void* dataPtr) { + executeSwitch( + a, b, c, result, dataPtr); } }; diff --git a/src/include/function/udf_function.h b/src/include/function/udf_function.h new file mode 100644 index 0000000000..ea2f1b359e --- /dev/null +++ b/src/include/function/udf_function.h @@ -0,0 +1,159 @@ +#pragma once + +#include "function/vector_functions.h" + +namespace kuzu { +namespace function { + +struct UnaryUDFExecutor { + template + static inline void operation(OPERAND_TYPE& input, RESULT_TYPE& result, void* udfFunc) { + typedef RESULT_TYPE (*unary_udf_func)(OPERAND_TYPE); + auto unaryUDFFunc = (unary_udf_func)udfFunc; + result = unaryUDFFunc(input); + } +}; + +struct BinaryUDFExecutor { + template + static inline void operation( + LEFT_TYPE& left, RIGHT_TYPE& right, RESULT_TYPE& result, void* udfFunc) { + typedef RESULT_TYPE (*binary_udf_func)(LEFT_TYPE, RIGHT_TYPE); + auto binaryUDFFunc = (binary_udf_func)udfFunc; + result = binaryUDFFunc(left, right); + } +}; + +struct TernaryUDFExecutor { + template + static inline void operation( + A_TYPE& a, B_TYPE& b, C_TYPE& c, RESULT_TYPE& result, void* udfFunc) { + typedef RESULT_TYPE (*ternary_udf_func)(A_TYPE, B_TYPE, C_TYPE); + auto ternaryUDFFunc = (ternary_udf_func)udfFunc; + result = ternaryUDFFunc(a, b, c); + } +}; + +template +static inline function::scalar_exec_func createUnaryExecFunc(RESULT_TYPE (*udfFunc)(Args...)) { + throw common::NotImplementedException{"function::createUnaryExecFunc()"}; +} + +template +static inline function::scalar_exec_func createUnaryExecFunc(RESULT_TYPE (*udfFunc)(OPERAND_TYPE)) { + function::scalar_exec_func execFunc = + [=](const std::vector>& params, + common::ValueVector& result) -> void { + assert(params.size() == 1); + UnaryFunctionExecutor::executeUDF( + *params[0], result, (void*)udfFunc); + }; + return execFunc; +} + +template +static inline function::scalar_exec_func createBinaryExecFunc(RESULT_TYPE (*udfFunc)(Args...)) { + throw common::NotImplementedException{"function::createBinaryExecFunc()"}; +} + +template +static inline function::scalar_exec_func createBinaryExecFunc( + RESULT_TYPE (*udfFunc)(LEFT_TYPE, RIGHT_TYPE)) { + function::scalar_exec_func execFunc = + [=](const std::vector>& params, + common::ValueVector& result) -> void { + assert(params.size() == 2); + BinaryFunctionExecutor::executeUDF( + *params[0], *params[1], result, (void*)udfFunc); + }; + return execFunc; +} + +template +static inline function::scalar_exec_func createTernaryExecFunc(RESULT_TYPE (*udfFunc)(Args...)) { + throw common::NotImplementedException{"function::createTernaryExecFunc()"}; +} + +template +static inline function::scalar_exec_func createTernaryExecFunc( + RESULT_TYPE (*udfFunc)(A_TYPE, B_TYPE, C_TYPE)) { + function::scalar_exec_func execFunc = + [=](const std::vector>& params, + common::ValueVector& result) -> void { + assert(params.size() == 3); + TernaryFunctionExecutor::executeUDF(*params[0], *params[1], *params[2], result, (void*)udfFunc); + }; + return execFunc; +} + +template +inline static scalar_exec_func getScalarExecFunc(TR (*udfFunc)(Args...)) { + constexpr auto numArgs = sizeof...(Args); + switch (numArgs) { + case 1: + return createUnaryExecFunc(udfFunc); + case 2: + return createBinaryExecFunc(udfFunc); + case 3: + return createTernaryExecFunc(udfFunc); + default: + throw common::BinderException("UDF function only supported until ternary!"); + } +} + +template +inline static common::LogicalTypeID getParameterType() { + if (std::is_same()) { + return common::LogicalTypeID::BOOL; + } else if (std::is_same()) { + return common::LogicalTypeID::INT16; + } else if (std::is_same()) { + return common::LogicalTypeID::INT32; + } else if (std::is_same()) { + return common::LogicalTypeID::INT64; + } else if (std::is_same()) { + return common::LogicalTypeID::FLOAT; + } else if (std::is_same()) { + return common::LogicalTypeID::DOUBLE; + } else if (std::is_same()) { + return common::LogicalTypeID::STRING; + } else { + throw common::NotImplementedException{"function::getParameterType"}; + } +} + +template +inline static void getParameterTypesRecursive(std::vector& arguments) { + arguments.push_back(getParameterType()); +} + +template +inline static void getParameterTypesRecursive(std::vector& arguments) { + arguments.push_back(getParameterType()); + getParameterTypesRecursive(arguments); +} + +template +inline static std::unique_ptr getFunctionDefinition( + const std::string& name, TR (*udfFunc)(Args...), + std::vector parameterTypes, common::LogicalTypeID returnType) { + function::scalar_exec_func scalarExecFunc = function::getScalarExecFunc(udfFunc); + return std::make_unique( + name, std::move(parameterTypes), returnType, std::move(scalarExecFunc)); +} + +template +inline static std::unique_ptr getFunctionDefinition( + const std::string& name, TR (*udfFunc)(Args...)) { + std::vector parameterTypes; + getParameterTypesRecursive(parameterTypes); + common::LogicalTypeID returnType = getParameterType(); + if (returnType == common::LogicalTypeID::STRING) { + throw common::NotImplementedException{"function::getFunctionDefinition"}; + } + return getFunctionDefinition(name, udfFunc, std::move(parameterTypes), returnType); +} + +} // namespace function +} // namespace kuzu diff --git a/src/include/function/unary_function_executor.h b/src/include/function/unary_function_executor.h index 1eb4e8c930..87d7e4c264 100644 --- a/src/include/function/unary_function_executor.h +++ b/src/include/function/unary_function_executor.h @@ -12,51 +12,60 @@ namespace function { * IS_NOT_NULL operation. */ -struct UnaryOperationWrapper { +struct UnaryFunctionWrapper { template - static inline void operation( - OPERAND_TYPE& input, RESULT_TYPE& result, void* inputVector, void* resultVector) { + static inline void operation(OPERAND_TYPE& input, RESULT_TYPE& result, void* inputVector, + void* resultVector, void* dataPtr) { FUNC::operation(input, result); } }; -struct UnaryStringOperationWrapper { +struct UnaryStringFunctionWrapper { template - static void operation( - OPERAND_TYPE& input, RESULT_TYPE& result, void* inputVector, void* resultVector) { + static void operation(OPERAND_TYPE& input, RESULT_TYPE& result, void* inputVector, + void* resultVector, void* dataPtr) { FUNC::operation(input, result, *(common::ValueVector*)resultVector); } }; -struct UnaryListOperationWrapper { +struct UnaryListFunctionWrapper { template - static inline void operation( - OPERAND_TYPE& input, RESULT_TYPE& result, void* leftValueVector, void* resultValueVector) { + static inline void operation(OPERAND_TYPE& input, RESULT_TYPE& result, void* leftValueVector, + void* resultValueVector, void* dataPtr) { FUNC::operation(input, result, *(common::ValueVector*)leftValueVector, *(common::ValueVector*)resultValueVector); } }; -struct UnaryCastOperationWrapper { +struct UnaryCastFunctionWrapper { template - static void operation( - OPERAND_TYPE& input, RESULT_TYPE& result, void* inputVector, void* resultVector) { + static void operation(OPERAND_TYPE& input, RESULT_TYPE& result, void* inputVector, + void* resultVector, void* dataPtr) { FUNC::operation( input, result, *(common::ValueVector*)inputVector, *(common::ValueVector*)resultVector); } }; +struct UnaryUDFFunctionWrapper { + template + static inline void operation(OPERAND_TYPE& input, RESULT_TYPE& result, void* inputVector, + void* resultVector, void* dataPtr) { + FUNC::operation(input, result, dataPtr); + } +}; + struct UnaryFunctionExecutor { template static void executeOnValue(common::ValueVector& operand, uint64_t operandPos, - RESULT_TYPE& resultValue, common::ValueVector& resultValueVector) { + RESULT_TYPE& resultValue, common::ValueVector& resultValueVector, void* dataPtr) { OP_WRAPPER::template operation( ((OPERAND_TYPE*)operand.getData())[operandPos], resultValue, (void*)&operand, - (void*)&resultValueVector); + (void*)&resultValueVector, dataPtr); } template - static void executeSwitch(common::ValueVector& operand, common::ValueVector& result) { + static void executeSwitch( + common::ValueVector& operand, common::ValueVector& result, void* dataPtr) { result.resetAuxiliaryBuffer(); auto resultValues = (RESULT_TYPE*)result.getData(); if (operand.state->isFlat()) { @@ -65,20 +74,20 @@ struct UnaryFunctionExecutor { result.setNull(resultPos, operand.isNull(inputPos)); if (!result.isNull(inputPos)) { executeOnValue( - operand, inputPos, resultValues[resultPos], result); + operand, inputPos, resultValues[resultPos], result, dataPtr); } } else { if (operand.hasNoNullsGuarantee()) { if (operand.state->selVector->isUnfiltered()) { for (auto i = 0u; i < operand.state->selVector->selectedSize; i++) { executeOnValue( - operand, i, resultValues[i], result); + operand, i, resultValues[i], result, dataPtr); } } else { for (auto i = 0u; i < operand.state->selVector->selectedSize; i++) { auto pos = operand.state->selVector->selectedPositions[i]; executeOnValue( - operand, pos, resultValues[pos], result); + operand, pos, resultValues[pos], result, dataPtr); } } } else { @@ -87,7 +96,7 @@ struct UnaryFunctionExecutor { result.setNull(i, operand.isNull(i)); if (!result.isNull(i)) { executeOnValue( - operand, i, resultValues[i], result); + operand, i, resultValues[i], result, dataPtr); } } } else { @@ -96,7 +105,7 @@ struct UnaryFunctionExecutor { result.setNull(pos, operand.isNull(pos)); if (!result.isNull(pos)) { executeOnValue( - operand, pos, resultValues[pos], result); + operand, pos, resultValues[pos], result, dataPtr); } } } @@ -106,23 +115,33 @@ struct UnaryFunctionExecutor { template static void execute(common::ValueVector& operand, common::ValueVector& result) { - executeSwitch(operand, result); + executeSwitch( + operand, result, nullptr /* dataPtr */); } template static void executeString(common::ValueVector& operand, common::ValueVector& result) { - executeSwitch( - operand, result); + executeSwitch( + operand, result, nullptr /* dataPtr */); } template static void executeListStruct(common::ValueVector& operand, common::ValueVector& result) { - executeSwitch(operand, result); + executeSwitch( + operand, result, nullptr /* dataPtr */); } template static void executeCast(common::ValueVector& operand, common::ValueVector& result) { - executeSwitch(operand, result); + executeSwitch( + operand, result, nullptr /* dataPtr */); + } + + template + static void executeUDF( + common::ValueVector& operand, common::ValueVector& result, void* dataPtr) { + executeSwitch( + operand, result, dataPtr); } }; diff --git a/src/include/main/connection.h b/src/include/main/connection.h index 07cd57c82f..f20f524314 100644 --- a/src/include/main/connection.h +++ b/src/include/main/connection.h @@ -4,6 +4,7 @@ #include "client_context.h" #include "database.h" +#include "function/udf_function.h" #include "prepared_statement.h" #include "query_result.h" @@ -147,6 +148,14 @@ class Connection { */ KUZU_API uint64_t getQueryTimeOut(); + template + void createScalarFunction(const std::string& name, TR (*udfFunc)(Args...)) { + function::vector_function_definitions definitions; + auto definition = function::getFunctionDefinition(name, udfFunc); + definitions.push_back(std::move(definition)); + addScalarFunction(name, std::move(definitions)); + } + protected: ConnectionTransactionMode getTransactionMode(); void setTransactionModeNoLock(ConnectionTransactionMode newTransactionMode); @@ -202,6 +211,8 @@ class Connection { return queryResultWithError(exceptionMessage); } + void addScalarFunction(std::string name, function::vector_function_definitions definitions); + protected: Database* database; std::unique_ptr clientContext; diff --git a/src/include/storage/buffer_manager/memory_manager.h b/src/include/storage/buffer_manager/memory_manager.h index bb10a960c9..171577d460 100644 --- a/src/include/storage/buffer_manager/memory_manager.h +++ b/src/include/storage/buffer_manager/memory_manager.h @@ -6,12 +6,14 @@ #include #include "common/constants.h" -#include "storage/buffer_manager/buffer_manager.h" +#include "common/types/types.h" namespace kuzu { namespace storage { class MemoryAllocator; +class BMFileHandle; +class BufferManager; class MemoryBuffer { public: diff --git a/src/main/connection.cpp b/src/main/connection.cpp index ef6ab04fcc..9cbfbf81d7 100644 --- a/src/main/connection.cpp +++ b/src/main/connection.cpp @@ -426,5 +426,10 @@ void Connection::beginTransactionIfAutoCommit(PreparedStatement* preparedStateme } } +void Connection::addScalarFunction( + std::string name, function::vector_function_definitions definitions) { + database->catalog->addVectorFunction(name, std::move(definitions)); +} + } // namespace main } // namespace kuzu diff --git a/src/storage/buffer_manager/memory_manager.cpp b/src/storage/buffer_manager/memory_manager.cpp index 0edf34d725..640b0e26e3 100644 --- a/src/storage/buffer_manager/memory_manager.cpp +++ b/src/storage/buffer_manager/memory_manager.cpp @@ -3,6 +3,7 @@ #include #include "common/utils.h" +#include "storage/buffer_manager/buffer_manager.h" using namespace kuzu::common; diff --git a/test/main/CMakeLists.txt b/test/main/CMakeLists.txt index e3231177bd..b5d319e043 100644 --- a/test/main/CMakeLists.txt +++ b/test/main/CMakeLists.txt @@ -4,4 +4,5 @@ add_kuzu_test(main_test csv_output_test.cpp prepare_test.cpp result_value_test.cpp - storage_driver_test.cpp) + storage_driver_test.cpp + udf_test.cpp) diff --git a/test/main/udf_test.cpp b/test/main/udf_test.cpp new file mode 100644 index 0000000000..af36ae67a3 --- /dev/null +++ b/test/main/udf_test.cpp @@ -0,0 +1,106 @@ +#include "main_test_helper/main_test_helper.h" + +namespace kuzu { +namespace testing { + +int64_t add5(int64_t x) { + return x + 5; +} + +TEST_F(ApiTest, UnaryUDFInt64) { + conn->createScalarFunction("add5", &add5); + auto actualResult = + TestHelper::convertResultToString(*conn->query("MATCH (p:person) return add5(p.age)")); + auto expectedResult = std::vector{"40", "35", "50", "25", "25", "30", "45", "88"}; + sortAndCheckTestResults(actualResult, expectedResult); +} + +float_t times2(int64_t x) { + return float_t(2 * x); +} + +TEST_F(ApiTest, UnaryUDFFloat) { + conn->createScalarFunction("times2", ×2); + auto actualResult = + TestHelper::convertResultToString(*conn->query("MATCH (p:person) return times2(p.age)")); + auto expectedResult = std::vector{"70.000000", "60.000000", "90.000000", + "40.000000", "40.000000", "50.000000", "80.000000", "166.000000"}; + sortAndCheckTestResults(actualResult, expectedResult); +} + +double_t timesFloat(int32_t x) { + return (double_t)2.4 * x; +} + +TEST_F(ApiTest, UnaryUDFDouble) { + conn->createScalarFunction("timesFloat", ×Float); + auto actualResult = TestHelper::convertResultToString( + *conn->query("MATCH (p:person) return timesFloat(to_int32(p.ID))")); + auto expectedResult = std::vector{"0.000000", "4.800000", "7.200000", "12.000000", + "16.800000", "19.200000", "21.600000", "24.000000"}; + sortAndCheckTestResults(actualResult, expectedResult); +} + +int16_t strDoubleLen(common::ku_string_t str) { + return str.len * 2; +} + +TEST_F(ApiTest, UnaryUDFString) { + conn->createScalarFunction("strDoubleLen", &strDoubleLen); + auto actualResult = TestHelper::convertResultToString( + *conn->query("MATCH (p:person) return strDoubleLen(p.fName)")); + auto expectedResult = std::vector{"10", "6", "10", "6", "18", "12", "8", "98"}; + sortAndCheckTestResults(actualResult, expectedResult); +} + +int64_t addSecondParamTwice(int16_t x, int32_t y) { + return x + y + y; +} + +TEST_F(ApiTest, BinaryUDFFlatUnflat) { + conn->createScalarFunction("addSecondParamTwice", &addSecondParamTwice); + auto actualResult = TestHelper::convertResultToString( + *conn->query("MATCH (p:person)-[:knows]->(p1:person) return " + "addSecondParamTwice(to_int16(p.ID), to_int32(p1.age))")); + auto expectedResult = std::vector{ + "60", "90", "40", "72", "92", "42", "73", "63", "43", "75", "65", "95", "57", "87"}; + sortAndCheckTestResults(actualResult, expectedResult); +} + +int64_t computeStringLenPlus(common::ku_string_t str, int32_t y) { + return str.len + y; +} + +TEST_F(ApiTest, BinaryUDFStr) { + conn->createScalarFunction("computeStringLenPlus", &computeStringLenPlus); + auto actualResult = TestHelper::convertResultToString( + *conn->query("MATCH (p:person) return computeStringLenPlus(p.fName, to_int32(p.gender))")); + auto expectedResult = std::vector{"6", "5", "6", "5", "10", "8", "6", "51"}; + sortAndCheckTestResults(actualResult, expectedResult); +} + +int64_t ternaryAdd(int16_t a, int32_t b, int64_t c) { + return a + b + c; +} + +TEST_F(ApiTest, TernaryUDFInt) { + conn->createScalarFunction("ternaryAdd", &ternaryAdd); + auto actualResult = TestHelper::convertResultToString(*conn->query( + "MATCH (p:person) return ternaryAdd(to_int16(p.gender), to_int32(p.ID), p.age)")); + auto expectedResult = std::vector{"36", "34", "49", "27", "28", "35", "51", "95"}; + sortAndCheckTestResults(actualResult, expectedResult); +} + +TEST_F(ApiTest, UDFError) { + conn->createScalarFunction("add5", &add5); + try { + conn->createScalarFunction("add5", &add5); + } catch (common::CatalogException& e) { + ASSERT_EQ(std::string(e.what()), "Catalog exception: function ADD5 already exists."); + return; + } catch (common::Exception&) { FAIL(); } + FAIL(); +} + +} // namespace testing +} // namespace kuzu diff --git a/tools/rust_api/build.rs b/tools/rust_api/build.rs index 55500e6227..fab3483faa 100644 --- a/tools/rust_api/build.rs +++ b/tools/rust_api/build.rs @@ -161,6 +161,7 @@ fn build_bundled_cmake() -> Result, Box> { Ok(vec![ kuzu_root.join("src/include"), + kuzu_root.join("third_party/concurrentqueue"), kuzu_root.join("third_party/nlohmann_json"), kuzu_root.join("third_party/spdlog"), arrow_install.join("include"),