Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue 3248 #3394

Merged
merged 1 commit into from
Apr 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 45 additions & 11 deletions src/function/array/array_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,19 +72,47 @@ function_set ArrayCrossProductFunction::getFunctionSet() {
return result;
}

static void validateArrayFunctionParameters(const LogicalType& leftType,
const LogicalType& rightType, const std::string& functionName) {
if (leftType != rightType) {
throw BinderException(
stringFormat("{} requires both arrays to have the same element type", functionName));
static LogicalType getChildType(const LogicalType& type) {
switch (type.getLogicalTypeID()) {
case LogicalTypeID::ARRAY:
return *ArrayType::getChildType(&type);
case LogicalTypeID::LIST:
return *ListType::getChildType(&type);
// LCOV_EXCL_START
default:
throw BinderException(stringFormat(
"Cannot retrieve child type of type {}. LIST or ARRAY is expected.", type.toString()));
// LCOV_EXCL_STOP
}
if (ArrayType::getChildType(&leftType)->getLogicalTypeID() != LogicalTypeID::FLOAT &&
ArrayType::getChildType(&leftType)->getLogicalTypeID() != LogicalTypeID::DOUBLE) {
}

static void validateChildType(const LogicalType& type, const std::string& functionName) {
switch (type.getLogicalTypeID()) {
case LogicalTypeID::DOUBLE:
case LogicalTypeID::FLOAT:
return;
default:
throw BinderException(
stringFormat("{} requires argument type of FLOAT or DOUBLE.", functionName));
stringFormat("{} requires argument type to be FLOAT[] or DOUBLE[].", functionName));
}
}

static LogicalType validateArrayFunctionParameters(const LogicalType& leftType,
const LogicalType& rightType, const std::string& functionName) {
auto leftChildType = getChildType(leftType);
auto rightChildType = getChildType(rightType);
validateChildType(leftChildType, functionName);
validateChildType(rightChildType, functionName);
if (leftType.getLogicalTypeID() == common::LogicalTypeID::ARRAY) {
return leftType;
} else if (rightType.getLogicalTypeID() == common::LogicalTypeID::ARRAY) {
return rightType;
}
throw BinderException(
stringFormat("{} requires at least one argument to be ARRAY but all parameters are LIST.",
functionName));
}

template<typename OPERATION, typename RESULT>
static scalar_func_exec_t getBinaryArrayExecFuncSwitchResultType() {
auto execFunc =
Expand Down Expand Up @@ -113,9 +141,15 @@ std::unique_ptr<FunctionBindData> arrayTemplateBindFunc(std::string functionName
const binder::expression_vector& arguments, Function* function) {
auto leftType = arguments[0]->dataType;
auto rightType = arguments[1]->dataType;
validateArrayFunctionParameters(leftType, rightType, functionName);
function->ptrCast<ScalarFunction>()->execFunc = getScalarExecFunc<OPERATION>(leftType);
return FunctionBindData::getSimpleBindData(arguments, *ArrayType::getChildType(&leftType));
auto paramType = validateArrayFunctionParameters(leftType, rightType, functionName);
function->ptrCast<ScalarFunction>()->execFunc = getScalarExecFunc<OPERATION>(paramType);
auto bindData = std::make_unique<FunctionBindData>(ArrayType::getChildType(&paramType)->copy());
std::vector<LogicalType> paramTypes;
for (auto& _ : arguments) {
(void)_;
bindData->paramTypes.push_back(paramType);
}
return bindData;
}

template<typename OPERATION>
Expand Down
33 changes: 33 additions & 0 deletions test/test_files/function/cast.test
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,36 @@ Runtime exception: Null value key is not allowed in map.
-STATEMENT MATCH (a:T0) RETURN a.a, a.b, a.c, a.d;
---- 1
0|[1,2]|[1.000000,2.000000]|[3,5]

-LOG 3248
-STATEMENT CREATE NODE TABLE Item(id UINT64, item STRING, price DOUBLE, vector DOUBLE[2], PRIMARY KEY (id));
---- ok
-STATEMENT CREATE (a:Item {id: 1, item: 'apple', price: 2.0, vector: [3.1, 4.1]});
---- ok
-STATEMENT MERGE (b:Item {id: 2, item: 'banana', price: 1.0, vector: [5.9, 26.5]});
---- ok
-STATEMENT MATCH (a:Item) RETURN a.item, a.price, array_cosine_similarity(a.vector, [6.0, 25.0]) AS sim ORDER BY sim DESC
---- 2
apple|2.000000|0.916383
banana|1.000000|0.999864
-STATEMENT MATCH (a:Item) RETURN a.item, a.price, array_cosine_similarity([6.0, 25.0], a.vector) AS sim ORDER BY sim DESC
---- 2
apple|2.000000|0.916383
banana|1.000000|0.999864
-STATEMENT MATCH (a:Item) RETURN a.item, a.price, array_distance([6.0, 25.0], a.vector) AS sim ORDER BY sim DESC
---- 2
apple|2.000000|21.100237
banana|1.000000|1.503330
-STATEMENT MATCH (a:Item) RETURN a.item, a.price, array_inner_product([6.0, 25.0], a.vector) AS sim ORDER BY sim DESC
---- 2
apple|2.000000|121.100000
banana|1.000000|697.900000
-STATEMENT RETURN array_cosine_similarity([1, 2], [3.0, 4.0]);
---- error
Binder exception: ARRAY_COSINE_SIMILARITY requires argument type to be FLOAT[] or DOUBLE[].
-STATEMENT RETURN array_cosine_similarity([1.0, 2.0], [3.0, 4.0]);
---- error
Binder exception: ARRAY_COSINE_SIMILARITY requires at least one argument to be ARRAY but all parameters are LIST.
-STATEMENT RETURN array_cosine_similarity(cast([1.0, 2.0], 'DOUBLE[2]'), [3.0, 4.0]);
---- 1
0.983870
3 changes: 1 addition & 2 deletions test/test_files/tinysnb/function/array.test
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
---- 1
[[3,2],[4,7],[-2,3]]

-CASE ArrayCrossProduct
-LOG ArrayCrossProductINT128
-STATEMENT RETURN ARRAY_CROSS_PRODUCT(ARRAY_VALUE(to_int128(1), to_int128(2), to_int128(3)), ARRAY_VALUE(to_int128(4), to_int128(5), to_int128(6)))
---- 1
Expand Down Expand Up @@ -85,7 +84,7 @@ Binder exception: ARRAY_CROSS_PRODUCT requires both arrays to have the same elem
-LOG ArrayCosineSimilarityWrongType
-STATEMENT MATCH (p:person) return ARRAY_COSINE_SIMILARITY(p.grades, p.grades)
---- error
Binder exception: ARRAY_COSINE_SIMILARITY requires argument type of FLOAT or DOUBLE.
Binder exception: ARRAY_COSINE_SIMILARITY requires argument type to be FLOAT[] or DOUBLE[].

-LOG ArrayDistance
-STATEMENT MATCH (p:person)-[e:meets]->(p1:person) return round(ARRAY_DISTANCE(e.location, array_value(to_float(3.4), to_float(2.7))),2)
Expand Down
Loading