Skip to content

Commit

Permalink
add implicit casting from list to array (#3375)
Browse files Browse the repository at this point in the history
add tests

add castcost

add cast cost for list->array and array->list

skip failed test
  • Loading branch information
mxwli committed Apr 24, 2024
1 parent 9253e97 commit b491af2
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 0 deletions.
22 changes: 22 additions & 0 deletions src/function/built_in_function_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ uint32_t BuiltInFunctionsUtils::getCastCost(LogicalTypeID inputTypeID, LogicalTy
// currently don't allow timestamp to other timestamp types
// When we implement this in the future, revise tryGetMaxLogicalTypeID
return castTimestamp(targetTypeID);
case LogicalTypeID::LIST:
return castList(targetTypeID);
case LogicalTypeID::ARRAY:
return castArray(targetTypeID);
default:
return UNDEFINED_CAST_COST;
}
Expand Down Expand Up @@ -388,6 +392,24 @@ uint32_t BuiltInFunctionsUtils::castFromRDFVariant(LogicalTypeID inputTypeID) {
}
}

uint32_t BuiltInFunctionsUtils::castList(LogicalTypeID targetTypeID) {
switch (targetTypeID) {
case LogicalTypeID::ARRAY:
return getTargetTypeCost(targetTypeID);
default:
return UNDEFINED_CAST_COST;
}
}

uint32_t BuiltInFunctionsUtils::castArray(LogicalTypeID targetTypeID) {
switch (targetTypeID) {
case LogicalTypeID::LIST:
return getTargetTypeCost(targetTypeID);
default:
return UNDEFINED_CAST_COST;
}
}

// When there is multiple candidates functions, e.g. double + int and double + double for input
// "1.5 + parameter", we prefer the one without any implicit casting i.e. double + double.
// Additionally, we prefer function with string parameter because string is most permissive and can
Expand Down
9 changes: 9 additions & 0 deletions src/function/vector_cast_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,11 @@ static bool hasImplicitCastArrayToList(const LogicalType& srcType, const Logical
*ListType::getChildType(&dstType));
}

static bool hasImplicitCastListToArray(const LogicalType& srcType, const LogicalType& dstType) {
return CastFunction::hasImplicitCast(*ListType::getChildType(&srcType),
*ArrayType::getChildType(&dstType));
}

static bool hasImplicitCastStruct(const LogicalType& srcType, const LogicalType& dstType) {
auto srcFields = StructType::getFields(&srcType), dstFields = StructType::getFields(&dstType);
if (srcFields.size() != dstFields.size()) {
Expand Down Expand Up @@ -178,6 +183,10 @@ bool CastFunction::hasImplicitCast(const LogicalType& srcType, const LogicalType
dstType.getLogicalTypeID() == LogicalTypeID::LIST) {
return hasImplicitCastArrayToList(srcType, dstType);
}
if (srcType.getLogicalTypeID() == LogicalTypeID::LIST &&
dstType.getLogicalTypeID() == LogicalTypeID::ARRAY) {
return hasImplicitCastListToArray(srcType, dstType);
}
if (srcType.getLogicalTypeID() != dstType.getLogicalTypeID()) {
return false;
}
Expand Down
4 changes: 4 additions & 0 deletions src/include/function/built_in_function_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ class BuiltInFunctionsUtils {

static uint32_t castUUID(common::LogicalTypeID targetTypeID);

static uint32_t castList(common::LogicalTypeID targetTypeID);

static uint32_t castArray(common::LogicalTypeID targetTypeID);

static Function* getBestMatch(std::vector<Function*>& functions);

static uint32_t getFunctionCost(const std::vector<common::LogicalType>& inputTypes,
Expand Down
22 changes: 22 additions & 0 deletions test/test_files/common/arrayimplicitcasting.test
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,25 @@
-STATEMENT CREATE (t:tab {a: cast([1, 2, 3], 'int64[3]')}) RETURN t.a
---- 1
[1,2,3]

-CASE ListArrayImplicitCast
-STATEMENT CREATE NODE TABLE T0(a SERIAL, b INT64[], c DOUBLE[2], d STRING[], PRIMARY KEY(a));
---- ok
-STATEMENT CREATE (:T0 {b: [1.0, 2.0], c: [1, 2], d: [3, 5]});
---- ok
-STATEMENT MATCH (a:T0) RETURN a.a, a.b, a.c, a.d;
---- 1
0|[1,2]|[1.000000,2.000000]|[3,5]

-CASE CastCostTest
-SKIP
-STATEMENT CREATE NODE TABLE Item(id UINT64, item STRING, price DOUBLE, vector DOUBLE[2], PRIMARY KEY (id))
---- ok
-STATEMENT MERGE (a:Item {id: 1, item: 'apple', price: 2.0, vector: [3.1, 4.1]})
---- ok
-STATEMENT MERGE (b:Item {id: 2, item: 'banana', price: 1.0, vector: [5.9, 26.5]})
---- ok
-STATEMENT MATCH (a:Item) RETURN a.item, a.price, array_cosine_similarity(a.vector, [6.0, 25.0]) AS sim ORDER BY sim DESC
---- 2
['banana',1.0,0.9998642653091405]
['apple',2.0,0.9163829638139936]

0 comments on commit b491af2

Please sign in to comment.