Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 2269 #2281

Merged
merged 1 commit into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 25 additions & 22 deletions src/binder/expression_binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,35 +73,41 @@
return result;
}

static std::string unsupportedImplicitCastException(
andyfengHKU marked this conversation as resolved.
Show resolved Hide resolved
const Expression& expression, const common::LogicalType& targetType) {
return stringFormat(
"Expression {} has data type {} but expected {}. Implicit cast is not supported.",
expression.toString(), LogicalTypeUtils::dataTypeToString(expression.dataType),
LogicalTypeUtils::dataTypeToString(targetType));
}

std::shared_ptr<Expression> ExpressionBinder::implicitCastIfNecessary(
const std::shared_ptr<Expression>& expression, const LogicalType& targetType) {
if (targetType.getLogicalTypeID() == LogicalTypeID::ANY || expression->dataType == targetType) {
return expression;
}
if (expression->dataType.getLogicalTypeID() == LogicalTypeID::ANY) {
resolveAnyDataType(*expression, targetType);
const std::shared_ptr<Expression>& expression, common::LogicalTypeID targetTypeID) {
if (LogicalTypeUtils::isNested(targetTypeID)) {
if (expression->getDataType().getLogicalTypeID() == common::LogicalTypeID::ANY) {
throw BinderException(stringFormat(
"Cannot resolve recursive data type for expression {}.", expression->toString()));
}
// We don't support casting to nested data type. So instead we validate type match.
if (expression->getDataType().getLogicalTypeID() != targetTypeID) {
throw BinderException(
unsupportedImplicitCastException(*expression, LogicalType{targetTypeID}));

Check warning on line 94 in src/binder/expression_binder.cpp

View check run for this annotation

Codecov / codecov/patch

src/binder/expression_binder.cpp#L93-L94

Added lines #L93 - L94 were not covered by tests
}
return expression;
}
return implicitCast(expression, targetType);
return implicitCastIfNecessary(expression, LogicalType(targetTypeID));
}

std::shared_ptr<Expression> ExpressionBinder::implicitCastIfNecessary(
const std::shared_ptr<Expression>& expression, LogicalTypeID targetTypeID) {
if (targetTypeID == LogicalTypeID::ANY ||
expression->dataType.getLogicalTypeID() == targetTypeID) {
const std::shared_ptr<Expression>& expression, const LogicalType& targetType) {
if (targetType.getLogicalTypeID() == LogicalTypeID::ANY || expression->dataType == targetType) {
return expression;
}
if (expression->dataType.getLogicalTypeID() == LogicalTypeID::ANY) {
if (targetTypeID == LogicalTypeID::VAR_LIST) {
// e.g. len($1) we cannot infer the child type for $1.
throw BinderException(stringFormat(
"Cannot resolve recursive data type for expression {}.", expression->toString()));
}
resolveAnyDataType(*expression, LogicalType(targetTypeID));
resolveAnyDataType(*expression, targetType);
return expression;
}
assert(targetTypeID != LogicalTypeID::VAR_LIST);
return implicitCast(expression, LogicalType(targetTypeID));
return implicitCast(expression, targetType);
}

std::shared_ptr<Expression> ExpressionBinder::implicitCast(
Expand All @@ -118,10 +124,7 @@
std::move(bindData), std::move(children), execFunc, nullptr /* selectFunc */,
std::move(uniqueName));
} else {
throw BinderException(stringFormat(
"Expression {} has data type {} but expected {}. Implicit cast is not supported.",
expression->toString(), LogicalTypeUtils::dataTypeToString(expression->dataType),
LogicalTypeUtils::dataTypeToString(targetType)));
throw BinderException(unsupportedImplicitCastException(*expression, targetType));
}
}

Expand Down
9 changes: 8 additions & 1 deletion src/common/types/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -707,12 +707,19 @@ bool LogicalTypeUtils::isNumerical(const LogicalType& dataType) {
}

bool LogicalTypeUtils::isNested(const LogicalType& dataType) {
switch (dataType.typeID) {
return isNested(dataType.typeID);
}

bool LogicalTypeUtils::isNested(kuzu::common::LogicalTypeID logicalTypeID) {
andyfengHKU marked this conversation as resolved.
Show resolved Hide resolved
switch (logicalTypeID) {
case LogicalTypeID::STRUCT:
case LogicalTypeID::VAR_LIST:
case LogicalTypeID::FIXED_LIST:
case LogicalTypeID::UNION:
case LogicalTypeID::MAP:
case LogicalTypeID::NODE:
case LogicalTypeID::REL:
case LogicalTypeID::RECURSIVE_REL:
return true;
default:
return false;
Expand Down
11 changes: 9 additions & 2 deletions src/function/built_in_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -370,19 +370,26 @@ uint32_t BuiltInFunctions::castSerial(LogicalTypeID targetTypeID) {

// When there is multiple candidates functions, e.g. double + int and double + double for input
// "1.5 + parameter", we prefer the one without any implicit casting i.e. double + double.
// Additionally, we prefer function with string parameter because string is most permissive and can
// be cast to any type.
ScalarFunction* BuiltInFunctions::getBestMatch(std::vector<ScalarFunction*>& functionsToMatch) {
assert(functionsToMatch.size() > 1);
ScalarFunction* result = nullptr;
auto cost = UNDEFINED_CAST_COST;
for (auto& function : functionsToMatch) {
auto currentCost = 0;
std::unordered_set<LogicalTypeID> distinctParameterTypes;
for (auto& parameterTypeID : function->parameterTypeIDs) {
if (parameterTypeID != LogicalTypeID::STRING) {
currentCost++;
}
if (!distinctParameterTypes.contains(parameterTypeID)) {
currentCost++;
distinctParameterTypes.insert(parameterTypeID);
}
}
if (distinctParameterTypes.size() < cost) {
cost = distinctParameterTypes.size();
if (currentCost < cost) {
cost = currentCost;
result = function;
}
}
Expand Down
29 changes: 16 additions & 13 deletions src/function/vector_list_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,35 +56,38 @@ void ListCreationFunction::execFunc(
}
}

static LogicalType getValidLogicalType(const binder::expression_vector& expressions) {
for (auto& expression : expressions) {
if (expression->dataType.getLogicalTypeID() != LogicalTypeID::ANY) {
return expression->dataType;
}
}
return LogicalType(common::LogicalTypeID::ANY);
}

std::unique_ptr<FunctionBindData> ListCreationFunction::bindFunc(
const binder::expression_vector& arguments, Function* /*function*/) {
// ListCreation requires all parameters to have the same type or be ANY type. The result type of
// listCreation can be determined by the first non-ANY type parameter. If all parameters have
// dataType ANY, then the resultType will be INT64[] (default type).
auto varListTypeInfo =
std::make_unique<VarListTypeInfo>(std::make_unique<LogicalType>(LogicalTypeID::INT64));
auto resultType = LogicalType{LogicalTypeID::VAR_LIST, std::move(varListTypeInfo)};
for (auto& argument : arguments) {
if (argument->getDataType().getLogicalTypeID() != LogicalTypeID::ANY) {
varListTypeInfo = std::make_unique<VarListTypeInfo>(
std::make_unique<LogicalType>(argument->getDataType()));
resultType = LogicalType{LogicalTypeID::VAR_LIST, std::move(varListTypeInfo)};
break;
}
auto childType = getValidLogicalType(arguments);
if (childType.getLogicalTypeID() == common::LogicalTypeID::ANY) {
childType = LogicalType(common::LogicalTypeID::STRING);
}
auto resultChildType = VarListType::getChildType(&resultType);
// Cast parameters with ANY dataType to resultChildType.
for (auto& argument : arguments) {
auto& parameterType = argument->getDataTypeReference();
if (parameterType != *resultChildType) {
if (parameterType != childType) {
if (parameterType.getLogicalTypeID() == LogicalTypeID::ANY) {
binder::ExpressionBinder::resolveAnyDataType(*argument, *resultChildType);
binder::ExpressionBinder::resolveAnyDataType(*argument, childType);
} else {
throw BinderException(getListFunctionIncompatibleChildrenTypeErrorMsg(
LIST_CREATION_FUNC_NAME, arguments[0]->getDataType(), argument->getDataType()));
}
}
}
auto varListTypeInfo = std::make_unique<VarListTypeInfo>(childType.copy());
auto resultType = LogicalType{LogicalTypeID::VAR_LIST, std::move(varListTypeInfo)};
return std::make_unique<FunctionBindData>(resultType);
}

Expand Down
2 changes: 1 addition & 1 deletion src/function/vector_struct_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ std::unique_ptr<FunctionBindData> StructPackFunctions::bindFunc(
for (auto& argument : arguments) {
if (argument->getDataType().getLogicalTypeID() == LogicalTypeID::ANY) {
binder::ExpressionBinder::resolveAnyDataType(
*argument, LogicalType{LogicalTypeID::INT64});
*argument, LogicalType{LogicalTypeID::STRING});
}
fields.emplace_back(
std::make_unique<StructField>(argument->getAlias(), argument->getDataType().copy()));
Expand Down
7 changes: 6 additions & 1 deletion src/function/vector_union_functions.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "function/union/vector_union_functions.h"

#include "binder/expression_binder.h"
#include "function/struct/vector_struct_functions.h"
#include "function/union/functions/union_tag.h"

Expand All @@ -23,8 +24,12 @@ std::unique_ptr<FunctionBindData> UnionValueFunction::bindFunc(
// TODO(Ziy): Use UINT8 to represent tag value.
fields.push_back(std::make_unique<StructField>(
UnionType::TAG_FIELD_NAME, std::make_unique<LogicalType>(UnionType::TAG_FIELD_TYPE)));
if (arguments[0]->getDataType().getLogicalTypeID() == common::LogicalTypeID::ANY) {
binder::ExpressionBinder::resolveAnyDataType(
*arguments[0], LogicalType(LogicalTypeID::STRING));
}
fields.push_back(std::make_unique<StructField>(
arguments[0]->getAlias(), std::make_unique<LogicalType>(arguments[0]->getDataType())));
arguments[0]->getAlias(), arguments[0]->getDataType().copy()));
auto resultType =
LogicalType(LogicalTypeID::UNION, std::make_unique<StructTypeInfo>(std::move(fields)));
return std::make_unique<FunctionBindData>(resultType);
Expand Down
9 changes: 2 additions & 7 deletions src/include/binder/expression_binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,10 @@ class ExpressionBinder {
const parser::ParsedExpression& parsedExpression);

/****** cast *****/
// Note: we expose two implicitCastIfNecessary interfaces.
// For function binding we cast with data type ID because function definition cannot be
// recursively generated, e.g. list_extract(param) we only declare param with type LIST but do
// not specify its child type.
// For the rest, i.e. set clause binding, we cast with data type. For example, a.list = $1.
static std::shared_ptr<Expression> implicitCastIfNecessary(
const std::shared_ptr<Expression>& expression, const common::LogicalType& targetType);
static std::shared_ptr<Expression> implicitCastIfNecessary(
const std::shared_ptr<Expression>& expression, common::LogicalTypeID targetTypeID);
static std::shared_ptr<Expression> implicitCastIfNecessary(
const std::shared_ptr<Expression>& expression, const common::LogicalType& targetType);
static std::shared_ptr<Expression> implicitCast(
const std::shared_ptr<Expression>& expression, const common::LogicalType& targetType);

Expand Down
2 changes: 2 additions & 0 deletions src/include/common/types/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ class LogicalType {

inline PhysicalTypeID getPhysicalType() const { return physicalType; }

inline bool hasExtraTypeInfo() const { return extraTypeInfo != nullptr; }
inline void setExtraTypeInfo(std::unique_ptr<ExtraTypeInfo> typeInfo) {
extraTypeInfo = std::move(typeInfo);
}
Expand Down Expand Up @@ -432,6 +433,7 @@ class LogicalTypeUtils {
static uint32_t getRowLayoutSize(const LogicalType& logicalType);
static bool isNumerical(const LogicalType& dataType);
static bool isNested(const LogicalType& dataType);
static bool isNested(LogicalTypeID logicalTypeID);
static std::vector<LogicalType> getAllValidComparableLogicalTypes();
static std::vector<LogicalTypeID> getNumericalLogicalTypeIDs();
static std::vector<LogicalTypeID> getIntegerLogicalTypeIDs();
Expand Down
96 changes: 70 additions & 26 deletions test/main/prepare_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@
using namespace kuzu::common;
using namespace kuzu::testing;

static void checkTuple(kuzu::processor::FlatTuple* tuple, const std::string& groundTruth) {
ASSERT_STREQ(tuple->toString().c_str(), groundTruth.c_str());
}

TEST_F(ApiTest, MultiParamsPrepare) {
auto preparedStatement = conn->prepare(
"MATCH (a:person) WHERE a.fName STARTS WITH $n OR a.fName CONTAINS $xx RETURN COUNT(*)");
auto result = conn->execute(preparedStatement.get(), std::make_pair(std::string("n"), "A"),
std::make_pair(std::string("xx"), "ooq"));
ASSERT_TRUE(result->hasNext());
auto tuple = result->getNext();
ASSERT_EQ(tuple->getValue(0)->getValue<int64_t>(), 2);
checkTuple(result->getNext().get(), "2\n");
ASSERT_FALSE(result->hasNext());
}

Expand All @@ -19,8 +22,7 @@ TEST_F(ApiTest, PrepareBool) {
conn->prepare("MATCH (a:person) WHERE a.isStudent = $1 RETURN COUNT(*)");
auto result = conn->execute(preparedStatement.get(), std::make_pair(std::string("1"), true));
ASSERT_TRUE(result->hasNext());
auto tuple = result->getNext();
ASSERT_EQ(tuple->getValue(0)->getValue<int64_t>(), 3);
checkTuple(result->getNext().get(), "3\n");
ASSERT_FALSE(result->hasNext());
}

Expand All @@ -29,8 +31,7 @@ TEST_F(ApiTest, PrepareInt) {
auto result =
conn->execute(preparedStatement.get(), std::make_pair(std::string("1"), (int64_t)10));
ASSERT_TRUE(result->hasNext());
auto tuple = result->getNext();
ASSERT_EQ(tuple->getValue(0)->getValue<int64_t>(), 45);
checkTuple(result->getNext().get(), "45\n");
ASSERT_FALSE(result->hasNext());
}

Expand All @@ -40,8 +41,7 @@ TEST_F(ApiTest, PrepareDouble) {
auto result =
conn->execute(preparedStatement.get(), std::make_pair(std::string("1"), (double_t)10.5));
ASSERT_TRUE(result->hasNext());
auto tuple = result->getNext();
ASSERT_EQ(tuple->getValue(0)->getValue<double>(), 15.5);
checkTuple(result->getNext().get(), "15.500000\n");
ASSERT_FALSE(result->hasNext());
}

Expand All @@ -50,8 +50,7 @@ TEST_F(ApiTest, PrepareString) {
conn->prepare("MATCH (a:person) WHERE a.fName STARTS WITH $n RETURN COUNT(*)");
auto result = conn->execute(preparedStatement.get(), std::make_pair(std::string("n"), "A"));
ASSERT_TRUE(result->hasNext());
auto tuple = result->getNext();
ASSERT_EQ(tuple->getValue(0)->getValue<int64_t>(), 1);
checkTuple(result->getNext().get(), "1\n");
ASSERT_FALSE(result->hasNext());
}

Expand All @@ -61,8 +60,7 @@ TEST_F(ApiTest, PrepareDate) {
auto result = conn->execute(
preparedStatement.get(), std::make_pair(std::string("n"), Date::fromDate(1900, 1, 1)));
ASSERT_TRUE(result->hasNext());
auto tuple = result->getNext();
ASSERT_EQ(tuple->getValue(0)->getValue<int64_t>(), 2);
checkTuple(result->getNext().get(), "2\n");
ASSERT_FALSE(result->hasNext());
}

Expand All @@ -74,8 +72,7 @@ TEST_F(ApiTest, PrepareTimestamp) {
auto result = conn->execute(preparedStatement.get(),
std::make_pair(std::string("n"), Timestamp::fromDateTime(date, time)));
ASSERT_TRUE(result->hasNext());
auto tuple = result->getNext();
ASSERT_EQ(tuple->getValue(0)->getValue<int64_t>(), 1);
checkTuple(result->getNext().get(), "1\n");
ASSERT_FALSE(result->hasNext());
}

Expand All @@ -87,21 +84,68 @@ TEST_F(ApiTest, PrepareInterval) {
std::make_pair(
std::string("n"), Interval::fromCString(intervalStr.c_str(), intervalStr.length())));
ASSERT_TRUE(result->hasNext());
auto tuple = result->getNext();
ASSERT_EQ(tuple->getValue(0)->getValue<int64_t>(), 2);
checkTuple(result->getNext().get(), "2\n");
ASSERT_FALSE(result->hasNext());
}

// TEST_F(ApiTest, DefaultParam) {
// auto preparedStatement = conn->prepare("MATCH (a:person) WHERE $1 = $2 RETURN COUNT(*)");
// auto result =
// conn->execute(preparedStatement.get(), std::make_pair(std::string("1"), (int64_t)1.4),
// std::make_pair(std::string("2"), (int64_t)1.4));
// ASSERT_TRUE(result->hasNext());
// auto tuple = result->getNext();
// ASSERT_EQ(tuple->getValue(0)->getValue<int64_t>(), 8);
// ASSERT_FALSE(result->hasNext());
//}
TEST_F(ApiTest, PrepareDefaultParam) {
auto preparedStatement = conn->prepare("RETURN to_int8($1)");
auto result = conn->execute(preparedStatement.get(), std::make_pair(std::string("1"), "1"));
ASSERT_TRUE(result->hasNext());
checkTuple(result->getNext().get(), "1\n");
ASSERT_FALSE(result->hasNext());
preparedStatement = conn->prepare("RETURN size($1)");
result = conn->execute(preparedStatement.get(), std::make_pair(std::string("1"), 1));
ASSERT_FALSE(result->isSuccess());
ASSERT_STREQ(
result->getErrorMessage().c_str(), "Parameter 1 has data type INT32 but expects STRING.");
}

TEST_F(ApiTest, PrepareDefaultListParam) {
auto preparedStatement = conn->prepare("RETURN [1, $1]");
auto result =
conn->execute(preparedStatement.get(), std::make_pair(std::string("1"), (int64_t)1));
ASSERT_TRUE(result->hasNext());
checkTuple(result->getNext().get(), "[1,1]\n");
result = conn->execute(preparedStatement.get(), std::make_pair(std::string("1"), "as"));
ASSERT_FALSE(result->isSuccess());
ASSERT_STREQ(
result->getErrorMessage().c_str(), "Parameter 1 has data type STRING but expects INT64.");
preparedStatement = conn->prepare("RETURN [$1]");
result = conn->execute(preparedStatement.get(), std::make_pair(std::string("1"), "as"));
ASSERT_TRUE(result->hasNext());
checkTuple(result->getNext().get(), "[as]\n");
preparedStatement = conn->prepare("RETURN [to_int32($1)]");
result = conn->execute(preparedStatement.get(), std::make_pair(std::string("1"), "10"));
ASSERT_TRUE(result->hasNext());
checkTuple(result->getNext().get(), "[10]\n");
}

TEST_F(ApiTest, PrepareDefaultStructParam) {
auto preparedStatement = conn->prepare("RETURN {a:$1}");
auto result = conn->execute(preparedStatement.get(), std::make_pair(std::string("1"), "10"));
ASSERT_TRUE(result->hasNext());
checkTuple(result->getNext().get(), "{a: 10}\n");
result = conn->execute(preparedStatement.get(), std::make_pair(std::string("1"), 1));
ASSERT_FALSE(result->isSuccess());
ASSERT_STREQ(
result->getErrorMessage().c_str(), "Parameter 1 has data type INT32 but expects STRING.");
}

TEST_F(ApiTest, PrepareDefaultMapParam) {
auto preparedStatement = conn->prepare("RETURN map([$1], [$2])");
auto result = conn->execute(preparedStatement.get(), std::make_pair(std::string("1"), "10"),
std::make_pair(std::string("2"), "abc"));
ASSERT_TRUE(result->hasNext());
checkTuple(result->getNext().get(), "{10=abc}\n");
}

TEST_F(ApiTest, PrepareDefaultUnionParam) {
auto preparedStatement = conn->prepare("RETURN union_value(a := $1)");
auto result = conn->execute(preparedStatement.get(), std::make_pair(std::string("1"), "10"));
ASSERT_TRUE(result->hasNext());
checkTuple(result->getNext().get(), "10\n");
}

TEST_F(ApiTest, PrepareLargeJoin) {
auto preparedStatement = conn->prepare(
Expand Down
Loading