Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

create CAST(item, type) function #2326

Merged
merged 1 commit into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions dataset/reader/union_correct.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"false","34","255","18446744073709551615",fsdfa
" true ","432","0","-1.43241543","543fasf"
" 34234 ","4294967295","65535",-128,432
" -42342345 ","-1","-1","-129",fasf
" T ","2022-06-06","4324.123","-32768",fds
"TRUE","2019-03-19","-12.3432","32768",""
"1","-2147483648","1970-01-01 00:00:00.004666-10","-32769",fsdxcv
"0","0","2014-05-12 12:11:59",4324254534123134324321.4343252435,"fsaf"
" F","-4325"," 14 ",18446744073709551616," dfsa"
10 changes: 6 additions & 4 deletions src/binder/bind/bind_graph_pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -398,12 +398,14 @@ std::pair<uint64_t, uint64_t> Binder::bindVariableLengthRelBound(
const kuzu::parser::RelPattern& relPattern) {
auto recursiveInfo = relPattern.getRecursiveInfo();
uint32_t lowerBound;
function::CastStringToTypes::operation(
recursiveInfo->lowerBound.c_str(), recursiveInfo->lowerBound.length(), lowerBound);
function::CastString::operation(
ku_string_t{recursiveInfo->lowerBound.c_str(), recursiveInfo->lowerBound.length()},
lowerBound);
auto upperBound = clientContext->varLengthExtendMaxDepth;
if (!recursiveInfo->upperBound.empty()) {
function::CastStringToTypes::operation(
recursiveInfo->upperBound.c_str(), recursiveInfo->upperBound.length(), upperBound);
function::CastString::operation(
ku_string_t{recursiveInfo->upperBound.c_str(), recursiveInfo->upperBound.length()},
upperBound);
}
if (lowerBound > upperBound) {
throw BinderException(
Expand Down
28 changes: 17 additions & 11 deletions src/binder/bind_expression/bind_function_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,23 @@ std::shared_ptr<Expression> ExpressionBinder::bindScalarFunctionExpression(
auto function = reinterpret_cast<function::ScalarFunction*>(
builtInFunctions->matchScalarFunction(functionName, childrenTypes));
expression_vector childrenAfterCast;
for (auto i = 0u; i < children.size(); ++i) {
auto targetType =
function->isVarLength ? function->parameterTypeIDs[0] : function->parameterTypeIDs[i];
childrenAfterCast.push_back(implicitCastIfNecessary(children[i], targetType));
}
std::unique_ptr<function::FunctionBindData> bindData;
if (function->bindFunc) {
bindData = function->bindFunc(childrenAfterCast, function);
if (functionName == CAST_FUNC_NAME) {
bindData = function->bindFunc(children, function);
childrenAfterCast.push_back(
implicitCastIfNecessary(children[0], function->parameterTypeIDs[0]));
} else {
bindData =
std::make_unique<function::FunctionBindData>(LogicalType(function->returnTypeID));
for (auto i = 0u; i < children.size(); ++i) {
auto targetType = function->isVarLength ? function->parameterTypeIDs[0] :
function->parameterTypeIDs[i];
childrenAfterCast.push_back(implicitCastIfNecessary(children[i], targetType));
}
if (function->bindFunc) {
bindData = function->bindFunc(childrenAfterCast, function);
} else {
bindData =
std::make_unique<function::FunctionBindData>(LogicalType(function->returnTypeID));
}
}
auto uniqueExpressionName =
ScalarFunctionExpression::getUniqueName(function->name, childrenAfterCast);
Expand Down Expand Up @@ -253,8 +259,8 @@ std::shared_ptr<Expression> ExpressionBinder::bindLabelFunction(const Expression
auto bindData =
std::make_unique<function::FunctionBindData>(LogicalType(LogicalTypeID::STRING));
auto uniqueExpressionName = ScalarFunctionExpression::getUniqueName(LABEL_FUNC_NAME, children);
return make_shared<ScalarFunctionExpression>(LABEL_FUNC_NAME, FUNCTION, std::move(bindData),
std::move(children), execFunc, nullptr, uniqueExpressionName);
return std::make_shared<ScalarFunctionExpression>(LABEL_FUNC_NAME, FUNCTION,
std::move(bindData), std::move(children), execFunc, nullptr, uniqueExpressionName);
}

std::unique_ptr<Expression> ExpressionBinder::createInternalLengthExpression(
Expand Down
6 changes: 5 additions & 1 deletion src/binder/expression/function_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ namespace kuzu {
namespace binder {

std::string ScalarFunctionExpression::getUniqueName(
const std::string& functionName, kuzu::binder::expression_vector& children) {
const std::string& functionName, const kuzu::binder::expression_vector& children) {
auto result = functionName + "(";
for (auto& child : children) {
result += child->getUniqueName() + ", ";
Expand All @@ -18,6 +18,10 @@ std::string ScalarFunctionExpression::getUniqueName(
std::string ScalarFunctionExpression::toStringInternal() const {
auto result = functionName + "(";
result += ExpressionUtil::toString(children);
if (functionName == "CAST") {
AEsir777 marked this conversation as resolved.
Show resolved Hide resolved
result += ", ";
result += common::LogicalTypeUtils::dataTypeToString(bindData->resultType);
}
result += ")";
return result;
}
Expand Down
2 changes: 1 addition & 1 deletion src/c_api/value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ kuzu_int128_t kuzu_int128_t_from_string(const char* str) {
int128_t int128_val = 0;
kuzu_int128_t c_int128;
try {
kuzu::function::CastStringToTypes::operation(str, strlen(str), int128_val);
kuzu::function::CastString::operation(ku_string_t{str, strlen(str)}, int128_val);
c_int128.low = int128_val.low;
c_int128.high = int128_val.high;
} catch (ConversionException& e) {
Expand Down
9 changes: 9 additions & 0 deletions src/common/types/ku_string.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,15 @@
namespace kuzu {
namespace common {

ku_string_t::ku_string_t(const char* value, uint64_t length) : len(length) {
if (isShortString(length)) {
memcpy(prefix, value, length);
return;
}
overflowPtr = (uint64_t)(value);
memcpy(prefix, value, PREFIX_LENGTH);
}

void ku_string_t::set(const std::string& value) {
set(value.data(), value.length());
}
Expand Down
11 changes: 9 additions & 2 deletions src/expression_evaluator/function_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,15 @@ void FunctionExpressionEvaluator::evaluate() {
for (auto& child : children) {
child->evaluate();
}
auto expr = reinterpret_cast<binder::ScalarFunctionExpression*>(expression.get());
AEsir777 marked this conversation as resolved.
Show resolved Hide resolved
if (expr->getFunctionName() == CAST_FUNC_NAME &&
parameters[0]->dataType.getLogicalTypeID() == LogicalTypeID::STRING) {
execFunc(parameters, *resultVector,
reinterpret_cast<function::StringCastFunctionBindData*>(expr->getBindData()));
return;
}
if (execFunc != nullptr) {
execFunc(parameters, *resultVector);
execFunc(parameters, *resultVector, nullptr);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can make the dataptr default type to null, so you don't need to pass in the nullptr there.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since
using scalar_exec_func = std::function<void( const std::vector<std::shared_ptr<common::ValueVector>>&, common::ValueVector&, void*)>;
specifies the function needs to take in 3 arguments, I still need to pass nullptr.

}
}

Expand All @@ -34,7 +41,7 @@ bool FunctionExpressionEvaluator::select(SelectionVector& selVector) {
// implemented (e.g. list_contains). We should remove this if statement eventually.
if (selectFunc == nullptr) {
assert(resultVector->dataType.getLogicalTypeID() == LogicalTypeID::BOOL);
execFunc(parameters, *resultVector);
execFunc(parameters, *resultVector, nullptr);
auto numSelectedValues = 0u;
for (auto i = 0u; i < resultVector->state->selVector->selectedSize; ++i) {
auto pos = resultVector->state->selVector->selectedPositions[i];
Expand Down
4 changes: 4 additions & 0 deletions src/function/built_in_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ Function* BuiltInFunctions::matchScalarFunction(
uint32_t minCost = UINT32_MAX;
for (auto& function : functionSet) {
auto func = reinterpret_cast<Function*>(function.get());
if (name == CAST_FUNC_NAME) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove this. And inside bindScalarFunctionExpression we check functionName == "Cast" and directly get the function without going through the the matchScalarFunction

return func;
}
auto cost = getFunctionCost(inputTypes, func, isOverload);
if (cost == UINT32_MAX) {
continue;
Expand Down Expand Up @@ -597,6 +600,7 @@ void BuiltInFunctions::registerCastFunctions() {
functions.insert({CAST_TO_UINT8_FUNC_NAME, CastToUInt8Function::getFunctionSet()});
functions.insert({CAST_TO_INT128_FUNC_NAME, CastToInt128Function::getFunctionSet()});
functions.insert({CAST_TO_BOOL_FUNC_NAME, CastToBoolFunction::getFunctionSet()});
functions.insert({CAST_FUNC_NAME, CastAnyFunction::getFunctionSet()});
}

void BuiltInFunctions::registerListFunctions() {
Expand Down
Loading