Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Numerical downcast #1429

Merged
merged 1 commit into from
Mar 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/binder/bind/bind_ddl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,12 @@ uint32_t Binder::bindPrimaryKey(
}
auto primaryKey = propertyNameDataTypes[primaryKeyIdx];
StringUtils::toUpper(primaryKey.second);
// We only support INT64 and STRING column as the primary key.
if ((primaryKey.second != std::string("INT64")) &&
(primaryKey.second != std::string("STRING"))) {
// We only support INT64, and STRING column as the primary key.
switch (Types::dataTypeFromString(primaryKey.second).typeID) {
case common::INT64:
case common::STRING:
break;
default:
andyfengHKU marked this conversation as resolved.
Show resolved Hide resolved
throw BinderException("Invalid primary key type: " + primaryKey.second + ".");
}
return primaryKeyIdx;
Expand Down
4 changes: 2 additions & 2 deletions src/binder/expression_binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ std::shared_ptr<Expression> ExpressionBinder::implicitCastIfNecessary(

std::shared_ptr<Expression> ExpressionBinder::implicitCast(
const std::shared_ptr<Expression>& expression, const common::DataType& targetType) {
if (BuiltInVectorOperations::getCastCost(expression->dataType, targetType) != UINT32_MAX) {
auto functionName = VectorCastOperations::bindCastFunctionName(targetType.typeID);
if (VectorCastOperations::hasImplicitCast(expression->dataType, targetType)) {
auto functionName = VectorCastOperations::bindImplicitCastFuncName(targetType);
auto children = expression_vector{expression};
auto uniqueName = ScalarFunctionExpression::getUniqueName(functionName, children);
return std::make_shared<ScalarFunctionExpression>(functionName, FUNCTION,
Expand Down
13 changes: 13 additions & 0 deletions src/common/types/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,19 @@ uint32_t Types::getDataTypeSize(const DataType& dataType) {
}
}

bool Types::isNumerical(const kuzu::common::DataType& dataType) {
switch (dataType.typeID) {
case INT64:
case INT32:
case INT16:
case DOUBLE:
case FLOAT:
return true;
default:
return false;
}
}

RelDirection operator!(RelDirection& direction) {
return (FWD == direction) ? BWD : FWD;
}
Expand Down
8 changes: 4 additions & 4 deletions src/common/types/value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ Value::Value(int64_t val_) : dataType{INT64}, isNull_{false} {
val.int64Val = val_;
}

Value::Value(float_t val_) : dataType{FLOAT}, isNull_{false} {
val.floatVal = val_;
}

Value::Value(double val_) : dataType{DOUBLE}, isNull_{false} {
val.doubleVal = val_;
}
Expand Down Expand Up @@ -118,10 +122,6 @@ Value::Value(DataType dataType, std::vector<std::unique_ptr<Value>> vals)
listVal = std::move(vals);
}

Value::Value(float_t val_) : dataType{FLOAT}, isNull_{false} {
val.floatVal = val_;
}

Value::Value(std::unique_ptr<NodeVal> val_) : dataType{NODE}, isNull_{false} {
nodeVal = std::move(val_);
}
Expand Down
2 changes: 2 additions & 0 deletions src/function/built_in_vector_operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,8 @@ void BuiltInVectorOperations::registerCastOperations() {
{CAST_TO_INT64_FUNC_NAME, CastToInt64VectorOperation::getDefinitions()});
vectorOperations.insert(
{CAST_TO_INT32_FUNC_NAME, CastToInt32VectorOperation::getDefinitions()});
vectorOperations.insert(
{CAST_TO_INT16_FUNC_NAME, CastToInt16VectorOperation::getDefinitions()});
}

void BuiltInVectorOperations::registerListOperations() {
Expand Down
171 changes: 93 additions & 78 deletions src/function/vector_cast_operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,103 +8,91 @@ using namespace kuzu::common;
namespace kuzu {
namespace function {

scalar_exec_func VectorCastOperations::bindImplicitCastFunc(
common::DataTypeID sourceTypeID, common::DataTypeID targetTypeID) {
switch (sourceTypeID) {
case common::INT16: {
return bindImplicitCastInt16Func(targetTypeID);
}
case common::INT32: {
return bindImplicitCastInt32Func(targetTypeID);
}
case common::INT64: {
return bindImplicitCastInt64Func(targetTypeID);
}
case common::FLOAT: {
return bindImplicitCastFloatFunc(targetTypeID);
bool VectorCastOperations::hasImplicitCast(
const common::DataType& srcType, const common::DataType& dstType) {
// We allow cast between any numerical types
if (Types::isNumerical(srcType) && Types::isNumerical(dstType)) {
return true;
}
switch (srcType.typeID) {
case common::STRING: {
switch (dstType.typeID) {
case common::DATE:
case common::TIMESTAMP:
case common::INTERVAL:
return true;
default:
return false;
}
}
default:
throw common::InternalException("Undefined casting operation from " +
common::Types::dataTypeToString(sourceTypeID) + " to " +
common::Types::dataTypeToString(targetTypeID) + ".");
return false;
}
}

scalar_exec_func VectorCastOperations::bindImplicitCastInt16Func(common::DataTypeID targetTypeID) {
switch (targetTypeID) {
case common::INT32: {
return VectorOperations::UnaryExecFunction<int16_t, int32_t, operation::CastToInt32>;
}
case common::INT64: {
return VectorOperations::UnaryExecFunction<int16_t, int64_t, operation::CastToInt64>;
}
case common::FLOAT: {
return VectorOperations::UnaryExecFunction<int16_t, float_t, operation::CastToFloat>;
}
case common::DOUBLE: {
return VectorOperations::UnaryExecFunction<int16_t, double_t, operation::CastToDouble>;
}
default: {
throw common::InternalException("Undefined casting operation from INT16 to " +
common::Types::dataTypeToString(targetTypeID) + ".");
}
}
}

scalar_exec_func VectorCastOperations::bindImplicitCastInt32Func(common::DataTypeID targetTypeID) {
switch (targetTypeID) {
std::string VectorCastOperations::bindImplicitCastFuncName(const common::DataType& dstType) {
switch (dstType.typeID) {
case common::INT16:
return CAST_TO_INT16_FUNC_NAME;
case common::INT32:
return CAST_TO_INT32_FUNC_NAME;
case common::INT64:
return VectorOperations::UnaryExecFunction<int32_t, int64_t, operation::CastToInt64>;
return CAST_TO_INT64_FUNC_NAME;
case common::FLOAT:
return VectorOperations::UnaryExecFunction<int32_t, float_t, operation::CastToFloat>;
return CAST_TO_FLOAT_FUNC_NAME;
case common::DOUBLE:
return VectorOperations::UnaryExecFunction<int32_t, double_t, operation::CastToDouble>;
return CAST_TO_DOUBLE_FUNC_NAME;
case common::DATE:
return CAST_TO_DATE_FUNC_NAME;
case common::TIMESTAMP:
return CAST_TO_TIMESTAMP_FUNC_NAME;
case common::INTERVAL:
return CAST_TO_INTERVAL_FUNC_NAME;
case common::STRING:
return CAST_TO_STRING_FUNC_NAME;
default:
throw common::InternalException("Undefined casting operation from INT32 to " +
common::Types::dataTypeToString(targetTypeID) + ".");
throw common::NotImplementedException("bindImplicitCastFuncName()");
}
}

scalar_exec_func VectorCastOperations::bindImplicitCastInt64Func(common::DataTypeID targetTypeID) {
scalar_exec_func VectorCastOperations::bindImplicitCastFunc(
common::DataTypeID sourceTypeID, common::DataTypeID targetTypeID) {
switch (targetTypeID) {
case common::FLOAT:
return VectorOperations::UnaryExecFunction<int64_t, float_t, operation::CastToFloat>;
case common::DOUBLE:
return VectorOperations::UnaryExecFunction<int64_t, double_t, operation::CastToDouble>;
default:
throw common::InternalException("Undefined casting operation from INT64 to " +
common::Types::dataTypeToString(targetTypeID) + ".");
case common::INT16: {
return bindImplicitNumericalCastFunc<int16_t, operation::CastToInt16>(sourceTypeID);
}
}

scalar_exec_func VectorCastOperations::bindImplicitCastFloatFunc(common::DataTypeID targetTypeID) {
switch (targetTypeID) {
case common::DOUBLE:
return VectorOperations::UnaryExecFunction<float_t, double_t, operation::CastToDouble>;
default:
throw common::InternalException("Undefined casting operation from FLOAT to " +
common::Types::dataTypeToString(targetTypeID) + ".");
case common::INT32: {
return bindImplicitNumericalCastFunc<int32_t, operation::CastToInt32>(sourceTypeID);
}
}

std::string VectorCastOperations::bindCastFunctionName(common::DataTypeID targetTypeID) {
switch (targetTypeID) {
case common::INT64: {
return CAST_TO_INT64_FUNC_NAME;
return bindImplicitNumericalCastFunc<int64_t, operation::CastToInt64>(sourceTypeID);
}
case common::INT32: {
return CAST_TO_INT32_FUNC_NAME;
case common::FLOAT: {
return bindImplicitNumericalCastFunc<float_t, operation::CastToFloat>(sourceTypeID);
}
case common::DOUBLE: {
return CAST_TO_DOUBLE_FUNC_NAME;
return bindImplicitNumericalCastFunc<double_t, operation::CastToDouble>(sourceTypeID);
}
case common::FLOAT: {
return CAST_TO_FLOAT_FUNC_NAME;
case common::DATE: {
assert(sourceTypeID == common::STRING);
return VectorOperations::UnaryExecFunction<ku_string_t, date_t,
operation::CastStringToDate>;
}
case common::TIMESTAMP: {
assert(sourceTypeID == common::STRING);
return VectorOperations::UnaryExecFunction<ku_string_t, timestamp_t,
operation::CastStringToTimestamp>;
}
default: {
throw common::InternalException("Cannot bind function name for cast to " +
common::Types::dataTypeToString(targetTypeID));
case common::INTERVAL: {
assert(sourceTypeID == common::STRING);
return VectorOperations::UnaryExecFunction<ku_string_t, interval_t,
operation::CastStringToInterval>;
}
default:
throw common::NotImplementedException("Unimplemented casting operation from " +
common::Types::dataTypeToString(sourceTypeID) +
" to " +
common::Types::dataTypeToString(targetTypeID) + ".");
}
}

Expand Down Expand Up @@ -162,9 +150,6 @@ CastToStringVectorOperation::getDefinitions() {
result.push_back(make_unique<VectorOperationDefinition>(CAST_TO_STRING_FUNC_NAME,
std::vector<DataTypeID>{VAR_LIST}, STRING,
UnaryCastExecFunction<ku_list_t, ku_string_t, operation::CastToString>));
result.push_back(make_unique<VectorOperationDefinition>(CAST_TO_STRING_FUNC_NAME,
std::vector<DataTypeID>{FLOAT}, STRING,
UnaryCastExecFunction<float_t, ku_string_t, operation::CastToString>));
return result;
}

Expand All @@ -191,6 +176,9 @@ CastToFloatVectorOperation::getDefinitions() {
CAST_TO_FLOAT_FUNC_NAME, INT32, FLOAT));
result.push_back(bindVectorOperation<int64_t, float_t, operation::CastToFloat>(
CAST_TO_FLOAT_FUNC_NAME, INT64, FLOAT));
// down cast
result.push_back(bindVectorOperation<double_t, float_t, operation::CastToFloat>(
CAST_TO_FLOAT_FUNC_NAME, DOUBLE, FLOAT));
return result;
}

Expand All @@ -201,6 +189,11 @@ CastToInt64VectorOperation::getDefinitions() {
CAST_TO_INT64_FUNC_NAME, INT16, INT64));
result.push_back(bindVectorOperation<int32_t, int64_t, operation::CastToInt64>(
CAST_TO_INT64_FUNC_NAME, INT32, INT64));
// down cast
result.push_back(bindVectorOperation<float_t, int64_t, operation::CastToInt64>(
CAST_TO_INT64_FUNC_NAME, FLOAT, INT64));
result.push_back(bindVectorOperation<double_t, int64_t, operation::CastToInt64>(
CAST_TO_INT64_FUNC_NAME, DOUBLE, INT64));
return result;
}

Expand All @@ -209,6 +202,28 @@ CastToInt32VectorOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> result;
result.push_back(bindVectorOperation<int16_t, int32_t, operation::CastToInt32>(
CAST_TO_INT32_FUNC_NAME, INT16, INT32));
// down cast
result.push_back(bindVectorOperation<int64_t, int32_t, operation::CastToInt32>(
CAST_TO_INT32_FUNC_NAME, INT64, INT32));
result.push_back(bindVectorOperation<float_t, int32_t, operation::CastToInt32>(
CAST_TO_INT32_FUNC_NAME, FLOAT, INT32));
result.push_back(bindVectorOperation<double_t, int32_t, operation::CastToInt32>(
CAST_TO_INT32_FUNC_NAME, DOUBLE, INT32));
return result;
}

std::vector<std::unique_ptr<VectorOperationDefinition>>
CastToInt16VectorOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> result;
// down cast
result.push_back(bindVectorOperation<int32_t, int16_t, operation::CastToInt16>(
CAST_TO_INT32_FUNC_NAME, INT32, INT16));
result.push_back(bindVectorOperation<int64_t, int16_t, operation::CastToInt16>(
CAST_TO_INT32_FUNC_NAME, INT64, INT16));
result.push_back(bindVectorOperation<float_t, int16_t, operation::CastToInt16>(
CAST_TO_INT32_FUNC_NAME, FLOAT, INT16));
result.push_back(bindVectorOperation<double_t, int16_t, operation::CastToInt16>(
CAST_TO_INT32_FUNC_NAME, DOUBLE, INT16));
return result;
}

Expand Down
6 changes: 1 addition & 5 deletions src/include/common/expression_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,7 @@ const std::string CAST_TO_DOUBLE_FUNC_NAME = "TO_DOUBLE";
const std::string CAST_TO_FLOAT_FUNC_NAME = "TO_FLOAT";
const std::string CAST_TO_INT64_FUNC_NAME = "TO_INT64";
const std::string CAST_TO_INT32_FUNC_NAME = "TO_INT32";
const std::string IMPLICIT_CAST_TO_BOOL_FUNC_NAME = "_BOOL";
const std::string IMPLICIT_CAST_TO_INT_FUNC_NAME = "_INT";
const std::string IMPLICIT_CAST_TO_STRING_FUNC_NAME = "_STRING";
const std::string IMPLICIT_CAST_TO_DATE_FUNC_NAME = "_DATE";
const std::string IMPLICIT_CAST_TO_TIMESTAMP_FUNC_NAME = "_TIMESTAMP";
const std::string CAST_TO_INT16_FUNC_NAME = "TO_INT16";

// list
const std::string LIST_CREATION_FUNC_NAME = "LIST_CREATION";
Expand Down
1 change: 1 addition & 0 deletions src/include/common/types/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class Types {
KUZU_API static DataType dataTypeFromString(const std::string& dataTypeString);
static uint32_t getDataTypeSize(DataTypeID dataTypeID);
static uint32_t getDataTypeSize(const DataType& dataType);
static bool isNumerical(const DataType& dataType);

private:
static DataTypeID dataTypeIDFromString(const std::string& dataTypeIDString);
Expand Down
9 changes: 9 additions & 0 deletions src/include/common/types/value.h
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,15 @@ inline int64_t Value::getValue() const {
return val.int64Val;
}

/**
* @return float value.
*/
KUZU_API template<>
inline float Value::getValue() const {
assert(dataType.getTypeID() == FLOAT);
return val.floatVal;
}

/**
* @return double value.
*/
Expand Down
Loading