Skip to content

Commit

Permalink
add function CAST(item, type)
Browse files Browse the repository at this point in the history
  • Loading branch information
AEsir777 committed Nov 2, 2023
1 parent 94a611c commit 944a80e
Show file tree
Hide file tree
Showing 22 changed files with 529 additions and 293 deletions.
6 changes: 4 additions & 2 deletions src/binder/bind/bind_graph_pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -401,11 +401,13 @@ std::pair<uint64_t, uint64_t> Binder::bindVariableLengthRelBound(
auto recursiveInfo = relPattern.getRecursiveInfo();
uint32_t lowerBound;
function::CastStringToTypes::operation(
recursiveInfo->lowerBound.c_str(), recursiveInfo->lowerBound.length(), lowerBound);
ku_string_t{recursiveInfo->lowerBound.c_str(), recursiveInfo->lowerBound.length()},
lowerBound);
auto upperBound = clientContext->varLengthExtendMaxDepth;
if (!recursiveInfo->upperBound.empty()) {
function::CastStringToTypes::operation(
recursiveInfo->upperBound.c_str(), recursiveInfo->upperBound.length(), upperBound);
ku_string_t{recursiveInfo->upperBound.c_str(), recursiveInfo->upperBound.length()},
upperBound);
}
if (lowerBound > upperBound) {
throw BinderException(
Expand Down
35 changes: 29 additions & 6 deletions src/binder/bind_expression/bind_function_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,17 @@ std::shared_ptr<Expression> ExpressionBinder::bindFunctionExpression(
std::shared_ptr<Expression> ExpressionBinder::bindScalarFunctionExpression(
const ParsedExpression& parsedExpression, const std::string& functionName) {
expression_vector children;

if (functionName == "CAST") {
if (parsedExpression.getNumChildren() != 2) {
throw BinderException("Cannot match a built-in function for given function CAST");
}
auto type = binder->bindDataType(parsedExpression.getChild(1)->toString());
children.push_back(bindExpression(*parsedExpression.getChild(0)));

return bindScalarFunctionExpression(children, functionName, type.get());
}

for (auto i = 0u; i < parsedExpression.getNumChildren(); ++i) {
auto child = bindExpression(*parsedExpression.getChild(i));
children.push_back(std::move(child));
Expand All @@ -50,14 +61,21 @@ std::shared_ptr<Expression> ExpressionBinder::bindScalarFunctionExpression(
}

std::shared_ptr<Expression> ExpressionBinder::bindScalarFunctionExpression(
const expression_vector& children, const std::string& functionName) {
const expression_vector& children, const std::string& functionName, const LogicalType* type) {
auto builtInFunctions = binder->catalog.getBuiltInFunctions();
std::vector<LogicalType*> childrenTypes;
for (auto& child : children) {
childrenTypes.push_back(&child->dataType);
}
auto function = reinterpret_cast<function::ScalarFunction*>(
builtInFunctions->matchScalarFunction(functionName, childrenTypes));

function::ScalarFunction* function;
if (type) {
function = reinterpret_cast<function::ScalarFunction*>(
builtInFunctions->matchCastScalarFunction(functionName, childrenTypes, type));
} else {
function = reinterpret_cast<function::ScalarFunction*>(
builtInFunctions->matchScalarFunction(functionName, childrenTypes));
}
expression_vector childrenAfterCast;
for (auto i = 0u; i < children.size(); ++i) {
auto targetType =
Expand All @@ -68,14 +86,19 @@ std::shared_ptr<Expression> ExpressionBinder::bindScalarFunctionExpression(
if (function->bindFunc) {
bindData = function->bindFunc(childrenAfterCast, function);
} else {
bindData =
std::make_unique<function::FunctionBindData>(LogicalType(function->returnTypeID));
if (function->castFunc) {
bindData = std::make_unique<function::StringCastFunctionBindData>(
LogicalType(function->returnTypeID));
} else {
bindData =
std::make_unique<function::FunctionBindData>(LogicalType(function->returnTypeID));
}
}
auto uniqueExpressionName =
ScalarFunctionExpression::getUniqueName(function->name, childrenAfterCast);
return make_shared<ScalarFunctionExpression>(functionName, FUNCTION, std::move(bindData),
std::move(childrenAfterCast), function->execFunc, function->selectFunc,
function->compileFunc, uniqueExpressionName);
function->compileFunc, function->castFunc, uniqueExpressionName);
}

std::shared_ptr<Expression> ExpressionBinder::bindAggregateFunctionExpression(
Expand Down
2 changes: 1 addition & 1 deletion src/c_api/value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ kuzu_int128_t kuzu_int128_t_from_string(const char* str) {
int128_t int128_val = 0;
kuzu_int128_t c_int128;
try {
kuzu::function::CastStringToTypes::operation(str, strlen(str), int128_val);
kuzu::function::CastStringToTypes::operation(ku_string_t{str, strlen(str)}, int128_val);
c_int128.low = int128_val.low;
c_int128.high = int128_val.high;
} catch (ConversionException& e) {
Expand Down
9 changes: 9 additions & 0 deletions src/common/types/ku_string.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,15 @@
namespace kuzu {
namespace common {

ku_string_t::ku_string_t(const char* value, uint64_t length) : len(length) {
if (isShortString(length)) {
memcpy(prefix, value, length);
return;
}
overflowPtr = (uint64_t)(value);
memcpy(prefix, value, PREFIX_LENGTH);
}

void ku_string_t::set(const std::string& value) {
set(value.data(), value.length());
}
Expand Down
13 changes: 11 additions & 2 deletions src/expression_evaluator/function_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,25 @@ namespace evaluator {

void FunctionExpressionEvaluator::init(const ResultSet& resultSet, MemoryManager* memoryManager) {
ExpressionEvaluator::init(resultSet, memoryManager);
execFunc = ((binder::ScalarFunctionExpression&)*expression).execFunc;
auto expr = reinterpret_cast<binder::ScalarFunctionExpression*>(expression.get());
execFunc = expr->execFunc;
castFunc = expr->castFunc;
if (expression->dataType.getLogicalTypeID() == LogicalTypeID::BOOL) {
selectFunc = ((binder::ScalarFunctionExpression&)*expression).selectFunc;
selectFunc = expr->selectFunc;
}
}

void FunctionExpressionEvaluator::evaluate() {
for (auto& child : children) {
child->evaluate();
}
if (castFunc != nullptr) {
auto expr = reinterpret_cast<binder::ScalarFunctionExpression*>(expression.get());
castFunc(parameters, *resultVector,
&reinterpret_cast<function::StringCastFunctionBindData*>(expr->getBindData())
->csvConfig);
return;
}
if (execFunc != nullptr) {
execFunc(parameters, *resultVector);
}
Expand Down
17 changes: 17 additions & 0 deletions src/function/built_in_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,22 @@ Function* BuiltInFunctions::matchScalarFunction(
return candidateFunctions[0];
}

Function* BuiltInFunctions::matchCastScalarFunction(
const std::string& name, const std::vector<LogicalType*>& inputTypes, const LogicalType* type) {
auto& functionSet = functions.at(name);
std::vector<Function*> candidateFunctions;
for (auto& function : functionSet) {
if (function->parameterTypeIDs[0] == inputTypes[0]->getLogicalTypeID() &&
reinterpret_cast<BaseScalarFunction*>(function.get())->returnTypeID ==
type->getLogicalTypeID()) {
candidateFunctions.push_back(function.get());
break;
}
}
validateNonEmptyCandidateFunctions(candidateFunctions, name, inputTypes);
return candidateFunctions[0];
}

AggregateFunction* BuiltInFunctions::matchAggregateFunction(
const std::string& name, const std::vector<common::LogicalType*>& inputTypes, bool isDistinct) {
auto& functionSet = functions.at(name);
Expand Down Expand Up @@ -597,6 +613,7 @@ void BuiltInFunctions::registerCastFunctions() {
functions.insert({CAST_TO_UINT8_FUNC_NAME, CastToUInt8Function::getFunctionSet()});
functions.insert({CAST_TO_INT128_FUNC_NAME, CastToInt128Function::getFunctionSet()});
functions.insert({CAST_TO_BOOL_FUNC_NAME, CastToBoolFunction::getFunctionSet()});
functions.insert({CAST_FUNC_NAME, CastAnyFunction::getFunctionSet()});
}

void BuiltInFunctions::registerListFunctions() {
Expand Down
Loading

0 comments on commit 944a80e

Please sign in to comment.