Skip to content

Commit

Permalink
Fix issue-2269
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed Oct 30, 2023
1 parent 627b5f5 commit d4772e1
Show file tree
Hide file tree
Showing 15 changed files with 153 additions and 84 deletions.
47 changes: 25 additions & 22 deletions src/binder/expression_binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,35 +73,41 @@ std::shared_ptr<Expression> ExpressionBinder::foldExpression(
return result;
}

static std::string unsupportedImplicitCastException(
const Expression& expression, const common::LogicalType& targetType) {
return stringFormat(
"Expression {} has data type {} but expected {}. Implicit cast is not supported.",
expression.toString(), LogicalTypeUtils::dataTypeToString(expression.dataType),
LogicalTypeUtils::dataTypeToString(targetType));
}

std::shared_ptr<Expression> ExpressionBinder::implicitCastIfNecessary(
const std::shared_ptr<Expression>& expression, const LogicalType& targetType) {
if (targetType.getLogicalTypeID() == LogicalTypeID::ANY || expression->dataType == targetType) {
return expression;
}
if (expression->dataType.getLogicalTypeID() == LogicalTypeID::ANY) {
resolveAnyDataType(*expression, targetType);
const std::shared_ptr<Expression>& expression, common::LogicalTypeID targetTypeID) {
if (LogicalTypeUtils::isNested(targetTypeID)) {
if (expression->getDataType().getLogicalTypeID() == common::LogicalTypeID::ANY) {
throw BinderException(stringFormat(
"Cannot resolve recursive data type for expression {}.", expression->toString()));
}
// We don't support casting to nested data type. So instead we validate type match.
if (expression->getDataType().getLogicalTypeID() != targetTypeID) {
throw BinderException(
unsupportedImplicitCastException(*expression, LogicalType{targetTypeID}));

Check warning on line 94 in src/binder/expression_binder.cpp

View check run for this annotation

Codecov / codecov/patch

src/binder/expression_binder.cpp#L93-L94

Added lines #L93 - L94 were not covered by tests
}
return expression;
}
return implicitCast(expression, targetType);
return implicitCastIfNecessary(expression, LogicalType(targetTypeID));
}

std::shared_ptr<Expression> ExpressionBinder::implicitCastIfNecessary(
const std::shared_ptr<Expression>& expression, LogicalTypeID targetTypeID) {
if (targetTypeID == LogicalTypeID::ANY ||
expression->dataType.getLogicalTypeID() == targetTypeID) {
const std::shared_ptr<Expression>& expression, const LogicalType& targetType) {
if (targetType.getLogicalTypeID() == LogicalTypeID::ANY || expression->dataType == targetType) {
return expression;
}
if (expression->dataType.getLogicalTypeID() == LogicalTypeID::ANY) {
if (targetTypeID == LogicalTypeID::VAR_LIST) {
// e.g. len($1) we cannot infer the child type for $1.
throw BinderException(stringFormat(
"Cannot resolve recursive data type for expression {}.", expression->toString()));
}
resolveAnyDataType(*expression, LogicalType(targetTypeID));
resolveAnyDataType(*expression, targetType);
return expression;
}
assert(targetTypeID != LogicalTypeID::VAR_LIST);
return implicitCast(expression, LogicalType(targetTypeID));
return implicitCast(expression, targetType);
}

std::shared_ptr<Expression> ExpressionBinder::implicitCast(
Expand All @@ -118,10 +124,7 @@ std::shared_ptr<Expression> ExpressionBinder::implicitCast(
std::move(bindData), std::move(children), execFunc, nullptr /* selectFunc */,
std::move(uniqueName));
} else {
throw BinderException(stringFormat(
"Expression {} has data type {} but expected {}. Implicit cast is not supported.",
expression->toString(), LogicalTypeUtils::dataTypeToString(expression->dataType),
LogicalTypeUtils::dataTypeToString(targetType)));
throw BinderException(unsupportedImplicitCastException(*expression, targetType));
}
}

Expand Down
9 changes: 8 additions & 1 deletion src/common/types/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -707,12 +707,19 @@ bool LogicalTypeUtils::isNumerical(const LogicalType& dataType) {
}

bool LogicalTypeUtils::isNested(const LogicalType& dataType) {
switch (dataType.typeID) {
return isNested(dataType.typeID);
}

bool LogicalTypeUtils::isNested(kuzu::common::LogicalTypeID logicalTypeID) {
switch (logicalTypeID) {
case LogicalTypeID::STRUCT:
case LogicalTypeID::VAR_LIST:
case LogicalTypeID::FIXED_LIST:
case LogicalTypeID::UNION:
case LogicalTypeID::MAP:
case LogicalTypeID::NODE:
case LogicalTypeID::REL:
case LogicalTypeID::RECURSIVE_REL:
return true;
default:
return false;
Expand Down
11 changes: 9 additions & 2 deletions src/function/built_in_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -370,19 +370,26 @@ uint32_t BuiltInFunctions::castSerial(LogicalTypeID targetTypeID) {

// When there is multiple candidates functions, e.g. double + int and double + double for input
// "1.5 + parameter", we prefer the one without any implicit casting i.e. double + double.
// Additionally, we prefer function with string parameter because string is most permissive and can
// be cast to any type.
ScalarFunction* BuiltInFunctions::getBestMatch(std::vector<ScalarFunction*>& functionsToMatch) {
assert(functionsToMatch.size() > 1);
ScalarFunction* result = nullptr;
auto cost = UNDEFINED_CAST_COST;
for (auto& function : functionsToMatch) {
auto currentCost = 0;
std::unordered_set<LogicalTypeID> distinctParameterTypes;
for (auto& parameterTypeID : function->parameterTypeIDs) {
if (parameterTypeID != LogicalTypeID::STRING) {
currentCost++;
}
if (!distinctParameterTypes.contains(parameterTypeID)) {
currentCost++;
distinctParameterTypes.insert(parameterTypeID);
}
}
if (distinctParameterTypes.size() < cost) {
cost = distinctParameterTypes.size();
if (currentCost < cost) {
cost = currentCost;
result = function;
}
}
Expand Down
29 changes: 16 additions & 13 deletions src/function/vector_list_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,35 +56,38 @@ void ListCreationFunction::execFunc(
}
}

static LogicalType getValidLogicalType(const binder::expression_vector& expressions) {
for (auto& expression : expressions) {
if (expression->dataType.getLogicalTypeID() != LogicalTypeID::ANY) {
return expression->dataType;
}
}
return LogicalType(common::LogicalTypeID::ANY);
}

std::unique_ptr<FunctionBindData> ListCreationFunction::bindFunc(
const binder::expression_vector& arguments, Function* /*function*/) {
// ListCreation requires all parameters to have the same type or be ANY type. The result type of
// listCreation can be determined by the first non-ANY type parameter. If all parameters have
// dataType ANY, then the resultType will be INT64[] (default type).
auto varListTypeInfo =
std::make_unique<VarListTypeInfo>(std::make_unique<LogicalType>(LogicalTypeID::INT64));
auto resultType = LogicalType{LogicalTypeID::VAR_LIST, std::move(varListTypeInfo)};
for (auto& argument : arguments) {
if (argument->getDataType().getLogicalTypeID() != LogicalTypeID::ANY) {
varListTypeInfo = std::make_unique<VarListTypeInfo>(
std::make_unique<LogicalType>(argument->getDataType()));
resultType = LogicalType{LogicalTypeID::VAR_LIST, std::move(varListTypeInfo)};
break;
}
auto childType = getValidLogicalType(arguments);
if (childType.getLogicalTypeID() == common::LogicalTypeID::ANY) {
childType = LogicalType(common::LogicalTypeID::STRING);
}
auto resultChildType = VarListType::getChildType(&resultType);
// Cast parameters with ANY dataType to resultChildType.
for (auto& argument : arguments) {
auto& parameterType = argument->getDataTypeReference();
if (parameterType != *resultChildType) {
if (parameterType != childType) {
if (parameterType.getLogicalTypeID() == LogicalTypeID::ANY) {
binder::ExpressionBinder::resolveAnyDataType(*argument, *resultChildType);
binder::ExpressionBinder::resolveAnyDataType(*argument, childType);
} else {
throw BinderException(getListFunctionIncompatibleChildrenTypeErrorMsg(
LIST_CREATION_FUNC_NAME, arguments[0]->getDataType(), argument->getDataType()));
}
}
}
auto varListTypeInfo = std::make_unique<VarListTypeInfo>(childType.copy());
auto resultType = LogicalType{LogicalTypeID::VAR_LIST, std::move(varListTypeInfo)};
return std::make_unique<FunctionBindData>(resultType);
}

Expand Down
2 changes: 1 addition & 1 deletion src/function/vector_struct_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ std::unique_ptr<FunctionBindData> StructPackFunctions::bindFunc(
for (auto& argument : arguments) {
if (argument->getDataType().getLogicalTypeID() == LogicalTypeID::ANY) {
binder::ExpressionBinder::resolveAnyDataType(
*argument, LogicalType{LogicalTypeID::INT64});
*argument, LogicalType{LogicalTypeID::STRING});
}
fields.emplace_back(
std::make_unique<StructField>(argument->getAlias(), argument->getDataType().copy()));
Expand Down
7 changes: 6 additions & 1 deletion src/function/vector_union_functions.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "function/union/vector_union_functions.h"

#include "binder/expression_binder.h"
#include "function/struct/vector_struct_functions.h"
#include "function/union/functions/union_tag.h"

Expand All @@ -23,8 +24,12 @@ std::unique_ptr<FunctionBindData> UnionValueFunction::bindFunc(
// TODO(Ziy): Use UINT8 to represent tag value.
fields.push_back(std::make_unique<StructField>(
UnionType::TAG_FIELD_NAME, std::make_unique<LogicalType>(UnionType::TAG_FIELD_TYPE)));
if (arguments[0]->getDataType().getLogicalTypeID() == common::LogicalTypeID::ANY) {
binder::ExpressionBinder::resolveAnyDataType(
*arguments[0], LogicalType(LogicalTypeID::STRING));
}
fields.push_back(std::make_unique<StructField>(
arguments[0]->getAlias(), std::make_unique<LogicalType>(arguments[0]->getDataType())));
arguments[0]->getAlias(), arguments[0]->getDataType().copy()));
auto resultType =
LogicalType(LogicalTypeID::UNION, std::make_unique<StructTypeInfo>(std::move(fields)));
return std::make_unique<FunctionBindData>(resultType);
Expand Down
9 changes: 2 additions & 7 deletions src/include/binder/expression_binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,10 @@ class ExpressionBinder {
const parser::ParsedExpression& parsedExpression);

/****** cast *****/
// Note: we expose two implicitCastIfNecessary interfaces.
// For function binding we cast with data type ID because function definition cannot be
// recursively generated, e.g. list_extract(param) we only declare param with type LIST but do
// not specify its child type.
// For the rest, i.e. set clause binding, we cast with data type. For example, a.list = $1.
static std::shared_ptr<Expression> implicitCastIfNecessary(
const std::shared_ptr<Expression>& expression, const common::LogicalType& targetType);
static std::shared_ptr<Expression> implicitCastIfNecessary(
const std::shared_ptr<Expression>& expression, common::LogicalTypeID targetTypeID);
static std::shared_ptr<Expression> implicitCastIfNecessary(
const std::shared_ptr<Expression>& expression, const common::LogicalType& targetType);
static std::shared_ptr<Expression> implicitCast(
const std::shared_ptr<Expression>& expression, const common::LogicalType& targetType);

Expand Down
2 changes: 2 additions & 0 deletions src/include/common/types/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ class LogicalType {

inline PhysicalTypeID getPhysicalType() const { return physicalType; }

inline bool hasExtraTypeInfo() const { return extraTypeInfo != nullptr; }
inline void setExtraTypeInfo(std::unique_ptr<ExtraTypeInfo> typeInfo) {
extraTypeInfo = std::move(typeInfo);
}
Expand Down Expand Up @@ -432,6 +433,7 @@ class LogicalTypeUtils {
static uint32_t getRowLayoutSize(const LogicalType& logicalType);
static bool isNumerical(const LogicalType& dataType);
static bool isNested(const LogicalType& dataType);
static bool isNested(LogicalTypeID logicalTypeID);
static std::vector<LogicalType> getAllValidComparableLogicalTypes();
static std::vector<LogicalTypeID> getNumericalLogicalTypeIDs();
static std::vector<LogicalTypeID> getIntegerLogicalTypeIDs();
Expand Down
96 changes: 70 additions & 26 deletions test/main/prepare_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@
using namespace kuzu::common;
using namespace kuzu::testing;

static void checkTuple(kuzu::processor::FlatTuple* tuple, const std::string& groundTruth) {
ASSERT_STREQ(tuple->toString().c_str(), groundTruth.c_str());
}

TEST_F(ApiTest, MultiParamsPrepare) {
auto preparedStatement = conn->prepare(
"MATCH (a:person) WHERE a.fName STARTS WITH $n OR a.fName CONTAINS $xx RETURN COUNT(*)");
auto result = conn->execute(preparedStatement.get(), std::make_pair(std::string("n"), "A"),
std::make_pair(std::string("xx"), "ooq"));
ASSERT_TRUE(result->hasNext());
auto tuple = result->getNext();
ASSERT_EQ(tuple->getValue(0)->getValue<int64_t>(), 2);
checkTuple(result->getNext().get(), "2\n");
ASSERT_FALSE(result->hasNext());
}

Expand All @@ -19,8 +22,7 @@ TEST_F(ApiTest, PrepareBool) {
conn->prepare("MATCH (a:person) WHERE a.isStudent = $1 RETURN COUNT(*)");
auto result = conn->execute(preparedStatement.get(), std::make_pair(std::string("1"), true));
ASSERT_TRUE(result->hasNext());
auto tuple = result->getNext();
ASSERT_EQ(tuple->getValue(0)->getValue<int64_t>(), 3);
checkTuple(result->getNext().get(), "3\n");
ASSERT_FALSE(result->hasNext());
}

Expand All @@ -29,8 +31,7 @@ TEST_F(ApiTest, PrepareInt) {
auto result =
conn->execute(preparedStatement.get(), std::make_pair(std::string("1"), (int64_t)10));
ASSERT_TRUE(result->hasNext());
auto tuple = result->getNext();
ASSERT_EQ(tuple->getValue(0)->getValue<int64_t>(), 45);
checkTuple(result->getNext().get(), "45\n");
ASSERT_FALSE(result->hasNext());
}

Expand All @@ -40,8 +41,7 @@ TEST_F(ApiTest, PrepareDouble) {
auto result =
conn->execute(preparedStatement.get(), std::make_pair(std::string("1"), (double_t)10.5));
ASSERT_TRUE(result->hasNext());
auto tuple = result->getNext();
ASSERT_EQ(tuple->getValue(0)->getValue<double>(), 15.5);
checkTuple(result->getNext().get(), "15.500000\n");
ASSERT_FALSE(result->hasNext());
}

Expand All @@ -50,8 +50,7 @@ TEST_F(ApiTest, PrepareString) {
conn->prepare("MATCH (a:person) WHERE a.fName STARTS WITH $n RETURN COUNT(*)");
auto result = conn->execute(preparedStatement.get(), std::make_pair(std::string("n"), "A"));
ASSERT_TRUE(result->hasNext());
auto tuple = result->getNext();
ASSERT_EQ(tuple->getValue(0)->getValue<int64_t>(), 1);
checkTuple(result->getNext().get(), "1\n");
ASSERT_FALSE(result->hasNext());
}

Expand All @@ -61,8 +60,7 @@ TEST_F(ApiTest, PrepareDate) {
auto result = conn->execute(
preparedStatement.get(), std::make_pair(std::string("n"), Date::fromDate(1900, 1, 1)));
ASSERT_TRUE(result->hasNext());
auto tuple = result->getNext();
ASSERT_EQ(tuple->getValue(0)->getValue<int64_t>(), 2);
checkTuple(result->getNext().get(), "2\n");
ASSERT_FALSE(result->hasNext());
}

Expand All @@ -74,8 +72,7 @@ TEST_F(ApiTest, PrepareTimestamp) {
auto result = conn->execute(preparedStatement.get(),
std::make_pair(std::string("n"), Timestamp::fromDateTime(date, time)));
ASSERT_TRUE(result->hasNext());
auto tuple = result->getNext();
ASSERT_EQ(tuple->getValue(0)->getValue<int64_t>(), 1);
checkTuple(result->getNext().get(), "1\n");
ASSERT_FALSE(result->hasNext());
}

Expand All @@ -87,21 +84,68 @@ TEST_F(ApiTest, PrepareInterval) {
std::make_pair(
std::string("n"), Interval::fromCString(intervalStr.c_str(), intervalStr.length())));
ASSERT_TRUE(result->hasNext());
auto tuple = result->getNext();
ASSERT_EQ(tuple->getValue(0)->getValue<int64_t>(), 2);
checkTuple(result->getNext().get(), "2\n");
ASSERT_FALSE(result->hasNext());
}

// TEST_F(ApiTest, DefaultParam) {
// auto preparedStatement = conn->prepare("MATCH (a:person) WHERE $1 = $2 RETURN COUNT(*)");
// auto result =
// conn->execute(preparedStatement.get(), std::make_pair(std::string("1"), (int64_t)1.4),
// std::make_pair(std::string("2"), (int64_t)1.4));
// ASSERT_TRUE(result->hasNext());
// auto tuple = result->getNext();
// ASSERT_EQ(tuple->getValue(0)->getValue<int64_t>(), 8);
// ASSERT_FALSE(result->hasNext());
//}
TEST_F(ApiTest, PrepareDefaultParam) {
auto preparedStatement = conn->prepare("RETURN to_int8($1)");
auto result = conn->execute(preparedStatement.get(), std::make_pair(std::string("1"), "1"));
ASSERT_TRUE(result->hasNext());
checkTuple(result->getNext().get(), "1\n");
ASSERT_FALSE(result->hasNext());
preparedStatement = conn->prepare("RETURN size($1)");
result = conn->execute(preparedStatement.get(), std::make_pair(std::string("1"), 1));
ASSERT_FALSE(result->isSuccess());
ASSERT_STREQ(
result->getErrorMessage().c_str(), "Parameter 1 has data type INT32 but expects STRING.");
}

TEST_F(ApiTest, PrepareDefaultListParam) {
auto preparedStatement = conn->prepare("RETURN [1, $1]");
auto result =
conn->execute(preparedStatement.get(), std::make_pair(std::string("1"), (int64_t)1));
ASSERT_TRUE(result->hasNext());
checkTuple(result->getNext().get(), "[1,1]\n");
result = conn->execute(preparedStatement.get(), std::make_pair(std::string("1"), "as"));
ASSERT_FALSE(result->isSuccess());
ASSERT_STREQ(
result->getErrorMessage().c_str(), "Parameter 1 has data type STRING but expects INT64.");
preparedStatement = conn->prepare("RETURN [$1]");
result = conn->execute(preparedStatement.get(), std::make_pair(std::string("1"), "as"));
ASSERT_TRUE(result->hasNext());
checkTuple(result->getNext().get(), "[as]\n");
preparedStatement = conn->prepare("RETURN [to_int32($1)]");
result = conn->execute(preparedStatement.get(), std::make_pair(std::string("1"), "10"));
ASSERT_TRUE(result->hasNext());
checkTuple(result->getNext().get(), "[10]\n");
}

TEST_F(ApiTest, PrepareDefaultStructParam) {
auto preparedStatement = conn->prepare("RETURN {a:$1}");
auto result = conn->execute(preparedStatement.get(), std::make_pair(std::string("1"), "10"));
ASSERT_TRUE(result->hasNext());
checkTuple(result->getNext().get(), "{a: 10}\n");
result = conn->execute(preparedStatement.get(), std::make_pair(std::string("1"), 1));
ASSERT_FALSE(result->isSuccess());
ASSERT_STREQ(
result->getErrorMessage().c_str(), "Parameter 1 has data type INT32 but expects STRING.");
}

TEST_F(ApiTest, PrepareDefaultMapParam) {
auto preparedStatement = conn->prepare("RETURN map([$1], [$2])");
auto result = conn->execute(preparedStatement.get(), std::make_pair(std::string("1"), "10"),
std::make_pair(std::string("2"), "abc"));
ASSERT_TRUE(result->hasNext());
checkTuple(result->getNext().get(), "{10=abc}\n");
}

TEST_F(ApiTest, PrepareDefaultUnionParam) {
auto preparedStatement = conn->prepare("RETURN union_value(a := $1)");
auto result = conn->execute(preparedStatement.get(), std::make_pair(std::string("1"), "10"));
ASSERT_TRUE(result->hasNext());
checkTuple(result->getNext().get(), "10\n");
}

TEST_F(ApiTest, PrepareLargeJoin) {
auto preparedStatement = conn->prepare(
Expand Down
Loading

0 comments on commit d4772e1

Please sign in to comment.