Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add overflow check for casting #2114

Merged
merged 1 commit into from
Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dataset/tinysnb/eStudyAt.csv
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from,to,YEAR,Places,length,level,code,temprature,ulength,ulevel
0,1,2021,"[wwAewsdndweusd,wek]",5,5,6556,35,120,15
0,1,2021,"[wwAewsdndweusd,wek]",5,5,9223372036854775808,32800,33768,250
2,1,2020,"[anew,jsdnwusklklklwewsd]",55,120,6689,1,90,220
8,1,2020,"[awndsnjwejwen,isuhuwennjnuhuhuwewe]",22,2,23,20,180,12
9 changes: 8 additions & 1 deletion src/function/vector_cast_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,6 @@ vector_function_definitions CastToInt64VectorFunction::getDefinitions() {
CAST_TO_INT64_FUNC_NAME, LogicalTypeID::INT16, LogicalTypeID::INT64));
result.push_back(bindVectorFunction<int32_t, int64_t, CastToInt64>(
CAST_TO_INT64_FUNC_NAME, LogicalTypeID::INT32, LogicalTypeID::INT64));
// down cast
result.push_back(bindVectorFunction<uint64_t, int64_t, CastToInt64>(
CAST_TO_INT64_FUNC_NAME, LogicalTypeID::UINT64, LogicalTypeID::INT64));
result.push_back(bindVectorFunction<float_t, int64_t, CastToInt64>(
Expand Down Expand Up @@ -397,6 +396,14 @@ vector_function_definitions CastToInt8VectorFunction::getDefinitions() {
CAST_TO_INT8_FUNC_NAME, LogicalTypeID::INT32, LogicalTypeID::INT8));
result.push_back(bindVectorFunction<int64_t, int8_t, CastToInt8>(
CAST_TO_INT8_FUNC_NAME, LogicalTypeID::INT64, LogicalTypeID::INT8));
result.push_back(bindVectorFunction<uint64_t, int8_t, CastToInt8>(
CAST_TO_INT8_FUNC_NAME, LogicalTypeID::UINT64, LogicalTypeID::INT8));
result.push_back(bindVectorFunction<uint32_t, int8_t, CastToInt8>(
CAST_TO_INT8_FUNC_NAME, LogicalTypeID::UINT32, LogicalTypeID::INT8));
result.push_back(bindVectorFunction<uint16_t, int8_t, CastToInt8>(
CAST_TO_INT8_FUNC_NAME, LogicalTypeID::UINT16, LogicalTypeID::INT8));
result.push_back(bindVectorFunction<uint8_t, int8_t, CastToInt8>(
CAST_TO_INT8_FUNC_NAME, LogicalTypeID::UINT8, LogicalTypeID::INT8));
result.push_back(bindVectorFunction<float_t, int8_t, CastToInt8>(
CAST_TO_INT8_FUNC_NAME, LogicalTypeID::FLOAT, LogicalTypeID::INT8));
result.push_back(bindVectorFunction<uint64_t, int8_t, CastToInt8>(
Expand Down
242 changes: 47 additions & 195 deletions src/include/function/cast/cast_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

#include <cassert>

#include "common/exception/runtime.h"
#include "common/string_utils.h"
#include "common/type_utils.h"
#include "common/types/blob.h"
#include "common/vector/value_vector.h"
#include "numeric_cast.h"
Expand Down Expand Up @@ -83,19 +86,13 @@
return common::TypeUtils::toString(input, (void*)&inputVector);
}

template<typename SRC, typename DST>
static inline void numericDownCast(SRC& input, DST& result, const std::string& dstTypeStr) {
if (input < std::numeric_limits<DST>::min() || input > std::numeric_limits<DST>::max()) {
throw common::RuntimeException(
"Cast failed. " + std::to_string(input) + " is not in " + dstTypeStr + " range.");
}
result = (DST)input;
}

struct CastToDouble {
template<typename T>
static inline void operation(T& input, double_t& result) {
result = static_cast<double_t>(input);
if (!tryCastWithOverflowCheck(input, result)) {
throw common::RuntimeException{common::StringUtils::string_format(
"Value {} is not within DOUBLE range", common::TypeUtils::toString(input).c_str())};
}
}
};

Expand All @@ -114,7 +111,10 @@
struct CastToFloat {
template<typename T>
static inline void operation(T& input, float_t& result) {
result = static_cast<float_t>(input);
if (!tryCastWithOverflowCheck(input, result)) {
throw common::RuntimeException{common::StringUtils::string_format(
"Value {} is not within FLOAT range", common::TypeUtils::toString(input).c_str())};
}
}
};

Expand All @@ -130,28 +130,16 @@
common::LogicalType{common::LogicalTypeID::FLOAT});
}

template<>
inline void CastToFloat::operation(double_t& input, float_t& result) {
numericDownCast<double_t, float_t>(input, result, "FLOAT");
}

struct CastToInt64 {
template<typename T>
static inline void operation(T& input, int64_t& result) {
result = static_cast<int64_t>(input);
if (!tryCastWithOverflowCheck(input, result)) {
throw common::RuntimeException{common::StringUtils::string_format(
"Value {} is not within INT64 range", common::TypeUtils::toString(input).c_str())};
}
}
};

template<>
inline void CastToInt64::operation(double_t& input, int64_t& result) {
numericDownCast<double_t, int64_t>(input, result, "INT64");
}

template<>
inline void CastToInt64::operation(float_t& input, int64_t& result) {
numericDownCast<float_t, int64_t>(input, result, "INT64");
}

template<>
inline void CastToInt64::operation(char*& input, int64_t& result) {
simpleIntegerCast<int64_t, true>(
Expand All @@ -167,32 +155,23 @@
struct CastToSerial {
template<typename T>
static inline void operation(T& input, int64_t& result) {
CastToInt64::operation(input, result);
if (!tryCastWithOverflowCheck(input, result)) {
throw common::RuntimeException{common::StringUtils::string_format(

Check warning on line 159 in src/include/function/cast/cast_functions.h

View check run for this annotation

Codecov / codecov/patch

src/include/function/cast/cast_functions.h#L159

Added line #L159 was not covered by tests
"Value {} is not within INT64 range", common::TypeUtils::toString(input).c_str())};
}
}
};

struct CastToInt32 {
template<typename T>
static inline void operation(T& input, int32_t& result) {
result = static_cast<int32_t>(input);
if (!tryCastWithOverflowCheck(input, result)) {
throw common::RuntimeException{common::StringUtils::string_format(
"Value {} is not within INT32 range", common::TypeUtils::toString(input).c_str())};
}
}
};

template<>
inline void CastToInt32::operation(double_t& input, int32_t& result) {
numericDownCast<double_t, int32_t>(input, result, "INT32");
}

template<>
inline void CastToInt32::operation(float_t& input, int32_t& result) {
numericDownCast<float_t, int32_t>(input, result, "INT32");
}

template<>
inline void CastToInt32::operation(int64_t& input, int32_t& result) {
numericDownCast<int64_t, int32_t>(input, result, "INT32");
}

template<>
inline void CastToInt32::operation(char*& input, int32_t& result) {
simpleIntegerCast<int32_t, true>(
Expand All @@ -208,30 +187,13 @@
struct CastToInt16 {
template<typename T>
static inline void operation(T& input, int16_t& result) {
result = static_cast<int16_t>(input);
if (!tryCastWithOverflowCheck(input, result)) {
throw common::RuntimeException{common::StringUtils::string_format(
"Value {} is not within INT16 range", common::TypeUtils::toString(input).c_str())};
}
}
};

template<>
inline void CastToInt16::operation(double_t& input, int16_t& result) {
numericDownCast<double_t, int16_t>(input, result, "INT16");
}

template<>
inline void CastToInt16::operation(float_t& input, int16_t& result) {
numericDownCast<float_t, int16_t>(input, result, "INT16");
}

template<>
inline void CastToInt16::operation(int64_t& input, int16_t& result) {
numericDownCast<int64_t, int16_t>(input, result, "INT16");
}

template<>
inline void CastToInt16::operation(int32_t& input, int16_t& result) {
numericDownCast<int32_t, int16_t>(input, result, "INT16");
}

template<>
inline void CastToInt16::operation(common::ku_string_t& input, int16_t& result) {
simpleIntegerCast<int16_t, true>((char*)input.getData(), input.len, result,
Expand All @@ -247,35 +209,13 @@
struct CastToInt8 {
template<typename T>
static inline void operation(T& input, int8_t& result) {
result = static_cast<int8_t>(input);
if (!tryCastWithOverflowCheck(input, result)) {
throw common::RuntimeException{common::StringUtils::string_format(
"Value {} is not within INT8 range", common::TypeUtils::toString(input).c_str())};
}
}
};

template<>
inline void CastToInt8::operation(double_t& input, int8_t& result) {
numericDownCast<double_t, int8_t>(input, result, "INT8");
}

template<>
inline void CastToInt8::operation(float_t& input, int8_t& result) {
numericDownCast<float_t, int8_t>(input, result, "INT8");
}

template<>
inline void CastToInt8::operation(int64_t& input, int8_t& result) {
numericDownCast<int64_t, int8_t>(input, result, "INT8");
}

template<>
inline void CastToInt8::operation(int32_t& input, int8_t& result) {
numericDownCast<int32_t, int8_t>(input, result, "INT8");
}

template<>
inline void CastToInt8::operation(int16_t& input, int8_t& result) {
numericDownCast<int16_t, int8_t>(input, result, "INT8");
}

template<>
inline void CastToInt8::operation(common::ku_string_t& input, int8_t& result) {
simpleIntegerCast<int8_t, true>((char*)input.getData(), input.len, result,
Expand All @@ -291,20 +231,13 @@
struct CastToUInt64 {
template<typename T>
static inline void operation(T& input, uint64_t& result) {
result = static_cast<uint64_t>(input);
if (!tryCastWithOverflowCheck(input, result)) {
throw common::RuntimeException{common::StringUtils::string_format(
"Value {} is not within UINT64 range", common::TypeUtils::toString(input).c_str())};
}
}
};

template<>
inline void CastToUInt64::operation(double_t& input, uint64_t& result) {
numericDownCast<double_t, uint64_t>(input, result, "UINT64");
}

template<>
inline void CastToUInt64::operation(float_t& input, uint64_t& result) {
numericDownCast<float_t, uint64_t>(input, result, "UINT64");
}

template<>
inline void CastToUInt64::operation(common::ku_string_t& input, uint64_t& result) {
simpleIntegerCast<uint64_t, false>((char*)input.getData(), input.len, result,
Expand All @@ -320,30 +253,13 @@
struct CastToUInt32 {
template<typename T>
static inline void operation(T& input, uint32_t& result) {
result = static_cast<uint32_t>(input);
if (!tryCastWithOverflowCheck(input, result)) {
throw common::RuntimeException{common::StringUtils::string_format(
"Value {} is not within UINT32 range", common::TypeUtils::toString(input).c_str())};
}
}
};

template<>
inline void CastToUInt32::operation(double_t& input, uint32_t& result) {
numericDownCast<double_t, uint32_t>(input, result, "UINT32");
}

template<>
inline void CastToUInt32::operation(float_t& input, uint32_t& result) {
numericDownCast<float_t, uint32_t>(input, result, "UINT32");
}

template<>
inline void CastToUInt32::operation(int64_t& input, uint32_t& result) {
numericDownCast<int64_t, uint32_t>(input, result, "UINT32");
}

template<>
inline void CastToUInt32::operation(uint64_t& input, uint32_t& result) {
numericDownCast<uint64_t, uint32_t>(input, result, "UINT32");
}

template<>
inline void CastToUInt32::operation(common::ku_string_t& input, uint32_t& result) {
simpleIntegerCast<uint32_t, false>((char*)input.getData(), input.len, result,
Expand All @@ -359,40 +275,13 @@
struct CastToUInt16 {
template<typename T>
static inline void operation(T& input, uint16_t& result) {
result = static_cast<uint16_t>(input);
if (!tryCastWithOverflowCheck(input, result)) {
throw common::RuntimeException{common::StringUtils::string_format(
"Value {} is not within UINT16 range", common::TypeUtils::toString(input).c_str())};
}
}
};

template<>
inline void CastToUInt16::operation(double_t& input, uint16_t& result) {
numericDownCast<double_t, uint16_t>(input, result, "UINT16");
}

template<>
inline void CastToUInt16::operation(float_t& input, uint16_t& result) {
numericDownCast<float_t, uint16_t>(input, result, "UINT16");
}

template<>
inline void CastToUInt16::operation(int64_t& input, uint16_t& result) {
numericDownCast<int64_t, uint16_t>(input, result, "UINT16");
}

template<>
inline void CastToUInt16::operation(uint64_t& input, uint16_t& result) {
numericDownCast<uint64_t, uint16_t>(input, result, "UINT16");
}

template<>
inline void CastToUInt16::operation(int32_t& input, uint16_t& result) {
numericDownCast<int32_t, uint16_t>(input, result, "UINT16");
}

template<>
inline void CastToUInt16::operation(uint32_t& input, uint16_t& result) {
numericDownCast<uint32_t, uint16_t>(input, result, "UINT16");
}

template<>
inline void CastToUInt16::operation(common::ku_string_t& input, uint16_t& result) {
simpleIntegerCast<uint16_t, false>((char*)input.getData(), input.len, result,
Expand All @@ -408,50 +297,13 @@
struct CastToUInt8 {
template<typename T>
static inline void operation(T& input, uint8_t& result) {
result = static_cast<uint8_t>(input);
if (!tryCastWithOverflowCheck(input, result)) {
throw common::RuntimeException{common::StringUtils::string_format(
"Value {} is not within UINT8 range", common::TypeUtils::toString(input).c_str())};
}
}
};

template<>
inline void CastToUInt8::operation(double_t& input, uint8_t& result) {
numericDownCast<double_t, uint8_t>(input, result, "UINT8");
}

template<>
inline void CastToUInt8::operation(float_t& input, uint8_t& result) {
numericDownCast<float_t, uint8_t>(input, result, "UINT8");
}

template<>
inline void CastToUInt8::operation(int64_t& input, uint8_t& result) {
numericDownCast<int64_t, uint8_t>(input, result, "UINT8");
}

template<>
inline void CastToUInt8::operation(uint64_t& input, uint8_t& result) {
numericDownCast<uint64_t, uint8_t>(input, result, "UINT8");
}

template<>
inline void CastToUInt8::operation(int32_t& input, uint8_t& result) {
numericDownCast<int32_t, uint8_t>(input, result, "UINT8");
}

template<>
inline void CastToUInt8::operation(uint32_t& input, uint8_t& result) {
numericDownCast<uint32_t, uint8_t>(input, result, "UINT8");
}

template<>
inline void CastToUInt8::operation(int16_t& input, uint8_t& result) {
numericDownCast<int16_t, uint8_t>(input, result, "UINT8");
}

template<>
inline void CastToUInt8::operation(uint16_t& input, uint8_t& result) {
numericDownCast<uint16_t, uint8_t>(input, result, "UINT8");
}

template<>
inline void CastToUInt8::operation(common::ku_string_t& input, uint8_t& result) {
simpleIntegerCast<uint8_t, false>((char*)input.getData(), input.len, result,
Expand Down
Loading
Loading