Skip to content

Commit

Permalink
finish CAST(item, type)
Browse files Browse the repository at this point in the history
  • Loading branch information
AEsir777 committed Nov 6, 2023
1 parent 4194475 commit 0da2553
Show file tree
Hide file tree
Showing 44 changed files with 1,520 additions and 390 deletions.
9 changes: 9 additions & 0 deletions dataset/reader/union_correct.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"false","34","255","18446744073709551615",fsdfa
" true ","432","0","-1.43241543","543fasf"
" 34234 ","4294967295","65535",-128,432
" -42342345 ","-1","-1","-129",fasf
" T ","2022-06-06","4324.123","-32768",fds
"TRUE","2019-03-19","-12.3432","32768",""
"1","-2147483648","1970-01-01 00:00:00.004666-10","-32769",fsdxcv
"0","0","2014-05-12 12:11:59",4324254534123134324321.4343252435,"fsaf"
" F","-4325"," 14 ",18446744073709551616," dfsa"
10 changes: 6 additions & 4 deletions src/binder/bind/bind_graph_pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -398,12 +398,14 @@ std::pair<uint64_t, uint64_t> Binder::bindVariableLengthRelBound(
const kuzu::parser::RelPattern& relPattern) {
auto recursiveInfo = relPattern.getRecursiveInfo();
uint32_t lowerBound;
function::CastStringToTypes::operation(
recursiveInfo->lowerBound.c_str(), recursiveInfo->lowerBound.length(), lowerBound);
function::CastString::operation(
ku_string_t{recursiveInfo->lowerBound.c_str(), recursiveInfo->lowerBound.length()},
lowerBound);
auto upperBound = clientContext->varLengthExtendMaxDepth;
if (!recursiveInfo->upperBound.empty()) {
function::CastStringToTypes::operation(
recursiveInfo->upperBound.c_str(), recursiveInfo->upperBound.length(), upperBound);
function::CastString::operation(
ku_string_t{recursiveInfo->upperBound.c_str(), recursiveInfo->upperBound.length()},
upperBound);
}
if (lowerBound > upperBound) {
throw BinderException(
Expand Down
28 changes: 17 additions & 11 deletions src/binder/bind_expression/bind_function_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,23 @@ std::shared_ptr<Expression> ExpressionBinder::bindScalarFunctionExpression(
auto function = reinterpret_cast<function::ScalarFunction*>(
builtInFunctions->matchScalarFunction(functionName, childrenTypes));
expression_vector childrenAfterCast;
for (auto i = 0u; i < children.size(); ++i) {
auto targetType =
function->isVarLength ? function->parameterTypeIDs[0] : function->parameterTypeIDs[i];
childrenAfterCast.push_back(implicitCastIfNecessary(children[i], targetType));
}
std::unique_ptr<function::FunctionBindData> bindData;
if (function->bindFunc) {
bindData = function->bindFunc(childrenAfterCast, function);
if (functionName == CAST_FUNC_NAME) {
bindData = function->bindFunc(children, function);
childrenAfterCast.push_back(
implicitCastIfNecessary(children[0], function->parameterTypeIDs[0]));
} else {
bindData =
std::make_unique<function::FunctionBindData>(LogicalType(function->returnTypeID));
for (auto i = 0u; i < children.size(); ++i) {
auto targetType = function->isVarLength ? function->parameterTypeIDs[0] :
function->parameterTypeIDs[i];
childrenAfterCast.push_back(implicitCastIfNecessary(children[i], targetType));
}
if (function->bindFunc) {
bindData = function->bindFunc(childrenAfterCast, function);
} else {
bindData =
std::make_unique<function::FunctionBindData>(LogicalType(function->returnTypeID));
}
}
auto uniqueExpressionName =
ScalarFunctionExpression::getUniqueName(function->name, childrenAfterCast);
Expand Down Expand Up @@ -253,8 +259,8 @@ std::shared_ptr<Expression> ExpressionBinder::bindLabelFunction(const Expression
auto bindData =
std::make_unique<function::FunctionBindData>(LogicalType(LogicalTypeID::STRING));
auto uniqueExpressionName = ScalarFunctionExpression::getUniqueName(LABEL_FUNC_NAME, children);
return make_shared<ScalarFunctionExpression>(LABEL_FUNC_NAME, FUNCTION, std::move(bindData),
std::move(children), execFunc, nullptr, uniqueExpressionName);
return std::make_shared<ScalarFunctionExpression>(LABEL_FUNC_NAME, FUNCTION,
std::move(bindData), std::move(children), execFunc, nullptr, uniqueExpressionName);
}

std::unique_ptr<Expression> ExpressionBinder::createInternalLengthExpression(
Expand Down
6 changes: 5 additions & 1 deletion src/binder/expression/function_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ namespace kuzu {
namespace binder {

std::string ScalarFunctionExpression::getUniqueName(
const std::string& functionName, kuzu::binder::expression_vector& children) {
const std::string& functionName, const kuzu::binder::expression_vector& children) {
auto result = functionName + "(";
for (auto& child : children) {
result += child->getUniqueName() + ", ";
Expand All @@ -18,6 +18,10 @@ std::string ScalarFunctionExpression::getUniqueName(
std::string ScalarFunctionExpression::toStringInternal() const {
auto result = functionName + "(";
result += ExpressionUtil::toString(children);
if (functionName == "CAST") {
result += ", ";
result += common::LogicalTypeUtils::dataTypeToString(bindData->resultType);
}
result += ")";
return result;
}
Expand Down
2 changes: 1 addition & 1 deletion src/c_api/value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ kuzu_int128_t kuzu_int128_t_from_string(const char* str) {
int128_t int128_val = 0;
kuzu_int128_t c_int128;
try {
kuzu::function::CastStringToTypes::operation(str, strlen(str), int128_val);
kuzu::function::CastString::operation(ku_string_t{str, strlen(str)}, int128_val);
c_int128.low = int128_val.low;
c_int128.high = int128_val.high;
} catch (ConversionException& e) {
Expand Down
9 changes: 9 additions & 0 deletions src/common/types/ku_string.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,15 @@
namespace kuzu {
namespace common {

ku_string_t::ku_string_t(const char* value, uint64_t length) : len(length) {
if (isShortString(length)) {
memcpy(prefix, value, length);
return;
}
overflowPtr = (uint64_t)(value);
memcpy(prefix, value, PREFIX_LENGTH);
}

void ku_string_t::set(const std::string& value) {
set(value.data(), value.length());
}
Expand Down
11 changes: 9 additions & 2 deletions src/expression_evaluator/function_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,15 @@ void FunctionExpressionEvaluator::evaluate() {
for (auto& child : children) {
child->evaluate();
}
auto expr = reinterpret_cast<binder::ScalarFunctionExpression*>(expression.get());
if (expr->getFunctionName() == CAST_FUNC_NAME &&
parameters[0]->dataType.getLogicalTypeID() == LogicalTypeID::STRING) {
execFunc(parameters, *resultVector,
reinterpret_cast<function::StringCastFunctionBindData*>(expr->getBindData()));
return;
}
if (execFunc != nullptr) {
execFunc(parameters, *resultVector);
execFunc(parameters, *resultVector, nullptr);
}
}

Expand All @@ -34,7 +41,7 @@ bool FunctionExpressionEvaluator::select(SelectionVector& selVector) {
// implemented (e.g. list_contains). We should remove this if statement eventually.
if (selectFunc == nullptr) {
assert(resultVector->dataType.getLogicalTypeID() == LogicalTypeID::BOOL);
execFunc(parameters, *resultVector);
execFunc(parameters, *resultVector, nullptr);
auto numSelectedValues = 0u;
for (auto i = 0u; i < resultVector->state->selVector->selectedSize; ++i) {
auto pos = resultVector->state->selVector->selectedPositions[i];
Expand Down
4 changes: 4 additions & 0 deletions src/function/built_in_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ Function* BuiltInFunctions::matchScalarFunction(
uint32_t minCost = UINT32_MAX;
for (auto& function : functionSet) {
auto func = reinterpret_cast<Function*>(function.get());
if (name == CAST_FUNC_NAME) {
return func;
}
auto cost = getFunctionCost(inputTypes, func, isOverload);
if (cost == UINT32_MAX) {
continue;
Expand Down Expand Up @@ -597,6 +600,7 @@ void BuiltInFunctions::registerCastFunctions() {
functions.insert({CAST_TO_UINT8_FUNC_NAME, CastToUInt8Function::getFunctionSet()});
functions.insert({CAST_TO_INT128_FUNC_NAME, CastToInt128Function::getFunctionSet()});
functions.insert({CAST_TO_BOOL_FUNC_NAME, CastToBoolFunction::getFunctionSet()});
functions.insert({CAST_FUNC_NAME, CastAnyFunction::getFunctionSet()});
}

void BuiltInFunctions::registerListFunctions() {
Expand Down
Loading

0 comments on commit 0da2553

Please sign in to comment.