Skip to content

Commit

Permalink
Arithmetic function framework refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
acquamarin committed Feb 23, 2023
1 parent bc9599b commit ea5d274
Show file tree
Hide file tree
Showing 13 changed files with 224 additions and 103 deletions.
25 changes: 17 additions & 8 deletions src/binder/expression_binder.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
#include "binder/expression_binder.h"

#include "binder/binder.h"
#include "binder/expression/function_expression.h"
#include "binder/expression/literal_expression.h"
#include "binder/expression/parameter_expression.h"
#include "common/type_utils.h"
#include "function/cast/vector_cast_operations.h"

using namespace kuzu::common;
using namespace kuzu::function;
Expand Down Expand Up @@ -79,6 +81,21 @@ std::shared_ptr<Expression> ExpressionBinder::implicitCastIfNecessary(
return implicitCast(expression, DataType(targetTypeID));
}

std::shared_ptr<Expression> ExpressionBinder::implicitCast(
const std::shared_ptr<Expression>& expression, const common::DataType& targetType) {
if (BuiltInVectorOperations::getCastCost(expression->dataType, targetType) != UINT32_MAX) {
return std::make_shared<ScalarFunctionExpression>(FUNCTION, DataType{targetType.typeID},
expression_vector{expression},
VectorCastOperations::bindExecFunc(expression->dataType.typeID, targetType.typeID),
nullptr /* selectFunc */, expression->getUniqueName() + " cast expression");
} else {
throw common::BinderException("Expression " + expression->getRawName() + " has data type " +
common::Types::dataTypeToString(expression->dataType) +
" but expect " + common::Types::dataTypeToString(targetType) +
". Implicit cast is not supported.");
}
}

void ExpressionBinder::resolveAnyDataType(Expression& expression, const DataType& targetType) {
if (expression.expressionType == PARAMETER) { // expression is parameter
((ParameterExpression&)expression).setDataType(targetType);
Expand All @@ -88,14 +105,6 @@ void ExpressionBinder::resolveAnyDataType(Expression& expression, const DataType
}
}

std::shared_ptr<Expression> ExpressionBinder::implicitCast(
const std::shared_ptr<Expression>& expression, const DataType& targetType) {
throw BinderException("Expression " + expression->getRawName() + " has data type " +
Types::dataTypeToString(expression->dataType) + " but expect " +
Types::dataTypeToString(targetType) +
". Implicit cast is not supported.");
}

void ExpressionBinder::validateExpectedDataType(
const Expression& expression, const std::unordered_set<DataTypeID>& targets) {
auto dataType = expression.dataType;
Expand Down
75 changes: 57 additions & 18 deletions src/function/built_in_vector_operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,59 @@ std::vector<std::string> BuiltInVectorOperations::getFunctionNames() {
return result;
}

uint32_t BuiltInVectorOperations::getCastCost(DataTypeID inputTypeID, DataTypeID targetTypeID) {
if (inputTypeID == targetTypeID) {
return 0;
} else {
if (targetTypeID == ANY) {
// Any inputTypeID can match to type ANY
return 0;
}
switch (inputTypeID) {
case common::ANY:
// ANY type can be any type
return 0;
case common::INT64:
return implicitCastInt64(targetTypeID);
case common::DOUBLE:
return implicitCastDouble(targetTypeID);
default:
return UINT32_MAX;
}
}
}

uint32_t BuiltInVectorOperations::getCastCost(
const DataType& inputType, const DataType& targetType) {
if (inputType == targetType) {
return 0;
} else {
switch (inputType.typeID) {
case common::FIXED_LIST:
case common::VAR_LIST:
return UINT32_MAX;
default:
return getCastCost(inputType.typeID, targetType.typeID);
}
}
}

uint32_t BuiltInVectorOperations::implicitCastInt64(common::DataTypeID targetTypeID) {
switch (targetTypeID) {
case common::DOUBLE:
return 102;
default:
return UINT32_MAX;
}
}

uint32_t BuiltInVectorOperations::implicitCastDouble(common::DataTypeID targetTypeID) {
switch (targetTypeID) {
default:
return UINT32_MAX;
}
}

// When there is multiple candidates functions, e.g. double + int and double + double for input
// "1.5 + parameter", we prefer the one without any implicit casting i.e. double + double.
VectorOperationDefinition* BuiltInVectorOperations::getBestMatch(
Expand Down Expand Up @@ -113,7 +166,7 @@ uint32_t BuiltInVectorOperations::matchParameters(const std::vector<DataType>& i
}
auto cost = 0u;
for (auto i = 0u; i < inputTypes.size(); ++i) {
auto castCost = castRules(inputTypes[i].typeID, targetTypeIDs[i]);
auto castCost = getCastCost(inputTypes[i].typeID, targetTypeIDs[i]);
if (castCost == UINT32_MAX) {
return UINT32_MAX;
}
Expand All @@ -126,7 +179,7 @@ uint32_t BuiltInVectorOperations::matchVarLengthParameters(
const std::vector<DataType>& inputTypes, DataTypeID targetTypeID, bool isOverload) {
auto cost = 0u;
for (auto& inputType : inputTypes) {
auto castCost = castRules(inputType.typeID, targetTypeID);
auto castCost = getCastCost(inputType.typeID, targetTypeID);
if (castCost == UINT32_MAX) {
return UINT32_MAX;
}
Expand All @@ -135,22 +188,6 @@ uint32_t BuiltInVectorOperations::matchVarLengthParameters(
return cost;
}

uint32_t BuiltInVectorOperations::castRules(DataTypeID inputTypeID, DataTypeID targetTypeID) {
if (inputTypeID == ANY) {
// ANY type can be any type
return 0;
}
if (targetTypeID == ANY) {
// Any inputTypeID can match to type ANY
return 0;
}
if (inputTypeID != targetTypeID) {
// Unable to cast
return UINT32_MAX;
}
return 0; // no cast needed
}

void BuiltInVectorOperations::validateNonEmptyCandidateFunctions(
std::vector<VectorOperationDefinition*>& candidateFunctions, const std::string& name,
const std::vector<DataType>& inputTypes) {
Expand Down Expand Up @@ -290,6 +327,8 @@ void BuiltInVectorOperations::registerCastOperations() {
{CAST_TO_INTERVAL_FUNC_NAME, CastToIntervalVectorOperation::getDefinitions()});
vectorOperations.insert(
{CAST_TO_STRING_FUNC_NAME, CastToStringVectorOperation::getDefinitions()});
vectorOperations.insert(
{CAST_TO_DOUBLE_FUNC_NAME, CastToDoubleVectorOperation::getDefinitions()});
}

void BuiltInVectorOperations::registerListOperations() {
Expand Down
82 changes: 45 additions & 37 deletions src/function/vector_arithmetic_operations.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
#include "function/arithmetic/vector_arithmetic_operations.h"

#include "function/arithmetic/arithmetic_operations.h"

using namespace kuzu::common;

namespace kuzu {
Expand All @@ -16,12 +14,18 @@ static DataTypeID resolveResultType(DataTypeID leftTypeID, DataTypeID rightTypeI

std::vector<std::unique_ptr<VectorOperationDefinition>> AddVectorOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> result;
for (auto& leftTypeID : DataType::getNumericalTypeIDs()) {
for (auto& rightTypeID : DataType::getNumericalTypeIDs()) {
result.push_back(getBinaryDefinition<operation::Add>(ADD_FUNC_NAME, leftTypeID,
rightTypeID, resolveResultType(leftTypeID, rightTypeID)));
}
}
// int64 + int64 -> int64
result.push_back(make_unique<VectorOperationDefinition>(ADD_FUNC_NAME,
std::vector<common::DataTypeID>{INT64, INT64}, INT64,
BinaryExecFunction<int64_t, int64_t, int64_t, operation::Add>));
// double + double -> double
result.push_back(make_unique<VectorOperationDefinition>(ADD_FUNC_NAME,
std::vector<common::DataTypeID>{DOUBLE, DOUBLE}, DOUBLE,
BinaryExecFunction<double_t, double_t, double_t, operation::Add>));
// interval + interval → interval
result.push_back(make_unique<VectorOperationDefinition>(ADD_FUNC_NAME,
std::vector<DataTypeID>{INTERVAL, INTERVAL}, INTERVAL,
BinaryExecFunction<interval_t, interval_t, interval_t, operation::Add>));
// date + int → date
result.push_back(
make_unique<VectorOperationDefinition>(ADD_FUNC_NAME, std::vector<DataTypeID>{DATE, INT64},
Expand All @@ -46,22 +50,20 @@ std::vector<std::unique_ptr<VectorOperationDefinition>> AddVectorOperation::getD
result.push_back(make_unique<VectorOperationDefinition>(ADD_FUNC_NAME,
std::vector<DataTypeID>{INTERVAL, TIMESTAMP}, TIMESTAMP,
BinaryExecFunction<interval_t, timestamp_t, timestamp_t, operation::Add>));
// interval + interval → interval
result.push_back(make_unique<VectorOperationDefinition>(ADD_FUNC_NAME,
std::vector<DataTypeID>{INTERVAL, INTERVAL}, INTERVAL,
BinaryExecFunction<interval_t, interval_t, interval_t, operation::Add>));
return result;
}

std::vector<std::unique_ptr<VectorOperationDefinition>> SubtractVectorOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> result;
for (auto& leftTypeID : DataType::getNumericalTypeIDs()) {
for (auto& rightTypeID : DataType::getNumericalTypeIDs()) {
result.push_back(getBinaryDefinition<operation::Subtract>(SUBTRACT_FUNC_NAME,
leftTypeID, rightTypeID, resolveResultType(leftTypeID, rightTypeID)));
}
}
// date - date → integer
// int64_t - int64_t -> int64_t
result.push_back(make_unique<VectorOperationDefinition>(SUBTRACT_FUNC_NAME,
std::vector<DataTypeID>{INT64, INT64}, INT64,
BinaryExecFunction<int64_t, int64_t, int64_t, operation::Subtract>));
// double_t - double_t -> double_t
result.push_back(make_unique<VectorOperationDefinition>(SUBTRACT_FUNC_NAME,
std::vector<DataTypeID>{DOUBLE, DOUBLE}, DOUBLE,
BinaryExecFunction<double_t, double_t, double_t, operation::Subtract>));
// date - date → int64
result.push_back(make_unique<VectorOperationDefinition>(SUBTRACT_FUNC_NAME,
std::vector<DataTypeID>{DATE, DATE}, INT64,
BinaryExecFunction<date_t, date_t, int64_t, operation::Subtract>));
Expand Down Expand Up @@ -90,23 +92,27 @@ std::vector<std::unique_ptr<VectorOperationDefinition>> SubtractVectorOperation:

std::vector<std::unique_ptr<VectorOperationDefinition>> MultiplyVectorOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> result;
for (auto& leftTypeID : DataType::getNumericalTypeIDs()) {
for (auto& rightTypeID : DataType::getNumericalTypeIDs()) {
result.push_back(getBinaryDefinition<operation::Multiply>(MULTIPLY_FUNC_NAME,
leftTypeID, rightTypeID, resolveResultType(leftTypeID, rightTypeID)));
}
}
// int64_t * int64_t -> int64_t
result.push_back(make_unique<VectorOperationDefinition>(MULTIPLY_FUNC_NAME,
std::vector<DataTypeID>{INT64, INT64}, INT64,
BinaryExecFunction<int64_t, int64_t, int64_t, operation::Multiply>));
// double_t * double_t -> double_t
result.push_back(make_unique<VectorOperationDefinition>(MULTIPLY_FUNC_NAME,
std::vector<DataTypeID>{DOUBLE, DOUBLE}, DOUBLE,
BinaryExecFunction<double_t, double_t, double_t, operation::Multiply>));
return result;
}

std::vector<std::unique_ptr<VectorOperationDefinition>> DivideVectorOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> result;
for (auto& leftType : DataType::getNumericalTypeIDs()) {
for (auto& rightType : DataType::getNumericalTypeIDs()) {
result.push_back(getBinaryDefinition<operation::Divide>(
DIVIDE_FUNC_NAME, leftType, rightType, resolveResultType(leftType, rightType)));
}
}
// int64_t / int64_t -> int64_t
result.push_back(make_unique<VectorOperationDefinition>(DIVIDE_FUNC_NAME,
std::vector<DataTypeID>{INT64, INT64}, INT64,
BinaryExecFunction<int64_t, int64_t, int64_t, operation::Divide>));
// double_t / double_t -> double_t
result.push_back(make_unique<VectorOperationDefinition>(DIVIDE_FUNC_NAME,
std::vector<DataTypeID>{DOUBLE, DOUBLE}, DOUBLE,
BinaryExecFunction<double_t, double_t, double_t, operation::Divide>));
// interval / int → interval
result.push_back(make_unique<VectorOperationDefinition>(DIVIDE_FUNC_NAME,
std::vector<DataTypeID>{INTERVAL, INT64}, INTERVAL,
Expand All @@ -116,12 +122,14 @@ std::vector<std::unique_ptr<VectorOperationDefinition>> DivideVectorOperation::g

std::vector<std::unique_ptr<VectorOperationDefinition>> ModuloVectorOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> result;
for (auto& leftTypeID : DataType::getNumericalTypeIDs()) {
for (auto& rightTypeID : DataType::getNumericalTypeIDs()) {
result.push_back(getBinaryDefinition<operation::Modulo>(MODULO_FUNC_NAME, leftTypeID,
rightTypeID, resolveResultType(leftTypeID, rightTypeID)));
}
}
// int64_t % int64_t -> int64_t
result.push_back(make_unique<VectorOperationDefinition>(MODULO_FUNC_NAME,
std::vector<DataTypeID>{INT64, INT64}, INT64,
BinaryExecFunction<int64_t, int64_t, int64_t, operation::Modulo>));
// double_t % double_t -> double_t
result.push_back(make_unique<VectorOperationDefinition>(MODULO_FUNC_NAME,
std::vector<DataTypeID>{DOUBLE, DOUBLE}, DOUBLE,
BinaryExecFunction<double_t, double_t, double_t, operation::Modulo>));
return result;
}

Expand Down
28 changes: 28 additions & 0 deletions src/function/vector_cast_operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,26 @@ using namespace kuzu::common;
namespace kuzu {
namespace function {

scalar_exec_func VectorCastOperations::bindExecFunc(
common::DataTypeID sourceTypeID, common::DataTypeID targetTypeID) {
switch (sourceTypeID) {
case common::INT64: {
switch (targetTypeID) {
case common::DOUBLE:
return VectorOperations::UnaryExecFunction<int64_t, double_t, operation::CastToDouble>;
default:
throw common::InternalException("Undefined casting operation from " +
common::Types::dataTypeToString(sourceTypeID) + " to " +
common::Types::dataTypeToString(targetTypeID) + ".");
}
}
default:
throw common::InternalException("Undefined casting operation from " +
common::Types::dataTypeToString(sourceTypeID) + " to " +
common::Types::dataTypeToString(targetTypeID) + ".");
}
}

std::vector<std::unique_ptr<VectorOperationDefinition>>
CastToDateVectorOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> result;
Expand Down Expand Up @@ -65,5 +85,13 @@ CastToStringVectorOperation::getDefinitions() {
return result;
}

std::vector<std::unique_ptr<VectorOperationDefinition>>
CastToDoubleVectorOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> result;
result.push_back(make_unique<VectorOperationDefinition>(CAST_TO_DOUBLE_FUNC_NAME,
std::vector<DataTypeID>{INT64}, DOUBLE, bindExecFunc(INT64, DOUBLE)));
return result;
}

} // namespace function
} // namespace kuzu
1 change: 1 addition & 0 deletions src/include/common/expression_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ const std::string CAST_TO_DATE_FUNC_NAME = "DATE";
const std::string CAST_TO_TIMESTAMP_FUNC_NAME = "TIMESTAMP";
const std::string CAST_TO_INTERVAL_FUNC_NAME = "INTERVAL";
const std::string CAST_TO_STRING_FUNC_NAME = "STRING";
const std::string CAST_TO_DOUBLE_FUNC_NAME = "TO_DOUBLE";
const std::string IMPLICIT_CAST_TO_BOOL_FUNC_NAME = "_BOOL";
const std::string IMPLICIT_CAST_TO_INT_FUNC_NAME = "_INT";
const std::string IMPLICIT_CAST_TO_STRING_FUNC_NAME = "_STRING";
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include "arithmetic_operations.h"
#include "function/vector_operations.h"

namespace kuzu {
Expand Down
10 changes: 9 additions & 1 deletion src/include/function/built_in_vector_operations.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,16 @@ class BuiltInVectorOperations {

std::vector<std::string> getFunctionNames();

static uint32_t getCastCost(common::DataTypeID inputTypeID, common::DataTypeID targetTypeID);

static uint32_t getCastCost(
const common::DataType& inputType, const common::DataType& targetType);

private:
static uint32_t implicitCastInt64(common::DataTypeID targetTypeID);

static uint32_t implicitCastDouble(common::DataTypeID targetTypeID);

VectorOperationDefinition* getBestMatch(std::vector<VectorOperationDefinition*>& functions);

uint32_t getFunctionCost(const std::vector<common::DataType>& inputTypes,
Expand All @@ -35,7 +44,6 @@ class BuiltInVectorOperations {
const std::vector<common::DataTypeID>& targetTypeIDs, bool isOverload);
uint32_t matchVarLengthParameters(const std::vector<common::DataType>& inputTypes,
common::DataTypeID targetTypeID, bool isOverload);
uint32_t castRules(common::DataTypeID inputTypeID, common::DataTypeID targetTypeID);

void validateNonEmptyCandidateFunctions(
std::vector<VectorOperationDefinition*>& candidateFunctions, const std::string& name,
Expand Down
6 changes: 6 additions & 0 deletions src/include/function/cast/cast_operations.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ inline std::string CastToString::castToStringWithDataType(
return common::TypeUtils::toString(input, dataType);
}

struct CastToDouble {
static inline void operation(int64_t& input, double_t& result) {
result = static_cast<double_t>(input);
}
};

} // namespace operation
} // namespace function
} // namespace kuzu
7 changes: 7 additions & 0 deletions src/include/function/cast/vector_cast_operations.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ class VectorCastOperations : public VectorOperations {
assert(params.size() == 1);
UnaryOperationExecutor::executeCast<OPERAND_TYPE, RESULT_TYPE, FUNC>(*params[0], result);
}

static scalar_exec_func bindExecFunc(
common::DataTypeID sourceTypeID, common::DataTypeID targetTypeID);
};

struct CastToDateVectorOperation : public VectorCastOperations {
Expand All @@ -38,5 +41,9 @@ struct CastToStringVectorOperation : public VectorCastOperations {
static std::vector<std::unique_ptr<VectorOperationDefinition>> getDefinitions();
};

struct CastToDoubleVectorOperation : public VectorCastOperations {
static std::vector<std::unique_ptr<VectorOperationDefinition>> getDefinitions();
};

} // namespace function
} // namespace kuzu
Loading

0 comments on commit ea5d274

Please sign in to comment.