Skip to content

Commit

Permalink
Implement array functions
Browse files Browse the repository at this point in the history
  • Loading branch information
acquamarin committed Mar 19, 2024
1 parent e7c6d73 commit 51944dc
Show file tree
Hide file tree
Showing 13 changed files with 458 additions and 3 deletions.
1 change: 1 addition & 0 deletions src/function/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ add_library(kuzu_function
find_function.cpp
scalar_macro_function.cpp
vector_arithmetic_functions.cpp
vector_array_functions.cpp
vector_boolean_functions.cpp
vector_cast_functions.cpp
vector_date_functions.cpp
Expand Down
17 changes: 17 additions & 0 deletions src/function/built_in_function_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "function/aggregate/count_star.h"
#include "function/aggregate_function.h"
#include "function/arithmetic/vector_arithmetic_functions.h"
#include "function/array/vector_array_functions.h"
#include "function/blob/vector_blob_functions.h"
#include "function/cast/vector_cast_functions.h"
#include "function/comparison/vector_comparison_functions.h"
Expand Down Expand Up @@ -71,6 +72,7 @@ void BuiltInFunctionsUtils::registerScalarFunctions(CatalogSet* catalogSet) {
registerBlobFunctions(catalogSet);
registerUUIDFunctions(catalogSet);
registerRdfFunctions(catalogSet);
registerArrayFunctions(catalogSet);
}

void BuiltInFunctionsUtils::registerAggregateFunctions(CatalogSet* catalogSet) {
Expand Down Expand Up @@ -939,6 +941,21 @@ void BuiltInFunctionsUtils::registerRdfFunctions(CatalogSet* catalogSet) {
VALIDATE_PREDICATE_FUNC_NAME, ValidatePredicateFunction::getFunctionSet()));
}

void BuiltInFunctionsUtils::registerArrayFunctions(CatalogSet* catalogSet) {
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
ARRAY_VALUE_FUNC_NAME, ArrayValue::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
ARRAY_CROSS_PRODUCT_FUNC_NAME, ArrayCrossProductFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
ARRAY_COSINE_SIMILARITY_FUNC_NAME, ArrayCosineSimilarityFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
ARRAY_DISTANCE_FUNC_NAME, ArrayDistanceFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
ARRAY_INNER_PRODUCT_FUNC_NAME, ArrayInnerProductFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<ScalarFunctionCatalogEntry>(
ARRAY_DOT_PRODUCT_FUNC_NAME, ArrayInnerProductFunction::getFunctionSet()));
}

void BuiltInFunctionsUtils::registerCountStar(CatalogSet* catalogSet) {
function_set functionSet;
functionSet.push_back(std::make_unique<AggregateFunction>(COUNT_STAR_FUNC_NAME,
Expand Down
165 changes: 165 additions & 0 deletions src/function/vector_array_functions.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
#include "function/array/vector_array_functions.h"

#include "function/array/functions/array_cosine_similarity.h"
#include "function/array/functions/array_cross_product.h"
#include "function/array/functions/array_distance.h"
#include "function/array/functions/array_inner_product.h"
#include "function/list/vector_list_functions.h"

using namespace kuzu::common;

namespace kuzu {
namespace function {

function_set ArrayValue::getFunctionSet() {
function_set result;
result.push_back(std::make_unique<ScalarFunction>(ARRAY_VALUE_FUNC_NAME,
std::vector<LogicalTypeID>{LogicalTypeID::ANY}, LogicalTypeID::ARRAY,
ListCreationFunction::execFunc, nullptr, bindFunc, true /* isVarLength */));
return result;
}

Check warning on line 20 in src/function/vector_array_functions.cpp

View check run for this annotation

Codecov / codecov/patch

src/function/vector_array_functions.cpp#L20

Added line #L20 was not covered by tests

std::unique_ptr<FunctionBindData> ArrayValue::bindFunc(
const binder::expression_vector& arguments, Function* /*function*/) {
auto resultType =
LogicalType::ARRAY(ListCreationFunction::getChildType(arguments).copy(), arguments.size());
return std::make_unique<FunctionBindData>(std::move(resultType));
}

function_set ArrayCrossProductFunction::getFunctionSet() {
function_set result;
result.push_back(std::make_unique<ScalarFunction>(ARRAY_CROSS_PRODUCT_FUNC_NAME,
std::vector<LogicalTypeID>{
LogicalTypeID::ARRAY,
LogicalTypeID::ARRAY,
},
LogicalTypeID::ARRAY, nullptr, nullptr, bindFunc, false /* isVarLength */));
return result;
}

Check warning on line 38 in src/function/vector_array_functions.cpp

View check run for this annotation

Codecov / codecov/patch

src/function/vector_array_functions.cpp#L38

Added line #L38 was not covered by tests

static void validateArrayFunctionParameters(
const LogicalType& leftType, const LogicalType& rightType, const std::string& functionName) {
if (leftType != rightType) {
throw common::BinderException(common::stringFormat(
"{} requires both arrays to have the same element type", functionName));

Check warning on line 44 in src/function/vector_array_functions.cpp

View check run for this annotation

Codecov / codecov/patch

src/function/vector_array_functions.cpp#L43-L44

Added lines #L43 - L44 were not covered by tests
}
if (ArrayType::getChildType(&leftType)->getLogicalTypeID() != LogicalTypeID::FLOAT &&
ArrayType::getChildType(&leftType)->getLogicalTypeID() != LogicalTypeID::DOUBLE) {
throw common::BinderException(
common::stringFormat("{} requires argument type of FLOAT or DOUBLE.", functionName));
}
}

std::unique_ptr<FunctionBindData> ArrayCrossProductFunction::bindFunc(
const binder::expression_vector& arguments, Function* function) {
auto leftType = arguments[0]->dataType;
auto rightType = arguments[1]->dataType;
if (leftType != rightType) {
throw common::BinderException(common::stringFormat(
"{} requires both arrays to have the same element type and size of 3",
ARRAY_CROSS_PRODUCT_FUNC_NAME));
}
scalar_func_exec_t execFunc;
switch (ArrayType::getChildType(&leftType)->getLogicalTypeID()) {
case LogicalTypeID::INT128:
execFunc = ScalarFunction::BinaryExecListStructFunction<common::list_entry_t,
common::list_entry_t, common::list_entry_t, ArrayCrossProduct<int128_t>>;
break;
case LogicalTypeID::INT64:
execFunc = ScalarFunction::BinaryExecListStructFunction<common::list_entry_t,
common::list_entry_t, common::list_entry_t, ArrayCrossProduct<int64_t>>;
break;
case LogicalTypeID::INT32:
execFunc = ScalarFunction::BinaryExecListStructFunction<common::list_entry_t,
common::list_entry_t, common::list_entry_t, ArrayCrossProduct<int32_t>>;
break;
case LogicalTypeID::INT16:
execFunc = ScalarFunction::BinaryExecListStructFunction<common::list_entry_t,
common::list_entry_t, common::list_entry_t, ArrayCrossProduct<int16_t>>;
break;
case LogicalTypeID::INT8:
execFunc = ScalarFunction::BinaryExecListStructFunction<common::list_entry_t,
common::list_entry_t, common::list_entry_t, ArrayCrossProduct<int8_t>>;
break;
case LogicalTypeID::FLOAT:
execFunc = ScalarFunction::BinaryExecListStructFunction<common::list_entry_t,
common::list_entry_t, common::list_entry_t, ArrayCrossProduct<float>>;
break;
case LogicalTypeID::DOUBLE:
execFunc = ScalarFunction::BinaryExecListStructFunction<common::list_entry_t,
common::list_entry_t, common::list_entry_t, ArrayCrossProduct<double>>;
break;
default:
throw BinderException{
stringFormat("{} can only be applied on array of floating points or integers",
ARRAY_CROSS_PRODUCT_FUNC_NAME)};
}
ku_dynamic_cast<Function*, ScalarFunction*>(function)->execFunc = execFunc;
auto resultType = LogicalType::ARRAY(
*ArrayType::getChildType(&leftType), ArrayType::getNumElements(&leftType));
return std::make_unique<FunctionBindData>(std::move(resultType));
}

template<typename OPERATION, typename RESULT>
static scalar_func_exec_t getBinaryArrayExecFuncSwitchResultType() {
auto execFunc = ScalarFunction::BinaryExecListStructFunction<common::list_entry_t,
common::list_entry_t, RESULT, OPERATION>;
return execFunc;
}

template<typename OPERATION>
scalar_func_exec_t getScalarExecFunc(LogicalType type) {
scalar_func_exec_t execFunc;
switch (ArrayType::getChildType(&type)->getLogicalTypeID()) {
case LogicalTypeID::FLOAT:
execFunc = getBinaryArrayExecFuncSwitchResultType<OPERATION, float>();
break;
case LogicalTypeID::DOUBLE:
execFunc = getBinaryArrayExecFuncSwitchResultType<OPERATION, double>();
break;
default:

Check warning on line 120 in src/function/vector_array_functions.cpp

View check run for this annotation

Codecov / codecov/patch

src/function/vector_array_functions.cpp#L117-L120

Added lines #L117 - L120 were not covered by tests
KU_UNREACHABLE;
}
return execFunc;
}

template<typename OPERATION>
std::unique_ptr<FunctionBindData> arrayTemplateBindFunc(
std::string functionName, const binder::expression_vector& arguments, Function* function) {
auto leftType = arguments[0]->dataType;
auto rightType = arguments[1]->dataType;
validateArrayFunctionParameters(leftType, rightType, functionName);
ku_dynamic_cast<Function*, ScalarFunction*>(function)->execFunc =
getScalarExecFunc<OPERATION>(leftType);
return std::make_unique<FunctionBindData>(ArrayType::getChildType(&leftType)->copy());
}

template<typename OPERATION>
function_set templateGetFunctionSet(const std::string& functionName) {
function_set result;
result.push_back(std::make_unique<ScalarFunction>(functionName,
std::vector<LogicalTypeID>{
LogicalTypeID::ARRAY,
LogicalTypeID::ARRAY,
},
LogicalTypeID::ANY, nullptr, nullptr,
std::bind(arrayTemplateBindFunc<OPERATION>, functionName, std::placeholders::_1,
std::placeholders::_2),
false /* isVarLength */));
return result;
}

Check warning on line 150 in src/function/vector_array_functions.cpp

View check run for this annotation

Codecov / codecov/patch

src/function/vector_array_functions.cpp#L150

Added line #L150 was not covered by tests

function_set ArrayCosineSimilarityFunction::getFunctionSet() {
return templateGetFunctionSet<ArrayCosineSimilarity>(ARRAY_COSINE_SIMILARITY_FUNC_NAME);
}

function_set ArrayDistanceFunction::getFunctionSet() {
return templateGetFunctionSet<ArrayDistance>(ARRAY_DISTANCE_FUNC_NAME);
}

function_set ArrayInnerProductFunction::getFunctionSet() {
return templateGetFunctionSet<ArrayInnerProduct>(ARRAY_INNER_PRODUCT_FUNC_NAME);
}

} // namespace function
} // namespace kuzu
9 changes: 6 additions & 3 deletions src/function/vector_list_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ static std::string getListFunctionIncompatibleChildrenTypeErrorMsg(

void ListCreationFunction::execFunc(const std::vector<std::shared_ptr<ValueVector>>& parameters,
ValueVector& result, void* /*dataPtr*/) {
KU_ASSERT(result.dataType.getLogicalTypeID() == LogicalTypeID::VAR_LIST);
result.resetAuxiliaryBuffer();
for (auto selectedPos = 0u; selectedPos < result.state->selVector->selectedSize;
++selectedPos) {
Expand Down Expand Up @@ -63,6 +62,11 @@ static LogicalType getValidLogicalType(const binder::expression_vector& expressi

std::unique_ptr<FunctionBindData> ListCreationFunction::bindFunc(
const binder::expression_vector& arguments, Function* /*function*/) {
auto resultType = LogicalType::VAR_LIST(getChildType(arguments).copy());
return std::make_unique<FunctionBindData>(std::move(resultType));
}

LogicalType ListCreationFunction::getChildType(const binder::expression_vector& arguments) {
// ListCreation requires all parameters to have the same type or be ANY type. The result type of
// listCreation can be determined by the first non-ANY type parameter. If all parameters have
// dataType ANY, then the resultType will be INT64[] (default type).
Expand All @@ -82,8 +86,7 @@ std::unique_ptr<FunctionBindData> ListCreationFunction::bindFunc(
}
}
}
auto resultType = LogicalType::VAR_LIST(childType.copy());
return std::make_unique<FunctionBindData>(std::move(resultType));
return childType;
}

function_set ListCreationFunction::getFunctionSet() {
Expand Down
8 changes: 8 additions & 0 deletions src/include/common/enums/expression_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,14 @@ const char* const LIST_UNIQUE_FUNC_NAME = "LIST_UNIQUE";
const char* const LIST_ANY_VALUE_FUNC_NAME = "LIST_ANY_VALUE";
const char* const LIST_REVERSE_FUNC_NAME = "LIST_REVERSE";

// array
const char* const ARRAY_VALUE_FUNC_NAME = "ARRAY_VALUE";
const char* const ARRAY_CROSS_PRODUCT_FUNC_NAME = "ARRAY_CROSS_PRODUCT";
const char* const ARRAY_COSINE_SIMILARITY_FUNC_NAME = "ARRAY_COSINE_SIMILARITY";
const char* const ARRAY_DISTANCE_FUNC_NAME = "ARRAY_DISTANCE";
const char* const ARRAY_INNER_PRODUCT_FUNC_NAME = "ARRAY_INNER_PRODUCT";
const char* const ARRAY_DOT_PRODUCT_FUNC_NAME = "ARRAY_DOT_PRODUCT";

// struct
const char* const STRUCT_PACK_FUNC_NAME = "STRUCT_PACK";
const char* const STRUCT_EXTRACT_FUNC_NAME = "STRUCT_EXTRACT";
Expand Down
33 changes: 33 additions & 0 deletions src/include/function/array/functions/array_cosine_similarity.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#pragma once

#include "math.h"

#include "common/vector/value_vector.h"

namespace kuzu {
namespace function {

struct ArrayCosineSimilarity {
template<typename T>
static inline void operation(common::list_entry_t& left, common::list_entry_t& right, T& result,
common::ValueVector& leftVector, common::ValueVector& rightVector,
common::ValueVector& /*resultVector*/) {
auto leftElements = (T*)common::ListVector::getListValues(&leftVector, left);
auto rightElements = (T*)common::ListVector::getListValues(&rightVector, right);
T distance = 0;
T normLeft = 0;
T normRight = 0;
for (auto i = 0u; i < left.size; i++) {
auto x = leftElements[i];
auto y = rightElements[i];
distance += x * y;
normLeft += x * x;
normRight += y * y;
}
auto similarity = distance / (std::sqrt(normLeft) * std::sqrt(normRight));
result = std::max(static_cast<T>(-1), std::min(similarity, static_cast<T>(1)));
}
};

} // namespace function
} // namespace kuzu
24 changes: 24 additions & 0 deletions src/include/function/array/functions/array_cross_product.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#pragma once

#include "common/vector/value_vector.h"

namespace kuzu {
namespace function {

template<typename T>
struct ArrayCrossProduct {
static inline void operation(common::list_entry_t& left, common::list_entry_t& right,
common::list_entry_t& result, common::ValueVector& leftVector,
common::ValueVector& rightVector, common::ValueVector& resultVector) {
auto leftElements = (T*)common::ListVector::getListValues(&leftVector, left);
auto rightElements = (T*)common::ListVector::getListValues(&rightVector, right);
result = common::ListVector::addList(&resultVector, left.size);
auto resultElements = (T*)common::ListVector::getListValues(&resultVector, result);
resultElements[0] = leftElements[1] * rightElements[2] - leftElements[2] * rightElements[1];
resultElements[1] = leftElements[2] * rightElements[0] - leftElements[0] * rightElements[2];
resultElements[2] = leftElements[0] * rightElements[1] - leftElements[1] * rightElements[0];
}
};

} // namespace function
} // namespace kuzu
25 changes: 25 additions & 0 deletions src/include/function/array/functions/array_distance.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#pragma once

#include "common/vector/value_vector.h"

namespace kuzu {
namespace function {

struct ArrayDistance {
template<typename T>
static inline void operation(common::list_entry_t& left, common::list_entry_t& right, T& result,
common::ValueVector& leftVector, common::ValueVector& rightVector,
common::ValueVector& /*resultVector*/) {
auto leftElements = (T*)common::ListVector::getListValues(&leftVector, left);
auto rightElements = (T*)common::ListVector::getListValues(&rightVector, right);
result = 0;
for (auto i = 0u; i < left.size; i++) {
auto diff = leftElements[i] - rightElements[i];
result += diff * diff;
}
result = std::sqrt(result);
}
};

} // namespace function
} // namespace kuzu
23 changes: 23 additions & 0 deletions src/include/function/array/functions/array_inner_product.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#pragma once

#include "common/vector/value_vector.h"

namespace kuzu {
namespace function {

struct ArrayInnerProduct {
template<typename T>
static inline void operation(common::list_entry_t& left, common::list_entry_t& right, T& result,
common::ValueVector& leftVector, common::ValueVector& rightVector,
common::ValueVector& /*resultVector*/) {
auto leftElements = (T*)common::ListVector::getListValues(&leftVector, left);
auto rightElements = (T*)common::ListVector::getListValues(&rightVector, right);
result = 0;
for (auto i = 0u; i < left.size; i++) {
result += leftElements[i] * rightElements[i];
}
}
};

} // namespace function
} // namespace kuzu
33 changes: 33 additions & 0 deletions src/include/function/array/vector_array_functions.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#pragma once

#include "function/scalar_function.h"

namespace kuzu {
namespace function {

struct ArrayValue {
static function_set getFunctionSet();
static std::unique_ptr<FunctionBindData> bindFunc(
const binder::expression_vector& arguments, Function* function);
};

struct ArrayCrossProductFunction {
static function_set getFunctionSet();
static std::unique_ptr<FunctionBindData> bindFunc(
const binder::expression_vector& arguments, Function* function);
};

struct ArrayCosineSimilarityFunction {
static function_set getFunctionSet();
};

struct ArrayDistanceFunction {
static function_set getFunctionSet();
};

struct ArrayInnerProductFunction {
static function_set getFunctionSet();
};

} // namespace function
} // namespace kuzu
Loading

0 comments on commit 51944dc

Please sign in to comment.