Skip to content

Commit

Permalink
Fix issue 3248 (#3394)
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU authored and manh9203 committed Apr 29, 2024
1 parent cd25709 commit 9a6b1c9
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 13 deletions.
56 changes: 45 additions & 11 deletions src/function/array/array_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,19 +72,47 @@ function_set ArrayCrossProductFunction::getFunctionSet() {
return result;
}

static void validateArrayFunctionParameters(const LogicalType& leftType,
const LogicalType& rightType, const std::string& functionName) {
if (leftType != rightType) {
throw BinderException(
stringFormat("{} requires both arrays to have the same element type", functionName));
static LogicalType getChildType(const LogicalType& type) {
switch (type.getLogicalTypeID()) {
case LogicalTypeID::ARRAY:
return *ArrayType::getChildType(&type);
case LogicalTypeID::LIST:
return *ListType::getChildType(&type);
// LCOV_EXCL_START
default:
throw BinderException(stringFormat(
"Cannot retrieve child type of type {}. LIST or ARRAY is expected.", type.toString()));
// LCOV_EXCL_STOP
}
if (ArrayType::getChildType(&leftType)->getLogicalTypeID() != LogicalTypeID::FLOAT &&
ArrayType::getChildType(&leftType)->getLogicalTypeID() != LogicalTypeID::DOUBLE) {
}

static void validateChildType(const LogicalType& type, const std::string& functionName) {
switch (type.getLogicalTypeID()) {
case LogicalTypeID::DOUBLE:
case LogicalTypeID::FLOAT:
return;
default:
throw BinderException(
stringFormat("{} requires argument type of FLOAT or DOUBLE.", functionName));
stringFormat("{} requires argument type to be FLOAT[] or DOUBLE[].", functionName));
}
}

static LogicalType validateArrayFunctionParameters(const LogicalType& leftType,
const LogicalType& rightType, const std::string& functionName) {
auto leftChildType = getChildType(leftType);
auto rightChildType = getChildType(rightType);
validateChildType(leftChildType, functionName);
validateChildType(rightChildType, functionName);
if (leftType.getLogicalTypeID() == common::LogicalTypeID::ARRAY) {
return leftType;
} else if (rightType.getLogicalTypeID() == common::LogicalTypeID::ARRAY) {
return rightType;
}
throw BinderException(
stringFormat("{} requires at least one argument to be ARRAY but all parameters are LIST.",
functionName));
}

template<typename OPERATION, typename RESULT>
static scalar_func_exec_t getBinaryArrayExecFuncSwitchResultType() {
auto execFunc =
Expand Down Expand Up @@ -113,9 +141,15 @@ std::unique_ptr<FunctionBindData> arrayTemplateBindFunc(std::string functionName
const binder::expression_vector& arguments, Function* function) {
auto leftType = arguments[0]->dataType;
auto rightType = arguments[1]->dataType;
validateArrayFunctionParameters(leftType, rightType, functionName);
function->ptrCast<ScalarFunction>()->execFunc = getScalarExecFunc<OPERATION>(leftType);
return FunctionBindData::getSimpleBindData(arguments, *ArrayType::getChildType(&leftType));
auto paramType = validateArrayFunctionParameters(leftType, rightType, functionName);
function->ptrCast<ScalarFunction>()->execFunc = getScalarExecFunc<OPERATION>(paramType);
auto bindData = std::make_unique<FunctionBindData>(ArrayType::getChildType(&paramType)->copy());
std::vector<LogicalType> paramTypes;
for (auto& _ : arguments) {
(void)_;
bindData->paramTypes.push_back(paramType);
}
return bindData;
}

template<typename OPERATION>
Expand Down
33 changes: 33 additions & 0 deletions test/test_files/function/cast.test
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,36 @@ Runtime exception: Null value key is not allowed in map.
-STATEMENT MATCH (a:T0) RETURN a.a, a.b, a.c, a.d;
---- 1
0|[1,2]|[1.000000,2.000000]|[3,5]

-LOG 3248
-STATEMENT CREATE NODE TABLE Item(id UINT64, item STRING, price DOUBLE, vector DOUBLE[2], PRIMARY KEY (id));
---- ok
-STATEMENT CREATE (a:Item {id: 1, item: 'apple', price: 2.0, vector: [3.1, 4.1]});
---- ok
-STATEMENT MERGE (b:Item {id: 2, item: 'banana', price: 1.0, vector: [5.9, 26.5]});
---- ok
-STATEMENT MATCH (a:Item) RETURN a.item, a.price, array_cosine_similarity(a.vector, [6.0, 25.0]) AS sim ORDER BY sim DESC
---- 2
apple|2.000000|0.916383
banana|1.000000|0.999864
-STATEMENT MATCH (a:Item) RETURN a.item, a.price, array_cosine_similarity([6.0, 25.0], a.vector) AS sim ORDER BY sim DESC
---- 2
apple|2.000000|0.916383
banana|1.000000|0.999864
-STATEMENT MATCH (a:Item) RETURN a.item, a.price, array_distance([6.0, 25.0], a.vector) AS sim ORDER BY sim DESC
---- 2
apple|2.000000|21.100237
banana|1.000000|1.503330
-STATEMENT MATCH (a:Item) RETURN a.item, a.price, array_inner_product([6.0, 25.0], a.vector) AS sim ORDER BY sim DESC
---- 2
apple|2.000000|121.100000
banana|1.000000|697.900000
-STATEMENT RETURN array_cosine_similarity([1, 2], [3.0, 4.0]);
---- error
Binder exception: ARRAY_COSINE_SIMILARITY requires argument type to be FLOAT[] or DOUBLE[].
-STATEMENT RETURN array_cosine_similarity([1.0, 2.0], [3.0, 4.0]);
---- error
Binder exception: ARRAY_COSINE_SIMILARITY requires at least one argument to be ARRAY but all parameters are LIST.
-STATEMENT RETURN array_cosine_similarity(cast([1.0, 2.0], 'DOUBLE[2]'), [3.0, 4.0]);
---- 1
0.983870
3 changes: 1 addition & 2 deletions test/test_files/tinysnb/function/array.test
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
---- 1
[[3,2],[4,7],[-2,3]]

-CASE ArrayCrossProduct
-LOG ArrayCrossProductINT128
-STATEMENT RETURN ARRAY_CROSS_PRODUCT(ARRAY_VALUE(to_int128(1), to_int128(2), to_int128(3)), ARRAY_VALUE(to_int128(4), to_int128(5), to_int128(6)))
---- 1
Expand Down Expand Up @@ -85,7 +84,7 @@ Binder exception: ARRAY_CROSS_PRODUCT requires both arrays to have the same elem
-LOG ArrayCosineSimilarityWrongType
-STATEMENT MATCH (p:person) return ARRAY_COSINE_SIMILARITY(p.grades, p.grades)
---- error
Binder exception: ARRAY_COSINE_SIMILARITY requires argument type of FLOAT or DOUBLE.
Binder exception: ARRAY_COSINE_SIMILARITY requires argument type to be FLOAT[] or DOUBLE[].

-LOG ArrayDistance
-STATEMENT MATCH (p:person)-[e:meets]->(p1:person) return round(ARRAY_DISTANCE(e.location, array_value(to_float(3.4), to_float(2.7))),2)
Expand Down

0 comments on commit 9a6b1c9

Please sign in to comment.