From b491af2e022cc00b947996b3abdf264473e11665 Mon Sep 17 00:00:00 2001 From: Maxwell <49460053+mxwli@users.noreply.github.com> Date: Wed, 24 Apr 2024 17:06:38 -0400 Subject: [PATCH] add implicit casting from list to array (#3375) add tests add castcost add cast cost for list->array and array->list skip failed test --- src/function/built_in_function_utils.cpp | 22 +++++++++++++++++++ src/function/vector_cast_functions.cpp | 9 ++++++++ .../function/built_in_function_utils.h | 4 ++++ .../common/arrayimplicitcasting.test | 22 +++++++++++++++++++ 4 files changed, 57 insertions(+) diff --git a/src/function/built_in_function_utils.cpp b/src/function/built_in_function_utils.cpp index dc7edc9770..849d7a4f70 100644 --- a/src/function/built_in_function_utils.cpp +++ b/src/function/built_in_function_utils.cpp @@ -149,6 +149,10 @@ uint32_t BuiltInFunctionsUtils::getCastCost(LogicalTypeID inputTypeID, LogicalTy // currently don't allow timestamp to other timestamp types // When we implement this in the future, revise tryGetMaxLogicalTypeID return castTimestamp(targetTypeID); + case LogicalTypeID::LIST: + return castList(targetTypeID); + case LogicalTypeID::ARRAY: + return castArray(targetTypeID); default: return UNDEFINED_CAST_COST; } @@ -388,6 +392,24 @@ uint32_t BuiltInFunctionsUtils::castFromRDFVariant(LogicalTypeID inputTypeID) { } } +uint32_t BuiltInFunctionsUtils::castList(LogicalTypeID targetTypeID) { + switch (targetTypeID) { + case LogicalTypeID::ARRAY: + return getTargetTypeCost(targetTypeID); + default: + return UNDEFINED_CAST_COST; + } +} + +uint32_t BuiltInFunctionsUtils::castArray(LogicalTypeID targetTypeID) { + switch (targetTypeID) { + case LogicalTypeID::LIST: + return getTargetTypeCost(targetTypeID); + default: + return UNDEFINED_CAST_COST; + } +} + // When there is multiple candidates functions, e.g. double + int and double + double for input // "1.5 + parameter", we prefer the one without any implicit casting i.e. double + double. // Additionally, we prefer function with string parameter because string is most permissive and can diff --git a/src/function/vector_cast_functions.cpp b/src/function/vector_cast_functions.cpp index 726fe17ab1..66fce84c87 100644 --- a/src/function/vector_cast_functions.cpp +++ b/src/function/vector_cast_functions.cpp @@ -138,6 +138,11 @@ static bool hasImplicitCastArrayToList(const LogicalType& srcType, const Logical *ListType::getChildType(&dstType)); } +static bool hasImplicitCastListToArray(const LogicalType& srcType, const LogicalType& dstType) { + return CastFunction::hasImplicitCast(*ListType::getChildType(&srcType), + *ArrayType::getChildType(&dstType)); +} + static bool hasImplicitCastStruct(const LogicalType& srcType, const LogicalType& dstType) { auto srcFields = StructType::getFields(&srcType), dstFields = StructType::getFields(&dstType); if (srcFields.size() != dstFields.size()) { @@ -178,6 +183,10 @@ bool CastFunction::hasImplicitCast(const LogicalType& srcType, const LogicalType dstType.getLogicalTypeID() == LogicalTypeID::LIST) { return hasImplicitCastArrayToList(srcType, dstType); } + if (srcType.getLogicalTypeID() == LogicalTypeID::LIST && + dstType.getLogicalTypeID() == LogicalTypeID::ARRAY) { + return hasImplicitCastListToArray(srcType, dstType); + } if (srcType.getLogicalTypeID() != dstType.getLogicalTypeID()) { return false; } diff --git a/src/include/function/built_in_function_utils.h b/src/include/function/built_in_function_utils.h index 0f748d3fbe..9b6c9c56eb 100644 --- a/src/include/function/built_in_function_utils.h +++ b/src/include/function/built_in_function_utils.h @@ -71,6 +71,10 @@ class BuiltInFunctionsUtils { static uint32_t castUUID(common::LogicalTypeID targetTypeID); + static uint32_t castList(common::LogicalTypeID targetTypeID); + + static uint32_t castArray(common::LogicalTypeID targetTypeID); + static Function* getBestMatch(std::vector& functions); static uint32_t getFunctionCost(const std::vector& inputTypes, diff --git a/test/test_files/common/arrayimplicitcasting.test b/test/test_files/common/arrayimplicitcasting.test index 253ef33bf5..646d7f8b72 100644 --- a/test/test_files/common/arrayimplicitcasting.test +++ b/test/test_files/common/arrayimplicitcasting.test @@ -9,3 +9,25 @@ -STATEMENT CREATE (t:tab {a: cast([1, 2, 3], 'int64[3]')}) RETURN t.a ---- 1 [1,2,3] + +-CASE ListArrayImplicitCast +-STATEMENT CREATE NODE TABLE T0(a SERIAL, b INT64[], c DOUBLE[2], d STRING[], PRIMARY KEY(a)); +---- ok +-STATEMENT CREATE (:T0 {b: [1.0, 2.0], c: [1, 2], d: [3, 5]}); +---- ok +-STATEMENT MATCH (a:T0) RETURN a.a, a.b, a.c, a.d; +---- 1 +0|[1,2]|[1.000000,2.000000]|[3,5] + +-CASE CastCostTest +-SKIP +-STATEMENT CREATE NODE TABLE Item(id UINT64, item STRING, price DOUBLE, vector DOUBLE[2], PRIMARY KEY (id)) +---- ok +-STATEMENT MERGE (a:Item {id: 1, item: 'apple', price: 2.0, vector: [3.1, 4.1]}) +---- ok +-STATEMENT MERGE (b:Item {id: 2, item: 'banana', price: 1.0, vector: [5.9, 26.5]}) +---- ok +-STATEMENT MATCH (a:Item) RETURN a.item, a.price, array_cosine_similarity(a.vector, [6.0, 25.0]) AS sim ORDER BY sim DESC +---- 2 +['banana',1.0,0.9998642653091405] +['apple',2.0,0.9163829638139936]