Skip to content

Commit

Permalink
Merge pull request #1586 from kuzudb/win-vector-operations-fix
Browse files Browse the repository at this point in the history
Fix return functor windows compatibility
  • Loading branch information
andyfengHKU committed May 29, 2023
2 parents 2478bc8 + d0dcbd5 commit 75a0bef
Show file tree
Hide file tree
Showing 23 changed files with 424 additions and 367 deletions.
10 changes: 6 additions & 4 deletions src/binder/bind_expression/bind_boolean_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@ std::shared_ptr<Expression> ExpressionBinder::bindBooleanExpression(
childrenAfterCast.push_back(implicitCastIfNecessary(child, LogicalTypeID::BOOL));
}
auto functionName = expressionTypeToString(expressionType);
auto execFunc =
function::VectorBooleanOperations::bindExecFunction(expressionType, childrenAfterCast);
auto selectFunc =
function::VectorBooleanOperations::bindSelectFunction(expressionType, childrenAfterCast);
function::scalar_exec_func execFunc;
function::VectorBooleanOperations::bindExecFunction(
expressionType, childrenAfterCast, execFunc);
function::scalar_select_func selectFunc;
function::VectorBooleanOperations::bindSelectFunction(
expressionType, childrenAfterCast, selectFunc);
auto bindData = std::make_unique<function::FunctionBindData>(LogicalType(LogicalTypeID::BOOL));
auto uniqueExpressionName =
ScalarFunctionExpression::getUniqueName(functionName, childrenAfterCast);
Expand Down
6 changes: 4 additions & 2 deletions src/binder/bind_expression/bind_null_operator_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ std::shared_ptr<Expression> ExpressionBinder::bindNullOperatorExpression(
}
auto expressionType = parsedExpression.getExpressionType();
auto functionName = expressionTypeToString(expressionType);
auto execFunc = function::VectorNullOperations::bindExecFunction(expressionType, children);
auto selectFunc = function::VectorNullOperations::bindSelectFunction(expressionType, children);
function::scalar_exec_func execFunc;
function::VectorNullOperations::bindExecFunction(expressionType, children, execFunc);
function::scalar_select_func selectFunc;
function::VectorNullOperations::bindSelectFunction(expressionType, children, selectFunc);
auto bindData = std::make_unique<function::FunctionBindData>(
common::LogicalType(common::LogicalTypeID::BOOL));
auto uniqueExpressionName = ScalarFunctionExpression::getUniqueName(functionName, children);
Expand Down
9 changes: 5 additions & 4 deletions src/binder/expression_binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,13 @@ std::shared_ptr<Expression> ExpressionBinder::implicitCast(
auto functionName = VectorCastOperations::bindImplicitCastFuncName(targetType);
auto children = expression_vector{expression};
auto bindData = std::make_unique<FunctionBindData>(targetType);
function::scalar_exec_func execFunc;
VectorCastOperations::bindImplicitCastFunc(
expression->dataType.getLogicalTypeID(), targetType.getLogicalTypeID(), execFunc);
auto uniqueName = ScalarFunctionExpression::getUniqueName(functionName, children);
return std::make_shared<ScalarFunctionExpression>(functionName, FUNCTION,
std::move(bindData), std::move(children),
VectorCastOperations::bindImplicitCastFunc(
expression->dataType.getLogicalTypeID(), targetType.getLogicalTypeID()),
nullptr /* selectFunc */, std::move(uniqueName));
std::move(bindData), std::move(children), execFunc, nullptr /* selectFunc */,
std::move(uniqueName));
} else {
throw common::BinderException(
"Expression " + expression->toString() + " has data type " +
Expand Down
56 changes: 32 additions & 24 deletions src/function/vector_boolean_operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,91 +7,99 @@ using namespace kuzu::common;
namespace kuzu {
namespace function {

scalar_exec_func VectorBooleanOperations::bindExecFunction(
ExpressionType expressionType, const binder::expression_vector& children) {
void VectorBooleanOperations::bindExecFunction(ExpressionType expressionType,
const binder::expression_vector& children, scalar_exec_func& func) {
if (isExpressionBinary(expressionType)) {
return bindBinaryExecFunction(expressionType, children);
bindBinaryExecFunction(expressionType, children, func);
} else {
assert(isExpressionUnary(expressionType));
return bindUnaryExecFunction(expressionType, children);
bindUnaryExecFunction(expressionType, children, func);
}
}

scalar_select_func VectorBooleanOperations::bindSelectFunction(
ExpressionType expressionType, const binder::expression_vector& children) {
void VectorBooleanOperations::bindSelectFunction(ExpressionType expressionType,
const binder::expression_vector& children, scalar_select_func& func) {
if (isExpressionBinary(expressionType)) {
return bindBinarySelectFunction(expressionType, children);
bindBinarySelectFunction(expressionType, children, func);
} else {
assert(isExpressionUnary(expressionType));
return bindUnarySelectFunction(expressionType, children);
bindUnarySelectFunction(expressionType, children, func);
}
}

scalar_exec_func VectorBooleanOperations::bindBinaryExecFunction(
ExpressionType expressionType, const binder::expression_vector& children) {
void VectorBooleanOperations::bindBinaryExecFunction(ExpressionType expressionType,
const binder::expression_vector& children, scalar_exec_func& func) {
assert(children.size() == 2);
auto leftType = children[0]->dataType;
auto rightType = children[1]->dataType;
assert(leftType.getLogicalTypeID() == LogicalTypeID::BOOL &&
rightType.getLogicalTypeID() == LogicalTypeID::BOOL);
switch (expressionType) {
case AND: {
return &BinaryBooleanExecFunction<operation::And>;
func = &BinaryBooleanExecFunction<operation::And>;
return;
}
case OR: {
return &BinaryBooleanExecFunction<operation::Or>;
func = &BinaryBooleanExecFunction<operation::Or>;
return;
}
case XOR: {
return &BinaryBooleanExecFunction<operation::Xor>;
func = &BinaryBooleanExecFunction<operation::Xor>;
return;
}
default:
throw RuntimeException("Invalid expression type " + expressionTypeToString(expressionType) +
" for VectorBooleanOperations::bindBinaryExecFunction.");
}
}

scalar_select_func VectorBooleanOperations::bindBinarySelectFunction(
ExpressionType expressionType, const binder::expression_vector& children) {
void VectorBooleanOperations::bindBinarySelectFunction(ExpressionType expressionType,
const binder::expression_vector& children, scalar_select_func& func) {
assert(children.size() == 2);
auto leftType = children[0]->dataType;
auto rightType = children[1]->dataType;
assert(leftType.getLogicalTypeID() == LogicalTypeID::BOOL &&
rightType.getLogicalTypeID() == LogicalTypeID::BOOL);
switch (expressionType) {
case AND: {
return &BinaryBooleanSelectFunction<operation::And>;
func = &BinaryBooleanSelectFunction<operation::And>;
return;
}
case OR: {
return &BinaryBooleanSelectFunction<operation::Or>;
func = &BinaryBooleanSelectFunction<operation::Or>;
return;
}
case XOR: {
return &BinaryBooleanSelectFunction<operation::Xor>;
func = &BinaryBooleanSelectFunction<operation::Xor>;
return;
}
default:
throw RuntimeException("Invalid expression type " + expressionTypeToString(expressionType) +
" for VectorBooleanOperations::bindBinarySelectFunction.");
}
}

scalar_exec_func VectorBooleanOperations::bindUnaryExecFunction(
ExpressionType expressionType, const binder::expression_vector& children) {
void VectorBooleanOperations::bindUnaryExecFunction(ExpressionType expressionType,
const binder::expression_vector& children, scalar_exec_func& func) {
assert(children.size() == 1 && children[0]->dataType.getLogicalTypeID() == LogicalTypeID::BOOL);
switch (expressionType) {
case NOT: {
return &UnaryBooleanExecFunction<operation::Not>;
func = &UnaryBooleanExecFunction<operation::Not>;
return;
}
default:
throw RuntimeException("Invalid expression type " + expressionTypeToString(expressionType) +
" for VectorBooleanOperations::bindUnaryExecFunction.");
}
}

scalar_select_func VectorBooleanOperations::bindUnarySelectFunction(
ExpressionType expressionType, const binder::expression_vector& children) {
void VectorBooleanOperations::bindUnarySelectFunction(ExpressionType expressionType,
const binder::expression_vector& children, scalar_select_func& func) {
assert(children.size() == 1 && children[0]->dataType.getLogicalTypeID() == LogicalTypeID::BOOL);
switch (expressionType) {
case NOT: {
return &UnaryBooleanSelectFunction<operation::Not>;
func = &UnaryBooleanSelectFunction<operation::Not>;
return;
}
default:
throw RuntimeException("Invalid expression type " + expressionTypeToString(expressionType) +
Expand Down
28 changes: 18 additions & 10 deletions src/function/vector_cast_operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,35 +55,43 @@ std::string VectorCastOperations::bindImplicitCastFuncName(const common::Logical
}
}

scalar_exec_func VectorCastOperations::bindImplicitCastFunc(
common::LogicalTypeID sourceTypeID, common::LogicalTypeID targetTypeID) {
void VectorCastOperations::bindImplicitCastFunc(common::LogicalTypeID sourceTypeID,
common::LogicalTypeID targetTypeID, scalar_exec_func& func) {
switch (targetTypeID) {
case common::LogicalTypeID::INT16: {
return bindImplicitNumericalCastFunc<int16_t, operation::CastToInt16>(sourceTypeID);
bindImplicitNumericalCastFunc<int16_t, operation::CastToInt16>(sourceTypeID, func);
return;
}
case common::LogicalTypeID::INT32: {
return bindImplicitNumericalCastFunc<int32_t, operation::CastToInt32>(sourceTypeID);
bindImplicitNumericalCastFunc<int32_t, operation::CastToInt32>(sourceTypeID, func);
return;
}
case common::LogicalTypeID::INT64: {
return bindImplicitNumericalCastFunc<int64_t, operation::CastToInt64>(sourceTypeID);
bindImplicitNumericalCastFunc<int64_t, operation::CastToInt64>(sourceTypeID, func);
return;
}
case common::LogicalTypeID::FLOAT: {
return bindImplicitNumericalCastFunc<float_t, operation::CastToFloat>(sourceTypeID);
bindImplicitNumericalCastFunc<float_t, operation::CastToFloat>(sourceTypeID, func);
return;
}
case common::LogicalTypeID::DOUBLE: {
return bindImplicitNumericalCastFunc<double_t, operation::CastToDouble>(sourceTypeID);
bindImplicitNumericalCastFunc<double_t, operation::CastToDouble>(sourceTypeID, func);
return;
}
case common::LogicalTypeID::DATE: {
assert(sourceTypeID == common::LogicalTypeID::STRING);
return &UnaryExecFunction<ku_string_t, date_t, operation::CastStringToDate>;
func = &UnaryExecFunction<ku_string_t, date_t, operation::CastStringToDate>;
return;
}
case common::LogicalTypeID::TIMESTAMP: {
assert(sourceTypeID == common::LogicalTypeID::STRING);
return &UnaryExecFunction<ku_string_t, timestamp_t, operation::CastStringToTimestamp>;
func = &UnaryExecFunction<ku_string_t, timestamp_t, operation::CastStringToTimestamp>;
return;
}
case common::LogicalTypeID::INTERVAL: {
assert(sourceTypeID == common::LogicalTypeID::STRING);
return &UnaryExecFunction<ku_string_t, interval_t, operation::CastStringToInterval>;
func = &UnaryExecFunction<ku_string_t, interval_t, operation::CastStringToInterval>;
return;
}
default:
throw common::NotImplementedException(
Expand Down
65 changes: 35 additions & 30 deletions src/function/vector_list_operation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -356,34 +356,34 @@ std::unique_ptr<FunctionBindData> ListSortVectorOperation::bindFunc(
auto vectorOperationDefinition = reinterpret_cast<VectorOperationDefinition*>(definition);
switch (VarListType::getChildType(&arguments[0]->dataType)->getLogicalTypeID()) {
case LogicalTypeID::INT64: {
vectorOperationDefinition->execFunc = getExecFunction<int64_t>(arguments);
getExecFunction<int64_t>(arguments, vectorOperationDefinition->execFunc);
} break;
case LogicalTypeID::INT32: {
vectorOperationDefinition->execFunc = getExecFunction<int32_t>(arguments);
getExecFunction<int32_t>(arguments, vectorOperationDefinition->execFunc);
} break;
case LogicalTypeID::INT16: {
vectorOperationDefinition->execFunc = getExecFunction<int16_t>(arguments);
getExecFunction<int16_t>(arguments, vectorOperationDefinition->execFunc);
} break;
case LogicalTypeID::DOUBLE: {
vectorOperationDefinition->execFunc = getExecFunction<double_t>(arguments);
getExecFunction<double_t>(arguments, vectorOperationDefinition->execFunc);
} break;
case LogicalTypeID::FLOAT: {
vectorOperationDefinition->execFunc = getExecFunction<float_t>(arguments);
getExecFunction<float_t>(arguments, vectorOperationDefinition->execFunc);
} break;
case LogicalTypeID::BOOL: {
vectorOperationDefinition->execFunc = getExecFunction<uint8_t>(arguments);
getExecFunction<uint8_t>(arguments, vectorOperationDefinition->execFunc);
} break;
case LogicalTypeID::STRING: {
vectorOperationDefinition->execFunc = getExecFunction<ku_string_t>(arguments);
getExecFunction<ku_string_t>(arguments, vectorOperationDefinition->execFunc);
} break;
case LogicalTypeID::DATE: {
vectorOperationDefinition->execFunc = getExecFunction<date_t>(arguments);
getExecFunction<date_t>(arguments, vectorOperationDefinition->execFunc);
} break;
case LogicalTypeID::TIMESTAMP: {
vectorOperationDefinition->execFunc = getExecFunction<timestamp_t>(arguments);
getExecFunction<timestamp_t>(arguments, vectorOperationDefinition->execFunc);
} break;
case LogicalTypeID::INTERVAL: {
vectorOperationDefinition->execFunc = getExecFunction<interval_t>(arguments);
getExecFunction<interval_t>(arguments, vectorOperationDefinition->execFunc);
} break;
default: {
throw common::NotImplementedException("ListSortVectorOperation::bindFunc");
Expand All @@ -393,16 +393,19 @@ std::unique_ptr<FunctionBindData> ListSortVectorOperation::bindFunc(
}

template<typename T>
scalar_exec_func ListSortVectorOperation::getExecFunction(
const binder::expression_vector& arguments) {
void ListSortVectorOperation::getExecFunction(
const binder::expression_vector& arguments, scalar_exec_func& func) {
if (arguments.size() == 1) {
return UnaryListExecFunction<list_entry_t, list_entry_t, operation::ListSort<T>>;
func = UnaryListExecFunction<list_entry_t, list_entry_t, operation::ListSort<T>>;
return;
} else if (arguments.size() == 2) {
return BinaryListExecFunction<list_entry_t, ku_string_t, list_entry_t,
operation::ListSort<T>>;
func =
BinaryListExecFunction<list_entry_t, ku_string_t, list_entry_t, operation::ListSort<T>>;
return;
} else if (arguments.size() == 3) {
return TernaryListExecFunction<list_entry_t, ku_string_t, ku_string_t, list_entry_t,
func = TernaryListExecFunction<list_entry_t, ku_string_t, ku_string_t, list_entry_t,
operation::ListSort<T>>;
return;
} else {
throw common::RuntimeException("Invalid number of arguments");
}
Expand All @@ -425,34 +428,34 @@ std::unique_ptr<FunctionBindData> ListReverseSortVectorOperation::bindFunc(
auto vectorOperationDefinition = reinterpret_cast<VectorOperationDefinition*>(definition);
switch (VarListType::getChildType(&arguments[0]->dataType)->getLogicalTypeID()) {
case LogicalTypeID::INT64: {
vectorOperationDefinition->execFunc = getExecFunction<int64_t>(arguments);
getExecFunction<int64_t>(arguments, vectorOperationDefinition->execFunc);
} break;
case LogicalTypeID::INT32: {
vectorOperationDefinition->execFunc = getExecFunction<int32_t>(arguments);
getExecFunction<int32_t>(arguments, vectorOperationDefinition->execFunc);
} break;
case LogicalTypeID::INT16: {
vectorOperationDefinition->execFunc = getExecFunction<int16_t>(arguments);
getExecFunction<int16_t>(arguments, vectorOperationDefinition->execFunc);
} break;
case LogicalTypeID::DOUBLE: {
vectorOperationDefinition->execFunc = getExecFunction<double_t>(arguments);
getExecFunction<double_t>(arguments, vectorOperationDefinition->execFunc);
} break;
case LogicalTypeID::FLOAT: {
vectorOperationDefinition->execFunc = getExecFunction<float_t>(arguments);
getExecFunction<float_t>(arguments, vectorOperationDefinition->execFunc);
} break;
case LogicalTypeID::BOOL: {
vectorOperationDefinition->execFunc = getExecFunction<uint8_t>(arguments);
getExecFunction<uint8_t>(arguments, vectorOperationDefinition->execFunc);
} break;
case LogicalTypeID::STRING: {
vectorOperationDefinition->execFunc = getExecFunction<ku_string_t>(arguments);
getExecFunction<ku_string_t>(arguments, vectorOperationDefinition->execFunc);
} break;
case LogicalTypeID::DATE: {
vectorOperationDefinition->execFunc = getExecFunction<date_t>(arguments);
getExecFunction<date_t>(arguments, vectorOperationDefinition->execFunc);
} break;
case LogicalTypeID::TIMESTAMP: {
vectorOperationDefinition->execFunc = getExecFunction<timestamp_t>(arguments);
getExecFunction<timestamp_t>(arguments, vectorOperationDefinition->execFunc);
} break;
case LogicalTypeID::INTERVAL: {
vectorOperationDefinition->execFunc = getExecFunction<interval_t>(arguments);
getExecFunction<interval_t>(arguments, vectorOperationDefinition->execFunc);
} break;
default: {
throw common::NotImplementedException("ListReverseSortVectorOperation::bindFunc");
Expand All @@ -462,13 +465,15 @@ std::unique_ptr<FunctionBindData> ListReverseSortVectorOperation::bindFunc(
}

template<typename T>
scalar_exec_func ListReverseSortVectorOperation::getExecFunction(
const binder::expression_vector& arguments) {
void ListReverseSortVectorOperation::getExecFunction(
const binder::expression_vector& arguments, scalar_exec_func& func) {
if (arguments.size() == 1) {
return UnaryListExecFunction<list_entry_t, list_entry_t, operation::ListReverseSort<T>>;
func = UnaryListExecFunction<list_entry_t, list_entry_t, operation::ListReverseSort<T>>;
return;
} else if (arguments.size() == 2) {
return BinaryListExecFunction<list_entry_t, ku_string_t, list_entry_t,
func = BinaryListExecFunction<list_entry_t, ku_string_t, list_entry_t,
operation::ListReverseSort<T>>;
return;
} else {
throw common::RuntimeException("Invalid number of arguments");
}
Expand Down
Loading

0 comments on commit 75a0bef

Please sign in to comment.