Skip to content

Commit

Permalink
add implicit casting from list to array
Browse files Browse the repository at this point in the history
add tests

add castcost

add cast cost for list->array and array->list

skip failed test
  • Loading branch information
mxwli committed Apr 24, 2024
1 parent fa3afd7 commit 5e1b603
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 0 deletions.
22 changes: 22 additions & 0 deletions src/function/built_in_function_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ uint32_t BuiltInFunctionsUtils::getCastCost(LogicalTypeID inputTypeID, LogicalTy
// currently don't allow timestamp to other timestamp types
// When we implement this in the future, revise tryGetMaxLogicalTypeID
return castTimestamp(targetTypeID);
case LogicalTypeID::LIST:
return castList(targetTypeID);
case LogicalTypeID::ARRAY:
return castArray(targetTypeID);
default:
return UNDEFINED_CAST_COST;
}
Expand Down Expand Up @@ -388,6 +392,24 @@ uint32_t BuiltInFunctionsUtils::castFromRDFVariant(LogicalTypeID inputTypeID) {
}
}

uint32_t BuiltInFunctionsUtils::castList(LogicalTypeID targetTypeID) {
switch (targetTypeID) {
case LogicalTypeID::ARRAY:
return getTargetTypeCost(targetTypeID);
default:
return UNDEFINED_CAST_COST;
}
}

uint32_t BuiltInFunctionsUtils::castArray(LogicalTypeID targetTypeID) {
switch (targetTypeID) {
case LogicalTypeID::LIST:
return getTargetTypeCost(targetTypeID);
default:
return UNDEFINED_CAST_COST;
}
}

// When there is multiple candidates functions, e.g. double + int and double + double for input
// "1.5 + parameter", we prefer the one without any implicit casting i.e. double + double.
// Additionally, we prefer function with string parameter because string is most permissive and can
Expand Down
9 changes: 9 additions & 0 deletions src/function/vector_cast_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,11 @@ static bool hasImplicitCastArrayToList(const LogicalType& srcType, const Logical
*ListType::getChildType(&dstType));
}

static bool hasImplicitCastListToArray(const LogicalType& srcType, const LogicalType& dstType) {
return CastFunction::hasImplicitCast(*ListType::getChildType(&srcType),
*ArrayType::getChildType(&dstType));
}

static bool hasImplicitCastStruct(const LogicalType& srcType, const LogicalType& dstType) {
auto srcFields = StructType::getFields(&srcType), dstFields = StructType::getFields(&dstType);
if (srcFields.size() != dstFields.size()) {
Expand Down Expand Up @@ -178,6 +183,10 @@ bool CastFunction::hasImplicitCast(const LogicalType& srcType, const LogicalType
dstType.getLogicalTypeID() == LogicalTypeID::LIST) {
return hasImplicitCastArrayToList(srcType, dstType);
}
if (srcType.getLogicalTypeID() == LogicalTypeID::LIST &&
dstType.getLogicalTypeID() == LogicalTypeID::ARRAY) {
return hasImplicitCastListToArray(srcType, dstType);
}
if (srcType.getLogicalTypeID() != dstType.getLogicalTypeID()) {
return false;
}
Expand Down
4 changes: 4 additions & 0 deletions src/include/function/built_in_function_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ class BuiltInFunctionsUtils {

static uint32_t castUUID(common::LogicalTypeID targetTypeID);

static uint32_t castList(common::LogicalTypeID targetTypeID);

static uint32_t castArray(common::LogicalTypeID targetTypeID);

static Function* getBestMatch(std::vector<Function*>& functions);

static uint32_t getFunctionCost(const std::vector<common::LogicalType>& inputTypes,
Expand Down
22 changes: 22 additions & 0 deletions test/test_files/common/arrayimplicitcasting.test
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,25 @@
-STATEMENT CREATE (t:tab {a: cast([1, 2, 3], 'int64[3]')}) RETURN t.a
---- 1
[1,2,3]

-CASE ListArrayImplicitCast
-STATEMENT CREATE NODE TABLE T0(a SERIAL, b INT64[], c DOUBLE[2], d STRING[], PRIMARY KEY(a));
---- ok
-STATEMENT CREATE (:T0 {b: [1.0, 2.0], c: [1, 2], d: [3, 5]});
---- ok
-STATEMENT MATCH (a:T0) RETURN a.a, a.b, a.c, a.d;
---- 1
0|[1,2]|[1.000000,2.000000]|[3,5]

-CASE CastCostTest
-SKIP
-STATEMENT CREATE NODE TABLE Item(id UINT64, item STRING, price DOUBLE, vector DOUBLE[2], PRIMARY KEY (id))
---- ok
-STATEMENT MERGE (a:Item {id: 1, item: 'apple', price: 2.0, vector: [3.1, 4.1]})
---- ok
-STATEMENT MERGE (b:Item {id: 2, item: 'banana', price: 1.0, vector: [5.9, 26.5]})
---- ok
-STATEMENT MATCH (a:Item) RETURN a.item, a.price, array_cosine_similarity(a.vector, [6.0, 25.0]) AS sim ORDER BY sim DESC
---- 2
['banana',1.0,0.9998642653091405]
['apple',2.0,0.9163829638139936]

0 comments on commit 5e1b603

Please sign in to comment.