-
Notifications
You must be signed in to change notification settings - Fork 94
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #3087 from kuzudb/array-functions
Implement array functions
- Loading branch information
Showing
12 changed files
with
463 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
#include "function/array/vector_array_functions.h" | ||
|
||
#include "function/array/functions/array_cosine_similarity.h" | ||
#include "function/array/functions/array_cross_product.h" | ||
#include "function/array/functions/array_distance.h" | ||
#include "function/array/functions/array_inner_product.h" | ||
#include "function/list/vector_list_functions.h" | ||
|
||
using namespace kuzu::common; | ||
|
||
namespace kuzu { | ||
namespace function { | ||
|
||
std::unique_ptr<FunctionBindData> ArrayValueBindFunc( | ||
const binder::expression_vector& arguments, Function* /*function*/) { | ||
auto resultType = | ||
LogicalType::ARRAY(ListCreationFunction::getChildType(arguments).copy(), arguments.size()); | ||
return std::make_unique<FunctionBindData>(std::move(resultType)); | ||
} | ||
|
||
function_set ArrayValueFunction::getFunctionSet() { | ||
function_set result; | ||
result.push_back(std::make_unique<ScalarFunction>(name, | ||
std::vector<LogicalTypeID>{LogicalTypeID::ANY}, LogicalTypeID::ARRAY, | ||
ListCreationFunction::execFunc, nullptr, ArrayValueBindFunc, true /* isVarLength */)); | ||
return result; | ||
} | ||
|
||
std::unique_ptr<FunctionBindData> ArrayCrossProductBindFunc( | ||
const binder::expression_vector& arguments, Function* function) { | ||
auto leftType = arguments[0]->dataType; | ||
auto rightType = arguments[1]->dataType; | ||
if (leftType != rightType) { | ||
throw BinderException( | ||
stringFormat("{} requires both arrays to have the same element type and size of 3", | ||
ArrayCrossProductFunction::name)); | ||
} | ||
scalar_func_exec_t execFunc; | ||
switch (ArrayType::getChildType(&leftType)->getLogicalTypeID()) { | ||
case LogicalTypeID::INT128: | ||
execFunc = ScalarFunction::BinaryExecListStructFunction<list_entry_t, list_entry_t, | ||
list_entry_t, ArrayCrossProduct<int128_t>>; | ||
break; | ||
case LogicalTypeID::INT64: | ||
execFunc = ScalarFunction::BinaryExecListStructFunction<list_entry_t, list_entry_t, | ||
list_entry_t, ArrayCrossProduct<int64_t>>; | ||
break; | ||
case LogicalTypeID::INT32: | ||
execFunc = ScalarFunction::BinaryExecListStructFunction<list_entry_t, list_entry_t, | ||
list_entry_t, ArrayCrossProduct<int32_t>>; | ||
break; | ||
case LogicalTypeID::INT16: | ||
execFunc = ScalarFunction::BinaryExecListStructFunction<list_entry_t, list_entry_t, | ||
list_entry_t, ArrayCrossProduct<int16_t>>; | ||
break; | ||
case LogicalTypeID::INT8: | ||
execFunc = ScalarFunction::BinaryExecListStructFunction<list_entry_t, list_entry_t, | ||
list_entry_t, ArrayCrossProduct<int8_t>>; | ||
break; | ||
case LogicalTypeID::FLOAT: | ||
execFunc = ScalarFunction::BinaryExecListStructFunction<list_entry_t, list_entry_t, | ||
list_entry_t, ArrayCrossProduct<float>>; | ||
break; | ||
case LogicalTypeID::DOUBLE: | ||
execFunc = ScalarFunction::BinaryExecListStructFunction<list_entry_t, list_entry_t, | ||
list_entry_t, ArrayCrossProduct<double>>; | ||
break; | ||
default: | ||
throw BinderException{ | ||
stringFormat("{} can only be applied on array of floating points or integers", | ||
ArrayCrossProductFunction::name)}; | ||
} | ||
ku_dynamic_cast<Function*, ScalarFunction*>(function)->execFunc = execFunc; | ||
auto resultType = LogicalType::ARRAY( | ||
*ArrayType::getChildType(&leftType), ArrayType::getNumElements(&leftType)); | ||
return std::make_unique<FunctionBindData>(std::move(resultType)); | ||
} | ||
|
||
function_set ArrayCrossProductFunction::getFunctionSet() { | ||
function_set result; | ||
result.push_back(std::make_unique<ScalarFunction>(name, | ||
std::vector<LogicalTypeID>{ | ||
LogicalTypeID::ARRAY, | ||
LogicalTypeID::ARRAY, | ||
}, | ||
LogicalTypeID::ARRAY, nullptr, nullptr, ArrayCrossProductBindFunc, | ||
false /* isVarLength */)); | ||
return result; | ||
} | ||
|
||
static void validateArrayFunctionParameters( | ||
const LogicalType& leftType, const LogicalType& rightType, const std::string& functionName) { | ||
if (leftType != rightType) { | ||
throw BinderException( | ||
stringFormat("{} requires both arrays to have the same element type", functionName)); | ||
} | ||
if (ArrayType::getChildType(&leftType)->getLogicalTypeID() != LogicalTypeID::FLOAT && | ||
ArrayType::getChildType(&leftType)->getLogicalTypeID() != LogicalTypeID::DOUBLE) { | ||
throw BinderException( | ||
stringFormat("{} requires argument type of FLOAT or DOUBLE.", functionName)); | ||
} | ||
} | ||
|
||
template<typename OPERATION, typename RESULT> | ||
static scalar_func_exec_t getBinaryArrayExecFuncSwitchResultType() { | ||
auto execFunc = | ||
ScalarFunction::BinaryExecListStructFunction<list_entry_t, list_entry_t, RESULT, OPERATION>; | ||
return execFunc; | ||
} | ||
|
||
template<typename OPERATION> | ||
scalar_func_exec_t getScalarExecFunc(LogicalType type) { | ||
scalar_func_exec_t execFunc; | ||
switch (ArrayType::getChildType(&type)->getLogicalTypeID()) { | ||
case LogicalTypeID::FLOAT: | ||
execFunc = getBinaryArrayExecFuncSwitchResultType<OPERATION, float>(); | ||
break; | ||
case LogicalTypeID::DOUBLE: | ||
execFunc = getBinaryArrayExecFuncSwitchResultType<OPERATION, double>(); | ||
break; | ||
default: | ||
KU_UNREACHABLE; | ||
} | ||
return execFunc; | ||
} | ||
|
||
template<typename OPERATION> | ||
std::unique_ptr<FunctionBindData> arrayTemplateBindFunc( | ||
std::string functionName, const binder::expression_vector& arguments, Function* function) { | ||
auto leftType = arguments[0]->dataType; | ||
auto rightType = arguments[1]->dataType; | ||
validateArrayFunctionParameters(leftType, rightType, functionName); | ||
ku_dynamic_cast<Function*, ScalarFunction*>(function)->execFunc = | ||
getScalarExecFunc<OPERATION>(leftType); | ||
return std::make_unique<FunctionBindData>(ArrayType::getChildType(&leftType)->copy()); | ||
} | ||
|
||
template<typename OPERATION> | ||
function_set templateGetFunctionSet(const std::string& functionName) { | ||
function_set result; | ||
result.push_back(std::make_unique<ScalarFunction>(functionName, | ||
std::vector<LogicalTypeID>{ | ||
LogicalTypeID::ARRAY, | ||
LogicalTypeID::ARRAY, | ||
}, | ||
LogicalTypeID::ANY, nullptr, nullptr, | ||
std::bind(arrayTemplateBindFunc<OPERATION>, functionName, std::placeholders::_1, | ||
std::placeholders::_2), | ||
false /* isVarLength */)); | ||
return result; | ||
} | ||
|
||
function_set ArrayCosineSimilarityFunction::getFunctionSet() { | ||
return templateGetFunctionSet<ArrayCosineSimilarity>(name); | ||
} | ||
|
||
function_set ArrayDistanceFunction::getFunctionSet() { | ||
return templateGetFunctionSet<ArrayDistance>(name); | ||
} | ||
|
||
function_set ArrayInnerProductFunction::getFunctionSet() { | ||
return templateGetFunctionSet<ArrayInnerProduct>(name); | ||
} | ||
|
||
function_set ArrayDotProductFunction::getFunctionSet() { | ||
return templateGetFunctionSet<ArrayInnerProduct>(name); | ||
} | ||
|
||
} // namespace function | ||
} // namespace kuzu |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
33 changes: 33 additions & 0 deletions
33
src/include/function/array/functions/array_cosine_similarity.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
#pragma once | ||
|
||
#include "math.h" | ||
|
||
#include "common/vector/value_vector.h" | ||
|
||
namespace kuzu { | ||
namespace function { | ||
|
||
struct ArrayCosineSimilarity { | ||
template<typename T> | ||
static inline void operation(common::list_entry_t& left, common::list_entry_t& right, T& result, | ||
common::ValueVector& leftVector, common::ValueVector& rightVector, | ||
common::ValueVector& /*resultVector*/) { | ||
auto leftElements = (T*)common::ListVector::getListValues(&leftVector, left); | ||
auto rightElements = (T*)common::ListVector::getListValues(&rightVector, right); | ||
T distance = 0; | ||
T normLeft = 0; | ||
T normRight = 0; | ||
for (auto i = 0u; i < left.size; i++) { | ||
auto x = leftElements[i]; | ||
auto y = rightElements[i]; | ||
distance += x * y; | ||
normLeft += x * x; | ||
normRight += y * y; | ||
} | ||
auto similarity = distance / (std::sqrt(normLeft) * std::sqrt(normRight)); | ||
result = std::max(static_cast<T>(-1), std::min(similarity, static_cast<T>(1))); | ||
} | ||
}; | ||
|
||
} // namespace function | ||
} // namespace kuzu |
24 changes: 24 additions & 0 deletions
24
src/include/function/array/functions/array_cross_product.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
#pragma once | ||
|
||
#include "common/vector/value_vector.h" | ||
|
||
namespace kuzu { | ||
namespace function { | ||
|
||
template<typename T> | ||
struct ArrayCrossProduct { | ||
static inline void operation(common::list_entry_t& left, common::list_entry_t& right, | ||
common::list_entry_t& result, common::ValueVector& leftVector, | ||
common::ValueVector& rightVector, common::ValueVector& resultVector) { | ||
auto leftElements = (T*)common::ListVector::getListValues(&leftVector, left); | ||
auto rightElements = (T*)common::ListVector::getListValues(&rightVector, right); | ||
result = common::ListVector::addList(&resultVector, left.size); | ||
auto resultElements = (T*)common::ListVector::getListValues(&resultVector, result); | ||
resultElements[0] = leftElements[1] * rightElements[2] - leftElements[2] * rightElements[1]; | ||
resultElements[1] = leftElements[2] * rightElements[0] - leftElements[0] * rightElements[2]; | ||
resultElements[2] = leftElements[0] * rightElements[1] - leftElements[1] * rightElements[0]; | ||
} | ||
}; | ||
|
||
} // namespace function | ||
} // namespace kuzu |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
#pragma once | ||
|
||
#include "math.h" | ||
|
||
#include "common/vector/value_vector.h" | ||
|
||
namespace kuzu { | ||
namespace function { | ||
|
||
struct ArrayDistance { | ||
template<typename T> | ||
static inline void operation(common::list_entry_t& left, common::list_entry_t& right, T& result, | ||
common::ValueVector& leftVector, common::ValueVector& rightVector, | ||
common::ValueVector& /*resultVector*/) { | ||
auto leftElements = (T*)common::ListVector::getListValues(&leftVector, left); | ||
auto rightElements = (T*)common::ListVector::getListValues(&rightVector, right); | ||
result = 0; | ||
for (auto i = 0u; i < left.size; i++) { | ||
auto diff = leftElements[i] - rightElements[i]; | ||
result += diff * diff; | ||
} | ||
result = std::sqrt(result); | ||
} | ||
}; | ||
|
||
} // namespace function | ||
} // namespace kuzu |
23 changes: 23 additions & 0 deletions
23
src/include/function/array/functions/array_inner_product.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
#pragma once | ||
|
||
#include "common/vector/value_vector.h" | ||
|
||
namespace kuzu { | ||
namespace function { | ||
|
||
struct ArrayInnerProduct { | ||
template<typename T> | ||
static inline void operation(common::list_entry_t& left, common::list_entry_t& right, T& result, | ||
common::ValueVector& leftVector, common::ValueVector& rightVector, | ||
common::ValueVector& /*resultVector*/) { | ||
auto leftElements = (T*)common::ListVector::getListValues(&leftVector, left); | ||
auto rightElements = (T*)common::ListVector::getListValues(&rightVector, right); | ||
result = 0; | ||
for (auto i = 0u; i < left.size; i++) { | ||
result += leftElements[i] * rightElements[i]; | ||
} | ||
} | ||
}; | ||
|
||
} // namespace function | ||
} // namespace kuzu |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
#pragma once | ||
|
||
#include "function/function.h" | ||
|
||
namespace kuzu { | ||
namespace function { | ||
|
||
struct ArrayValueFunction { | ||
static constexpr const char* name = "ARRAY_VALUE"; | ||
|
||
static function_set getFunctionSet(); | ||
}; | ||
|
||
struct ArrayCrossProductFunction { | ||
static constexpr const char* name = "ARRAY_CROSS_PRODUCT"; | ||
|
||
static function_set getFunctionSet(); | ||
}; | ||
|
||
struct ArrayCosineSimilarityFunction { | ||
static constexpr const char* name = "ARRAY_COSINE_SIMILARITY"; | ||
|
||
static function_set getFunctionSet(); | ||
}; | ||
|
||
struct ArrayDistanceFunction { | ||
static constexpr const char* name = "ARRAY_DISTANCE"; | ||
|
||
static function_set getFunctionSet(); | ||
}; | ||
|
||
struct ArrayInnerProductFunction { | ||
static constexpr const char* name = "ARRAY_INNER_PRODUCT"; | ||
|
||
static function_set getFunctionSet(); | ||
}; | ||
|
||
struct ArrayDotProductFunction { | ||
static constexpr const char* name = "ARRAY_DOT_PRODUCT"; | ||
|
||
static function_set getFunctionSet(); | ||
}; | ||
|
||
} // namespace function | ||
} // namespace kuzu |
Oops, something went wrong.