Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement array functions #3087

Merged
merged 1 commit into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/function/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ add_library(kuzu_function
function_collection.cpp
scalar_macro_function.cpp
vector_arithmetic_functions.cpp
vector_array_functions.cpp
vector_boolean_functions.cpp
vector_cast_functions.cpp
vector_date_functions.cpp
Expand Down
5 changes: 4 additions & 1 deletion src/function/function_collection.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "function/function_collection.h"

#include "function/arithmetic/vector_arithmetic_functions.h"
#include "function/array/vector_array_functions.h"

namespace kuzu {
namespace function {
Expand Down Expand Up @@ -35,7 +36,9 @@ FunctionCollection* FunctionCollection::getFunctions() {
SCALAR_FUNCTION_ALIAS(PowerFunction), SCALAR_FUNCTION(RadiansFunction),
SCALAR_FUNCTION(RoundFunction), SCALAR_FUNCTION(SinFunction), SCALAR_FUNCTION(SignFunction),
SCALAR_FUNCTION(SqrtFunction), SCALAR_FUNCTION(TanFunction),

SCALAR_FUNCTION(ArrayValueFunction), SCALAR_FUNCTION(ArrayCrossProductFunction),
SCALAR_FUNCTION(ArrayCosineSimilarityFunction), SCALAR_FUNCTION(ArrayDistanceFunction),
SCALAR_FUNCTION(ArrayInnerProductFunction), SCALAR_FUNCTION(ArrayDotProductFunction),
// End of array
FINAL_FUNCTION};

Expand Down
170 changes: 170 additions & 0 deletions src/function/vector_array_functions.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
#include "function/array/vector_array_functions.h"

#include "function/array/functions/array_cosine_similarity.h"
#include "function/array/functions/array_cross_product.h"
#include "function/array/functions/array_distance.h"
#include "function/array/functions/array_inner_product.h"
#include "function/list/vector_list_functions.h"

using namespace kuzu::common;

namespace kuzu {
namespace function {

std::unique_ptr<FunctionBindData> ArrayValueBindFunc(
const binder::expression_vector& arguments, Function* /*function*/) {
auto resultType =
LogicalType::ARRAY(ListCreationFunction::getChildType(arguments).copy(), arguments.size());
return std::make_unique<FunctionBindData>(std::move(resultType));
}

function_set ArrayValueFunction::getFunctionSet() {
function_set result;
result.push_back(std::make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{LogicalTypeID::ANY}, LogicalTypeID::ARRAY,
ListCreationFunction::execFunc, nullptr, ArrayValueBindFunc, true /* isVarLength */));
return result;
}

Check warning on line 27 in src/function/vector_array_functions.cpp

View check run for this annotation

Codecov / codecov/patch

src/function/vector_array_functions.cpp#L27

Added line #L27 was not covered by tests

std::unique_ptr<FunctionBindData> ArrayCrossProductBindFunc(
const binder::expression_vector& arguments, Function* function) {
auto leftType = arguments[0]->dataType;
auto rightType = arguments[1]->dataType;
if (leftType != rightType) {
throw BinderException(
stringFormat("{} requires both arrays to have the same element type and size of 3",
ArrayCrossProductFunction::name));
}
scalar_func_exec_t execFunc;
switch (ArrayType::getChildType(&leftType)->getLogicalTypeID()) {
case LogicalTypeID::INT128:
execFunc = ScalarFunction::BinaryExecListStructFunction<list_entry_t, list_entry_t,
list_entry_t, ArrayCrossProduct<int128_t>>;
break;
case LogicalTypeID::INT64:
execFunc = ScalarFunction::BinaryExecListStructFunction<list_entry_t, list_entry_t,
list_entry_t, ArrayCrossProduct<int64_t>>;
break;
case LogicalTypeID::INT32:
execFunc = ScalarFunction::BinaryExecListStructFunction<list_entry_t, list_entry_t,
list_entry_t, ArrayCrossProduct<int32_t>>;
break;
case LogicalTypeID::INT16:
execFunc = ScalarFunction::BinaryExecListStructFunction<list_entry_t, list_entry_t,
list_entry_t, ArrayCrossProduct<int16_t>>;
break;
case LogicalTypeID::INT8:
execFunc = ScalarFunction::BinaryExecListStructFunction<list_entry_t, list_entry_t,
list_entry_t, ArrayCrossProduct<int8_t>>;
break;
case LogicalTypeID::FLOAT:
execFunc = ScalarFunction::BinaryExecListStructFunction<list_entry_t, list_entry_t,
list_entry_t, ArrayCrossProduct<float>>;
break;
case LogicalTypeID::DOUBLE:
execFunc = ScalarFunction::BinaryExecListStructFunction<list_entry_t, list_entry_t,
list_entry_t, ArrayCrossProduct<double>>;
break;
default:
throw BinderException{
stringFormat("{} can only be applied on array of floating points or integers",
ArrayCrossProductFunction::name)};
}
ku_dynamic_cast<Function*, ScalarFunction*>(function)->execFunc = execFunc;
auto resultType = LogicalType::ARRAY(
*ArrayType::getChildType(&leftType), ArrayType::getNumElements(&leftType));
return std::make_unique<FunctionBindData>(std::move(resultType));
}

function_set ArrayCrossProductFunction::getFunctionSet() {
function_set result;
result.push_back(std::make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{
LogicalTypeID::ARRAY,
LogicalTypeID::ARRAY,
},
LogicalTypeID::ARRAY, nullptr, nullptr, ArrayCrossProductBindFunc,
false /* isVarLength */));
return result;
}

Check warning on line 89 in src/function/vector_array_functions.cpp

View check run for this annotation

Codecov / codecov/patch

src/function/vector_array_functions.cpp#L89

Added line #L89 was not covered by tests

static void validateArrayFunctionParameters(
const LogicalType& leftType, const LogicalType& rightType, const std::string& functionName) {
if (leftType != rightType) {
throw BinderException(
stringFormat("{} requires both arrays to have the same element type", functionName));

Check warning on line 95 in src/function/vector_array_functions.cpp

View check run for this annotation

Codecov / codecov/patch

src/function/vector_array_functions.cpp#L94-L95

Added lines #L94 - L95 were not covered by tests
}
if (ArrayType::getChildType(&leftType)->getLogicalTypeID() != LogicalTypeID::FLOAT &&
ArrayType::getChildType(&leftType)->getLogicalTypeID() != LogicalTypeID::DOUBLE) {
throw BinderException(
stringFormat("{} requires argument type of FLOAT or DOUBLE.", functionName));
}
}

template<typename OPERATION, typename RESULT>
static scalar_func_exec_t getBinaryArrayExecFuncSwitchResultType() {
auto execFunc =
ScalarFunction::BinaryExecListStructFunction<list_entry_t, list_entry_t, RESULT, OPERATION>;
return execFunc;
}

template<typename OPERATION>
scalar_func_exec_t getScalarExecFunc(LogicalType type) {
scalar_func_exec_t execFunc;
switch (ArrayType::getChildType(&type)->getLogicalTypeID()) {
case LogicalTypeID::FLOAT:
execFunc = getBinaryArrayExecFuncSwitchResultType<OPERATION, float>();
break;
case LogicalTypeID::DOUBLE:
execFunc = getBinaryArrayExecFuncSwitchResultType<OPERATION, double>();
break;
default:

Check warning on line 121 in src/function/vector_array_functions.cpp

View check run for this annotation

Codecov / codecov/patch

src/function/vector_array_functions.cpp#L118-L121

Added lines #L118 - L121 were not covered by tests
KU_UNREACHABLE;
}
return execFunc;
}

template<typename OPERATION>
std::unique_ptr<FunctionBindData> arrayTemplateBindFunc(
std::string functionName, const binder::expression_vector& arguments, Function* function) {
auto leftType = arguments[0]->dataType;
auto rightType = arguments[1]->dataType;
validateArrayFunctionParameters(leftType, rightType, functionName);
ku_dynamic_cast<Function*, ScalarFunction*>(function)->execFunc =
getScalarExecFunc<OPERATION>(leftType);
return std::make_unique<FunctionBindData>(ArrayType::getChildType(&leftType)->copy());
}

template<typename OPERATION>
function_set templateGetFunctionSet(const std::string& functionName) {
function_set result;
result.push_back(std::make_unique<ScalarFunction>(functionName,
std::vector<LogicalTypeID>{
LogicalTypeID::ARRAY,
LogicalTypeID::ARRAY,
},
LogicalTypeID::ANY, nullptr, nullptr,
std::bind(arrayTemplateBindFunc<OPERATION>, functionName, std::placeholders::_1,
std::placeholders::_2),
false /* isVarLength */));
return result;
}

Check warning on line 151 in src/function/vector_array_functions.cpp

View check run for this annotation

Codecov / codecov/patch

src/function/vector_array_functions.cpp#L151

Added line #L151 was not covered by tests

function_set ArrayCosineSimilarityFunction::getFunctionSet() {
return templateGetFunctionSet<ArrayCosineSimilarity>(name);
}

function_set ArrayDistanceFunction::getFunctionSet() {
return templateGetFunctionSet<ArrayDistance>(name);
}

function_set ArrayInnerProductFunction::getFunctionSet() {
return templateGetFunctionSet<ArrayInnerProduct>(name);
}

function_set ArrayDotProductFunction::getFunctionSet() {
return templateGetFunctionSet<ArrayInnerProduct>(name);
}

} // namespace function
} // namespace kuzu
12 changes: 9 additions & 3 deletions src/function/vector_list_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ static std::string getListFunctionIncompatibleChildrenTypeErrorMsg(

void ListCreationFunction::execFunc(const std::vector<std::shared_ptr<ValueVector>>& parameters,
ValueVector& result, void* /*dataPtr*/) {
KU_ASSERT(result.dataType.getLogicalTypeID() == LogicalTypeID::VAR_LIST);
result.resetAuxiliaryBuffer();
for (auto selectedPos = 0u; selectedPos < result.state->selVector->selectedSize;
++selectedPos) {
Expand Down Expand Up @@ -63,6 +62,11 @@ static LogicalType getValidLogicalType(const binder::expression_vector& expressi

std::unique_ptr<FunctionBindData> ListCreationFunction::bindFunc(
const binder::expression_vector& arguments, Function* /*function*/) {
auto resultType = LogicalType::VAR_LIST(getChildType(arguments).copy());
return std::make_unique<FunctionBindData>(std::move(resultType));
}

LogicalType ListCreationFunction::getChildType(const binder::expression_vector& arguments) {
// ListCreation requires all parameters to have the same type or be ANY type. The result type of
// listCreation can be determined by the first non-ANY type parameter. If all parameters have
// dataType ANY, then the resultType will be INT64[] (default type).
Expand All @@ -82,8 +86,7 @@ std::unique_ptr<FunctionBindData> ListCreationFunction::bindFunc(
}
}
}
auto resultType = LogicalType::VAR_LIST(childType.copy());
return std::make_unique<FunctionBindData>(std::move(resultType));
return childType;
}

function_set ListCreationFunction::getFunctionSet() {
Expand Down Expand Up @@ -232,6 +235,9 @@ function_set ListExtractFunction::getFunctionSet() {
LogicalTypeID::STRING,
ScalarFunction::BinaryExecFunction<ku_string_t, int64_t, ku_string_t, ListExtract>,
false /* isVarlength */));
result.push_back(std::make_unique<ScalarFunction>(LIST_EXTRACT_FUNC_NAME,
std::vector<LogicalTypeID>{LogicalTypeID::ARRAY, LogicalTypeID::INT64}, LogicalTypeID::ANY,
nullptr, nullptr, bindFunc, false /* isVarlength*/));
return result;
}

Expand Down
33 changes: 33 additions & 0 deletions src/include/function/array/functions/array_cosine_similarity.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#pragma once

#include "math.h"

#include "common/vector/value_vector.h"

namespace kuzu {
namespace function {

struct ArrayCosineSimilarity {
template<typename T>
static inline void operation(common::list_entry_t& left, common::list_entry_t& right, T& result,
common::ValueVector& leftVector, common::ValueVector& rightVector,
common::ValueVector& /*resultVector*/) {
auto leftElements = (T*)common::ListVector::getListValues(&leftVector, left);
auto rightElements = (T*)common::ListVector::getListValues(&rightVector, right);
T distance = 0;
T normLeft = 0;
T normRight = 0;
for (auto i = 0u; i < left.size; i++) {
auto x = leftElements[i];
auto y = rightElements[i];
distance += x * y;
normLeft += x * x;
normRight += y * y;
}
auto similarity = distance / (std::sqrt(normLeft) * std::sqrt(normRight));
result = std::max(static_cast<T>(-1), std::min(similarity, static_cast<T>(1)));
}
};

} // namespace function
} // namespace kuzu
24 changes: 24 additions & 0 deletions src/include/function/array/functions/array_cross_product.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#pragma once

#include "common/vector/value_vector.h"

namespace kuzu {
namespace function {

template<typename T>
struct ArrayCrossProduct {
static inline void operation(common::list_entry_t& left, common::list_entry_t& right,
common::list_entry_t& result, common::ValueVector& leftVector,
common::ValueVector& rightVector, common::ValueVector& resultVector) {
auto leftElements = (T*)common::ListVector::getListValues(&leftVector, left);
auto rightElements = (T*)common::ListVector::getListValues(&rightVector, right);
result = common::ListVector::addList(&resultVector, left.size);
auto resultElements = (T*)common::ListVector::getListValues(&resultVector, result);
resultElements[0] = leftElements[1] * rightElements[2] - leftElements[2] * rightElements[1];
resultElements[1] = leftElements[2] * rightElements[0] - leftElements[0] * rightElements[2];
resultElements[2] = leftElements[0] * rightElements[1] - leftElements[1] * rightElements[0];
}
};

} // namespace function
} // namespace kuzu
27 changes: 27 additions & 0 deletions src/include/function/array/functions/array_distance.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#pragma once

#include "math.h"

#include "common/vector/value_vector.h"

namespace kuzu {
namespace function {

struct ArrayDistance {
template<typename T>
static inline void operation(common::list_entry_t& left, common::list_entry_t& right, T& result,
common::ValueVector& leftVector, common::ValueVector& rightVector,
common::ValueVector& /*resultVector*/) {
auto leftElements = (T*)common::ListVector::getListValues(&leftVector, left);
auto rightElements = (T*)common::ListVector::getListValues(&rightVector, right);
result = 0;
for (auto i = 0u; i < left.size; i++) {
auto diff = leftElements[i] - rightElements[i];
result += diff * diff;
}
result = std::sqrt(result);
}
};

} // namespace function
} // namespace kuzu
23 changes: 23 additions & 0 deletions src/include/function/array/functions/array_inner_product.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#pragma once

#include "common/vector/value_vector.h"

namespace kuzu {
namespace function {

struct ArrayInnerProduct {
template<typename T>
static inline void operation(common::list_entry_t& left, common::list_entry_t& right, T& result,
common::ValueVector& leftVector, common::ValueVector& rightVector,
common::ValueVector& /*resultVector*/) {
auto leftElements = (T*)common::ListVector::getListValues(&leftVector, left);
auto rightElements = (T*)common::ListVector::getListValues(&rightVector, right);
result = 0;
for (auto i = 0u; i < left.size; i++) {
result += leftElements[i] * rightElements[i];
}
}
};

} // namespace function
} // namespace kuzu
45 changes: 45 additions & 0 deletions src/include/function/array/vector_array_functions.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#pragma once

#include "function/function.h"

namespace kuzu {
namespace function {

struct ArrayValueFunction {
static constexpr const char* name = "ARRAY_VALUE";

static function_set getFunctionSet();
};

struct ArrayCrossProductFunction {
static constexpr const char* name = "ARRAY_CROSS_PRODUCT";

static function_set getFunctionSet();
};

struct ArrayCosineSimilarityFunction {
static constexpr const char* name = "ARRAY_COSINE_SIMILARITY";

static function_set getFunctionSet();
};

struct ArrayDistanceFunction {
static constexpr const char* name = "ARRAY_DISTANCE";

static function_set getFunctionSet();
};

struct ArrayInnerProductFunction {
static constexpr const char* name = "ARRAY_INNER_PRODUCT";

static function_set getFunctionSet();
};

struct ArrayDotProductFunction {
static constexpr const char* name = "ARRAY_DOT_PRODUCT";

static function_set getFunctionSet();
};

} // namespace function
} // namespace kuzu
Loading
Loading