Skip to content

Commit

Permalink
Implement array functions
Browse files Browse the repository at this point in the history
  • Loading branch information
acquamarin committed Mar 19, 2024
1 parent 8fa40d6 commit ff051b2
Show file tree
Hide file tree
Showing 12 changed files with 457 additions and 4 deletions.
1 change: 1 addition & 0 deletions src/function/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ add_library(kuzu_function
function_collection.cpp
scalar_macro_function.cpp
vector_arithmetic_functions.cpp
vector_array_functions.cpp
vector_boolean_functions.cpp
vector_cast_functions.cpp
vector_date_functions.cpp
Expand Down
1 change: 1 addition & 0 deletions src/function/built_in_function_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "function/aggregate/count_star.h"
#include "function/aggregate_function.h"
#include "function/arithmetic/vector_arithmetic_functions.h"
#include "function/array/vector_array_functions.h"
#include "function/blob/vector_blob_functions.h"
#include "function/cast/vector_cast_functions.h"
#include "function/comparison/vector_comparison_functions.h"
Expand Down
5 changes: 4 additions & 1 deletion src/function/function_collection.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "function/function_collection.h"

#include "function/arithmetic/vector_arithmetic_functions.h"
#include "function/array/vector_array_functions.h"

namespace kuzu {
namespace function {
Expand Down Expand Up @@ -35,7 +36,9 @@ FunctionCollection* FunctionCollection::getFunctions() {
SCALAR_FUNCTION_ALIAS(PowerFunction), SCALAR_FUNCTION(RadiansFunction),
SCALAR_FUNCTION(RoundFunction), SCALAR_FUNCTION(SinFunction), SCALAR_FUNCTION(SignFunction),
SCALAR_FUNCTION(SqrtFunction), SCALAR_FUNCTION(TanFunction),

SCALAR_FUNCTION(ArrayValueFunction), SCALAR_FUNCTION(ArrayCrossProductFunction),
SCALAR_FUNCTION(ArrayCosineSimilarityFunction), SCALAR_FUNCTION(ArrayDistanceFunction),
SCALAR_FUNCTION(ArrayInnerProductFunction), SCALAR_FUNCTION(ArrayDotProductFunction),
// End of array
FINAL_FUNCTION};

Expand Down
167 changes: 167 additions & 0 deletions src/function/vector_array_functions.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
#include "function/array/vector_array_functions.h"

#include "function/array/functions/array_cosine_similarity.h"
#include "function/array/functions/array_cross_product.h"
#include "function/array/functions/array_distance.h"
#include "function/array/functions/array_inner_product.h"
#include "function/list/vector_list_functions.h"

using namespace kuzu::common;

namespace kuzu {
namespace function {

function_set ArrayValueFunction::getFunctionSet() {
function_set result;
result.push_back(std::make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{LogicalTypeID::ANY}, LogicalTypeID::ARRAY,
ListCreationFunction::execFunc, nullptr, bindFunc, true /* isVarLength */));
return result;
}

std::unique_ptr<FunctionBindData> ArrayValueFunction::bindFunc(
const binder::expression_vector& arguments, Function* /*function*/) {
auto resultType =
LogicalType::ARRAY(ListCreationFunction::getChildType(arguments).copy(), arguments.size());
return std::make_unique<FunctionBindData>(std::move(resultType));
}

function_set ArrayCrossProductFunction::getFunctionSet() {
function_set result;
result.push_back(std::make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{
LogicalTypeID::ARRAY,
LogicalTypeID::ARRAY,
},
LogicalTypeID::ARRAY, nullptr, nullptr, bindFunc, false /* isVarLength */));
return result;
}

static void validateArrayFunctionParameters(
const LogicalType& leftType, const LogicalType& rightType, const std::string& functionName) {
if (leftType != rightType) {
throw common::BinderException(common::stringFormat(
"{} requires both arrays to have the same element type", functionName));
}
if (ArrayType::getChildType(&leftType)->getLogicalTypeID() != LogicalTypeID::FLOAT &&
ArrayType::getChildType(&leftType)->getLogicalTypeID() != LogicalTypeID::DOUBLE) {
throw common::BinderException(
common::stringFormat("{} requires argument type of FLOAT or DOUBLE.", functionName));
}
}

std::unique_ptr<FunctionBindData> ArrayCrossProductFunction::bindFunc(
const binder::expression_vector& arguments, Function* function) {
auto leftType = arguments[0]->dataType;
auto rightType = arguments[1]->dataType;
if (leftType != rightType) {
throw common::BinderException(common::stringFormat(
"{} requires both arrays to have the same element type and size of 3", name));
}
scalar_func_exec_t execFunc;
switch (ArrayType::getChildType(&leftType)->getLogicalTypeID()) {
case LogicalTypeID::INT128:
execFunc = ScalarFunction::BinaryExecListStructFunction<common::list_entry_t,
common::list_entry_t, common::list_entry_t, ArrayCrossProduct<int128_t>>;
break;
case LogicalTypeID::INT64:
execFunc = ScalarFunction::BinaryExecListStructFunction<common::list_entry_t,
common::list_entry_t, common::list_entry_t, ArrayCrossProduct<int64_t>>;
break;
case LogicalTypeID::INT32:
execFunc = ScalarFunction::BinaryExecListStructFunction<common::list_entry_t,
common::list_entry_t, common::list_entry_t, ArrayCrossProduct<int32_t>>;
break;
case LogicalTypeID::INT16:
execFunc = ScalarFunction::BinaryExecListStructFunction<common::list_entry_t,
common::list_entry_t, common::list_entry_t, ArrayCrossProduct<int16_t>>;
break;
case LogicalTypeID::INT8:
execFunc = ScalarFunction::BinaryExecListStructFunction<common::list_entry_t,
common::list_entry_t, common::list_entry_t, ArrayCrossProduct<int8_t>>;
break;
case LogicalTypeID::FLOAT:
execFunc = ScalarFunction::BinaryExecListStructFunction<common::list_entry_t,
common::list_entry_t, common::list_entry_t, ArrayCrossProduct<float>>;
break;
case LogicalTypeID::DOUBLE:
execFunc = ScalarFunction::BinaryExecListStructFunction<common::list_entry_t,
common::list_entry_t, common::list_entry_t, ArrayCrossProduct<double>>;
break;
default:
throw BinderException{
stringFormat("{} can only be applied on array of floating points or integers", name)};
}
ku_dynamic_cast<Function*, ScalarFunction*>(function)->execFunc = execFunc;
auto resultType = LogicalType::ARRAY(
*ArrayType::getChildType(&leftType), ArrayType::getNumElements(&leftType));
return std::make_unique<FunctionBindData>(std::move(resultType));
}

template<typename OPERATION, typename RESULT>
static scalar_func_exec_t getBinaryArrayExecFuncSwitchResultType() {
auto execFunc = ScalarFunction::BinaryExecListStructFunction<common::list_entry_t,
common::list_entry_t, RESULT, OPERATION>;
return execFunc;
}

template<typename OPERATION>
scalar_func_exec_t getScalarExecFunc(LogicalType type) {
scalar_func_exec_t execFunc;
switch (ArrayType::getChildType(&type)->getLogicalTypeID()) {
case LogicalTypeID::FLOAT:
execFunc = getBinaryArrayExecFuncSwitchResultType<OPERATION, float>();
break;
case LogicalTypeID::DOUBLE:
execFunc = getBinaryArrayExecFuncSwitchResultType<OPERATION, double>();
break;
default:
KU_UNREACHABLE;
}
return execFunc;
}

template<typename OPERATION>
std::unique_ptr<FunctionBindData> arrayTemplateBindFunc(
std::string functionName, const binder::expression_vector& arguments, Function* function) {
auto leftType = arguments[0]->dataType;
auto rightType = arguments[1]->dataType;
validateArrayFunctionParameters(leftType, rightType, functionName);
ku_dynamic_cast<Function*, ScalarFunction*>(function)->execFunc =
getScalarExecFunc<OPERATION>(leftType);
return std::make_unique<FunctionBindData>(ArrayType::getChildType(&leftType)->copy());
}

template<typename OPERATION>
function_set templateGetFunctionSet(const std::string& functionName) {
function_set result;
result.push_back(std::make_unique<ScalarFunction>(functionName,
std::vector<LogicalTypeID>{
LogicalTypeID::ARRAY,
LogicalTypeID::ARRAY,
},
LogicalTypeID::ANY, nullptr, nullptr,
std::bind(arrayTemplateBindFunc<OPERATION>, functionName, std::placeholders::_1,
std::placeholders::_2),
false /* isVarLength */));
return result;
}

function_set ArrayCosineSimilarityFunction::getFunctionSet() {
return templateGetFunctionSet<ArrayCosineSimilarity>(name);
}

function_set ArrayDistanceFunction::getFunctionSet() {
return templateGetFunctionSet<ArrayDistance>(name);
}

function_set ArrayInnerProductFunction::getFunctionSet() {
return templateGetFunctionSet<ArrayInnerProduct>(name);
}

function_set ArrayDotProductFunction::getFunctionSet() {
return templateGetFunctionSet<ArrayInnerProduct>(name);
}

} // namespace function
} // namespace kuzu
9 changes: 6 additions & 3 deletions src/function/vector_list_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ static std::string getListFunctionIncompatibleChildrenTypeErrorMsg(

void ListCreationFunction::execFunc(const std::vector<std::shared_ptr<ValueVector>>& parameters,
ValueVector& result, void* /*dataPtr*/) {
KU_ASSERT(result.dataType.getLogicalTypeID() == LogicalTypeID::VAR_LIST);
result.resetAuxiliaryBuffer();
for (auto selectedPos = 0u; selectedPos < result.state->selVector->selectedSize;
++selectedPos) {
Expand Down Expand Up @@ -63,6 +62,11 @@ static LogicalType getValidLogicalType(const binder::expression_vector& expressi

std::unique_ptr<FunctionBindData> ListCreationFunction::bindFunc(
const binder::expression_vector& arguments, Function* /*function*/) {
auto resultType = LogicalType::VAR_LIST(getChildType(arguments).copy());
return std::make_unique<FunctionBindData>(std::move(resultType));
}

LogicalType ListCreationFunction::getChildType(const binder::expression_vector& arguments) {
// ListCreation requires all parameters to have the same type or be ANY type. The result type of
// listCreation can be determined by the first non-ANY type parameter. If all parameters have
// dataType ANY, then the resultType will be INT64[] (default type).
Expand All @@ -82,8 +86,7 @@ std::unique_ptr<FunctionBindData> ListCreationFunction::bindFunc(
}
}
}
auto resultType = LogicalType::VAR_LIST(childType.copy());
return std::make_unique<FunctionBindData>(std::move(resultType));
return childType;
}

function_set ListCreationFunction::getFunctionSet() {
Expand Down
33 changes: 33 additions & 0 deletions src/include/function/array/functions/array_cosine_similarity.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#pragma once

#include "math.h"

#include "common/vector/value_vector.h"

namespace kuzu {
namespace function {

struct ArrayCosineSimilarity {
template<typename T>
static inline void operation(common::list_entry_t& left, common::list_entry_t& right, T& result,
common::ValueVector& leftVector, common::ValueVector& rightVector,
common::ValueVector& /*resultVector*/) {
auto leftElements = (T*)common::ListVector::getListValues(&leftVector, left);
auto rightElements = (T*)common::ListVector::getListValues(&rightVector, right);
T distance = 0;
T normLeft = 0;
T normRight = 0;
for (auto i = 0u; i < left.size; i++) {
auto x = leftElements[i];
auto y = rightElements[i];
distance += x * y;
normLeft += x * x;
normRight += y * y;
}
auto similarity = distance / (std::sqrt(normLeft) * std::sqrt(normRight));
result = std::max(static_cast<T>(-1), std::min(similarity, static_cast<T>(1)));
}
};

} // namespace function
} // namespace kuzu
24 changes: 24 additions & 0 deletions src/include/function/array/functions/array_cross_product.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#pragma once

#include "common/vector/value_vector.h"

namespace kuzu {
namespace function {

template<typename T>
struct ArrayCrossProduct {
static inline void operation(common::list_entry_t& left, common::list_entry_t& right,
common::list_entry_t& result, common::ValueVector& leftVector,
common::ValueVector& rightVector, common::ValueVector& resultVector) {
auto leftElements = (T*)common::ListVector::getListValues(&leftVector, left);
auto rightElements = (T*)common::ListVector::getListValues(&rightVector, right);
result = common::ListVector::addList(&resultVector, left.size);
auto resultElements = (T*)common::ListVector::getListValues(&resultVector, result);
resultElements[0] = leftElements[1] * rightElements[2] - leftElements[2] * rightElements[1];
resultElements[1] = leftElements[2] * rightElements[0] - leftElements[0] * rightElements[2];
resultElements[2] = leftElements[0] * rightElements[1] - leftElements[1] * rightElements[0];
}
};

} // namespace function
} // namespace kuzu
27 changes: 27 additions & 0 deletions src/include/function/array/functions/array_distance.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#pragma once

#include "math.h"

#include "common/vector/value_vector.h"

namespace kuzu {
namespace function {

struct ArrayDistance {
template<typename T>
static inline void operation(common::list_entry_t& left, common::list_entry_t& right, T& result,
common::ValueVector& leftVector, common::ValueVector& rightVector,
common::ValueVector& /*resultVector*/) {
auto leftElements = (T*)common::ListVector::getListValues(&leftVector, left);
auto rightElements = (T*)common::ListVector::getListValues(&rightVector, right);
result = 0;
for (auto i = 0u; i < left.size; i++) {
auto diff = leftElements[i] - rightElements[i];
result += diff * diff;
}
result = std::sqrt(result);
}
};

} // namespace function
} // namespace kuzu
23 changes: 23 additions & 0 deletions src/include/function/array/functions/array_inner_product.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#pragma once

#include "common/vector/value_vector.h"

namespace kuzu {
namespace function {

struct ArrayInnerProduct {
template<typename T>
static inline void operation(common::list_entry_t& left, common::list_entry_t& right, T& result,
common::ValueVector& leftVector, common::ValueVector& rightVector,
common::ValueVector& /*resultVector*/) {
auto leftElements = (T*)common::ListVector::getListValues(&leftVector, left);
auto rightElements = (T*)common::ListVector::getListValues(&rightVector, right);
result = 0;
for (auto i = 0u; i < left.size; i++) {
result += leftElements[i] * rightElements[i];
}
}
};

} // namespace function
} // namespace kuzu
49 changes: 49 additions & 0 deletions src/include/function/array/vector_array_functions.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#pragma once

#include "function/scalar_function.h"

namespace kuzu {
namespace function {

struct ArrayValueFunction {
static constexpr const char* name = "ARRAY_VALUE";

static function_set getFunctionSet();
static std::unique_ptr<FunctionBindData> bindFunc(
const binder::expression_vector& arguments, Function* function);
};

struct ArrayCrossProductFunction {
static constexpr const char* name = "ARRAY_CROSS_PRODUCT";

static function_set getFunctionSet();
static std::unique_ptr<FunctionBindData> bindFunc(
const binder::expression_vector& arguments, Function* function);
};

struct ArrayCosineSimilarityFunction {
static constexpr const char* name = "ARRAY_COSINE_SIMILARITY";

static function_set getFunctionSet();
};

struct ArrayDistanceFunction {
static constexpr const char* name = "ARRAY_DISTANCE";

static function_set getFunctionSet();
};

struct ArrayInnerProductFunction {
static constexpr const char* name = "ARRAY_INNER_PRODUCT";

static function_set getFunctionSet();
};

struct ArrayDotProductFunction {
static constexpr const char* name = "ARRAY_DOT_PRODUCT";

static function_set getFunctionSet();
};

} // namespace function
} // namespace kuzu
Loading

0 comments on commit ff051b2

Please sign in to comment.