Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Implicit Casting from List to Array #3375

Merged
merged 1 commit into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions src/function/built_in_function_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ uint32_t BuiltInFunctionsUtils::getCastCost(LogicalTypeID inputTypeID, LogicalTy
// currently don't allow timestamp to other timestamp types
// When we implement this in the future, revise tryGetMaxLogicalTypeID
return castTimestamp(targetTypeID);
case LogicalTypeID::LIST:
return castList(targetTypeID);
case LogicalTypeID::ARRAY:
return castArray(targetTypeID);
default:
return UNDEFINED_CAST_COST;
}
Expand Down Expand Up @@ -388,6 +392,24 @@ uint32_t BuiltInFunctionsUtils::castFromRDFVariant(LogicalTypeID inputTypeID) {
}
}

uint32_t BuiltInFunctionsUtils::castList(LogicalTypeID targetTypeID) {
switch (targetTypeID) {
case LogicalTypeID::ARRAY:
return getTargetTypeCost(targetTypeID);
default:
return UNDEFINED_CAST_COST;
}
}

uint32_t BuiltInFunctionsUtils::castArray(LogicalTypeID targetTypeID) {
switch (targetTypeID) {
case LogicalTypeID::LIST:
return getTargetTypeCost(targetTypeID);
default:
return UNDEFINED_CAST_COST;
}
}

// When there is multiple candidates functions, e.g. double + int and double + double for input
// "1.5 + parameter", we prefer the one without any implicit casting i.e. double + double.
// Additionally, we prefer function with string parameter because string is most permissive and can
Expand Down
9 changes: 9 additions & 0 deletions src/function/vector_cast_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,11 @@ static bool hasImplicitCastArrayToList(const LogicalType& srcType, const Logical
*ListType::getChildType(&dstType));
}

static bool hasImplicitCastListToArray(const LogicalType& srcType, const LogicalType& dstType) {
return CastFunction::hasImplicitCast(*ListType::getChildType(&srcType),
*ArrayType::getChildType(&dstType));
}

static bool hasImplicitCastStruct(const LogicalType& srcType, const LogicalType& dstType) {
auto srcFields = StructType::getFields(&srcType), dstFields = StructType::getFields(&dstType);
if (srcFields.size() != dstFields.size()) {
Expand Down Expand Up @@ -178,6 +183,10 @@ bool CastFunction::hasImplicitCast(const LogicalType& srcType, const LogicalType
dstType.getLogicalTypeID() == LogicalTypeID::LIST) {
return hasImplicitCastArrayToList(srcType, dstType);
}
if (srcType.getLogicalTypeID() == LogicalTypeID::LIST &&
dstType.getLogicalTypeID() == LogicalTypeID::ARRAY) {
return hasImplicitCastListToArray(srcType, dstType);
}
if (srcType.getLogicalTypeID() != dstType.getLogicalTypeID()) {
return false;
}
Expand Down
4 changes: 4 additions & 0 deletions src/include/function/built_in_function_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ class BuiltInFunctionsUtils {

static uint32_t castUUID(common::LogicalTypeID targetTypeID);

static uint32_t castList(common::LogicalTypeID targetTypeID);

static uint32_t castArray(common::LogicalTypeID targetTypeID);

static Function* getBestMatch(std::vector<Function*>& functions);

static uint32_t getFunctionCost(const std::vector<common::LogicalType>& inputTypes,
Expand Down
22 changes: 22 additions & 0 deletions test/test_files/common/arrayimplicitcasting.test
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,25 @@
-STATEMENT CREATE (t:tab {a: cast([1, 2, 3], 'int64[3]')}) RETURN t.a
---- 1
[1,2,3]

-CASE ListArrayImplicitCast
-STATEMENT CREATE NODE TABLE T0(a SERIAL, b INT64[], c DOUBLE[2], d STRING[], PRIMARY KEY(a));
---- ok
-STATEMENT CREATE (:T0 {b: [1.0, 2.0], c: [1, 2], d: [3, 5]});
---- ok
-STATEMENT MATCH (a:T0) RETURN a.a, a.b, a.c, a.d;
---- 1
0|[1,2]|[1.000000,2.000000]|[3,5]

-CASE CastCostTest
-SKIP
-STATEMENT CREATE NODE TABLE Item(id UINT64, item STRING, price DOUBLE, vector DOUBLE[2], PRIMARY KEY (id))
---- ok
-STATEMENT MERGE (a:Item {id: 1, item: 'apple', price: 2.0, vector: [3.1, 4.1]})
---- ok
-STATEMENT MERGE (b:Item {id: 2, item: 'banana', price: 1.0, vector: [5.9, 26.5]})
---- ok
-STATEMENT MATCH (a:Item) RETURN a.item, a.price, array_cosine_similarity(a.vector, [6.0, 25.0]) AS sim ORDER BY sim DESC
---- 2
['banana',1.0,0.9998642653091405]
['apple',2.0,0.9163829638139936]
Loading