Skip to content

Commit

Permalink
Merge pull request #2114 from kuzudb/cast-fix
Browse files Browse the repository at this point in the history
Add overflow check for casting
  • Loading branch information
acquamarin committed Sep 29, 2023
2 parents 815580a + 81f77b7 commit bc74559
Show file tree
Hide file tree
Showing 23 changed files with 620 additions and 329 deletions.
2 changes: 1 addition & 1 deletion dataset/tinysnb/eStudyAt.csv
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from,to,YEAR,Places,length,level,code,temprature,ulength,ulevel
0,1,2021,"[wwAewsdndweusd,wek]",5,5,6556,35,120,15
0,1,2021,"[wwAewsdndweusd,wek]",5,5,9223372036854775808,32800,33768,250
2,1,2020,"[anew,jsdnwusklklklwewsd]",55,120,6689,1,90,220
8,1,2020,"[awndsnjwejwen,isuhuwennjnuhuhuwewe]",22,2,23,20,180,12
9 changes: 8 additions & 1 deletion src/function/vector_cast_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,6 @@ vector_function_definitions CastToInt64VectorFunction::getDefinitions() {
CAST_TO_INT64_FUNC_NAME, LogicalTypeID::INT16, LogicalTypeID::INT64));
result.push_back(bindVectorFunction<int32_t, int64_t, CastToInt64>(
CAST_TO_INT64_FUNC_NAME, LogicalTypeID::INT32, LogicalTypeID::INT64));
// down cast
result.push_back(bindVectorFunction<uint64_t, int64_t, CastToInt64>(
CAST_TO_INT64_FUNC_NAME, LogicalTypeID::UINT64, LogicalTypeID::INT64));
result.push_back(bindVectorFunction<float_t, int64_t, CastToInt64>(
Expand Down Expand Up @@ -397,6 +396,14 @@ vector_function_definitions CastToInt8VectorFunction::getDefinitions() {
CAST_TO_INT8_FUNC_NAME, LogicalTypeID::INT32, LogicalTypeID::INT8));
result.push_back(bindVectorFunction<int64_t, int8_t, CastToInt8>(
CAST_TO_INT8_FUNC_NAME, LogicalTypeID::INT64, LogicalTypeID::INT8));
result.push_back(bindVectorFunction<uint64_t, int8_t, CastToInt8>(
CAST_TO_INT8_FUNC_NAME, LogicalTypeID::UINT64, LogicalTypeID::INT8));
result.push_back(bindVectorFunction<uint32_t, int8_t, CastToInt8>(
CAST_TO_INT8_FUNC_NAME, LogicalTypeID::UINT32, LogicalTypeID::INT8));
result.push_back(bindVectorFunction<uint16_t, int8_t, CastToInt8>(
CAST_TO_INT8_FUNC_NAME, LogicalTypeID::UINT16, LogicalTypeID::INT8));
result.push_back(bindVectorFunction<uint8_t, int8_t, CastToInt8>(
CAST_TO_INT8_FUNC_NAME, LogicalTypeID::UINT8, LogicalTypeID::INT8));
result.push_back(bindVectorFunction<float_t, int8_t, CastToInt8>(
CAST_TO_INT8_FUNC_NAME, LogicalTypeID::FLOAT, LogicalTypeID::INT8));
result.push_back(bindVectorFunction<uint64_t, int8_t, CastToInt8>(
Expand Down
242 changes: 47 additions & 195 deletions src/include/function/cast/cast_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

#include <cassert>

#include "common/exception/runtime.h"
#include "common/string_utils.h"
#include "common/type_utils.h"
#include "common/types/blob.h"
#include "common/vector/value_vector.h"
#include "numeric_cast.h"
Expand Down Expand Up @@ -83,19 +86,13 @@ inline std::string CastToString::castToStringWithVector(
return common::TypeUtils::toString(input, (void*)&inputVector);
}

template<typename SRC, typename DST>
static inline void numericDownCast(SRC& input, DST& result, const std::string& dstTypeStr) {
if (input < std::numeric_limits<DST>::min() || input > std::numeric_limits<DST>::max()) {
throw common::RuntimeException(
"Cast failed. " + std::to_string(input) + " is not in " + dstTypeStr + " range.");
}
result = (DST)input;
}

struct CastToDouble {
template<typename T>
static inline void operation(T& input, double_t& result) {
result = static_cast<double_t>(input);
if (!tryCastWithOverflowCheck(input, result)) {
throw common::RuntimeException{common::StringUtils::string_format(
"Value {} is not within DOUBLE range", common::TypeUtils::toString(input).c_str())};
}
}
};

Expand All @@ -114,7 +111,10 @@ inline void CastToDouble::operation(common::ku_string_t& input, double_t& result
struct CastToFloat {
template<typename T>
static inline void operation(T& input, float_t& result) {
result = static_cast<float_t>(input);
if (!tryCastWithOverflowCheck(input, result)) {
throw common::RuntimeException{common::StringUtils::string_format(
"Value {} is not within FLOAT range", common::TypeUtils::toString(input).c_str())};
}
}
};

Expand All @@ -130,28 +130,16 @@ inline void CastToFloat::operation(common::ku_string_t& input, float_t& result)
common::LogicalType{common::LogicalTypeID::FLOAT});
}

template<>
inline void CastToFloat::operation(double_t& input, float_t& result) {
numericDownCast<double_t, float_t>(input, result, "FLOAT");
}

struct CastToInt64 {
template<typename T>
static inline void operation(T& input, int64_t& result) {
result = static_cast<int64_t>(input);
if (!tryCastWithOverflowCheck(input, result)) {
throw common::RuntimeException{common::StringUtils::string_format(
"Value {} is not within INT64 range", common::TypeUtils::toString(input).c_str())};
}
}
};

template<>
inline void CastToInt64::operation(double_t& input, int64_t& result) {
numericDownCast<double_t, int64_t>(input, result, "INT64");
}

template<>
inline void CastToInt64::operation(float_t& input, int64_t& result) {
numericDownCast<float_t, int64_t>(input, result, "INT64");
}

template<>
inline void CastToInt64::operation(char*& input, int64_t& result) {
simpleIntegerCast<int64_t, true>(
Expand All @@ -167,32 +155,23 @@ inline void CastToInt64::operation(common::ku_string_t& input, int64_t& result)
struct CastToSerial {
template<typename T>
static inline void operation(T& input, int64_t& result) {
CastToInt64::operation(input, result);
if (!tryCastWithOverflowCheck(input, result)) {
throw common::RuntimeException{common::StringUtils::string_format(
"Value {} is not within INT64 range", common::TypeUtils::toString(input).c_str())};
}
}
};

struct CastToInt32 {
template<typename T>
static inline void operation(T& input, int32_t& result) {
result = static_cast<int32_t>(input);
if (!tryCastWithOverflowCheck(input, result)) {
throw common::RuntimeException{common::StringUtils::string_format(
"Value {} is not within INT32 range", common::TypeUtils::toString(input).c_str())};
}
}
};

template<>
inline void CastToInt32::operation(double_t& input, int32_t& result) {
numericDownCast<double_t, int32_t>(input, result, "INT32");
}

template<>
inline void CastToInt32::operation(float_t& input, int32_t& result) {
numericDownCast<float_t, int32_t>(input, result, "INT32");
}

template<>
inline void CastToInt32::operation(int64_t& input, int32_t& result) {
numericDownCast<int64_t, int32_t>(input, result, "INT32");
}

template<>
inline void CastToInt32::operation(char*& input, int32_t& result) {
simpleIntegerCast<int32_t, true>(
Expand All @@ -208,30 +187,13 @@ inline void CastToInt32::operation(common::ku_string_t& input, int32_t& result)
struct CastToInt16 {
template<typename T>
static inline void operation(T& input, int16_t& result) {
result = static_cast<int16_t>(input);
if (!tryCastWithOverflowCheck(input, result)) {
throw common::RuntimeException{common::StringUtils::string_format(
"Value {} is not within INT16 range", common::TypeUtils::toString(input).c_str())};
}
}
};

template<>
inline void CastToInt16::operation(double_t& input, int16_t& result) {
numericDownCast<double_t, int16_t>(input, result, "INT16");
}

template<>
inline void CastToInt16::operation(float_t& input, int16_t& result) {
numericDownCast<float_t, int16_t>(input, result, "INT16");
}

template<>
inline void CastToInt16::operation(int64_t& input, int16_t& result) {
numericDownCast<int64_t, int16_t>(input, result, "INT16");
}

template<>
inline void CastToInt16::operation(int32_t& input, int16_t& result) {
numericDownCast<int32_t, int16_t>(input, result, "INT16");
}

template<>
inline void CastToInt16::operation(common::ku_string_t& input, int16_t& result) {
simpleIntegerCast<int16_t, true>((char*)input.getData(), input.len, result,
Expand All @@ -247,35 +209,13 @@ inline void CastToInt16::operation(char*& input, int16_t& result) {
struct CastToInt8 {
template<typename T>
static inline void operation(T& input, int8_t& result) {
result = static_cast<int8_t>(input);
if (!tryCastWithOverflowCheck(input, result)) {
throw common::RuntimeException{common::StringUtils::string_format(
"Value {} is not within INT8 range", common::TypeUtils::toString(input).c_str())};
}
}
};

template<>
inline void CastToInt8::operation(double_t& input, int8_t& result) {
numericDownCast<double_t, int8_t>(input, result, "INT8");
}

template<>
inline void CastToInt8::operation(float_t& input, int8_t& result) {
numericDownCast<float_t, int8_t>(input, result, "INT8");
}

template<>
inline void CastToInt8::operation(int64_t& input, int8_t& result) {
numericDownCast<int64_t, int8_t>(input, result, "INT8");
}

template<>
inline void CastToInt8::operation(int32_t& input, int8_t& result) {
numericDownCast<int32_t, int8_t>(input, result, "INT8");
}

template<>
inline void CastToInt8::operation(int16_t& input, int8_t& result) {
numericDownCast<int16_t, int8_t>(input, result, "INT8");
}

template<>
inline void CastToInt8::operation(common::ku_string_t& input, int8_t& result) {
simpleIntegerCast<int8_t, true>((char*)input.getData(), input.len, result,
Expand All @@ -291,20 +231,13 @@ inline void CastToInt8::operation(char*& input, int8_t& result) {
struct CastToUInt64 {
template<typename T>
static inline void operation(T& input, uint64_t& result) {
result = static_cast<uint64_t>(input);
if (!tryCastWithOverflowCheck(input, result)) {
throw common::RuntimeException{common::StringUtils::string_format(
"Value {} is not within UINT64 range", common::TypeUtils::toString(input).c_str())};
}
}
};

template<>
inline void CastToUInt64::operation(double_t& input, uint64_t& result) {
numericDownCast<double_t, uint64_t>(input, result, "UINT64");
}

template<>
inline void CastToUInt64::operation(float_t& input, uint64_t& result) {
numericDownCast<float_t, uint64_t>(input, result, "UINT64");
}

template<>
inline void CastToUInt64::operation(common::ku_string_t& input, uint64_t& result) {
simpleIntegerCast<uint64_t, false>((char*)input.getData(), input.len, result,
Expand All @@ -320,30 +253,13 @@ inline void CastToUInt64::operation(char*& input, uint64_t& result) {
struct CastToUInt32 {
template<typename T>
static inline void operation(T& input, uint32_t& result) {
result = static_cast<uint32_t>(input);
if (!tryCastWithOverflowCheck(input, result)) {
throw common::RuntimeException{common::StringUtils::string_format(
"Value {} is not within UINT32 range", common::TypeUtils::toString(input).c_str())};
}
}
};

template<>
inline void CastToUInt32::operation(double_t& input, uint32_t& result) {
numericDownCast<double_t, uint32_t>(input, result, "UINT32");
}

template<>
inline void CastToUInt32::operation(float_t& input, uint32_t& result) {
numericDownCast<float_t, uint32_t>(input, result, "UINT32");
}

template<>
inline void CastToUInt32::operation(int64_t& input, uint32_t& result) {
numericDownCast<int64_t, uint32_t>(input, result, "UINT32");
}

template<>
inline void CastToUInt32::operation(uint64_t& input, uint32_t& result) {
numericDownCast<uint64_t, uint32_t>(input, result, "UINT32");
}

template<>
inline void CastToUInt32::operation(common::ku_string_t& input, uint32_t& result) {
simpleIntegerCast<uint32_t, false>((char*)input.getData(), input.len, result,
Expand All @@ -359,40 +275,13 @@ inline void CastToUInt32::operation(char*& input, uint32_t& result) {
struct CastToUInt16 {
template<typename T>
static inline void operation(T& input, uint16_t& result) {
result = static_cast<uint16_t>(input);
if (!tryCastWithOverflowCheck(input, result)) {
throw common::RuntimeException{common::StringUtils::string_format(
"Value {} is not within UINT16 range", common::TypeUtils::toString(input).c_str())};
}
}
};

template<>
inline void CastToUInt16::operation(double_t& input, uint16_t& result) {
numericDownCast<double_t, uint16_t>(input, result, "UINT16");
}

template<>
inline void CastToUInt16::operation(float_t& input, uint16_t& result) {
numericDownCast<float_t, uint16_t>(input, result, "UINT16");
}

template<>
inline void CastToUInt16::operation(int64_t& input, uint16_t& result) {
numericDownCast<int64_t, uint16_t>(input, result, "UINT16");
}

template<>
inline void CastToUInt16::operation(uint64_t& input, uint16_t& result) {
numericDownCast<uint64_t, uint16_t>(input, result, "UINT16");
}

template<>
inline void CastToUInt16::operation(int32_t& input, uint16_t& result) {
numericDownCast<int32_t, uint16_t>(input, result, "UINT16");
}

template<>
inline void CastToUInt16::operation(uint32_t& input, uint16_t& result) {
numericDownCast<uint32_t, uint16_t>(input, result, "UINT16");
}

template<>
inline void CastToUInt16::operation(common::ku_string_t& input, uint16_t& result) {
simpleIntegerCast<uint16_t, false>((char*)input.getData(), input.len, result,
Expand All @@ -408,50 +297,13 @@ inline void CastToUInt16::operation(char*& input, uint16_t& result) {
struct CastToUInt8 {
template<typename T>
static inline void operation(T& input, uint8_t& result) {
result = static_cast<uint8_t>(input);
if (!tryCastWithOverflowCheck(input, result)) {
throw common::RuntimeException{common::StringUtils::string_format(
"Value {} is not within UINT8 range", common::TypeUtils::toString(input).c_str())};
}
}
};

template<>
inline void CastToUInt8::operation(double_t& input, uint8_t& result) {
numericDownCast<double_t, uint8_t>(input, result, "UINT8");
}

template<>
inline void CastToUInt8::operation(float_t& input, uint8_t& result) {
numericDownCast<float_t, uint8_t>(input, result, "UINT8");
}

template<>
inline void CastToUInt8::operation(int64_t& input, uint8_t& result) {
numericDownCast<int64_t, uint8_t>(input, result, "UINT8");
}

template<>
inline void CastToUInt8::operation(uint64_t& input, uint8_t& result) {
numericDownCast<uint64_t, uint8_t>(input, result, "UINT8");
}

template<>
inline void CastToUInt8::operation(int32_t& input, uint8_t& result) {
numericDownCast<int32_t, uint8_t>(input, result, "UINT8");
}

template<>
inline void CastToUInt8::operation(uint32_t& input, uint8_t& result) {
numericDownCast<uint32_t, uint8_t>(input, result, "UINT8");
}

template<>
inline void CastToUInt8::operation(int16_t& input, uint8_t& result) {
numericDownCast<int16_t, uint8_t>(input, result, "UINT8");
}

template<>
inline void CastToUInt8::operation(uint16_t& input, uint8_t& result) {
numericDownCast<uint16_t, uint8_t>(input, result, "UINT8");
}

template<>
inline void CastToUInt8::operation(common::ku_string_t& input, uint8_t& result) {
simpleIntegerCast<uint8_t, false>((char*)input.getData(), input.len, result,
Expand Down
Loading

0 comments on commit bc74559

Please sign in to comment.