From ff051b2942b2d22f4f6aadadaf4596bc8073570a Mon Sep 17 00:00:00 2001 From: ziyi chen Date: Tue, 19 Mar 2024 15:13:29 -0400 Subject: [PATCH] Implement array functions --- src/function/CMakeLists.txt | 1 + src/function/built_in_function_utils.cpp | 1 + src/function/function_collection.cpp | 5 +- src/function/vector_array_functions.cpp | 167 ++++++++++++++++++ src/function/vector_list_functions.cpp | 9 +- .../array/functions/array_cosine_similarity.h | 33 ++++ .../array/functions/array_cross_product.h | 24 +++ .../function/array/functions/array_distance.h | 27 +++ .../array/functions/array_inner_product.h | 23 +++ .../function/array/vector_array_functions.h | 49 +++++ .../function/list/vector_list_functions.h | 1 + test/test_files/tinysnb/function/array.test | 121 +++++++++++++ 12 files changed, 457 insertions(+), 4 deletions(-) create mode 100644 src/function/vector_array_functions.cpp create mode 100644 src/include/function/array/functions/array_cosine_similarity.h create mode 100644 src/include/function/array/functions/array_cross_product.h create mode 100644 src/include/function/array/functions/array_distance.h create mode 100644 src/include/function/array/functions/array_inner_product.h create mode 100644 src/include/function/array/vector_array_functions.h create mode 100644 test/test_files/tinysnb/function/array.test diff --git a/src/function/CMakeLists.txt b/src/function/CMakeLists.txt index 8751c231461..f782566ac1c 100644 --- a/src/function/CMakeLists.txt +++ b/src/function/CMakeLists.txt @@ -16,6 +16,7 @@ add_library(kuzu_function function_collection.cpp scalar_macro_function.cpp vector_arithmetic_functions.cpp + vector_array_functions.cpp vector_boolean_functions.cpp vector_cast_functions.cpp vector_date_functions.cpp diff --git a/src/function/built_in_function_utils.cpp b/src/function/built_in_function_utils.cpp index 8f5be99ddbc..be164efa565 100644 --- a/src/function/built_in_function_utils.cpp +++ b/src/function/built_in_function_utils.cpp @@ -11,6 +11,7 @@ #include "function/aggregate/count_star.h" #include "function/aggregate_function.h" #include "function/arithmetic/vector_arithmetic_functions.h" +#include "function/array/vector_array_functions.h" #include "function/blob/vector_blob_functions.h" #include "function/cast/vector_cast_functions.h" #include "function/comparison/vector_comparison_functions.h" diff --git a/src/function/function_collection.cpp b/src/function/function_collection.cpp index 26e6fd304f7..414eb6f96f5 100644 --- a/src/function/function_collection.cpp +++ b/src/function/function_collection.cpp @@ -1,6 +1,7 @@ #include "function/function_collection.h" #include "function/arithmetic/vector_arithmetic_functions.h" +#include "function/array/vector_array_functions.h" namespace kuzu { namespace function { @@ -35,7 +36,9 @@ FunctionCollection* FunctionCollection::getFunctions() { SCALAR_FUNCTION_ALIAS(PowerFunction), SCALAR_FUNCTION(RadiansFunction), SCALAR_FUNCTION(RoundFunction), SCALAR_FUNCTION(SinFunction), SCALAR_FUNCTION(SignFunction), SCALAR_FUNCTION(SqrtFunction), SCALAR_FUNCTION(TanFunction), - + SCALAR_FUNCTION(ArrayValueFunction), SCALAR_FUNCTION(ArrayCrossProductFunction), + SCALAR_FUNCTION(ArrayCosineSimilarityFunction), SCALAR_FUNCTION(ArrayDistanceFunction), + SCALAR_FUNCTION(ArrayInnerProductFunction), SCALAR_FUNCTION(ArrayDotProductFunction), // End of array FINAL_FUNCTION}; diff --git a/src/function/vector_array_functions.cpp b/src/function/vector_array_functions.cpp new file mode 100644 index 00000000000..0aad82ebd6d --- /dev/null +++ b/src/function/vector_array_functions.cpp @@ -0,0 +1,167 @@ +#include "function/array/vector_array_functions.h" + +#include "function/array/functions/array_cosine_similarity.h" +#include "function/array/functions/array_cross_product.h" +#include "function/array/functions/array_distance.h" +#include "function/array/functions/array_inner_product.h" +#include "function/list/vector_list_functions.h" + +using namespace kuzu::common; + +namespace kuzu { +namespace function { + +function_set ArrayValueFunction::getFunctionSet() { + function_set result; + result.push_back(std::make_unique(name, + std::vector{LogicalTypeID::ANY}, LogicalTypeID::ARRAY, + ListCreationFunction::execFunc, nullptr, bindFunc, true /* isVarLength */)); + return result; +} + +std::unique_ptr ArrayValueFunction::bindFunc( + const binder::expression_vector& arguments, Function* /*function*/) { + auto resultType = + LogicalType::ARRAY(ListCreationFunction::getChildType(arguments).copy(), arguments.size()); + return std::make_unique(std::move(resultType)); +} + +function_set ArrayCrossProductFunction::getFunctionSet() { + function_set result; + result.push_back(std::make_unique(name, + std::vector{ + LogicalTypeID::ARRAY, + LogicalTypeID::ARRAY, + }, + LogicalTypeID::ARRAY, nullptr, nullptr, bindFunc, false /* isVarLength */)); + return result; +} + +static void validateArrayFunctionParameters( + const LogicalType& leftType, const LogicalType& rightType, const std::string& functionName) { + if (leftType != rightType) { + throw common::BinderException(common::stringFormat( + "{} requires both arrays to have the same element type", functionName)); + } + if (ArrayType::getChildType(&leftType)->getLogicalTypeID() != LogicalTypeID::FLOAT && + ArrayType::getChildType(&leftType)->getLogicalTypeID() != LogicalTypeID::DOUBLE) { + throw common::BinderException( + common::stringFormat("{} requires argument type of FLOAT or DOUBLE.", functionName)); + } +} + +std::unique_ptr ArrayCrossProductFunction::bindFunc( + const binder::expression_vector& arguments, Function* function) { + auto leftType = arguments[0]->dataType; + auto rightType = arguments[1]->dataType; + if (leftType != rightType) { + throw common::BinderException(common::stringFormat( + "{} requires both arrays to have the same element type and size of 3", name)); + } + scalar_func_exec_t execFunc; + switch (ArrayType::getChildType(&leftType)->getLogicalTypeID()) { + case LogicalTypeID::INT128: + execFunc = ScalarFunction::BinaryExecListStructFunction>; + break; + case LogicalTypeID::INT64: + execFunc = ScalarFunction::BinaryExecListStructFunction>; + break; + case LogicalTypeID::INT32: + execFunc = ScalarFunction::BinaryExecListStructFunction>; + break; + case LogicalTypeID::INT16: + execFunc = ScalarFunction::BinaryExecListStructFunction>; + break; + case LogicalTypeID::INT8: + execFunc = ScalarFunction::BinaryExecListStructFunction>; + break; + case LogicalTypeID::FLOAT: + execFunc = ScalarFunction::BinaryExecListStructFunction>; + break; + case LogicalTypeID::DOUBLE: + execFunc = ScalarFunction::BinaryExecListStructFunction>; + break; + default: + throw BinderException{ + stringFormat("{} can only be applied on array of floating points or integers", name)}; + } + ku_dynamic_cast(function)->execFunc = execFunc; + auto resultType = LogicalType::ARRAY( + *ArrayType::getChildType(&leftType), ArrayType::getNumElements(&leftType)); + return std::make_unique(std::move(resultType)); +} + +template +static scalar_func_exec_t getBinaryArrayExecFuncSwitchResultType() { + auto execFunc = ScalarFunction::BinaryExecListStructFunction; + return execFunc; +} + +template +scalar_func_exec_t getScalarExecFunc(LogicalType type) { + scalar_func_exec_t execFunc; + switch (ArrayType::getChildType(&type)->getLogicalTypeID()) { + case LogicalTypeID::FLOAT: + execFunc = getBinaryArrayExecFuncSwitchResultType(); + break; + case LogicalTypeID::DOUBLE: + execFunc = getBinaryArrayExecFuncSwitchResultType(); + break; + default: + KU_UNREACHABLE; + } + return execFunc; +} + +template +std::unique_ptr arrayTemplateBindFunc( + std::string functionName, const binder::expression_vector& arguments, Function* function) { + auto leftType = arguments[0]->dataType; + auto rightType = arguments[1]->dataType; + validateArrayFunctionParameters(leftType, rightType, functionName); + ku_dynamic_cast(function)->execFunc = + getScalarExecFunc(leftType); + return std::make_unique(ArrayType::getChildType(&leftType)->copy()); +} + +template +function_set templateGetFunctionSet(const std::string& functionName) { + function_set result; + result.push_back(std::make_unique(functionName, + std::vector{ + LogicalTypeID::ARRAY, + LogicalTypeID::ARRAY, + }, + LogicalTypeID::ANY, nullptr, nullptr, + std::bind(arrayTemplateBindFunc, functionName, std::placeholders::_1, + std::placeholders::_2), + false /* isVarLength */)); + return result; +} + +function_set ArrayCosineSimilarityFunction::getFunctionSet() { + return templateGetFunctionSet(name); +} + +function_set ArrayDistanceFunction::getFunctionSet() { + return templateGetFunctionSet(name); +} + +function_set ArrayInnerProductFunction::getFunctionSet() { + return templateGetFunctionSet(name); +} + +function_set ArrayDotProductFunction::getFunctionSet() { + return templateGetFunctionSet(name); +} + +} // namespace function +} // namespace kuzu diff --git a/src/function/vector_list_functions.cpp b/src/function/vector_list_functions.cpp index 35c77e8e3b0..0fd61602d1a 100644 --- a/src/function/vector_list_functions.cpp +++ b/src/function/vector_list_functions.cpp @@ -33,7 +33,6 @@ static std::string getListFunctionIncompatibleChildrenTypeErrorMsg( void ListCreationFunction::execFunc(const std::vector>& parameters, ValueVector& result, void* /*dataPtr*/) { - KU_ASSERT(result.dataType.getLogicalTypeID() == LogicalTypeID::VAR_LIST); result.resetAuxiliaryBuffer(); for (auto selectedPos = 0u; selectedPos < result.state->selVector->selectedSize; ++selectedPos) { @@ -63,6 +62,11 @@ static LogicalType getValidLogicalType(const binder::expression_vector& expressi std::unique_ptr ListCreationFunction::bindFunc( const binder::expression_vector& arguments, Function* /*function*/) { + auto resultType = LogicalType::VAR_LIST(getChildType(arguments).copy()); + return std::make_unique(std::move(resultType)); +} + +LogicalType ListCreationFunction::getChildType(const binder::expression_vector& arguments) { // ListCreation requires all parameters to have the same type or be ANY type. The result type of // listCreation can be determined by the first non-ANY type parameter. If all parameters have // dataType ANY, then the resultType will be INT64[] (default type). @@ -82,8 +86,7 @@ std::unique_ptr ListCreationFunction::bindFunc( } } } - auto resultType = LogicalType::VAR_LIST(childType.copy()); - return std::make_unique(std::move(resultType)); + return childType; } function_set ListCreationFunction::getFunctionSet() { diff --git a/src/include/function/array/functions/array_cosine_similarity.h b/src/include/function/array/functions/array_cosine_similarity.h new file mode 100644 index 00000000000..3130eecaa3a --- /dev/null +++ b/src/include/function/array/functions/array_cosine_similarity.h @@ -0,0 +1,33 @@ +#pragma once + +#include "math.h" + +#include "common/vector/value_vector.h" + +namespace kuzu { +namespace function { + +struct ArrayCosineSimilarity { + template + static inline void operation(common::list_entry_t& left, common::list_entry_t& right, T& result, + common::ValueVector& leftVector, common::ValueVector& rightVector, + common::ValueVector& /*resultVector*/) { + auto leftElements = (T*)common::ListVector::getListValues(&leftVector, left); + auto rightElements = (T*)common::ListVector::getListValues(&rightVector, right); + T distance = 0; + T normLeft = 0; + T normRight = 0; + for (auto i = 0u; i < left.size; i++) { + auto x = leftElements[i]; + auto y = rightElements[i]; + distance += x * y; + normLeft += x * x; + normRight += y * y; + } + auto similarity = distance / (std::sqrt(normLeft) * std::sqrt(normRight)); + result = std::max(static_cast(-1), std::min(similarity, static_cast(1))); + } +}; + +} // namespace function +} // namespace kuzu diff --git a/src/include/function/array/functions/array_cross_product.h b/src/include/function/array/functions/array_cross_product.h new file mode 100644 index 00000000000..f30c1ba6340 --- /dev/null +++ b/src/include/function/array/functions/array_cross_product.h @@ -0,0 +1,24 @@ +#pragma once + +#include "common/vector/value_vector.h" + +namespace kuzu { +namespace function { + +template +struct ArrayCrossProduct { + static inline void operation(common::list_entry_t& left, common::list_entry_t& right, + common::list_entry_t& result, common::ValueVector& leftVector, + common::ValueVector& rightVector, common::ValueVector& resultVector) { + auto leftElements = (T*)common::ListVector::getListValues(&leftVector, left); + auto rightElements = (T*)common::ListVector::getListValues(&rightVector, right); + result = common::ListVector::addList(&resultVector, left.size); + auto resultElements = (T*)common::ListVector::getListValues(&resultVector, result); + resultElements[0] = leftElements[1] * rightElements[2] - leftElements[2] * rightElements[1]; + resultElements[1] = leftElements[2] * rightElements[0] - leftElements[0] * rightElements[2]; + resultElements[2] = leftElements[0] * rightElements[1] - leftElements[1] * rightElements[0]; + } +}; + +} // namespace function +} // namespace kuzu diff --git a/src/include/function/array/functions/array_distance.h b/src/include/function/array/functions/array_distance.h new file mode 100644 index 00000000000..1975343b336 --- /dev/null +++ b/src/include/function/array/functions/array_distance.h @@ -0,0 +1,27 @@ +#pragma once + +#include "math.h" + +#include "common/vector/value_vector.h" + +namespace kuzu { +namespace function { + +struct ArrayDistance { + template + static inline void operation(common::list_entry_t& left, common::list_entry_t& right, T& result, + common::ValueVector& leftVector, common::ValueVector& rightVector, + common::ValueVector& /*resultVector*/) { + auto leftElements = (T*)common::ListVector::getListValues(&leftVector, left); + auto rightElements = (T*)common::ListVector::getListValues(&rightVector, right); + result = 0; + for (auto i = 0u; i < left.size; i++) { + auto diff = leftElements[i] - rightElements[i]; + result += diff * diff; + } + result = std::sqrt(result); + } +}; + +} // namespace function +} // namespace kuzu diff --git a/src/include/function/array/functions/array_inner_product.h b/src/include/function/array/functions/array_inner_product.h new file mode 100644 index 00000000000..d1853e5cf7d --- /dev/null +++ b/src/include/function/array/functions/array_inner_product.h @@ -0,0 +1,23 @@ +#pragma once + +#include "common/vector/value_vector.h" + +namespace kuzu { +namespace function { + +struct ArrayInnerProduct { + template + static inline void operation(common::list_entry_t& left, common::list_entry_t& right, T& result, + common::ValueVector& leftVector, common::ValueVector& rightVector, + common::ValueVector& /*resultVector*/) { + auto leftElements = (T*)common::ListVector::getListValues(&leftVector, left); + auto rightElements = (T*)common::ListVector::getListValues(&rightVector, right); + result = 0; + for (auto i = 0u; i < left.size; i++) { + result += leftElements[i] * rightElements[i]; + } + } +}; + +} // namespace function +} // namespace kuzu diff --git a/src/include/function/array/vector_array_functions.h b/src/include/function/array/vector_array_functions.h new file mode 100644 index 00000000000..5a5a04bf25a --- /dev/null +++ b/src/include/function/array/vector_array_functions.h @@ -0,0 +1,49 @@ +#pragma once + +#include "function/scalar_function.h" + +namespace kuzu { +namespace function { + +struct ArrayValueFunction { + static constexpr const char* name = "ARRAY_VALUE"; + + static function_set getFunctionSet(); + static std::unique_ptr bindFunc( + const binder::expression_vector& arguments, Function* function); +}; + +struct ArrayCrossProductFunction { + static constexpr const char* name = "ARRAY_CROSS_PRODUCT"; + + static function_set getFunctionSet(); + static std::unique_ptr bindFunc( + const binder::expression_vector& arguments, Function* function); +}; + +struct ArrayCosineSimilarityFunction { + static constexpr const char* name = "ARRAY_COSINE_SIMILARITY"; + + static function_set getFunctionSet(); +}; + +struct ArrayDistanceFunction { + static constexpr const char* name = "ARRAY_DISTANCE"; + + static function_set getFunctionSet(); +}; + +struct ArrayInnerProductFunction { + static constexpr const char* name = "ARRAY_INNER_PRODUCT"; + + static function_set getFunctionSet(); +}; + +struct ArrayDotProductFunction { + static constexpr const char* name = "ARRAY_DOT_PRODUCT"; + + static function_set getFunctionSet(); +}; + +} // namespace function +} // namespace kuzu diff --git a/src/include/function/list/vector_list_functions.h b/src/include/function/list/vector_list_functions.h index aa39c62e2d9..970d3f5832d 100644 --- a/src/include/function/list/vector_list_functions.h +++ b/src/include/function/list/vector_list_functions.h @@ -220,6 +220,7 @@ struct ListCreationFunction { const binder::expression_vector& arguments, Function* function); static void execFunc(const std::vector>& parameters, common::ValueVector& result, void* /*dataPtr*/ = nullptr); + static common::LogicalType getChildType(const binder::expression_vector& arguments); }; struct ListRangeFunction { diff --git a/test/test_files/tinysnb/function/array.test b/test/test_files/tinysnb/function/array.test new file mode 100644 index 00000000000..e93ba9210d8 --- /dev/null +++ b/test/test_files/tinysnb/function/array.test @@ -0,0 +1,121 @@ +-GROUP TinySnbReadTest +-DATASET CSV tinysnb + +-- + +-CASE ArrayValue + +-LOG CreateArrayValue +-STATEMENT RETURN ARRAY_VALUE(3.2, 5.4, 7.2, 32.3) +---- 1 +[3.200000,5.400000,7.200000,32.300000] + +-LOG CreateArrayValueWithNull +-STATEMENT RETURN ARRAY_VALUE(2, 7, NULL, 256, 32, NULL) +---- 1 +[2,7,,256,32,] + +-LOG CreateEmptyArray +-STATEMENT RETURN ARRAY_VALUE() +---- 1 +[] + +-LOG CreateArrayRecursively +-STATEMENT RETURN ARRAY_VALUE(ARRAY_VALUE(3,2), ARRAY_VALUE(4,7), ARRAY_VALUE(-2,3)) +---- 1 +[[3,2],[4,7],[-2,3]] + +-CASE ArrayCrossProduct +-LOG ArrayCrossProductINT128 +-STATEMENT RETURN ARRAY_CROSS_PRODUCT(ARRAY_VALUE(to_int128(1), to_int128(2), to_int128(3)), ARRAY_VALUE(to_int128(4), to_int128(5), to_int128(6))) +---- 1 +[-3,6,-3] + +-LOG ArrayCrossProductINT64 +-STATEMENT RETURN ARRAY_CROSS_PRODUCT(ARRAY_VALUE(to_int64(148), to_int64(176), to_int64(112)), ARRAY_VALUE(to_int64(182), to_int64(187), to_int64(190))) +---- 1 +[12496,-7736,-4356] + +-LOG ArrayCrossProductINT32 +-STATEMENT RETURN ARRAY_CROSS_PRODUCT(ARRAY_VALUE(to_int32(195), to_int32(894), to_int32(539)), ARRAY_VALUE(to_int32(823), to_int32(158), to_int32(177))) +---- 1 +[73076,409082,-704952] + +-LOG ArrayCrossProductINT16 +-STATEMENT RETURN ARRAY_CROSS_PRODUCT(ARRAY_VALUE(to_int16(463), to_int16(184), to_int16(189)), ARRAY_VALUE(to_int16(94), to_int16(161), to_int16(410))) +---- 1 +[-20525,24544,-8289] + +-LOG ArrayCrossProductINT8 +-STATEMENT RETURN ARRAY_CROSS_PRODUCT(ARRAY_VALUE(to_int8(12), to_int8(7), to_int8(6)), ARRAY_VALUE(to_int8(3), to_int8(4), to_int8(5))) +---- 1 +[11,-42,27] + +-LOG ArrayCrossProductFLOAT +-STATEMENT RETURN ARRAY_CROSS_PRODUCT(ARRAY_VALUE(TO_FLOAT(3.5), TO_FLOAT(2.55), TO_FLOAT(6.2)), ARRAY_VALUE(TO_FLOAT(4.2), TO_FLOAT(7.8), TO_FLOAT(9.254))) +---- 1 +[-24.762302,-6.349003,16.590002] + +-LOG ArrayCrossProductDOUBLE +-STATEMENT RETURN ARRAY_CROSS_PRODUCT(ARRAY_VALUE(TO_DOUBLE(12.62), TO_DOUBLE(7.54), TO_DOUBLE(6.13)), ARRAY_VALUE(TO_DOUBLE(3.23), TO_DOUBLE(4.56), TO_DOUBLE(5.34))) +---- 1 +[12.310800,-47.590900,33.193000] + +-LOG ArrayCrossProductTypeError +-STATEMENT RETURN ARRAY_CROSS_PRODUCT(ARRAY_VALUE('test','test1'), ARRAY_VALUE('test2', 'test3')) +---- error +Binder exception: ARRAY_CROSS_PRODUCT can only be applied on array of floating points or integers + +-LOG ArrayCrossProductInCorrectSize +-STATEMENT RETURN ARRAY_CROSS_PRODUCT(ARRAY_VALUE(32, 54, 77), ARRAY_VALUE(31, 24)) +---- error +Binder exception: ARRAY_CROSS_PRODUCT requires both arrays to have the same element type and size of 3 + +-LOG ArrayCosineSimilarity +-STATEMENT MATCH (p:person)-[e:meets]->(p1:person) return ARRAY_COSINE_SIMILARITY(e.location, array_value(to_float(3.4), to_float(2.7))) +---- 7 +0.790048 +0.881637 +0.930004 +0.933540 +0.954293 +0.954733 +0.969880 + +-LOG ArrayCosineSimilarityWrongType +-STATEMENT MATCH (p:person) return ARRAY_COSINE_SIMILARITY(p.grades, p.grades) +---- error +Binder exception: ARRAY_COSINE_SIMILARITY requires argument type of FLOAT or DOUBLE. + +-LOG ArrayDistance +-STATEMENT MATCH (p:person)-[e:meets]->(p1:person) return ARRAY_DISTANCE(e.location, array_value(to_float(3.4), to_float(2.7))) +---- 7 +1.350593 +1.603122 +1.619197 +2.531798 +4.499111 +5.745441 +6.413268 + +-LOG ArrayInnerProduct +-STATEMENT MATCH (p:person)-[e:meets]->(p1:person) return ARRAY_INNER_PRODUCT(e.location, array_value(to_float(3.4), to_float(2.7))) +---- 7 +14.870001 +15.544000 +21.179001 +24.240002 +31.780003 +35.198002 +36.146000 + +-LOG ArrayDotProduct +-STATEMENT MATCH (p:person)-[e:meets]->(p1:person) return ARRAY_DOT_PRODUCT(e.location, array_value(to_float(5.6), to_float(2.1))) +---- 7 +18.325998 +21.910000 +24.954998 +27.719997 +31.219999 +38.164001 +51.225998