Skip to content

Commit

Permalink
Add numerical downcast
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed Mar 31, 2023
1 parent 728a28b commit 31433c8
Show file tree
Hide file tree
Showing 15 changed files with 313 additions and 113 deletions.
9 changes: 6 additions & 3 deletions src/binder/bind/bind_ddl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,12 @@ uint32_t Binder::bindPrimaryKey(
}
auto primaryKey = propertyNameDataTypes[primaryKeyIdx];
StringUtils::toUpper(primaryKey.second);
// We only support INT64 and STRING column as the primary key.
if ((primaryKey.second != std::string("INT64")) &&
(primaryKey.second != std::string("STRING"))) {
// We only support INT64, and STRING column as the primary key.
switch (Types::dataTypeFromString(primaryKey.second).typeID) {
case common::INT64:
case common::STRING:
break;
default:
throw BinderException("Invalid primary key type: " + primaryKey.second + ".");
}
return primaryKeyIdx;
Expand Down
4 changes: 2 additions & 2 deletions src/binder/expression_binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ std::shared_ptr<Expression> ExpressionBinder::implicitCastIfNecessary(

std::shared_ptr<Expression> ExpressionBinder::implicitCast(
const std::shared_ptr<Expression>& expression, const common::DataType& targetType) {
if (BuiltInVectorOperations::getCastCost(expression->dataType, targetType) != UINT32_MAX) {
auto functionName = VectorCastOperations::bindCastFunctionName(targetType.typeID);
if (VectorCastOperations::hasImplicitCast(expression->dataType, targetType)) {
auto functionName = VectorCastOperations::bindImplicitCastFuncName(targetType);
auto children = expression_vector{expression};
auto uniqueName = ScalarFunctionExpression::getUniqueName(functionName, children);
return std::make_shared<ScalarFunctionExpression>(functionName, FUNCTION,
Expand Down
13 changes: 13 additions & 0 deletions src/common/types/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,19 @@ uint32_t Types::getDataTypeSize(const DataType& dataType) {
}
}

bool Types::isNumerical(const kuzu::common::DataType& dataType) {
switch (dataType.typeID) {
case INT64:
case INT32:
case INT16:
case DOUBLE:
case FLOAT:
return true;
default:
return false;
}
}

RelDirection operator!(RelDirection& direction) {
return (FWD == direction) ? BWD : FWD;
}
Expand Down
8 changes: 4 additions & 4 deletions src/common/types/value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ Value::Value(int64_t val_) : dataType{INT64}, isNull_{false} {
val.int64Val = val_;
}

Value::Value(float_t val_) : dataType{FLOAT}, isNull_{false} {
val.floatVal = val_;
}

Value::Value(double val_) : dataType{DOUBLE}, isNull_{false} {
val.doubleVal = val_;
}
Expand Down Expand Up @@ -118,10 +122,6 @@ Value::Value(DataType dataType, std::vector<std::unique_ptr<Value>> vals)
listVal = std::move(vals);
}

Value::Value(float_t val_) : dataType{FLOAT}, isNull_{false} {
val.floatVal = val_;
}

Value::Value(std::unique_ptr<NodeVal> val_) : dataType{NODE}, isNull_{false} {
nodeVal = std::move(val_);
}
Expand Down
2 changes: 2 additions & 0 deletions src/function/built_in_vector_operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,8 @@ void BuiltInVectorOperations::registerCastOperations() {
{CAST_TO_INT64_FUNC_NAME, CastToInt64VectorOperation::getDefinitions()});
vectorOperations.insert(
{CAST_TO_INT32_FUNC_NAME, CastToInt32VectorOperation::getDefinitions()});
vectorOperations.insert(
{CAST_TO_INT16_FUNC_NAME, CastToInt16VectorOperation::getDefinitions()});
}

void BuiltInVectorOperations::registerListOperations() {
Expand Down
171 changes: 93 additions & 78 deletions src/function/vector_cast_operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,103 +8,91 @@ using namespace kuzu::common;
namespace kuzu {
namespace function {

scalar_exec_func VectorCastOperations::bindImplicitCastFunc(
common::DataTypeID sourceTypeID, common::DataTypeID targetTypeID) {
switch (sourceTypeID) {
case common::INT16: {
return bindImplicitCastInt16Func(targetTypeID);
}
case common::INT32: {
return bindImplicitCastInt32Func(targetTypeID);
}
case common::INT64: {
return bindImplicitCastInt64Func(targetTypeID);
}
case common::FLOAT: {
return bindImplicitCastFloatFunc(targetTypeID);
bool VectorCastOperations::hasImplicitCast(
const common::DataType& srcType, const common::DataType& dstType) {
// We allow cast between any numerical types
if (Types::isNumerical(srcType) && Types::isNumerical(dstType)) {
return true;
}
switch (srcType.typeID) {
case common::STRING: {
switch (dstType.typeID) {
case common::DATE:
case common::TIMESTAMP:
case common::INTERVAL:
return true;
default:
return false;
}
}
default:
throw common::InternalException("Undefined casting operation from " +
common::Types::dataTypeToString(sourceTypeID) + " to " +
common::Types::dataTypeToString(targetTypeID) + ".");
return false;
}
}

scalar_exec_func VectorCastOperations::bindImplicitCastInt16Func(common::DataTypeID targetTypeID) {
switch (targetTypeID) {
case common::INT32: {
return VectorOperations::UnaryExecFunction<int16_t, int32_t, operation::CastToInt32>;
}
case common::INT64: {
return VectorOperations::UnaryExecFunction<int16_t, int64_t, operation::CastToInt64>;
}
case common::FLOAT: {
return VectorOperations::UnaryExecFunction<int16_t, float_t, operation::CastToFloat>;
}
case common::DOUBLE: {
return VectorOperations::UnaryExecFunction<int16_t, double_t, operation::CastToDouble>;
}
default: {
throw common::InternalException("Undefined casting operation from INT16 to " +
common::Types::dataTypeToString(targetTypeID) + ".");
}
}
}

scalar_exec_func VectorCastOperations::bindImplicitCastInt32Func(common::DataTypeID targetTypeID) {
switch (targetTypeID) {
std::string VectorCastOperations::bindImplicitCastFuncName(const common::DataType& dstType) {
switch (dstType.typeID) {
case common::INT16:
return CAST_TO_INT16_FUNC_NAME;
case common::INT32:
return CAST_TO_INT32_FUNC_NAME;
case common::INT64:
return VectorOperations::UnaryExecFunction<int32_t, int64_t, operation::CastToInt64>;
return CAST_TO_INT64_FUNC_NAME;
case common::FLOAT:
return VectorOperations::UnaryExecFunction<int32_t, float_t, operation::CastToFloat>;
return CAST_TO_FLOAT_FUNC_NAME;
case common::DOUBLE:
return VectorOperations::UnaryExecFunction<int32_t, double_t, operation::CastToDouble>;
return CAST_TO_DOUBLE_FUNC_NAME;
case common::DATE:
return CAST_TO_DATE_FUNC_NAME;
case common::TIMESTAMP:
return CAST_TO_TIMESTAMP_FUNC_NAME;
case common::INTERVAL:
return CAST_TO_INTERVAL_FUNC_NAME;
case common::STRING:
return CAST_TO_STRING_FUNC_NAME;
default:
throw common::InternalException("Undefined casting operation from INT32 to " +
common::Types::dataTypeToString(targetTypeID) + ".");
throw common::NotImplementedException("bindImplicitCastFuncName()");
}
}

scalar_exec_func VectorCastOperations::bindImplicitCastInt64Func(common::DataTypeID targetTypeID) {
scalar_exec_func VectorCastOperations::bindImplicitCastFunc(
common::DataTypeID sourceTypeID, common::DataTypeID targetTypeID) {
switch (targetTypeID) {
case common::FLOAT:
return VectorOperations::UnaryExecFunction<int64_t, float_t, operation::CastToFloat>;
case common::DOUBLE:
return VectorOperations::UnaryExecFunction<int64_t, double_t, operation::CastToDouble>;
default:
throw common::InternalException("Undefined casting operation from INT64 to " +
common::Types::dataTypeToString(targetTypeID) + ".");
case common::INT16: {
return bindImplicitNumericalCastFunc<int16_t, operation::CastToInt16>(sourceTypeID);
}
}

scalar_exec_func VectorCastOperations::bindImplicitCastFloatFunc(common::DataTypeID targetTypeID) {
switch (targetTypeID) {
case common::DOUBLE:
return VectorOperations::UnaryExecFunction<float_t, double_t, operation::CastToDouble>;
default:
throw common::InternalException("Undefined casting operation from FLOAT to " +
common::Types::dataTypeToString(targetTypeID) + ".");
case common::INT32: {
return bindImplicitNumericalCastFunc<int32_t, operation::CastToInt32>(sourceTypeID);
}
}

std::string VectorCastOperations::bindCastFunctionName(common::DataTypeID targetTypeID) {
switch (targetTypeID) {
case common::INT64: {
return CAST_TO_INT64_FUNC_NAME;
return bindImplicitNumericalCastFunc<int64_t, operation::CastToInt64>(sourceTypeID);
}
case common::INT32: {
return CAST_TO_INT32_FUNC_NAME;
case common::FLOAT: {
return bindImplicitNumericalCastFunc<float_t, operation::CastToFloat>(sourceTypeID);
}
case common::DOUBLE: {
return CAST_TO_DOUBLE_FUNC_NAME;
return bindImplicitNumericalCastFunc<double_t, operation::CastToDouble>(sourceTypeID);
}
case common::FLOAT: {
return CAST_TO_FLOAT_FUNC_NAME;
case common::DATE: {
assert(sourceTypeID == common::STRING);
return VectorOperations::UnaryExecFunction<ku_string_t, date_t,
operation::CastStringToDate>;
}
case common::TIMESTAMP: {
assert(sourceTypeID == common::STRING);
return VectorOperations::UnaryExecFunction<ku_string_t, timestamp_t,
operation::CastStringToTimestamp>;
}
default: {
throw common::InternalException("Cannot bind function name for cast to " +
common::Types::dataTypeToString(targetTypeID));
case common::INTERVAL: {
assert(sourceTypeID == common::STRING);
return VectorOperations::UnaryExecFunction<ku_string_t, interval_t,
operation::CastStringToInterval>;
}
default:
throw common::NotImplementedException("Unimplemented casting operation from " +
common::Types::dataTypeToString(sourceTypeID) +
" to " +
common::Types::dataTypeToString(targetTypeID) + ".");
}
}

Expand Down Expand Up @@ -162,9 +150,6 @@ CastToStringVectorOperation::getDefinitions() {
result.push_back(make_unique<VectorOperationDefinition>(CAST_TO_STRING_FUNC_NAME,
std::vector<DataTypeID>{VAR_LIST}, STRING,
UnaryCastExecFunction<ku_list_t, ku_string_t, operation::CastToString>));
result.push_back(make_unique<VectorOperationDefinition>(CAST_TO_STRING_FUNC_NAME,
std::vector<DataTypeID>{FLOAT}, STRING,
UnaryCastExecFunction<float_t, ku_string_t, operation::CastToString>));
return result;
}

Expand All @@ -191,6 +176,9 @@ CastToFloatVectorOperation::getDefinitions() {
CAST_TO_FLOAT_FUNC_NAME, INT32, FLOAT));
result.push_back(bindVectorOperation<int64_t, float_t, operation::CastToFloat>(
CAST_TO_FLOAT_FUNC_NAME, INT64, FLOAT));
// down cast
result.push_back(bindVectorOperation<double_t, float_t, operation::CastToFloat>(
CAST_TO_FLOAT_FUNC_NAME, DOUBLE, FLOAT));
return result;
}

Expand All @@ -201,6 +189,11 @@ CastToInt64VectorOperation::getDefinitions() {
CAST_TO_INT64_FUNC_NAME, INT16, INT64));
result.push_back(bindVectorOperation<int32_t, int64_t, operation::CastToInt64>(
CAST_TO_INT64_FUNC_NAME, INT32, INT64));
// down cast
result.push_back(bindVectorOperation<float_t, int64_t, operation::CastToInt64>(
CAST_TO_INT64_FUNC_NAME, FLOAT, INT64));
result.push_back(bindVectorOperation<double_t, int64_t, operation::CastToInt64>(
CAST_TO_INT64_FUNC_NAME, DOUBLE, INT64));
return result;
}

Expand All @@ -209,6 +202,28 @@ CastToInt32VectorOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> result;
result.push_back(bindVectorOperation<int16_t, int32_t, operation::CastToInt32>(
CAST_TO_INT32_FUNC_NAME, INT16, INT32));
// down cast
result.push_back(bindVectorOperation<int64_t, int32_t, operation::CastToInt32>(
CAST_TO_INT32_FUNC_NAME, INT64, INT32));
result.push_back(bindVectorOperation<float_t, int32_t, operation::CastToInt32>(
CAST_TO_INT32_FUNC_NAME, FLOAT, INT32));
result.push_back(bindVectorOperation<double_t, int32_t, operation::CastToInt32>(
CAST_TO_INT32_FUNC_NAME, DOUBLE, INT32));
return result;
}

std::vector<std::unique_ptr<VectorOperationDefinition>>
CastToInt16VectorOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> result;
// down cast
result.push_back(bindVectorOperation<int32_t, int16_t, operation::CastToInt16>(
CAST_TO_INT32_FUNC_NAME, INT32, INT16));
result.push_back(bindVectorOperation<int64_t, int16_t, operation::CastToInt16>(
CAST_TO_INT32_FUNC_NAME, INT64, INT16));
result.push_back(bindVectorOperation<float_t, int16_t, operation::CastToInt16>(
CAST_TO_INT32_FUNC_NAME, FLOAT, INT16));
result.push_back(bindVectorOperation<double_t, int16_t, operation::CastToInt16>(
CAST_TO_INT32_FUNC_NAME, DOUBLE, INT16));
return result;
}

Expand Down
6 changes: 1 addition & 5 deletions src/include/common/expression_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,7 @@ const std::string CAST_TO_DOUBLE_FUNC_NAME = "TO_DOUBLE";
const std::string CAST_TO_FLOAT_FUNC_NAME = "TO_FLOAT";
const std::string CAST_TO_INT64_FUNC_NAME = "TO_INT64";
const std::string CAST_TO_INT32_FUNC_NAME = "TO_INT32";
const std::string IMPLICIT_CAST_TO_BOOL_FUNC_NAME = "_BOOL";
const std::string IMPLICIT_CAST_TO_INT_FUNC_NAME = "_INT";
const std::string IMPLICIT_CAST_TO_STRING_FUNC_NAME = "_STRING";
const std::string IMPLICIT_CAST_TO_DATE_FUNC_NAME = "_DATE";
const std::string IMPLICIT_CAST_TO_TIMESTAMP_FUNC_NAME = "_TIMESTAMP";
const std::string CAST_TO_INT16_FUNC_NAME = "TO_INT16";

// list
const std::string LIST_CREATION_FUNC_NAME = "LIST_CREATION";
Expand Down
1 change: 1 addition & 0 deletions src/include/common/types/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class Types {
KUZU_API static DataType dataTypeFromString(const std::string& dataTypeString);
static uint32_t getDataTypeSize(DataTypeID dataTypeID);
static uint32_t getDataTypeSize(const DataType& dataType);
static bool isNumerical(const DataType& dataType);

private:
static DataTypeID dataTypeIDFromString(const std::string& dataTypeIDString);
Expand Down
9 changes: 9 additions & 0 deletions src/include/common/types/value.h
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,15 @@ inline int64_t Value::getValue() const {
return val.int64Val;
}

/**
* @return float value.
*/
KUZU_API template<>
inline float Value::getValue() const {
assert(dataType.getTypeID() == FLOAT);
return val.floatVal;
}

/**
* @return double value.
*/
Expand Down
Loading

0 comments on commit 31433c8

Please sign in to comment.