From eb65ca04944327ac5f0315b4ffa7093967fdfeb5 Mon Sep 17 00:00:00 2001 From: AEsir777 Date: Fri, 20 Oct 2023 15:08:36 -0400 Subject: [PATCH] refactor to use wrapper function to wrap all the codes in cast_string_to_function.h --- src/binder/bind/bind_graph_pattern.cpp | 6 +- src/c_api/value.cpp | 8 +- src/function/CMakeLists.txt | 3 +- ...p => cast_string_non_nested_functions.cpp} | 8 +- src/function/cast_string_to_functions.cpp | 750 ++++++++++++++++++ src/function/vector_cast_functions.cpp | 128 ++- src/function/vector_union_functions.cpp | 2 +- src/include/common/type_utils.h | 7 +- src/include/common/types/types.h | 4 +- src/include/function/cast/cast_functions.h | 396 --------- .../function/cast/functions/cast_functions.h | 217 +++++ .../cast_string_non_nested_functions.h} | 218 +---- .../cast/functions/cast_string_to_functions.h | 269 +++++++ .../function/cast/functions/numeric_cast.h | 164 ++++ .../cast/{ => functions}/numeric_limits.h | 0 .../function/cast/vector_cast_functions.h | 44 +- .../function/string/vector_string_functions.h | 14 +- src/include/function/vector_functions.h | 16 + .../writer/parquet/standard_column_writer.h | 2 +- .../in_mem_column_chunk.h | 26 +- .../in_mem_storage_structure/in_mem_lists.h | 23 +- src/include/storage/store/column_chunk.h | 1 - src/parser/transform/transform_expression.cpp | 27 +- .../operator/persistent/copy_to_csv.cpp | 2 +- .../operator/persistent/reader/csv/driver.cpp | 688 +--------------- .../writer/parquet/basic_column_writer.cpp | 2 +- .../writer/parquet/column_writer.cpp | 2 +- .../in_mem_column_chunk.cpp | 34 - .../in_mem_storage_structure/in_mem_lists.cpp | 40 - src/storage/store/table_copy_utils.cpp | 89 +-- test/test_files/tinysnb/cast/cast_error.test | 78 +- .../tinysnb/exception/exception.test | 17 +- 32 files changed, 1634 insertions(+), 1651 deletions(-) rename src/function/{cast_utils.cpp => cast_string_non_nested_functions.cpp} (85%) create mode 100644 src/function/cast_string_to_functions.cpp delete mode 100644 src/include/function/cast/cast_functions.h create mode 100644 src/include/function/cast/functions/cast_functions.h rename src/include/function/cast/{cast_utils.h => functions/cast_string_non_nested_functions.h} (50%) create mode 100644 src/include/function/cast/functions/cast_string_to_functions.h create mode 100644 src/include/function/cast/functions/numeric_cast.h rename src/include/function/cast/{ => functions}/numeric_limits.h (100%) diff --git a/src/binder/bind/bind_graph_pattern.cpp b/src/binder/bind/bind_graph_pattern.cpp index dbb167e2ec3..4e90da4c5d1 100644 --- a/src/binder/bind/bind_graph_pattern.cpp +++ b/src/binder/bind/bind_graph_pattern.cpp @@ -11,7 +11,7 @@ #include "catalog/rel_table_schema.h" #include "common/exception/binder.h" #include "common/string_format.h" -#include "function/cast/cast_utils.h" +#include "function/cast/functions/cast_string_to_functions.h" #include "main/client_context.h" using namespace kuzu::common; @@ -428,11 +428,11 @@ std::pair Binder::bindVariableLengthRelBound( const kuzu::parser::RelPattern& relPattern) { auto recursiveInfo = relPattern.getRecursiveInfo(); uint32_t lowerBound; - function::simpleIntegerCast( + function::CastStringToTypes::operation( recursiveInfo->lowerBound.c_str(), recursiveInfo->lowerBound.length(), lowerBound); auto upperBound = clientContext->varLengthExtendMaxDepth; if (!recursiveInfo->upperBound.empty()) { - function::simpleIntegerCast( + function::CastStringToTypes::operation( recursiveInfo->upperBound.c_str(), recursiveInfo->upperBound.length(), upperBound); } if (lowerBound > upperBound) { diff --git a/src/c_api/value.cpp b/src/c_api/value.cpp index c7102652954..21f73c6260b 100644 --- a/src/c_api/value.cpp +++ b/src/c_api/value.cpp @@ -8,7 +8,7 @@ #include "common/types/value/node.h" #include "common/types/value/recursive_rel.h" #include "common/types/value/rel.h" -#include "function/cast/cast_functions.h" +#include "function/cast/functions/cast_string_to_functions.h" #include "main/kuzu.h" using namespace kuzu::common; @@ -269,10 +269,10 @@ kuzu_int128_t kuzu_int128_t_from_string(const char* str) { int128_t int128_val = 0; kuzu_int128_t c_int128; try { - kuzu::function::CastToInt128::operation(str, int128_val); + kuzu::function::CastStringToTypes::operation(str, strlen(str), int128_val); c_int128.low = int128_val.low; c_int128.high = int128_val.high; - } catch (kuzu::common::ConversionException& e) { + } catch (ConversionException& e) { c_int128.low = 0; c_int128.high = 0; } @@ -283,7 +283,7 @@ char* kuzu_int128_t_to_string(kuzu_int128_t int128_val) { int128_t c_int128; c_int128.low = int128_val.low; c_int128.high = int128_val.high; - return convertToOwnedCString(kuzu::common::Int128_t::ToString(c_int128)); + return convertToOwnedCString(TypeUtils::toString(c_int128)); } // TODO: bind all int128_t supported functions diff --git a/src/function/CMakeLists.txt b/src/function/CMakeLists.txt index 5ca76cee0e5..69393a53d51 100644 --- a/src/function/CMakeLists.txt +++ b/src/function/CMakeLists.txt @@ -7,7 +7,8 @@ add_library(kuzu_function built_in_aggregate_functions.cpp built_in_vector_functions.cpp built_in_table_functions.cpp - cast_utils.cpp + cast_string_non_nested_functions.cpp + cast_string_to_functions.cpp comparison_functions.cpp find_function.cpp scalar_macro_function.cpp diff --git a/src/function/cast_utils.cpp b/src/function/cast_string_non_nested_functions.cpp similarity index 85% rename from src/function/cast_utils.cpp rename to src/function/cast_string_non_nested_functions.cpp index 0a9930b50f2..1afe2adf057 100644 --- a/src/function/cast_utils.cpp +++ b/src/function/cast_string_non_nested_functions.cpp @@ -1,4 +1,6 @@ -#include "function/cast/cast_utils.h" +#include "function/cast/functions/cast_string_non_nested_functions.h" + +#include "common/string_format.h" namespace kuzu { namespace function { @@ -48,8 +50,8 @@ bool tryCastToBool(const char* input, uint64_t len, bool& result) { void castStringToBool(const char* input, uint64_t len, bool& result) { if (!tryCastToBool(input, len, result)) { - throw common::ConversionException( - "Cast failed. " + std::string{input, len} + " is not in BOOL range."); + throw common::ConversionException{ + common::stringFormat("Value {} is not a valid boolean", std::string{input, len})}; } } diff --git a/src/function/cast_string_to_functions.cpp b/src/function/cast_string_to_functions.cpp new file mode 100644 index 00000000000..37cc7cacc03 --- /dev/null +++ b/src/function/cast_string_to_functions.cpp @@ -0,0 +1,750 @@ +#include "function/cast/functions/cast_string_to_functions.h" + +#include "common/exception/copy.h" +#include "common/exception/message.h" +#include "common/exception/parser.h" +#include "common/string_format.h" +#include "common/types/blob.h" +#include "storage/store/table_copy_utils.h" +#include "utf8proc_wrapper.h" + +using namespace kuzu::common; + +namespace kuzu { +namespace function { + +static void skipWhitespace(const char*& input, const char* end) { + while (input < end && isspace(*input)) { + input++; + } +} + +static void trimRightWhitespace(const char* input, const char*& end) { + while (input < end && isspace(*(end - 1))) { + end--; + } +} + +static bool skipToCloseQuotes(const char*& input, const char* end) { + auto ch = *input; + input++; // skip the first " ' + // TODO: escape char + while (input != end) { + if (*input == ch) { + return true; + } + input++; + } + return false; +} + +static bool skipToClose(const char*& input, const char* end, uint64_t& lvl, char target, + const CSVReaderConfig& csvReaderConfig) { + input++; + while (input != end) { + if (*input == '\'') { + if (!skipToCloseQuotes(input, end)) { + return false; + } + } else if (*input == '{') { // must have closing brackets fro {, ] if they are not quoted + if (!skipToClose(input, end, lvl, '}', csvReaderConfig)) { + return false; + } + } else if (*input == csvReaderConfig.listBeginChar) { + if (!skipToClose(input, end, lvl, csvReaderConfig.listEndChar, csvReaderConfig)) { + return false; + } + lvl++; // nested one more level + } else if (*input == target) { + if (target == csvReaderConfig.listEndChar) { + lvl--; + } + return true; + } + input++; + } + return false; // no corresponding closing bracket +} + +static bool isNull(std::string_view& str) { + auto start = str.data(); + auto end = start + str.length(); + skipWhitespace(start, end); + if (start == end) { + return true; + } + if (end - start >= 4 && (*start == 'N' || *start == 'n') && + (*(start + 1) == 'U' || *(start + 1) == 'u') && + (*(start + 2) == 'L' || *(start + 2) == 'l') && + (*(start + 3) == 'L' || *(start + 3) == 'l')) { + start += 4; + skipWhitespace(start, end); + if (start == end) { + return true; + } + } + return false; +} + +struct CountPartOperation { + uint64_t count = 0; + + static inline bool handleKey( + const char* start, const char* end, const CSVReaderConfig& config) { + return true; + } + inline void handleValue(const char* start, const char* end, const CSVReaderConfig& config) { + count++; + } +}; + +struct SplitStringListOperation { + SplitStringListOperation(uint64_t& offset, ValueVector* resultVector) + : offset(offset), resultVector(resultVector) {} + + uint64_t& offset; + ValueVector* resultVector; + + void handleValue(const char* start, const char* end, const CSVReaderConfig& csvReaderConfig) { + CastStringToTypes::operation(resultVector, offset, + std::string_view{start, (uint32_t)(end - start)}, csvReaderConfig); + offset++; + } +}; + +template +struct SplitStringFixedListOperation { + SplitStringFixedListOperation(uint64_t& offset, ValueVector* resultVector) + : offset(offset), resultVector(resultVector) {} + + uint64_t& offset; + ValueVector* resultVector; + + void handleValue(const char* start, const char* end, const CSVReaderConfig& csvReaderConfig) { + T value; + auto str = std::string_view{start, (uint32_t)(end - start)}; + if (str.empty() || isNull(str)) { + throw ConversionException("Cast failed. NULL is not allowed for FIXEDLIST."); + } + auto type = FixedListType::getChildType(&resultVector->dataType); + function::CastStringToTypes::operation(start, str.length(), value); + resultVector->setValue(offset, value); + offset++; + } +}; + +struct SplitStringMapOperation { + SplitStringMapOperation(uint64_t& offset, ValueVector* resultVector) + : offset(offset), resultVector(resultVector) {} + + uint64_t& offset; + ValueVector* resultVector; + + inline bool handleKey( + const char* start, const char* end, const CSVReaderConfig& csvReaderConfig) { + trimRightWhitespace(start, end); + CastStringToTypes::operation(StructVector::getFieldVector(resultVector, 0).get(), offset, + std::string_view{start, (uint32_t)(end - start)}, csvReaderConfig); + return true; + } + + inline void handleValue( + const char* start, const char* end, const CSVReaderConfig& csvReaderConfig) { + trimRightWhitespace(start, end); + CastStringToTypes::operation(StructVector::getFieldVector(resultVector, 1).get(), offset++, + std::string_view{start, (uint32_t)(end - start)}, csvReaderConfig); + } +}; + +template +static bool splitCStringList( + const char* input, uint64_t len, T& state, const CSVReaderConfig& csvReaderConfig) { + auto end = input + len; + uint64_t lvl = 1; + bool seen_value = false; + + // locate [ + skipWhitespace(input, end); + if (input == end || *input != csvReaderConfig.listBeginChar) { + return false; + } + input++; + + auto start_ptr = input; + while (input < end) { + auto ch = *input; + if (ch == csvReaderConfig.listBeginChar) { + if (!skipToClose(input, end, ++lvl, csvReaderConfig.listEndChar, csvReaderConfig)) { + return false; + } + } else if (ch == '\'' || ch == '"') { + if (!skipToCloseQuotes(input, end)) { + return false; + } + } else if (ch == '{') { + uint64_t struct_lvl = 0; + skipToClose(input, end, struct_lvl, '}', csvReaderConfig); + } else if (ch == csvReaderConfig.delimiter || ch == csvReaderConfig.listEndChar) { // split + if (ch != csvReaderConfig.listEndChar || start_ptr < input || seen_value) { + state.handleValue(start_ptr, input, csvReaderConfig); + seen_value = true; + } + if (ch == csvReaderConfig.listEndChar) { // last ] + lvl--; + break; + } + start_ptr = ++input; + continue; + } + input++; + } + skipWhitespace(++input, end); + return (input == end && lvl == 0); +} + +template +static inline void startListCast(const char* input, uint64_t len, T split, + const CSVReaderConfig& csvReaderConfig, ValueVector* vector) { + if (!splitCStringList(input, len, split, csvReaderConfig)) { + throw ConversionException("Cast failed. " + std::string{input, len} + " is not in " + + LogicalTypeUtils::dataTypeToString(vector->dataType) + " range."); + } +} + +void castStringToList(const char* input, uint64_t len, ValueVector* vector, uint64_t rowToAdd, + const CSVReaderConfig& csvReaderConfig) { + // calculate the number of elements in array + CountPartOperation state; + splitCStringList(input, len, state, csvReaderConfig); + + auto list_entry = ListVector::addList(vector, state.count); + vector->setValue(rowToAdd, list_entry); + auto listDataVector = ListVector::getDataVector(vector); + + SplitStringListOperation split{list_entry.offset, listDataVector}; + startListCast(input, len, split, csvReaderConfig, vector); +} + +static void validateNumElementsInList(uint64_t numElementsRead, const LogicalType& type) { + auto numElementsInList = FixedListType::getNumElementsInList(&type); + if (numElementsRead != numElementsInList) { + throw CopyException(stringFormat( + "Each fixed list should have fixed number of elements. Expected: {}, Actual: {}.", + numElementsInList, numElementsRead)); + } +} + +void castStringToFixedList(const char* input, uint64_t len, ValueVector* vector, uint64_t rowToAdd, + const CSVReaderConfig& csvReaderConfig) { + assert(vector->dataType.getLogicalTypeID() == LogicalTypeID::FIXED_LIST); + auto childDataType = FixedListType::getChildType(&vector->dataType); + + // calculate the number of elements in array + CountPartOperation state; + splitCStringList(input, len, state, csvReaderConfig); + validateNumElementsInList(state.count, vector->dataType); + + auto startOffset = state.count * rowToAdd; + switch (childDataType->getLogicalTypeID()) { + // TODO: currently only allow these type + case LogicalTypeID::INT64: { + SplitStringFixedListOperation split{startOffset, vector}; + startListCast(input, len, split, csvReaderConfig, vector); + } break; + case LogicalTypeID::INT32: { + SplitStringFixedListOperation split{startOffset, vector}; + startListCast(input, len, split, csvReaderConfig, vector); + } break; + case LogicalTypeID::INT16: { + SplitStringFixedListOperation split{startOffset, vector}; + startListCast(input, len, split, csvReaderConfig, vector); + } break; + case LogicalTypeID::FLOAT: { + SplitStringFixedListOperation split{startOffset, vector}; + startListCast(input, len, split, csvReaderConfig, vector); + } break; + case LogicalTypeID::DOUBLE: { + SplitStringFixedListOperation split{startOffset, vector}; + startListCast(input, len, split, csvReaderConfig, vector); + } break; + default: { + throw NotImplementedException("Unsupported data type: Driver::castStringToFixedList"); + } + } +} + +template +static bool parseKeyOrValue(const char*& input, const char* end, T& state, bool isKey, + bool& closeBracket, const CSVReaderConfig& csvReaderConfig) { + auto start = input; + uint64_t lvl = 0; + + while (input < end) { + if (*input == '"' || *input == '\'') { + if (!skipToCloseQuotes(input, end)) { + return false; + }; + } else if (*input == '{') { + if (!skipToClose(input, end, lvl, '}', csvReaderConfig)) { + return false; + } + } else if (*input == csvReaderConfig.listBeginChar) { + if (!skipToClose(input, end, lvl, csvReaderConfig.listEndChar, csvReaderConfig)) { + return false; + }; + } else if (isKey && *input == '=') { + return state.handleKey(start, input, csvReaderConfig); + } else if (!isKey && (*input == csvReaderConfig.delimiter || *input == '}')) { + state.handleValue(start, input, csvReaderConfig); + if (*input == '}') { + closeBracket = true; + } + return true; + } + input++; + } + return false; +} + +// Split map of format: {a=12,b=13} +template +static bool splitCStringMap( + const char* input, uint64_t len, T& state, const CSVReaderConfig& csvReaderConfig) { + auto end = input + len; + bool closeBracket = false; + + skipWhitespace(input, end); + if (input == end || *input != '{') { // start with { + return false; + } + skipWhitespace(++input, end); + if (input == end) { + return false; + } + if (*input == '}') { + skipWhitespace(++input, end); // empty + return input == end; + } + + while (input < end) { + if (!parseKeyOrValue(input, end, state, true, closeBracket, csvReaderConfig)) { + return false; + } + skipWhitespace(++input, end); + if (!parseKeyOrValue(input, end, state, false, closeBracket, csvReaderConfig)) { + return false; + } + skipWhitespace(++input, end); + if (closeBracket) { + return (input == end); + } + } + return false; +} + +void castStringToMap(const char* input, uint64_t len, ValueVector* vector, uint64_t rowToAdd, + const CSVReaderConfig& csvReaderConfig) { + // count the number of maps in map + CountPartOperation state; + splitCStringMap(input, len, state, csvReaderConfig); + + auto list_entry = ListVector::addList(vector, state.count); + vector->setValue(rowToAdd, list_entry); + auto structVector = ListVector::getDataVector(vector); + + SplitStringMapOperation split{list_entry.offset, structVector}; + if (!splitCStringMap(input, len, split, csvReaderConfig)) { + throw ConversionException("Cast failed. " + std::string{input, len} + " is not in " + + LogicalTypeUtils::dataTypeToString(vector->dataType) + " range."); + } +} + +static bool parseStructFieldName(const char*& input, const char* end) { + while (input < end) { + if (*input == ':') { + return true; + } + input++; + } + return false; +} + +static bool parseStructFieldValue( + const char*& input, const char* end, const CSVReaderConfig& csvReaderConfig, bool& closeBrack) { + uint64_t lvl = 0; + while (input < end) { + if (*input == '"' || *input == '\'') { + if (!skipToCloseQuotes(input, end)) { + return false; + } + } else if (*input == '{') { + if (!skipToClose(input, end, lvl, '}', csvReaderConfig)) { + return false; + } + } else if (*input == csvReaderConfig.listBeginChar) { + if (!skipToClose(input, end, ++lvl, csvReaderConfig.listEndChar, csvReaderConfig)) { + return false; + } + } else if (*input == csvReaderConfig.delimiter || *input == '}') { + if (*input == '}') { + closeBrack = true; + } + return (lvl == 0); + } + input++; + } + return false; +} + +static bool tryCastStringToStruct(const char* input, uint64_t len, ValueVector* vector, + uint64_t rowToAdd, const CSVReaderConfig& csvReaderConfig) { + // check if start with { + auto end = input + len; + auto type = vector->dataType; + skipWhitespace(input, end); + if (input == end || *input != '{') { + return false; + } + skipWhitespace(++input, end); + + if (input == end) { // no closing bracket + return false; + } + if (*input == '}') { + skipWhitespace(++input, end); + return input == end; + } + + bool closeBracket = false; + while (input < end) { + auto keyStart = input; + if (!parseStructFieldName(input, end)) { // find key + return false; + } + auto keyEnd = input; + trimRightWhitespace(keyStart, keyEnd); + auto fieldIdx = StructType::getFieldIdx(&type, std::string{keyStart, keyEnd}); + if (fieldIdx == INVALID_STRUCT_FIELD_IDX) { + throw ParserException{"Invalid struct field name: " + std::string{keyStart, keyEnd}}; + } + + skipWhitespace(++input, end); + auto valStart = input; + if (!parseStructFieldValue(input, end, csvReaderConfig, closeBracket)) { // find value + return false; + } + auto valEnd = input; + trimRightWhitespace(valStart, valEnd); + skipWhitespace(++input, end); + + CastStringToTypes::operation(StructVector::getFieldVector(vector, fieldIdx).get(), rowToAdd, + std::string_view{valStart, (uint32_t)(valEnd - valStart)}, csvReaderConfig); + + if (closeBracket) { + return (input == end); + } + } + return false; +} + +void castStringToStruct(const char* input, uint64_t len, ValueVector* vector, uint64_t rowToAdd, + const CSVReaderConfig& csvReaderConfig) { + if (!tryCastStringToStruct(input, len, vector, rowToAdd, csvReaderConfig)) { + throw ConversionException("Cast failed. " + std::string{input, len} + " is not in " + + LogicalTypeUtils::dataTypeToString(vector->dataType) + " range."); + } +} + +template +static inline void testAndSetValue(ValueVector* vector, uint64_t rowToAdd, T result, bool success) { + if (success) { + vector->setValue(rowToAdd, result); + } +} + +static bool tryCastUnionField( + ValueVector* vector, uint64_t rowToAdd, const char* input, uint64_t len) { + auto& targetType = vector->dataType; + bool success = false; + switch (targetType.getLogicalTypeID()) { + case LogicalTypeID::BOOL: { + bool result; + success = function::tryCastToBool(input, len, result); + testAndSetValue(vector, rowToAdd, result, success); + } break; + case LogicalTypeID::INT64: { + int64_t result; + success = function::trySimpleIntegerCast(input, len, result); + testAndSetValue(vector, rowToAdd, result, success); + } break; + case LogicalTypeID::INT32: { + int32_t result; + success = function::trySimpleIntegerCast(input, len, result); + testAndSetValue(vector, rowToAdd, result, success); + } break; + case LogicalTypeID::INT16: { + int16_t result; + success = function::trySimpleIntegerCast(input, len, result); + testAndSetValue(vector, rowToAdd, result, success); + } break; + case LogicalTypeID::INT8: { + int8_t result; + success = function::trySimpleIntegerCast(input, len, result); + testAndSetValue(vector, rowToAdd, result, success); + } break; + case LogicalTypeID::UINT64: { + uint64_t result; + success = function::trySimpleIntegerCast(input, len, result); + testAndSetValue(vector, rowToAdd, result, success); + } break; + case LogicalTypeID::UINT32: { + uint32_t result; + success = function::trySimpleIntegerCast(input, len, result); + testAndSetValue(vector, rowToAdd, result, success); + } break; + case LogicalTypeID::UINT16: { + uint16_t result; + success = function::trySimpleIntegerCast(input, len, result); + testAndSetValue(vector, rowToAdd, result, success); + } break; + case LogicalTypeID::UINT8: { + uint8_t result; + success = function::trySimpleIntegerCast(input, len, result); + testAndSetValue(vector, rowToAdd, result, success); + } break; + case LogicalTypeID::DOUBLE: { + double_t result; + success = function::tryDoubleCast(input, len, result); + testAndSetValue(vector, rowToAdd, result, success); + } break; + case LogicalTypeID::FLOAT: { + float_t result; + success = function::tryDoubleCast(input, len, result); + testAndSetValue(vector, rowToAdd, result, success); + } break; + case LogicalTypeID::DATE: { + date_t result; + uint64_t pos; + success = Date::tryConvertDate(input, len, pos, result); + testAndSetValue(vector, rowToAdd, result, success); + } break; + case LogicalTypeID::TIMESTAMP: { + timestamp_t result; + success = Timestamp::tryConvertTimestamp(input, len, result); + testAndSetValue(vector, rowToAdd, result, success); + } break; + case LogicalTypeID::STRING: { + storage::TableCopyUtils::validateStrLen(len); + if (!utf8proc::Utf8Proc::isValid(input, len)) { + throw CopyException{"Invalid UTF8-encoded string."}; + } + StringVector::addString(vector, rowToAdd, input, len); + return true; + } break; + default: { + return false; + } + } + return success; +} + +void castStringToUnion(const char* input, uint64_t len, ValueVector* vector, uint64_t rowToAdd) { + auto& type = vector->dataType; + union_field_idx_t selectedFieldIdx = INVALID_STRUCT_FIELD_IDX; + + for (auto i = 0u; i < UnionType::getNumFields(&type); i++) { + auto internalFieldIdx = UnionType::getInternalFieldIdx(i); + auto fieldVector = StructVector::getFieldVector(vector, internalFieldIdx).get(); + if (tryCastUnionField(fieldVector, rowToAdd, input, len)) { + fieldVector->setNull(rowToAdd, false /* isNull */); + selectedFieldIdx = i; + break; + } else { + fieldVector->setNull(rowToAdd, true /* isNull */); + } + } + + if (selectedFieldIdx == INVALID_STRUCT_FIELD_IDX) { + throw ConversionException{stringFormat("Could not convert to union type {}: {}.", + LogicalTypeUtils::dataTypeToString(type), std::string{input, len})}; + } + StructVector::getFieldVector(vector, UnionType::TAG_FIELD_IDX) + ->setValue(rowToAdd, selectedFieldIdx); + StructVector::getFieldVector(vector, UnionType::TAG_FIELD_IDX) + ->setNull(rowToAdd, false /* isNull */); +} + +template<> +void CastStringToTypes::operation(const char* input, uint64_t len, ValueVector* vector, + uint64_t rowToAdd, const CSVReaderConfig& csvReaderConfig) { + // base case: blob + storage::TableCopyUtils::validateStrLen(len); + auto blobBuffer = std::make_unique(len); + auto blobLen = Blob::fromString(input, len, blobBuffer.get()); + StringVector::addString(vector, rowToAdd, reinterpret_cast(blobBuffer.get()), blobLen); +} + +template<> +void CastStringToTypes::operation(const char* input, uint64_t len, + ValueVector* vector, uint64_t rowToAdd, const CSVReaderConfig& csvReaderConfig) { + castStringToList(input, len, vector, rowToAdd, csvReaderConfig); +} + +template<> +void CastStringToTypes::operation(const char* input, uint64_t len, ValueVector* vector, + uint64_t rowToAdd, const CSVReaderConfig& csvReaderConfig) { + castStringToMap(input, len, vector, rowToAdd, csvReaderConfig); +} + +template<> +void CastStringToTypes::operation(const char* input, uint64_t len, + ValueVector* vector, uint64_t rowToAdd, const CSVReaderConfig& csvReaderConfig) { + castStringToStruct(input, len, vector, rowToAdd, csvReaderConfig); +} + +template<> +void CastStringToTypes::operation(const char* input, uint64_t len, + ValueVector* vector, uint64_t rowToAdd, const CSVReaderConfig& csvReaderConfig) { + castStringToUnion(input, len, vector, rowToAdd); +} + +void CastStringToTypes::operation(ValueVector* vector, uint64_t rowToAdd, std::string_view strVal, + const CSVReaderConfig& csvReaderConfig) { + auto& type = vector->dataType; + + if (strVal.empty() || isNull(strVal)) { + vector->setNull(rowToAdd, true /* isNull */); + return; + } else { + vector->setNull(rowToAdd, false /* isNull */); + } + switch (type.getLogicalTypeID()) { + case LogicalTypeID::INT128: { + int128_t val; + operation(strVal.data(), strVal.length(), val); + vector->setValue(rowToAdd, val); + } break; + case LogicalTypeID::INT64: { + int64_t val; + operation(strVal.data(), strVal.length(), val); + vector->setValue(rowToAdd, val); + } break; + case LogicalTypeID::INT32: { + int32_t val; + operation(strVal.data(), strVal.length(), val); + vector->setValue(rowToAdd, val); + } break; + case LogicalTypeID::INT16: { + int16_t val; + operation(strVal.data(), strVal.length(), val); + vector->setValue(rowToAdd, val); + } break; + case LogicalTypeID::INT8: { + int8_t val; + operation(strVal.data(), strVal.length(), val); + vector->setValue(rowToAdd, val); + } break; + case LogicalTypeID::UINT64: { + uint64_t val; + operation(strVal.data(), strVal.length(), val); + vector->setValue(rowToAdd, val); + } break; + case LogicalTypeID::UINT32: { + uint32_t val; + operation(strVal.data(), strVal.length(), val); + vector->setValue(rowToAdd, val); + } break; + case LogicalTypeID::UINT16: { + uint16_t val; + operation(strVal.data(), strVal.length(), val); + vector->setValue(rowToAdd, val); + } break; + case LogicalTypeID::UINT8: { + uint8_t val; + operation(strVal.data(), strVal.length(), val); + vector->setValue(rowToAdd, val); + } break; + case LogicalTypeID::FLOAT: { + float_t val; + operation(strVal.data(), strVal.length(), val); + vector->setValue(rowToAdd, val); + } break; + case LogicalTypeID::DOUBLE: { + double_t val; + operation(strVal.data(), strVal.length(), val); + vector->setValue(rowToAdd, val); + } break; + case LogicalTypeID::BOOL: { + bool val; + operation(strVal.data(), strVal.length(), val); + vector->setValue(rowToAdd, val); + } break; + case LogicalTypeID::BLOB: { + blob_t val; + operation(strVal.data(), strVal.length(), vector, rowToAdd, csvReaderConfig); + } break; + case LogicalTypeID::STRING: { + storage::TableCopyUtils::validateStrLen(strVal.length()); + if (!utf8proc::Utf8Proc::isValid(strVal.data(), strVal.length())) { + throw CopyException{"Invalid UTF8-encoded string."}; + } + StringVector::addString(vector, rowToAdd, strVal.data(), strVal.length()); + } break; + case LogicalTypeID::DATE: { + date_t val; + operation(strVal.data(), strVal.length(), val); + vector->setValue(rowToAdd, val); + } break; + case LogicalTypeID::TIMESTAMP: { + timestamp_t val; + operation(strVal.data(), strVal.length(), val); + vector->setValue(rowToAdd, val); + } break; + case LogicalTypeID::INTERVAL: { + interval_t val; + operation(strVal.data(), strVal.length(), val); + vector->setValue(rowToAdd, val); + } break; + case LogicalTypeID::MAP: { + operation(strVal.data(), strVal.length(), vector, rowToAdd, csvReaderConfig); + } break; + case LogicalTypeID::VAR_LIST: { + list_entry_t val; + operation(strVal.data(), strVal.length(), vector, rowToAdd, csvReaderConfig); + } break; + case LogicalTypeID::FIXED_LIST: { + // TODO: add fix list function wrapper + castStringToFixedList(strVal.data(), strVal.length(), vector, rowToAdd, csvReaderConfig); + } break; + case LogicalTypeID::STRUCT: { + operation( + strVal.data(), strVal.length(), vector, rowToAdd, csvReaderConfig); + } break; + case LogicalTypeID::UNION: { + operation(strVal.data(), strVal.length(), vector, rowToAdd, csvReaderConfig); + } break; + default: { // LCOV_EXCL_START + throw NotImplementedException("CastStringToTypes::operation"); + } // LCOV_EXCL_STOP + } +} + +template<> +void CastStringToTypes::operation( + common::ku_string_t& input, common::blob_t& result, common::ValueVector& resultVector) { + result.value.len = common::Blob::getBlobSize(input); + if (!common::ku_string_t::isShortString(result.value.len)) { + auto overflowBuffer = common::StringVector::getInMemOverflowBuffer(&resultVector); + auto overflowPtr = overflowBuffer->allocateSpace(result.value.len); + result.value.overflowPtr = reinterpret_cast(overflowPtr); + common::Blob::fromString( + reinterpret_cast(input.getData()), input.len, overflowPtr); + memcpy(result.value.prefix, overflowPtr, common::ku_string_t::PREFIX_LENGTH); + } else { + common::Blob::fromString( + reinterpret_cast(input.getData()), input.len, result.value.prefix); + } +} + +} // namespace function +} // namespace kuzu diff --git a/src/function/vector_cast_functions.cpp b/src/function/vector_cast_functions.cpp index f14b20b8f14..53f8069cd96 100644 --- a/src/function/vector_cast_functions.cpp +++ b/src/function/vector_cast_functions.cpp @@ -1,7 +1,5 @@ #include "function/cast/vector_cast_functions.h" -#include "function/cast/cast_functions.h" - using namespace kuzu::common; namespace kuzu { @@ -125,19 +123,19 @@ void VectorCastFunction::bindImplicitCastFunc( } case LogicalTypeID::DATE: { assert(sourceTypeID == LogicalTypeID::STRING); - func = &UnaryExecFunction; + func = &UnaryStringExecFunction; return; } case LogicalTypeID::TIMESTAMP: { assert(sourceTypeID == LogicalTypeID::STRING || sourceTypeID == LogicalTypeID::DATE); func = sourceTypeID == LogicalTypeID::STRING ? - &UnaryExecFunction : + &UnaryStringExecFunction : &UnaryExecFunction; return; } case LogicalTypeID::INTERVAL: { assert(sourceTypeID == LogicalTypeID::STRING); - func = &UnaryExecFunction; + func = &UnaryStringExecFunction; return; } default: @@ -149,29 +147,26 @@ void VectorCastFunction::bindImplicitCastFunc( vector_function_definitions CastToDateVectorFunction::getDefinitions() { vector_function_definitions result; - result.push_back(make_unique(CAST_TO_DATE_FUNC_NAME, - std::vector{LogicalTypeID::STRING}, LogicalTypeID::DATE, - UnaryExecFunction)); + result.push_back( + bindCastToStringVectorFunction(CAST_TO_DATE_FUNC_NAME, LogicalTypeID::DATE)); return result; } vector_function_definitions CastToTimestampVectorFunction::getDefinitions() { vector_function_definitions result; - result.push_back(std::make_unique(CAST_TO_TIMESTAMP_FUNC_NAME, - std::vector{LogicalTypeID::STRING}, LogicalTypeID::TIMESTAMP, - UnaryExecFunction)); + result.push_back(bindCastToStringVectorFunction( + CAST_TO_TIMESTAMP_FUNC_NAME, LogicalTypeID::TIMESTAMP)); return result; } vector_function_definitions CastToIntervalVectorFunction::getDefinitions() { vector_function_definitions result; - result.push_back(make_unique(CAST_TO_INTERVAL_FUNC_NAME, - std::vector{LogicalTypeID::STRING}, LogicalTypeID::INTERVAL, - UnaryExecFunction)); + result.push_back(bindCastToStringVectorFunction( + CAST_TO_INTERVAL_FUNC_NAME, LogicalTypeID::INTERVAL)); return result; } -void CastToStringVectorFunction::getUnaryCastExecFunction( +void CastToStringVectorFunction::getUnaryCastToStringExecFunction( common::LogicalTypeID typeID, scalar_exec_func& func) { switch (typeID) { case common::LogicalTypeID::BOOL: { @@ -241,7 +236,7 @@ void CastToStringVectorFunction::getUnaryCastExecFunction( // LCOV_EXCL_START default: throw common::NotImplementedException{ - "CastToStringVectorFunction::getUnaryCastExecFunction"}; + "CastToStringVectorFunction::getUnaryCastToStringExecFunction"}; // LCOV_EXCL_END } } @@ -251,7 +246,7 @@ vector_function_definitions CastToStringVectorFunction::getDefinitions() { result.reserve(LogicalTypeUtils::getAllValidLogicTypes().size()); for (auto& type : LogicalTypeUtils::getAllValidLogicTypes()) { scalar_exec_func execFunc; - getUnaryCastExecFunction(type.getLogicalTypeID(), execFunc); + getUnaryCastToStringExecFunction(type.getLogicalTypeID(), execFunc); auto definition = std::make_unique(CAST_TO_STRING_FUNC_NAME, std::vector{type.getLogicalTypeID()}, LogicalTypeID::STRING, execFunc); result.push_back(std::move(definition)); @@ -261,150 +256,147 @@ vector_function_definitions CastToStringVectorFunction::getDefinitions() { vector_function_definitions CastToBlobVectorFunction::getDefinitions() { vector_function_definitions result; - result.push_back(make_unique(CAST_TO_BLOB_FUNC_NAME, - std::vector{LogicalTypeID::STRING}, LogicalTypeID::BLOB, - UnaryCastExecFunction)); + result.push_back( + bindCastToStringVectorFunction(CAST_TO_BLOB_FUNC_NAME, LogicalTypeID::BLOB)); return result; } vector_function_definitions CastToBoolVectorFunction::getDefinitions() { vector_function_definitions result; - result.push_back(make_unique(CAST_TO_BOOL_FUNC_NAME, - std::vector{LogicalTypeID::STRING}, LogicalTypeID::BOOL, - UnaryExecFunction)); + result.push_back( + bindCastToStringVectorFunction(CAST_TO_BOOL_FUNC_NAME, LogicalTypeID::BOOL)); return result; } vector_function_definitions CastToDoubleVectorFunction::getDefinitions() { vector_function_definitions result; for (auto typeID : LogicalTypeUtils::getNumericalLogicalTypeIDs()) { - result.push_back(bindVectorFunction( + result.push_back(bindNumericCastVectorFunction( CAST_TO_DOUBLE_FUNC_NAME, typeID, LogicalTypeID::DOUBLE)); } - result.push_back(bindVectorFunction( - CAST_TO_DOUBLE_FUNC_NAME, LogicalTypeID::STRING, LogicalTypeID::DOUBLE)); + result.push_back( + bindCastToStringVectorFunction(CAST_TO_DOUBLE_FUNC_NAME, LogicalTypeID::DOUBLE)); return result; } vector_function_definitions CastToFloatVectorFunction::getDefinitions() { vector_function_definitions result; for (auto typeID : LogicalTypeUtils::getNumericalLogicalTypeIDs()) { - result.push_back(bindVectorFunction( + result.push_back(bindNumericCastVectorFunction( CAST_TO_FLOAT_FUNC_NAME, typeID, LogicalTypeID::FLOAT)); } - result.push_back(bindVectorFunction( - CAST_TO_FLOAT_FUNC_NAME, LogicalTypeID::STRING, LogicalTypeID::FLOAT)); + result.push_back( + bindCastToStringVectorFunction(CAST_TO_FLOAT_FUNC_NAME, LogicalTypeID::FLOAT)); + return result; +} + +vector_function_definitions CastToInt128VectorFunction::getDefinitions() { + vector_function_definitions result; + for (auto typeID : LogicalTypeUtils::getNumericalLogicalTypeIDs()) { + result.push_back(bindNumericCastVectorFunction( + CAST_TO_INT128_FUNC_NAME, typeID, LogicalTypeID::INT128)); + } + result.push_back( + bindCastToStringVectorFunction(CAST_TO_INT128_FUNC_NAME, LogicalTypeID::INT128)); return result; } vector_function_definitions CastToSerialVectorFunction::getDefinitions() { vector_function_definitions result; for (auto typeID : LogicalTypeUtils::getNumericalLogicalTypeIDs()) { - result.push_back(bindVectorFunction( + result.push_back(bindNumericCastVectorFunction( CAST_TO_SERIAL_FUNC_NAME, typeID, LogicalTypeID::SERIAL)); } - result.push_back(bindVectorFunction( - CAST_TO_SERIAL_FUNC_NAME, LogicalTypeID::STRING, LogicalTypeID::SERIAL)); + result.push_back( + bindCastToStringVectorFunction(CAST_TO_SERIAL_FUNC_NAME, LogicalTypeID::INT64)); return result; } vector_function_definitions CastToInt64VectorFunction::getDefinitions() { vector_function_definitions result; for (auto typeID : LogicalTypeUtils::getNumericalLogicalTypeIDs()) { - result.push_back(bindVectorFunction( + result.push_back(bindNumericCastVectorFunction( CAST_TO_INT64_FUNC_NAME, typeID, LogicalTypeID::INT64)); } - result.push_back(bindVectorFunction( - CAST_TO_INT64_FUNC_NAME, LogicalTypeID::STRING, LogicalTypeID::INT64)); + result.push_back( + bindCastToStringVectorFunction(CAST_TO_INT64_FUNC_NAME, LogicalTypeID::INT64)); return result; } vector_function_definitions CastToInt32VectorFunction::getDefinitions() { vector_function_definitions result; for (auto typeID : LogicalTypeUtils::getNumericalLogicalTypeIDs()) { - result.push_back(bindVectorFunction( + result.push_back(bindNumericCastVectorFunction( CAST_TO_INT32_FUNC_NAME, typeID, LogicalTypeID::INT32)); } - result.push_back(bindVectorFunction( - CAST_TO_INT32_FUNC_NAME, LogicalTypeID::STRING, LogicalTypeID::INT32)); + result.push_back( + bindCastToStringVectorFunction(CAST_TO_INT32_FUNC_NAME, LogicalTypeID::INT32)); return result; } vector_function_definitions CastToInt16VectorFunction::getDefinitions() { vector_function_definitions result; for (auto typeID : LogicalTypeUtils::getNumericalLogicalTypeIDs()) { - result.push_back(bindVectorFunction( + result.push_back(bindNumericCastVectorFunction( CAST_TO_INT16_FUNC_NAME, typeID, LogicalTypeID::INT16)); } - result.push_back(bindVectorFunction( - CAST_TO_INT16_FUNC_NAME, LogicalTypeID::STRING, LogicalTypeID::INT16)); + result.push_back( + bindCastToStringVectorFunction(CAST_TO_INT16_FUNC_NAME, LogicalTypeID::INT16)); return result; } vector_function_definitions CastToInt8VectorFunction::getDefinitions() { vector_function_definitions result; for (auto typeID : LogicalTypeUtils::getNumericalLogicalTypeIDs()) { - result.push_back(bindVectorFunction( + result.push_back(bindNumericCastVectorFunction( CAST_TO_INT8_FUNC_NAME, typeID, LogicalTypeID::INT8)); } - result.push_back(bindVectorFunction( - CAST_TO_INT8_FUNC_NAME, LogicalTypeID::STRING, LogicalTypeID::INT8)); + result.push_back( + bindCastToStringVectorFunction(CAST_TO_INT8_FUNC_NAME, LogicalTypeID::INT8)); return result; } vector_function_definitions CastToUInt64VectorFunction::getDefinitions() { vector_function_definitions result; for (auto typeID : LogicalTypeUtils::getNumericalLogicalTypeIDs()) { - result.push_back(bindVectorFunction( + result.push_back(bindNumericCastVectorFunction( CAST_TO_UINT64_FUNC_NAME, typeID, LogicalTypeID::UINT64)); } - result.push_back(bindVectorFunction( - CAST_TO_UINT64_FUNC_NAME, LogicalTypeID::STRING, LogicalTypeID::UINT64)); + result.push_back( + bindCastToStringVectorFunction(CAST_TO_UINT64_FUNC_NAME, LogicalTypeID::UINT64)); return result; } vector_function_definitions CastToUInt32VectorFunction::getDefinitions() { vector_function_definitions result; for (auto typeID : LogicalTypeUtils::getNumericalLogicalTypeIDs()) { - result.push_back(bindVectorFunction( + result.push_back(bindNumericCastVectorFunction( CAST_TO_UINT32_FUNC_NAME, typeID, LogicalTypeID::UINT32)); } - result.push_back(bindVectorFunction( - CAST_TO_UINT32_FUNC_NAME, LogicalTypeID::STRING, LogicalTypeID::UINT32)); + result.push_back( + bindCastToStringVectorFunction(CAST_TO_UINT32_FUNC_NAME, LogicalTypeID::UINT32)); return result; } vector_function_definitions CastToUInt16VectorFunction::getDefinitions() { vector_function_definitions result; for (auto typeID : LogicalTypeUtils::getNumericalLogicalTypeIDs()) { - result.push_back(bindVectorFunction( + result.push_back(bindNumericCastVectorFunction( CAST_TO_UINT16_FUNC_NAME, typeID, LogicalTypeID::UINT16)); } - result.push_back(bindVectorFunction( - CAST_TO_UINT16_FUNC_NAME, LogicalTypeID::STRING, LogicalTypeID::UINT16)); + result.push_back( + bindCastToStringVectorFunction(CAST_TO_UINT16_FUNC_NAME, LogicalTypeID::UINT16)); return result; } vector_function_definitions CastToUInt8VectorFunction::getDefinitions() { vector_function_definitions result; for (auto typeID : LogicalTypeUtils::getNumericalLogicalTypeIDs()) { - result.push_back(bindVectorFunction( + result.push_back(bindNumericCastVectorFunction( CAST_TO_UINT8_FUNC_NAME, typeID, LogicalTypeID::UINT8)); } - result.push_back(bindVectorFunction( - CAST_TO_UINT8_FUNC_NAME, LogicalTypeID::STRING, LogicalTypeID::UINT8)); - return result; -} - -vector_function_definitions CastToInt128VectorFunction::getDefinitions() { - vector_function_definitions result; - // down cast - for (auto typeID : LogicalTypeUtils::getNumericalLogicalTypeIDs()) { - result.push_back(bindVectorFunction( - CAST_TO_INT128_FUNC_NAME, typeID, LogicalTypeID::INT128)); - } - result.push_back(bindVectorFunction( - CAST_TO_INT128_FUNC_NAME, LogicalTypeID::STRING, LogicalTypeID::INT128)); + result.push_back( + bindCastToStringVectorFunction(CAST_TO_UINT8_FUNC_NAME, LogicalTypeID::UINT8)); return result; } diff --git a/src/function/vector_union_functions.cpp b/src/function/vector_union_functions.cpp index 51c5be21144..d0dcbd63769 100644 --- a/src/function/vector_union_functions.cpp +++ b/src/function/vector_union_functions.cpp @@ -49,7 +49,7 @@ vector_function_definitions UnionTagVectorFunction::getDefinitions() { vector_function_definitions definitions; definitions.push_back(make_unique(UNION_TAG_FUNC_NAME, std::vector{LogicalTypeID::UNION}, LogicalTypeID::STRING, - UnaryExecListStructFunction, nullptr, nullptr, + UnaryExecListStructFunction, nullptr, nullptr, false /* isVarLength */)); return definitions; } diff --git a/src/include/common/type_utils.h b/src/include/common/type_utils.h index 4aef8845e67..909eccb8df8 100644 --- a/src/include/common/type_utils.h +++ b/src/include/common/type_utils.h @@ -24,7 +24,7 @@ class TypeUtils { std::is_same::value || std::is_same::value); return std::to_string(val); } - static inline std::string toString(int128_t val) { return Int128_t::ToString(val); } + static inline void encodeOverflowPtr( uint64_t& overflowPtr, page_idx_t pageIdx, uint16_t pageOffset) { memcpy(&overflowPtr, &pageIdx, 4); @@ -44,6 +44,11 @@ class TypeUtils { const LogicalType& dataType, const uint8_t* value, void* vector); }; +template<> +inline std::string TypeUtils::toString(const int128_t& val, void* valueVector) { + return Int128_t::ToString(val); +} + // Forward declaration of template specializations. template<> std::string TypeUtils::toString(const bool& val, void* valueVector); diff --git a/src/include/common/types/types.h b/src/include/common/types/types.h index 7749d5ed15f..8a634f4dd15 100644 --- a/src/include/common/types/types.h +++ b/src/include/common/types/types.h @@ -68,7 +68,9 @@ struct map_entry_t { list_entry_t entry; }; -using union_entry_t = struct_entry_t; +struct union_entry_t { + int64_t pos; +}; enum class KUZU_API LogicalTypeID : uint8_t { ANY = 0, diff --git a/src/include/function/cast/cast_functions.h b/src/include/function/cast/cast_functions.h deleted file mode 100644 index 2735f0744af..00000000000 --- a/src/include/function/cast/cast_functions.h +++ /dev/null @@ -1,396 +0,0 @@ -#pragma once - -#include - -#include "cast_utils.h" -#include "common/exception/runtime.h" -#include "common/string_format.h" -#include "common/type_utils.h" -#include "common/types/blob.h" -#include "common/vector/value_vector.h" - -namespace kuzu { -namespace function { - -struct CastStringToDate { - static inline void operation(common::ku_string_t& input, common::date_t& result) { - result = common::Date::fromCString((const char*)input.getData(), input.len); - } -}; - -struct CastStringToTimestamp { - static inline void operation(common::ku_string_t& input, common::timestamp_t& result) { - result = common::Timestamp::fromCString((const char*)input.getData(), input.len); - } -}; - -struct CastStringToInterval { - static inline void operation(common::ku_string_t& input, common::interval_t& result) { - result = common::Interval::fromCString((const char*)input.getData(), input.len); - } -}; - -struct CastToString { - template - static inline std::string castToString(T& input, const common::ValueVector& inputVector) { - return common::TypeUtils::toString(input, (void*)&inputVector); - } - - template - static inline void operation(T& input, common::ku_string_t& result, - common::ValueVector& inputVector, common::ValueVector& resultVector) { - std::string resultStr = castToString(input, inputVector); - if (resultStr.length() > common::ku_string_t::SHORT_STR_LENGTH) { - result.overflowPtr = reinterpret_cast( - common::StringVector::getInMemOverflowBuffer(&resultVector) - ->allocateSpace(resultStr.length())); - } - result.set(resultStr); - } -}; - -template<> -inline std::string CastToString::castToString( - common::int128_t& input, const common::ValueVector& /*inputVector*/) { - return common::Int128_t::ToString(input); -} - -struct CastToBlob { - static inline void operation(common::ku_string_t& input, common::blob_t& result, - common::ValueVector& /*inputVector*/, common::ValueVector& resultVector) { - result.value.len = common::Blob::getBlobSize(input); - if (!common::ku_string_t::isShortString(result.value.len)) { - auto overflowBuffer = common::StringVector::getInMemOverflowBuffer(&resultVector); - auto overflowPtr = overflowBuffer->allocateSpace(result.value.len); - result.value.overflowPtr = reinterpret_cast(overflowPtr); - common::Blob::fromString( - reinterpret_cast(input.getData()), input.len, overflowPtr); - memcpy(result.value.prefix, overflowPtr, common::ku_string_t::PREFIX_LENGTH); - } else { - common::Blob::fromString( - reinterpret_cast(input.getData()), input.len, result.value.prefix); - } - } -}; - -struct CastDateToTimestamp { - static inline void operation(common::date_t& input, common::timestamp_t& result) { - result = common::Timestamp::fromDateTime(input, common::dtime_t{}); - } -}; - -struct CastToBool { - static inline void operation(common::ku_string_t& input, bool& result) { - if (!tryCastToBool(reinterpret_cast(input.getData()), input.len, result)) { - throw common::ConversionException{ - common::stringFormat("Value {} is not a valid boolean", input.getAsString())}; - } - } -}; - -struct CastToDouble { - template - static inline void operation(T& input, double_t& result) { - if (!tryCastWithOverflowCheck(input, result)) { - throw common::RuntimeException{common::stringFormat( - "Value {} is not within DOUBLE range", common::TypeUtils::toString(input).c_str())}; - } - } -}; - -template<> -inline void CastToDouble::operation(common::int128_t& input, double_t& result) { - common::Int128_t::tryCast(input, result); -} - -template<> -inline void CastToDouble::operation(char*& input, double_t& result) { - doubleCast( - input, strlen(input), result, common::LogicalType{common::LogicalTypeID::DOUBLE}); -} - -template<> -inline void CastToDouble::operation(common::ku_string_t& input, double_t& result) { - doubleCast((char*)input.getData(), input.len, result, - common::LogicalType{common::LogicalTypeID::DOUBLE}); -} - -struct CastToFloat { - template - static inline void operation(T& input, float_t& result) { - if (!tryCastWithOverflowCheck(input, result)) { - throw common::RuntimeException{common::stringFormat( - "Value {} is not within FLOAT range", common::TypeUtils::toString(input).c_str())}; - } - } -}; - -template<> -inline void CastToFloat::operation(char*& input, float_t& result) { - doubleCast( - input, strlen(input), result, common::LogicalType{common::LogicalTypeID::FLOAT}); -} - -template<> -inline void CastToFloat::operation(common::ku_string_t& input, float_t& result) { - doubleCast((char*)input.getData(), input.len, result, - common::LogicalType{common::LogicalTypeID::FLOAT}); -} - -template<> -inline void CastToFloat::operation(common::int128_t& input, float_t& result) { - common::Int128_t::tryCast(input, result); -} - -struct CastToInt64 { - template - static inline void operation(T& input, int64_t& result) { - if (!tryCastWithOverflowCheck(input, result)) { - throw common::RuntimeException{common::stringFormat( - "Value {} is not within INT64 range", common::TypeUtils::toString(input).c_str())}; - } - } -}; - -template<> -inline void CastToInt64::operation(char*& input, int64_t& result) { - simpleIntegerCast( - input, strlen(input), result, common::LogicalType{common::LogicalTypeID::INT64}); -} - -template<> -inline void CastToInt64::operation(common::ku_string_t& input, int64_t& result) { - simpleIntegerCast((char*)input.getData(), input.len, result, - common::LogicalType{common::LogicalTypeID::INT64}); -} - -template<> -inline void CastToInt64::operation(common::int128_t& input, int64_t& result) { - common::Int128_t::tryCast(input, result); -} - -struct CastToSerial { - template - static inline void operation(T& input, int64_t& result) { - if (!tryCastWithOverflowCheck(input, result)) { - throw common::RuntimeException{common::stringFormat( - "Value {} is not within INT64 range", common::TypeUtils::toString(input).c_str())}; - } - } -}; - -template<> -inline void CastToSerial::operation(common::ku_string_t& input, int64_t& result) { - simpleIntegerCast((char*)input.getData(), input.len, result, - common::LogicalType{common::LogicalTypeID::INT64}); -} - -struct CastToInt32 { - template - static inline void operation(T& input, int32_t& result) { - if (!tryCastWithOverflowCheck(input, result)) { - throw common::RuntimeException{common::stringFormat( - "Value {} is not within INT32 range", common::TypeUtils::toString(input).c_str())}; - } - } -}; - -template<> -inline void CastToInt32::operation(char*& input, int32_t& result) { - simpleIntegerCast( - input, strlen(input), result, common::LogicalType{common::LogicalTypeID::INT32}); -} - -template<> -inline void CastToInt32::operation(common::ku_string_t& input, int32_t& result) { - simpleIntegerCast((char*)input.getData(), input.len, result, - common::LogicalType{common::LogicalTypeID::INT32}); -} - -template<> -inline void CastToInt32::operation(common::int128_t& input, int32_t& result) { - common::Int128_t::tryCast(input, result); -} - -struct CastToInt16 { - template - static inline void operation(T& input, int16_t& result) { - if (!tryCastWithOverflowCheck(input, result)) { - throw common::RuntimeException{common::stringFormat( - "Value {} is not within INT16 range", common::TypeUtils::toString(input).c_str())}; - } - } -}; - -template<> -inline void CastToInt16::operation(common::ku_string_t& input, int16_t& result) { - simpleIntegerCast((char*)input.getData(), input.len, result, - common::LogicalType{common::LogicalTypeID::INT16}); -} - -template<> -inline void CastToInt16::operation(char*& input, int16_t& result) { - simpleIntegerCast( - input, strlen(input), result, common::LogicalType{common::LogicalTypeID::INT16}); -} - -template<> -inline void CastToInt16::operation(common::int128_t& input, int16_t& result) { - common::Int128_t::tryCast(input, result); -} - -struct CastToInt8 { - template - static inline void operation(T& input, int8_t& result) { - if (!tryCastWithOverflowCheck(input, result)) { - throw common::RuntimeException{common::stringFormat( - "Value {} is not within INT8 range", common::TypeUtils::toString(input).c_str())}; - } - } -}; - -template<> -inline void CastToInt8::operation(common::ku_string_t& input, int8_t& result) { - simpleIntegerCast((char*)input.getData(), input.len, result, - common::LogicalType{common::LogicalTypeID::INT8}); -} - -template<> -inline void CastToInt8::operation(char*& input, int8_t& result) { - simpleIntegerCast( - input, strlen(input), result, common::LogicalType{common::LogicalTypeID::INT8}); -} - -template<> -inline void CastToInt8::operation(common::int128_t& input, int8_t& result) { - common::Int128_t::tryCast(input, result); -} - -struct CastToUInt64 { - template - static inline void operation(T& input, uint64_t& result) { - if (!tryCastWithOverflowCheck(input, result)) { - throw common::RuntimeException{common::stringFormat( - "Value {} is not within UINT64 range", common::TypeUtils::toString(input).c_str())}; - } - } -}; - -template<> -inline void CastToUInt64::operation(common::ku_string_t& input, uint64_t& result) { - simpleIntegerCast((char*)input.getData(), input.len, result, - common::LogicalType{common::LogicalTypeID::UINT64}); -} - -template<> -inline void CastToUInt64::operation(char*& input, uint64_t& result) { - simpleIntegerCast( - input, strlen(input), result, common::LogicalType{common::LogicalTypeID::UINT64}); -} - -template<> -inline void CastToUInt64::operation(common::int128_t& input, uint64_t& result) { - common::Int128_t::tryCast(input, result); -} - -struct CastToUInt32 { - template - static inline void operation(T& input, uint32_t& result) { - if (!tryCastWithOverflowCheck(input, result)) { - throw common::RuntimeException{common::stringFormat( - "Value {} is not within UINT32 range", common::TypeUtils::toString(input).c_str())}; - } - } -}; - -template<> -inline void CastToUInt32::operation(common::ku_string_t& input, uint32_t& result) { - simpleIntegerCast((char*)input.getData(), input.len, result, - common::LogicalType{common::LogicalTypeID::UINT32}); -} - -template<> -inline void CastToUInt32::operation(char*& input, uint32_t& result) { - simpleIntegerCast( - input, strlen(input), result, common::LogicalType{common::LogicalTypeID::UINT32}); -} - -template<> -inline void CastToUInt32::operation(common::int128_t& input, uint32_t& result) { - common::Int128_t::tryCast(input, result); -} - -struct CastToUInt16 { - template - static inline void operation(T& input, uint16_t& result) { - if (!tryCastWithOverflowCheck(input, result)) { - throw common::RuntimeException{common::stringFormat( - "Value {} is not within UINT16 range", common::TypeUtils::toString(input).c_str())}; - } - } -}; - -template<> -inline void CastToUInt16::operation(common::ku_string_t& input, uint16_t& result) { - simpleIntegerCast((char*)input.getData(), input.len, result, - common::LogicalType{common::LogicalTypeID::UINT16}); -} - -template<> -inline void CastToUInt16::operation(char*& input, uint16_t& result) { - simpleIntegerCast( - input, strlen(input), result, common::LogicalType{common::LogicalTypeID::UINT16}); -} - -template<> -inline void CastToUInt16::operation(common::int128_t& input, uint16_t& result) { - common::Int128_t::tryCast(input, result); -} - -struct CastToUInt8 { - template - static inline void operation(T& input, uint8_t& result) { - if (!tryCastWithOverflowCheck(input, result)) { - throw common::RuntimeException{common::stringFormat( - "Value {} is not within UINT8 range", common::TypeUtils::toString(input).c_str())}; - } - } -}; - -template<> -inline void CastToUInt8::operation(common::ku_string_t& input, uint8_t& result) { - simpleIntegerCast((char*)input.getData(), input.len, result, - common::LogicalType{common::LogicalTypeID::UINT8}); -} - -template<> -inline void CastToUInt8::operation(char*& input, uint8_t& result) { - simpleIntegerCast( - input, strlen(input), result, common::LogicalType{common::LogicalTypeID::UINT8}); -} - -template<> -inline void CastToUInt8::operation(common::int128_t& input, uint8_t& result) { - common::Int128_t::tryCast(input, result); -} - -struct CastToInt128 { - template - static inline void operation(T& input, common::int128_t& result) { - common::Int128_t::tryCastTo(input, result); - } -}; - -template<> -inline void CastToInt128::operation(common::ku_string_t& input, common::int128_t& result) { - auto data = (char*)input.getData(); - simpleInt128Cast(data, input.len, result); -} - -template<> -inline void CastToInt128::operation(const char*& input, common::int128_t& result) { - simpleInt128Cast(input, strlen(input), result); -} - -} // namespace function -} // namespace kuzu diff --git a/src/include/function/cast/functions/cast_functions.h b/src/include/function/cast/functions/cast_functions.h new file mode 100644 index 00000000000..4a9a3a4f0a4 --- /dev/null +++ b/src/include/function/cast/functions/cast_functions.h @@ -0,0 +1,217 @@ +#pragma once + +#include + +#include "cast_string_non_nested_functions.h" +#include "common/exception/conversion.h" +#include "common/exception/overflow.h" +#include "common/string_format.h" +#include "common/type_utils.h" +#include "common/types/blob.h" +#include "common/vector/value_vector.h" + +namespace kuzu { +namespace function { + +struct CastToString { + template + static inline std::string castToString(T& input, const common::ValueVector& inputVector) { + return common::TypeUtils::toString(input, (void*)&inputVector); + } + + template + static inline void operation(T& input, common::ku_string_t& result, + common::ValueVector& inputVector, common::ValueVector& resultVector) { + std::string resultStr = castToString(input, inputVector); + if (resultStr.length() > common::ku_string_t::SHORT_STR_LENGTH) { + result.overflowPtr = reinterpret_cast( + common::StringVector::getInMemOverflowBuffer(&resultVector) + ->allocateSpace(resultStr.length())); + } + result.set(resultStr); + } +}; + +struct CastDateToTimestamp { + static inline void operation(common::date_t& input, common::timestamp_t& result) { + result = common::Timestamp::fromDateTime(input, common::dtime_t{}); + } +}; + +struct CastToDouble { + template + static inline void operation(T& input, double_t& result) { + if (!tryCastWithOverflowCheck(input, result)) { + throw common::OverflowException{common::stringFormat( + "Value {} is not within DOUBLE range", common::TypeUtils::toString(input).c_str())}; + } + } +}; + +template<> +inline void CastToDouble::operation(common::int128_t& input, double_t& result) { + common::Int128_t::tryCast(input, result); +} + +struct CastToFloat { + template + static inline void operation(T& input, float_t& result) { + if (!tryCastWithOverflowCheck(input, result)) { + throw common::OverflowException{common::stringFormat( + "Value {} is not within FLOAT range", common::TypeUtils::toString(input).c_str())}; + } + } +}; + +template<> +inline void CastToFloat::operation(common::int128_t& input, float_t& result) { + common::Int128_t::tryCast(input, result); +} + +struct CastToInt128 { + template + static inline void operation(T& input, common::int128_t& result) { + if (!common::Int128_t::tryCastTo(input, result)) { + throw common::ConversionException{common::stringFormat( + "Value {} is not within INT64 range", common::TypeUtils::toString(input).c_str())}; + } + } +}; + +struct CastToInt64 { + template + static inline void operation(T& input, int64_t& result) { + if (!tryCastWithOverflowCheck(input, result)) { + throw common::OverflowException{common::stringFormat( + "Value {} is not within INT64 range", common::TypeUtils::toString(input).c_str())}; + } + } +}; + +template<> +inline void CastToInt64::operation(common::int128_t& input, int64_t& result) { + common::Int128_t::tryCast(input, result); +} + +struct CastToSerial { + template + static inline void operation(T& input, int64_t& result) { + if (!tryCastWithOverflowCheck(input, result)) { + throw common::OverflowException{common::stringFormat( + "Value {} is not within INT64 range", common::TypeUtils::toString(input).c_str())}; + } + } +}; + +template<> +inline void CastToSerial::operation(common::int128_t& input, int64_t& result) { + common::Int128_t::tryCast(input, result); +} + +struct CastToInt32 { + template + static inline void operation(T& input, int32_t& result) { + if (!tryCastWithOverflowCheck(input, result)) { + throw common::OverflowException{common::stringFormat( + "Value {} is not within INT32 range", common::TypeUtils::toString(input).c_str())}; + } + } +}; + +template<> +inline void CastToInt32::operation(common::int128_t& input, int32_t& result) { + common::Int128_t::tryCast(input, result); +} + +struct CastToInt16 { + template + static inline void operation(T& input, int16_t& result) { + if (!tryCastWithOverflowCheck(input, result)) { + throw common::OverflowException{common::stringFormat( + "Value {} is not within INT16 range", common::TypeUtils::toString(input).c_str())}; + } + } +}; + +template<> +inline void CastToInt16::operation(common::int128_t& input, int16_t& result) { + common::Int128_t::tryCast(input, result); +} + +struct CastToInt8 { + template + static inline void operation(T& input, int8_t& result) { + if (!tryCastWithOverflowCheck(input, result)) { + throw common::OverflowException{common::stringFormat( + "Value {} is not within INT8 range", common::TypeUtils::toString(input).c_str())}; + } + } +}; + +template<> +inline void CastToInt8::operation(common::int128_t& input, int8_t& result) { + common::Int128_t::tryCast(input, result); +} + +struct CastToUInt64 { + template + static inline void operation(T& input, uint64_t& result) { + if (!tryCastWithOverflowCheck(input, result)) { + throw common::OverflowException{common::stringFormat( + "Value {} is not within UINT64 range", common::TypeUtils::toString(input).c_str())}; + } + } +}; + +template<> +inline void CastToUInt64::operation(common::int128_t& input, uint64_t& result) { + common::Int128_t::tryCast(input, result); +} + +struct CastToUInt32 { + template + static inline void operation(T& input, uint32_t& result) { + if (!tryCastWithOverflowCheck(input, result)) { + throw common::OverflowException{common::stringFormat( + "Value {} is not within UINT32 range", common::TypeUtils::toString(input).c_str())}; + } + } +}; + +template<> +inline void CastToUInt32::operation(common::int128_t& input, uint32_t& result) { + common::Int128_t::tryCast(input, result); +} + +struct CastToUInt16 { + template + static inline void operation(T& input, uint16_t& result) { + if (!tryCastWithOverflowCheck(input, result)) { + throw common::OverflowException{common::stringFormat( + "Value {} is not within UINT16 range", common::TypeUtils::toString(input).c_str())}; + } + } +}; + +template<> +inline void CastToUInt16::operation(common::int128_t& input, uint16_t& result) { + common::Int128_t::tryCast(input, result); +} + +struct CastToUInt8 { + template + static inline void operation(T& input, uint8_t& result) { + if (!tryCastWithOverflowCheck(input, result)) { + throw common::OverflowException{common::stringFormat( + "Value {} is not within UINT8 range", common::TypeUtils::toString(input).c_str())}; + } + } +}; + +template<> +inline void CastToUInt8::operation(common::int128_t& input, uint8_t& result) { + common::Int128_t::tryCast(input, result); +} + +} // namespace function +} // namespace kuzu diff --git a/src/include/function/cast/cast_utils.h b/src/include/function/cast/functions/cast_string_non_nested_functions.h similarity index 50% rename from src/include/function/cast/cast_utils.h rename to src/include/function/cast/functions/cast_string_non_nested_functions.h index 7e132f3b1f3..0097fc4eb08 100644 --- a/src/include/function/cast/cast_utils.h +++ b/src/include/function/cast/functions/cast_string_non_nested_functions.h @@ -3,14 +3,13 @@ #include #include "common/exception/conversion.h" -#include "common/exception/overflow.h" #include "common/string_utils.h" #include "common/type_utils.h" #include "common/types/int128_t.h" #include "common/types/ku_string.h" #include "common/vector/value_vector.h" #include "fast_float.h" -#include "numeric_limits.h" +#include "numeric_cast.h" namespace kuzu { namespace function { @@ -39,8 +38,8 @@ struct IntegerCastOperation { } return true; } - // TODO: handle decimals + // TODO: handle decimals template static bool finalize(T& /*state*/) { return true; @@ -101,6 +100,7 @@ static bool tryIntegerCast(const char* input, uint64_t& len, T& result) { if (len > 1 && *input == '0') { return false; } + return integerCastLoop(input, len, result); } @@ -174,7 +174,7 @@ static bool trySimpleInt128Cast(const char* input, uint64_t len, common::int128_ static void simpleInt128Cast(const char* input, uint64_t len, common::int128_t& result) { if (!trySimpleInt128Cast(input, len, result)) { - throw common::OverflowException( + throw common::ConversionException( "Cast failed. " + std::string{input, len} + " is not within INT128 range."); } } @@ -230,215 +230,5 @@ static void doubleCast(const char* input, uint64_t len, T& result, } } -template -static inline T castStringToNum(const char* input, uint64_t len, - const common::LogicalType& type = common::LogicalType{common::LogicalTypeID::ANY}) { - T result; - simpleIntegerCast(input, len, result, type); - return result; -} - -template<> -inline common::int128_t castStringToNum(const char* input, uint64_t len, - const common::LogicalType& /*type*/) { // NOLINT(misc-unused-parameters): False positive - common::int128_t result{}; - simpleInt128Cast(input, len, result); - return result; -} - -template<> -inline uint64_t castStringToNum(const char* input, uint64_t len, const common::LogicalType& type) { - uint64_t result; - simpleIntegerCast(input, len, result, type); - return result; -} - -template<> -uint32_t castStringToNum(const char* input, uint64_t len, const common::LogicalType& type) { - uint32_t result; - simpleIntegerCast(input, len, result, type); - return result; -} - -template<> -uint16_t castStringToNum(const char* input, uint64_t len, const common::LogicalType& type) { - uint16_t result; - simpleIntegerCast(input, len, result, type); - return result; -} - -template<> -uint8_t castStringToNum(const char* input, uint64_t len, const common::LogicalType& type) { - uint8_t result; - simpleIntegerCast(input, len, result, type); - return result; -} - -template<> -inline double_t castStringToNum(const char* input, uint64_t len, const common::LogicalType& type) { - double_t result; - doubleCast(input, len, result, type); - return result; -} - -template<> -inline float_t castStringToNum(const char* input, uint64_t len, const common::LogicalType& type) { - float_t result; - doubleCast(input, len, result, type); - return result; -} - -template -static bool tryCastWithOverflowCheck(SRC value, DST& result) { - if (NumericLimits::isSigned() != NumericLimits::isSigned()) { - if (NumericLimits::isSigned()) { - if (NumericLimits::digits() > NumericLimits::digits()) { - if (value < 0 || value > (SRC)NumericLimits::maximum()) { - return false; - } - } else { - if (value < 0) { - return false; - } - } - result = (DST)value; - return true; - } else { - // unsigned to signed conversion - if (NumericLimits::digits() >= NumericLimits::digits()) { - if (value <= (SRC)NumericLimits::maximum()) { - result = (DST)value; - return true; - } - return false; - } else { - result = (DST)value; - return true; - } - } - } else { - // same sign conversion - if (NumericLimits::digits() >= NumericLimits::digits()) { - result = (DST)value; - return true; - } else { - if (value < SRC(NumericLimits::minimum()) || - value > SRC(NumericLimits::maximum())) { - return false; - } - result = (DST)value; - return true; - } - } -} - -template -bool tryCastWithOverflowCheckFloat(SRC value, T& result, SRC min, SRC max) { - if (!(value >= min && value < max)) { - return false; - } - // PG FLOAT => INT casts use statistical rounding. - result = std::nearbyint(value); - return true; -} - -template<> -bool tryCastWithOverflowCheck(float value, int8_t& result) { - return tryCastWithOverflowCheckFloat(value, result, -128.0f, 128.0f); -} - -template<> -bool tryCastWithOverflowCheck(float value, int16_t& result) { - return tryCastWithOverflowCheckFloat(value, result, -32768.0f, 32768.0f); -} - -template<> -bool tryCastWithOverflowCheck(float value, int32_t& result) { - return tryCastWithOverflowCheckFloat( - value, result, -2147483648.0f, 2147483648.0f); -} - -template<> -bool tryCastWithOverflowCheck(float value, int64_t& result) { - return tryCastWithOverflowCheckFloat( - value, result, -9223372036854775808.0f, 9223372036854775808.0f); -} - -template<> -bool tryCastWithOverflowCheck(float value, uint8_t& result) { - return tryCastWithOverflowCheckFloat(value, result, 0.0f, 256.0f); -} - -template<> -bool tryCastWithOverflowCheck(float value, uint16_t& result) { - return tryCastWithOverflowCheckFloat(value, result, 0.0f, 65536.0f); -} - -template<> -bool tryCastWithOverflowCheck(float value, uint32_t& result) { - return tryCastWithOverflowCheckFloat(value, result, 0.0f, 4294967296.0f); -} - -template<> -bool tryCastWithOverflowCheck(float value, uint64_t& result) { - return tryCastWithOverflowCheckFloat( - value, result, 0.0f, 18446744073709551616.0f); -} - -template<> -bool tryCastWithOverflowCheck(double value, int8_t& result) { - return tryCastWithOverflowCheckFloat(value, result, -128.0, 128.0); -} - -template<> -bool tryCastWithOverflowCheck(double value, int16_t& result) { - return tryCastWithOverflowCheckFloat(value, result, -32768.0, 32768.0); -} - -template<> -bool tryCastWithOverflowCheck(double value, int32_t& result) { - return tryCastWithOverflowCheckFloat( - value, result, -2147483648.0, 2147483648.0); -} - -template<> -bool tryCastWithOverflowCheck(double value, int64_t& result) { - return tryCastWithOverflowCheckFloat( - value, result, -9223372036854775808.0, 9223372036854775808.0); -} - -template<> -bool tryCastWithOverflowCheck(double value, uint8_t& result) { - return tryCastWithOverflowCheckFloat(value, result, 0.0, 256.0); -} - -template<> -bool tryCastWithOverflowCheck(double value, uint16_t& result) { - return tryCastWithOverflowCheckFloat(value, result, 0.0, 65536.0); -} - -template<> -bool tryCastWithOverflowCheck(double value, uint32_t& result) { - return tryCastWithOverflowCheckFloat(value, result, 0.0, 4294967296.0); -} - -template<> -bool tryCastWithOverflowCheck(double value, uint64_t& result) { - return tryCastWithOverflowCheckFloat( - value, result, 0.0, 18446744073709551615.0); -} - -template<> -bool tryCastWithOverflowCheck(float input, double& result) { - result = double(input); - return true; -} - -template<> -bool tryCastWithOverflowCheck(double input, float& result) { - result = float(input); - return true; -} - } // namespace function } // namespace kuzu diff --git a/src/include/function/cast/functions/cast_string_to_functions.h b/src/include/function/cast/functions/cast_string_to_functions.h new file mode 100644 index 00000000000..937fa049045 --- /dev/null +++ b/src/include/function/cast/functions/cast_string_to_functions.h @@ -0,0 +1,269 @@ +#pragma once + +#include + +#include "cast_string_non_nested_functions.h" +#include "common/copier_config/copier_config.h" +#include "common/exception/runtime.h" +#include "common/string_format.h" +#include "common/type_utils.h" +#include "common/types/blob.h" +#include "common/vector/value_vector.h" + +namespace kuzu { +namespace function { + +struct CastStringToTypes { + // non-nested types: used externally + template + static inline void operation( + common::ku_string_t& input, T& result, common::ValueVector& resultVector) { + // base case: int64 + simpleIntegerCast((char*)input.getData(), input.len, result, + common::LogicalType{common::LogicalTypeID::INT64}); + } + + template + static inline bool tryCast(const char* input, uint64_t len, T& result) { + // try cast for signed integer types + return trySimpleIntegerCast(input, len, result); + } + + // non nested-types (const char* are used internally) + template + static inline void operation(const char* input, uint64_t len, T& result) { + // base case: int64 + simpleIntegerCast( + input, len, result, common::LogicalType{common::LogicalTypeID::INT64}); + } + + // nested types + template + static void operation(const char* input, uint64_t len, common::ValueVector* vector, + uint64_t rowToAdd, const common::CSVReaderConfig& csvReaderConfig); + + // operation used by driver.cpp + static void operation(common::ValueVector* vector, uint64_t rowToAdd, std::string_view strVal, + const common::CSVReaderConfig& csvReaderConfig); +}; + +template<> +inline void CastStringToTypes::operation( + common::ku_string_t& input, common::int128_t& result, common::ValueVector& resultVector) { + simpleInt128Cast((char*)input.getData(), input.len, result); +} + +template<> +inline void CastStringToTypes::operation( + common::ku_string_t& input, int32_t& result, common::ValueVector& resultVector) { + simpleIntegerCast((char*)input.getData(), input.len, result, + common::LogicalType{common::LogicalTypeID::INT32}); +} + +template<> +inline void CastStringToTypes::operation( + common::ku_string_t& input, int16_t& result, common::ValueVector& resultVector) { + simpleIntegerCast((char*)input.getData(), input.len, result, + common::LogicalType{common::LogicalTypeID::INT16}); +} + +template<> +inline void CastStringToTypes::operation( + common::ku_string_t& input, int8_t& result, common::ValueVector& resultVector) { + simpleIntegerCast((char*)input.getData(), input.len, result, + common::LogicalType{common::LogicalTypeID::INT8}); +} + +template<> +inline void CastStringToTypes::operation( + common::ku_string_t& input, uint64_t& result, common::ValueVector& resultVector) { + simpleIntegerCast((char*)input.getData(), input.len, result, + common::LogicalType{common::LogicalTypeID::UINT64}); +} + +template<> +inline void CastStringToTypes::operation( + common::ku_string_t& input, uint32_t& result, common::ValueVector& resultVector) { + simpleIntegerCast((char*)input.getData(), input.len, result, + common::LogicalType{common::LogicalTypeID::UINT32}); +} + +template<> +inline void CastStringToTypes::operation( + common::ku_string_t& input, uint16_t& result, common::ValueVector& resultVector) { + simpleIntegerCast((char*)input.getData(), input.len, result, + common::LogicalType{common::LogicalTypeID::UINT16}); +} + +template<> +inline void CastStringToTypes::operation( + common::ku_string_t& input, uint8_t& result, common::ValueVector& resultVector) { + simpleIntegerCast((char*)input.getData(), input.len, result, + common::LogicalType{common::LogicalTypeID::UINT8}); +} + +template<> +inline void CastStringToTypes::operation( + common::ku_string_t& input, float_t& result, common::ValueVector& resultVector) { + doubleCast((char*)input.getData(), input.len, result, + common::LogicalType{common::LogicalTypeID::FLOAT}); +} + +template<> +inline void CastStringToTypes::operation( + common::ku_string_t& input, double_t& result, common::ValueVector& resultVector) { + doubleCast((char*)input.getData(), input.len, result, + common::LogicalType{common::LogicalTypeID::DOUBLE}); +} + +template<> +inline void CastStringToTypes::operation( + common::ku_string_t& input, common::date_t& result, common::ValueVector& resultVector) { + result = common::Date::fromCString((const char*)input.getData(), input.len); +} + +template<> +inline void CastStringToTypes::operation( + common::ku_string_t& input, common::timestamp_t& result, common::ValueVector& resultVector) { + result = common::Timestamp::fromCString((const char*)input.getData(), input.len); +} + +template<> +inline void CastStringToTypes::operation( + common::ku_string_t& input, common::interval_t& result, common::ValueVector& resultVector) { + result = common::Interval::fromCString((const char*)input.getData(), input.len); +} + +template<> +inline void CastStringToTypes::operation( + common::ku_string_t& input, bool& result, common::ValueVector& resultVector) { + castStringToBool(reinterpret_cast(input.getData()), input.len, result); +} + +template<> +void CastStringToTypes::operation( + common::ku_string_t& input, common::blob_t& result, common::ValueVector& resultVector); + +template<> +void CastStringToTypes::operation( + common::ku_string_t& input, common::list_entry_t& result, common::ValueVector& resultVector); + +template<> +void CastStringToTypes::operation( + common::ku_string_t& input, common::map_entry_t& result, common::ValueVector& resultVector); + +template<> +void CastStringToTypes::operation( + common::ku_string_t& input, common::union_entry_t& result, common::ValueVector& resultVector); + +template<> +void CastStringToTypes::operation( + common::ku_string_t& input, common::struct_entry_t& result, common::ValueVector& resultVector); + +template<> +inline void CastStringToTypes::operation( + const char* input, uint64_t len, common::int128_t& result) { + simpleInt128Cast(input, len, result); +} + +template<> +inline void CastStringToTypes::operation(const char* input, uint64_t len, int64_t& result) { + simpleIntegerCast( + input, len, result, common::LogicalType{common::LogicalTypeID::INT64}); +} + +template<> +inline void CastStringToTypes::operation(const char* input, uint64_t len, int32_t& result) { + simpleIntegerCast( + input, len, result, common::LogicalType{common::LogicalTypeID::INT32}); +} + +template<> +inline void CastStringToTypes::operation(const char* input, uint64_t len, int16_t& result) { + simpleIntegerCast( + input, len, result, common::LogicalType{common::LogicalTypeID::INT16}); +} + +template<> +inline void CastStringToTypes::operation(const char* input, uint64_t len, int8_t& result) { + simpleIntegerCast(input, len, result, common::LogicalType{common::LogicalTypeID::INT8}); +} + +template<> +inline void CastStringToTypes::operation(const char* input, uint64_t len, uint64_t& result) { + simpleIntegerCast( + input, len, result, common::LogicalType{common::LogicalTypeID::UINT64}); +} + +template<> +inline void CastStringToTypes::operation(const char* input, uint64_t len, uint32_t& result) { + simpleIntegerCast( + input, len, result, common::LogicalType{common::LogicalTypeID::UINT32}); +} + +template<> +inline void CastStringToTypes::operation(const char* input, uint64_t len, uint16_t& result) { + simpleIntegerCast( + input, len, result, common::LogicalType{common::LogicalTypeID::UINT16}); +} + +template<> +inline void CastStringToTypes::operation(const char* input, uint64_t len, uint8_t& result) { + simpleIntegerCast( + input, len, result, common::LogicalType{common::LogicalTypeID::UINT8}); +} + +template<> +inline void CastStringToTypes::operation(const char* input, uint64_t len, float_t& result) { + doubleCast(input, len, result, common::LogicalType{common::LogicalTypeID::FLOAT}); +} + +template<> +inline void CastStringToTypes::operation(const char* input, uint64_t len, double_t& result) { + doubleCast(input, len, result, common::LogicalType{common::LogicalTypeID::DOUBLE}); +} + +template<> +inline void CastStringToTypes::operation(const char* input, uint64_t len, bool& result) { + castStringToBool(input, len, result); +} + +template<> +inline void CastStringToTypes::operation(const char* input, uint64_t len, common::date_t& result) { + result = common::Date::fromCString(input, len); +} + +template<> +inline void CastStringToTypes::operation( + const char* input, uint64_t len, common::timestamp_t& result) { + result = common::Timestamp::fromCString(input, len); +} + +template<> +inline void CastStringToTypes::operation( + const char* input, uint64_t len, common::interval_t& result) { + result = common::Interval::fromCString(input, len); +} + +template<> +void CastStringToTypes::operation(const char* input, uint64_t len, + common::ValueVector* vector, uint64_t rowToAdd, const common::CSVReaderConfig& csvReaderConfig); + +template<> +void CastStringToTypes::operation(const char* input, uint64_t len, + common::ValueVector* vector, uint64_t rowToAdd, const common::CSVReaderConfig& csvReaderConfig); + +template<> +void CastStringToTypes::operation(const char* input, uint64_t len, + common::ValueVector* vector, uint64_t rowToAdd, const common::CSVReaderConfig& csvReaderConfig); + +template<> +void CastStringToTypes::operation(const char* input, uint64_t len, + common::ValueVector* vector, uint64_t rowToAdd, const common::CSVReaderConfig& csvReaderConfig); + +template<> +void CastStringToTypes::operation(const char* input, uint64_t len, + common::ValueVector* vector, uint64_t rowToAdd, const common::CSVReaderConfig& csvReaderConfig); + +} // namespace function +} // namespace kuzu diff --git a/src/include/function/cast/functions/numeric_cast.h b/src/include/function/cast/functions/numeric_cast.h new file mode 100644 index 00000000000..4c31f2ec456 --- /dev/null +++ b/src/include/function/cast/functions/numeric_cast.h @@ -0,0 +1,164 @@ +#pragma once + +#include +#include + +#include "numeric_limits.h" + +namespace kuzu { +namespace function { + +template +static bool tryCastWithOverflowCheck(SRC value, DST& result) { + if (NumericLimits::isSigned() != NumericLimits::isSigned()) { + if (NumericLimits::isSigned()) { + if (NumericLimits::digits() > NumericLimits::digits()) { + if (value < 0 || value > (SRC)NumericLimits::maximum()) { + return false; + } + } else { + if (value < 0) { + return false; + } + } + result = (DST)value; + return true; + } else { + // unsigned to signed conversion + if (NumericLimits::digits() >= NumericLimits::digits()) { + if (value <= (SRC)NumericLimits::maximum()) { + result = (DST)value; + return true; + } + return false; + } else { + result = (DST)value; + return true; + } + } + } else { + // same sign conversion + if (NumericLimits::digits() >= NumericLimits::digits()) { + result = (DST)value; + return true; + } else { + if (value < SRC(NumericLimits::minimum()) || + value > SRC(NumericLimits::maximum())) { + return false; + } + result = (DST)value; + return true; + } + } +} + +template +bool tryCastWithOverflowCheckFloat(SRC value, T& result, SRC min, SRC max) { + if (!(value >= min && value < max)) { + return false; + } + // PG FLOAT => INT casts use statistical rounding. + result = std::nearbyint(value); + return true; +} + +template<> +bool tryCastWithOverflowCheck(float value, int8_t& result) { + return tryCastWithOverflowCheckFloat(value, result, -128.0f, 128.0f); +} + +template<> +bool tryCastWithOverflowCheck(float value, int16_t& result) { + return tryCastWithOverflowCheckFloat(value, result, -32768.0f, 32768.0f); +} + +template<> +bool tryCastWithOverflowCheck(float value, int32_t& result) { + return tryCastWithOverflowCheckFloat( + value, result, -2147483648.0f, 2147483648.0f); +} + +template<> +bool tryCastWithOverflowCheck(float value, int64_t& result) { + return tryCastWithOverflowCheckFloat( + value, result, -9223372036854775808.0f, 9223372036854775808.0f); +} + +template<> +bool tryCastWithOverflowCheck(float value, uint8_t& result) { + return tryCastWithOverflowCheckFloat(value, result, 0.0f, 256.0f); +} + +template<> +bool tryCastWithOverflowCheck(float value, uint16_t& result) { + return tryCastWithOverflowCheckFloat(value, result, 0.0f, 65536.0f); +} + +template<> +bool tryCastWithOverflowCheck(float value, uint32_t& result) { + return tryCastWithOverflowCheckFloat(value, result, 0.0f, 4294967296.0f); +} + +template<> +bool tryCastWithOverflowCheck(float value, uint64_t& result) { + return tryCastWithOverflowCheckFloat( + value, result, 0.0f, 18446744073709551616.0f); +} + +template<> +bool tryCastWithOverflowCheck(double value, int8_t& result) { + return tryCastWithOverflowCheckFloat(value, result, -128.0, 128.0); +} + +template<> +bool tryCastWithOverflowCheck(double value, int16_t& result) { + return tryCastWithOverflowCheckFloat(value, result, -32768.0, 32768.0); +} + +template<> +bool tryCastWithOverflowCheck(double value, int32_t& result) { + return tryCastWithOverflowCheckFloat( + value, result, -2147483648.0, 2147483648.0); +} + +template<> +bool tryCastWithOverflowCheck(double value, int64_t& result) { + return tryCastWithOverflowCheckFloat( + value, result, -9223372036854775808.0, 9223372036854775808.0); +} + +template<> +bool tryCastWithOverflowCheck(double value, uint8_t& result) { + return tryCastWithOverflowCheckFloat(value, result, 0.0, 256.0); +} + +template<> +bool tryCastWithOverflowCheck(double value, uint16_t& result) { + return tryCastWithOverflowCheckFloat(value, result, 0.0, 65536.0); +} + +template<> +bool tryCastWithOverflowCheck(double value, uint32_t& result) { + return tryCastWithOverflowCheckFloat(value, result, 0.0, 4294967296.0); +} + +template<> +bool tryCastWithOverflowCheck(double value, uint64_t& result) { + return tryCastWithOverflowCheckFloat( + value, result, 0.0, 18446744073709551615.0); +} + +template<> +bool tryCastWithOverflowCheck(float input, double& result) { + result = double(input); + return true; +} + +template<> +bool tryCastWithOverflowCheck(double input, float& result) { + result = float(input); + return true; +} + +} // namespace function +} // namespace kuzu diff --git a/src/include/function/cast/numeric_limits.h b/src/include/function/cast/functions/numeric_limits.h similarity index 100% rename from src/include/function/cast/numeric_limits.h rename to src/include/function/cast/functions/numeric_limits.h diff --git a/src/include/function/cast/vector_cast_functions.h b/src/include/function/cast/vector_cast_functions.h index 182497c5cdd..03260036b8a 100644 --- a/src/include/function/cast/vector_cast_functions.h +++ b/src/include/function/cast/vector_cast_functions.h @@ -1,6 +1,7 @@ #pragma once -#include "cast_functions.h" +#include "function/cast/functions/cast_functions.h" +#include "function/cast/functions/cast_string_to_functions.h" #include "function/vector_functions.h" namespace kuzu { @@ -23,23 +24,15 @@ class VectorCastFunction : public VectorFunction { protected: template - inline static std::unique_ptr bindVectorFunction( + inline static std::unique_ptr bindNumericCastVectorFunction( const std::string& funcName, common::LogicalTypeID sourceTypeID, common::LogicalTypeID targetTypeID) { scalar_exec_func func; - getUnaryExecFunc(sourceTypeID, func); + bindImplicitNumericalCastFunc(sourceTypeID, func); return std::make_unique( funcName, std::vector{sourceTypeID}, targetTypeID, func); } - template - static void UnaryCastExecFunction( - const std::vector>& params, - common::ValueVector& result) { - assert(params.size() == 1); - UnaryFunctionExecutor::executeCast(*params[0], result); - } - template static void bindImplicitNumericalCastFunc( common::LogicalTypeID srcTypeID, scalar_exec_func& func) { @@ -78,7 +71,7 @@ class VectorCastFunction : public VectorFunction { return; } case common::LogicalTypeID::INT128: { - func = UnaryExecFunction; + func = UnaryExecFunction; return; } case common::LogicalTypeID::FLOAT: { @@ -96,16 +89,12 @@ class VectorCastFunction : public VectorFunction { } } - template - static void getUnaryExecFunc(common::LogicalTypeID srcTypeID, scalar_exec_func& func) { - switch (srcTypeID) { - case common::LogicalTypeID::STRING: { - func = UnaryExecFunction; - return; - } - default: - bindImplicitNumericalCastFunc(srcTypeID, func); - } + template + inline static std::unique_ptr bindCastToStringVectorFunction( + const std::string& funcName, common::LogicalTypeID targetTypeID) { + return std::make_unique(funcName, + std::vector{common::LogicalTypeID::STRING}, targetTypeID, + UnaryStringExecFunction); } }; @@ -122,7 +111,8 @@ struct CastToIntervalVectorFunction : public VectorCastFunction { }; struct CastToStringVectorFunction : public VectorCastFunction { - static void getUnaryCastExecFunction(common::LogicalTypeID typeID, scalar_exec_func& func); + static void getUnaryCastToStringExecFunction( + common::LogicalTypeID typeID, scalar_exec_func& func); static vector_function_definitions getDefinitions(); }; @@ -146,6 +136,10 @@ struct CastToSerialVectorFunction : public VectorCastFunction { static vector_function_definitions getDefinitions(); }; +struct CastToInt128VectorFunction : public VectorCastFunction { + static vector_function_definitions getDefinitions(); +}; + struct CastToInt64VectorFunction : public VectorCastFunction { static vector_function_definitions getDefinitions(); }; @@ -178,9 +172,5 @@ struct CastToUInt8VectorFunction : public VectorCastFunction { static vector_function_definitions getDefinitions(); }; -struct CastToInt128VectorFunction : public VectorCastFunction { - static vector_function_definitions getDefinitions(); -}; - } // namespace function } // namespace kuzu diff --git a/src/include/function/string/vector_string_functions.h b/src/include/function/string/vector_string_functions.h index 4385b70b629..8bb1bfd6226 100644 --- a/src/include/function/string/vector_string_functions.h +++ b/src/include/function/string/vector_string_functions.h @@ -40,7 +40,7 @@ struct VectorStringFunction : public VectorFunction { } template - static inline vector_function_definitions getUnaryStrFunctionDefintion(std::string funcName) { + static inline vector_function_definitions getUnaryStrFunctionDefinition(std::string funcName) { vector_function_definitions definitions; definitions.emplace_back(std::make_unique(funcName, std::vector{common::LogicalTypeID::STRING}, @@ -73,7 +73,7 @@ struct LeftVectorFunction : public VectorStringFunction { struct LowerVectorFunction : public VectorStringFunction { static inline vector_function_definitions getDefinitions() { - return getUnaryStrFunctionDefintion(common::LOWER_FUNC_NAME); + return getUnaryStrFunctionDefinition(common::LOWER_FUNC_NAME); } }; @@ -83,7 +83,7 @@ struct LpadVectorFunction : public VectorStringFunction { struct LtrimVectorFunction : public VectorStringFunction { static inline vector_function_definitions getDefinitions() { - return getUnaryStrFunctionDefintion(common::LTRIM_FUNC_NAME); + return getUnaryStrFunctionDefinition(common::LTRIM_FUNC_NAME); } }; @@ -93,7 +93,7 @@ struct RepeatVectorFunction : public VectorStringFunction { struct ReverseVectorFunction : public VectorStringFunction { static inline vector_function_definitions getDefinitions() { - return getUnaryStrFunctionDefintion(common::REVERSE_FUNC_NAME); + return getUnaryStrFunctionDefinition(common::REVERSE_FUNC_NAME); } }; @@ -107,7 +107,7 @@ struct RpadVectorFunction : public VectorStringFunction { struct RtrimVectorFunction : public VectorStringFunction { static inline vector_function_definitions getDefinitions() { - return getUnaryStrFunctionDefintion(common::RTRIM_FUNC_NAME); + return getUnaryStrFunctionDefinition(common::RTRIM_FUNC_NAME); } }; @@ -121,13 +121,13 @@ struct SubStrVectorFunction : public VectorStringFunction { struct TrimVectorFunction : public VectorStringFunction { static inline vector_function_definitions getDefinitions() { - return getUnaryStrFunctionDefintion(common::TRIM_FUNC_NAME); + return getUnaryStrFunctionDefinition(common::TRIM_FUNC_NAME); } }; struct UpperVectorFunction : public VectorStringFunction { static inline vector_function_definitions getDefinitions() { - return getUnaryStrFunctionDefintion(common::UPPER_FUNC_NAME); + return getUnaryStrFunctionDefinition(common::UPPER_FUNC_NAME); } }; diff --git a/src/include/function/vector_functions.h b/src/include/function/vector_functions.h index 76a3591490f..a2d4f2b0824 100644 --- a/src/include/function/vector_functions.h +++ b/src/include/function/vector_functions.h @@ -91,6 +91,22 @@ struct VectorFunction { UnaryFunctionExecutor::execute(*params[0], result); } + template + static void UnaryStringExecFunction( + const std::vector>& params, + common::ValueVector& result) { + assert(params.size() == 1); + UnaryFunctionExecutor::executeString(*params[0], result); + } + + template + static void UnaryCastExecFunction( + const std::vector>& params, + common::ValueVector& result) { + assert(params.size() == 1); + UnaryFunctionExecutor::executeCast(*params[0], result); + } + template static void ConstExecFunction(const std::vector>& params, common::ValueVector& result) { diff --git a/src/include/processor/operator/persistent/writer/parquet/standard_column_writer.h b/src/include/processor/operator/persistent/writer/parquet/standard_column_writer.h index 2bc6a449f5b..c9054821f7d 100644 --- a/src/include/processor/operator/persistent/writer/parquet/standard_column_writer.h +++ b/src/include/processor/operator/persistent/writer/parquet/standard_column_writer.h @@ -2,7 +2,7 @@ #include "basic_column_writer.h" #include "common/serializer/serializer.h" -#include "function/cast/numeric_limits.h" +#include "function/cast/functions/numeric_limits.h" #include "function/comparison/comparison_functions.h" namespace kuzu { diff --git a/src/include/storage/in_mem_storage_structure/in_mem_column_chunk.h b/src/include/storage/in_mem_storage_structure/in_mem_column_chunk.h index 1e1bfa3ae50..0700a2135a7 100644 --- a/src/include/storage/in_mem_storage_structure/in_mem_column_chunk.h +++ b/src/include/storage/in_mem_storage_structure/in_mem_column_chunk.h @@ -1,7 +1,7 @@ #pragma once #include "common/types/types.h" -#include "function/cast/cast_utils.h" +#include "function/cast/functions/cast_string_to_functions.h" #include "storage/storage_structure/in_mem_file.h" #include "storage/store/table_copy_utils.h" #include @@ -50,10 +50,10 @@ class InMemColumnChunk { template void templateCopyValuesToPage(arrow::Array& array, arrow::Array* nodeOffsets); - template - void setValueFromString( - const char* value, uint64_t length, common::offset_t pos, Args... /*args*/) { - auto val = function::castStringToNum(value, length); + template + void setValueFromString(const char* value, uint64_t length, common::offset_t pos) { + T val; + function::CastStringToTypes::operation(value, length, val); setValue(val, pos); } @@ -141,26 +141,10 @@ template<> void InMemColumnChunk::templateCopyValuesToPage( arrow::Array& array, arrow::Array* offsets); -// BOOL -template<> -void InMemColumnChunk::setValueFromString( - const char* value, uint64_t length, common::offset_t pos); // FIXED_LIST template<> void InMemColumnChunk::setValueFromString( const char* value, uint64_t length, uint64_t pos); -// INTERVAL -template<> -void InMemColumnChunk::setValueFromString( - const char* value, uint64_t length, uint64_t pos); -// DATE -template<> -void InMemColumnChunk::setValueFromString( - const char* value, uint64_t length, uint64_t pos); -// TIMESTAMP -template<> -void InMemColumnChunk::setValueFromString( - const char* value, uint64_t length, uint64_t pos); } // namespace storage } // namespace kuzu diff --git a/src/include/storage/in_mem_storage_structure/in_mem_lists.h b/src/include/storage/in_mem_storage_structure/in_mem_lists.h index 1da77239d9d..c21f3b70621 100644 --- a/src/include/storage/in_mem_storage_structure/in_mem_lists.h +++ b/src/include/storage/in_mem_storage_structure/in_mem_lists.h @@ -1,6 +1,7 @@ #pragma once #include "common/copier_config/copier_config.h" +#include "function/cast/functions/cast_string_to_functions.h" #include "storage/storage_structure/in_mem_file.h" #include "storage/storage_structure/lists/list_headers.h" #include "storage/storage_structure/lists/lists_metadata.h" @@ -48,7 +49,11 @@ class InMemLists { void setValue(common::offset_t nodeOffset, uint64_t pos, uint8_t* val); template void setValueFromString( - common::offset_t nodeOffset, uint64_t pos, const char* val, uint64_t length); + common::offset_t nodeOffset, uint64_t pos, const char* val, uint64_t length) { + T result; + function::CastStringToTypes::operation(val, length, result); + setValue(nodeOffset, pos, (uint8_t*)&result); + } virtual inline InMemOverflowFile* getInMemOverflowFile() { return nullptr; } inline ListsMetadataBuilder* getListsMetadataBuilder() { return listsMetadataBuilder.get(); } @@ -200,26 +205,10 @@ template<> void InMemLists::templateCopyArrayToRelLists( arrow::Array* boundNodeOffsets, arrow::Array* posInRelList, arrow::Array* array); -// BOOL -template<> -void InMemLists::setValueFromString( - common::offset_t nodeOffset, uint64_t pos, const char* val, uint64_t length); // FIXED_LIST template<> void InMemLists::setValueFromString( common::offset_t nodeOffset, uint64_t pos, const char* val, uint64_t length); -// INTERVAL -template<> -void InMemLists::setValueFromString( - common::offset_t nodeOffset, uint64_t pos, const char* val, uint64_t length); -// DATE -template<> -void InMemLists::setValueFromString( - common::offset_t nodeOffset, uint64_t pos, const char* val, uint64_t length); -// TIMESTAMP -template<> -void InMemLists::setValueFromString( - common::offset_t nodeOffset, uint64_t pos, const char* val, uint64_t length); template<> void InMemListsWithOverflow::setValueFromStringWithOverflow( diff --git a/src/include/storage/store/column_chunk.h b/src/include/storage/store/column_chunk.h index f99115f25c1..c697cb29478 100644 --- a/src/include/storage/store/column_chunk.h +++ b/src/include/storage/store/column_chunk.h @@ -5,7 +5,6 @@ #include "common/types/types.h" #include "common/vector/value_vector.h" #include "compression.h" -#include "function/cast/cast_utils.h" #include "storage/buffer_manager/bm_file_handle.h" #include "storage/wal/wal.h" #include "transaction/transaction.h" diff --git a/src/parser/transform/transform_expression.cpp b/src/parser/transform/transform_expression.cpp index 402b13eac95..9c1ad4cac50 100644 --- a/src/parser/transform/transform_expression.cpp +++ b/src/parser/transform/transform_expression.cpp @@ -1,5 +1,5 @@ #include "common/string_utils.h" -#include "function/cast/cast_utils.h" +#include "function/cast/functions/cast_string_to_functions.h" #include "parser/expression/parsed_case_expression.h" #include "parser/expression/parsed_function_expression.h" #include "parser/expression/parsed_literal_expression.h" @@ -579,25 +579,24 @@ std::string Transformer::transformPropertyKeyName(CypherParser::OC_PropertyKeyNa std::unique_ptr Transformer::transformIntegerLiteral( CypherParser::OC_IntegerLiteralContext& ctx) { auto text = ctx.DecimalInteger()->getText(); - function::IntegerCastData data{}; - auto len = text.length(); - if (tryIntegerCast, true>( - text.c_str(), (uint64_t&)len, data)) { - auto value = std::make_unique( - function::castStringToNum(text.c_str(), text.length())); - return std::make_unique(std::move(value), ctx.getText()); + int64_t result; + if (function::CastStringToTypes::tryCast(text.c_str(), text.length(), result)) { + return std::make_unique( + std::make_unique(result), ctx.getText()); } - auto value = std::make_unique( - function::castStringToNum(text.c_str(), text.length())); - return std::make_unique(std::move(value), ctx.getText()); + int128_t result128; + function::CastStringToTypes::operation(text.c_str(), text.length(), result128); + return std::make_unique( + std::make_unique(result128), ctx.getText()); } std::unique_ptr Transformer::transformDoubleLiteral( CypherParser::OC_DoubleLiteralContext& ctx) { auto text = ctx.RegularDecimalReal()->getText(); - auto value = - std::make_unique(function::castStringToNum(text.c_str(), text.length())); - return std::make_unique(std::move(value), ctx.getText()); + double_t result; + function::CastStringToTypes::operation(text.c_str(), text.length(), result); + return std::make_unique( + std::make_unique(result), ctx.getText()); } } // namespace parser diff --git a/src/processor/operator/persistent/copy_to_csv.cpp b/src/processor/operator/persistent/copy_to_csv.cpp index fa5f7cbd299..4d429ca8e80 100644 --- a/src/processor/operator/persistent/copy_to_csv.cpp +++ b/src/processor/operator/persistent/copy_to_csv.cpp @@ -28,7 +28,7 @@ void CopyToCSVLocalState::init(CopyToInfo* info, MemoryManager* mm, ResultSet* r castFuncs.resize(info->dataPoses.size()); for (auto i = 0u; i < info->dataPoses.size(); i++) { auto vectorToCast = resultSet->getValueVector(info->dataPoses[i]); - function::CastToStringVectorFunction::getUnaryCastExecFunction( + function::CastToStringVectorFunction::getUnaryCastToStringExecFunction( vectorToCast->dataType.getLogicalTypeID(), castFuncs[i]); vectorsToCast.push_back(std::move(vectorToCast)); auto castVector = std::make_unique(LogicalTypeID::STRING, mm); diff --git a/src/processor/operator/persistent/reader/csv/driver.cpp b/src/processor/operator/persistent/reader/csv/driver.cpp index 810cd9eae81..323d076b5be 100644 --- a/src/processor/operator/persistent/reader/csv/driver.cpp +++ b/src/processor/operator/persistent/reader/csv/driver.cpp @@ -2,16 +2,12 @@ #include "common/exception/copy.h" #include "common/exception/message.h" -#include "common/exception/parser.h" #include "common/string_format.h" -#include "common/type_utils.h" -#include "common/types/blob.h" #include "common/types/value/value.h" -#include "function/cast/cast_utils.h" +#include "function/cast/functions/cast_string_to_functions.h" #include "processor/operator/persistent/reader/csv/parallel_csv_reader.h" #include "processor/operator/persistent/reader/csv/serial_csv_reader.h" #include "storage/store/table_copy_utils.h" -#include "utf8proc_wrapper.h" using namespace kuzu::common; @@ -20,686 +16,6 @@ namespace processor { ParsingDriver::ParsingDriver(common::DataChunk& chunk) : chunk(chunk), rowEmpty(false) {} -void copyStringToVector(ValueVector* vector, uint64_t rowToAdd, std::string_view strVal, - const CSVReaderConfig& csvReaderConfig); - -static void skipWhitespace(const char*& input, const char* end) { - while (input < end && isspace(*input)) { - input++; - } -} - -static void trimRightWhitespace(const char* input, const char*& end) { - while (input < end && isspace(*(end - 1))) { - end--; - } -} - -bool skipToCloseQuotes(const char*& input, const char* end) { - auto ch = *input; - input++; // skip the first " ' - // TODO: escape char - while (input != end) { - if (*input == ch) { - return true; - } - input++; - } - return false; -} - -static bool skipToClose(const char*& input, const char* end, uint64_t& lvl, char target, - const CSVReaderConfig& csvReaderConfig) { - input++; - while (input != end) { - if (*input == '\'') { - if (!skipToCloseQuotes(input, end)) { - return false; - } - } else if (*input == '{') { // must have closing brackets fro {, ] if they are not quoted - if (!skipToClose(input, end, lvl, '}', csvReaderConfig)) { - return false; - } - } else if (*input == csvReaderConfig.listBeginChar) { - if (!skipToClose(input, end, lvl, csvReaderConfig.listEndChar, csvReaderConfig)) { - return false; - } - lvl++; // nested one more level - } else if (*input == target) { - if (target == csvReaderConfig.listEndChar) { - lvl--; - } - return true; - } - input++; - } - return false; // no corresponding closing bracket -} - -static bool isNull(std::string_view& str) { - auto start = str.data(); - auto end = start + str.length(); - skipWhitespace(start, end); - if (start == end) { - return true; - } - if (end - start >= 4 && (*start == 'N' || *start == 'n') && - (*(start + 1) == 'U' || *(start + 1) == 'u') && - (*(start + 2) == 'L' || *(start + 2) == 'l') && - (*(start + 3) == 'L' || *(start + 3) == 'l')) { - start += 4; - skipWhitespace(start, end); - if (start == end) { - return true; - } - } - return false; -} - -struct CountPartOperation { - uint64_t count = 0; - - static inline bool handleKey( - const char* /*start*/, const char* /*end*/, const CSVReaderConfig& /*config*/) { - return true; - } - inline void handleValue( - const char* /*start*/, const char* /*end*/, const CSVReaderConfig& /*config*/) { - count++; - } -}; - -struct SplitStringListOperation { - SplitStringListOperation(uint64_t& offset, ValueVector* resultVector) - : offset(offset), resultVector(resultVector) {} - - uint64_t& offset; - ValueVector* resultVector; - - void handleValue(const char* start, const char* end, const CSVReaderConfig& csvReaderConfig) { - copyStringToVector(resultVector, offset, std::string_view{start, (uint32_t)(end - start)}, - csvReaderConfig); - offset++; - } -}; - -template -struct SplitStringFixedListOperation { - SplitStringFixedListOperation(uint64_t& offset, ValueVector* resultVector) - : offset(offset), resultVector(resultVector) {} - - uint64_t& offset; - ValueVector* resultVector; - - void handleValue( - const char* start, const char* end, const CSVReaderConfig& /*csvReaderConfig*/) { - T value; - auto str = std::string_view{start, (uint32_t)(end - start)}; - if (str.empty() || isNull(str)) { - throw ConversionException("Cast failed. NULL is not allowed for FIXEDLIST."); - } - auto type = FixedListType::getChildType(&resultVector->dataType); - value = function::castStringToNum(start, str.length(), *type); - resultVector->setValue(offset, value); - offset++; - } -}; - -struct SplitStringMapOperation { - SplitStringMapOperation(uint64_t& offset, ValueVector* resultVector) - : offset(offset), resultVector(resultVector) {} - - uint64_t& offset; - ValueVector* resultVector; - - inline bool handleKey( - const char* start, const char* end, const CSVReaderConfig& csvReaderConfig) { - trimRightWhitespace(start, end); - copyStringToVector(StructVector::getFieldVector(resultVector, 0).get(), offset, - std::string_view{start, (uint32_t)(end - start)}, csvReaderConfig); - return true; - } - - inline void handleValue( - const char* start, const char* end, const CSVReaderConfig& csvReaderConfig) { - trimRightWhitespace(start, end); - copyStringToVector(StructVector::getFieldVector(resultVector, 1).get(), offset++, - std::string_view{start, (uint32_t)(end - start)}, csvReaderConfig); - } -}; - -template -static bool splitCStringList( - const char* input, uint64_t len, T& state, const CSVReaderConfig& csvReaderConfig) { - auto end = input + len; - uint64_t lvl = 1; - bool seen_value = false; - - // locate [ - skipWhitespace(input, end); - if (input == end || *input != csvReaderConfig.listBeginChar) { - return false; - } - input++; - - auto start_ptr = input; - while (input < end) { - auto ch = *input; - if (ch == csvReaderConfig.listBeginChar) { - if (!skipToClose(input, end, ++lvl, csvReaderConfig.listEndChar, csvReaderConfig)) { - return false; - } - } else if (ch == '\'' || ch == '"') { - if (!skipToCloseQuotes(input, end)) { - return false; - } - } else if (ch == '{') { - uint64_t struct_lvl = 0; - skipToClose(input, end, struct_lvl, '}', csvReaderConfig); - } else if (ch == csvReaderConfig.delimiter || ch == csvReaderConfig.listEndChar) { // split - if (ch != csvReaderConfig.listEndChar || start_ptr < input || seen_value) { - state.handleValue(start_ptr, input, csvReaderConfig); - seen_value = true; - } - if (ch == csvReaderConfig.listEndChar) { // last ] - lvl--; - break; - } - start_ptr = ++input; - continue; - } - input++; - } - skipWhitespace(++input, end); - return (input == end && lvl == 0); -} - -template -static void tryListCast(const char* input, uint64_t len, T split, - const CSVReaderConfig& csvReaderConfig, ValueVector* vector) { - if (!splitCStringList(input, len, split, csvReaderConfig)) { - throw ConversionException("Cast failed. " + std::string{input, len} + " is not in " + - LogicalTypeUtils::dataTypeToString(vector->dataType) + " range."); - } -} - -static void castStringToList(const char* input, uint64_t len, ValueVector* vector, - uint64_t rowToAdd, const CSVReaderConfig& csvReaderConfig) { - // calculate the number of elements in array - CountPartOperation state; - splitCStringList(input, len, state, csvReaderConfig); - - auto list_entry = ListVector::addList(vector, state.count); - vector->setValue(rowToAdd, list_entry); - auto listDataVector = common::ListVector::getDataVector(vector); - - SplitStringListOperation split{list_entry.offset, listDataVector}; - tryListCast(input, len, split, csvReaderConfig, vector); -} - -static void validateNumElementsInList(uint64_t numElementsRead, const LogicalType& type) { - auto numElementsInList = FixedListType::getNumElementsInList(&type); - if (numElementsRead != numElementsInList) { - throw CopyException(stringFormat( - "Each fixed list should have fixed number of elements. Expected: {}, Actual: {}.", - numElementsInList, numElementsRead)); - } -} - -static void castStringToFixedList(const char* input, uint64_t len, ValueVector* vector, - uint64_t rowToAdd, const CSVReaderConfig& csvReaderConfig) { - assert(vector->dataType.getLogicalTypeID() == LogicalTypeID::FIXED_LIST); - auto childDataType = FixedListType::getChildType(&vector->dataType); - - // calculate the number of elements in array - CountPartOperation state; - splitCStringList(input, len, state, csvReaderConfig); - validateNumElementsInList(state.count, vector->dataType); - - auto startOffset = state.count * rowToAdd; - switch (childDataType->getLogicalTypeID()) { - // TODO: currently only allow these type - case LogicalTypeID::INT64: { - SplitStringFixedListOperation split{startOffset, vector}; - tryListCast(input, len, split, csvReaderConfig, vector); - } break; - case LogicalTypeID::INT32: { - SplitStringFixedListOperation split{startOffset, vector}; - tryListCast(input, len, split, csvReaderConfig, vector); - } break; - case LogicalTypeID::INT16: { - SplitStringFixedListOperation split{startOffset, vector}; - tryListCast(input, len, split, csvReaderConfig, vector); - } break; - case LogicalTypeID::FLOAT: { - SplitStringFixedListOperation split{startOffset, vector}; - tryListCast(input, len, split, csvReaderConfig, vector); - } break; - case LogicalTypeID::DOUBLE: { - SplitStringFixedListOperation split{startOffset, vector}; - tryListCast(input, len, split, csvReaderConfig, vector); - } break; - default: { - throw NotImplementedException("Unsupported data type: Driver::castStringToFixedList"); - } - } -} - -template -static bool parseKeyOrValue(const char*& input, const char* end, T& state, bool isKey, - bool& closeBracket, const CSVReaderConfig& csvReaderConfig) { - auto start = input; - uint64_t lvl = 0; - - while (input < end) { - if (*input == '"' || *input == '\'') { - if (!skipToCloseQuotes(input, end)) { - return false; - }; - } else if (*input == '{') { - if (!skipToClose(input, end, lvl, '}', csvReaderConfig)) { - return false; - } - } else if (*input == csvReaderConfig.listBeginChar) { - if (!skipToClose(input, end, lvl, csvReaderConfig.listEndChar, csvReaderConfig)) { - return false; - }; - } else if (isKey && *input == '=') { - return state.handleKey(start, input, csvReaderConfig); - } else if (!isKey && (*input == csvReaderConfig.delimiter || *input == '}')) { - state.handleValue(start, input, csvReaderConfig); - if (*input == '}') { - closeBracket = true; - } - return true; - } - input++; - } - return false; -} - -// Split map of format: {a=12,b=13} -template -static bool splitCStringMap( - const char* input, uint64_t len, T& state, const CSVReaderConfig& csvReaderConfig) { - auto end = input + len; - bool closeBracket = false; - - skipWhitespace(input, end); - if (input == end || *input != '{') { // start with { - return false; - } - skipWhitespace(++input, end); - if (input == end) { - return false; - } - if (*input == '}') { - skipWhitespace(++input, end); // empty - return input == end; - } - - while (input < end) { - if (!parseKeyOrValue(input, end, state, true, closeBracket, csvReaderConfig)) { - return false; - } - skipWhitespace(++input, end); - if (!parseKeyOrValue(input, end, state, false, closeBracket, csvReaderConfig)) { - return false; - } - skipWhitespace(++input, end); - if (closeBracket) { - return (input == end); - } - } - return false; -} - -static void castStringToMap(const char* input, uint64_t len, ValueVector* vector, uint64_t rowToAdd, - const CSVReaderConfig& csvReaderConfig) { - // count the number of maps in map - CountPartOperation state; - splitCStringMap(input, len, state, csvReaderConfig); - - auto list_entry = ListVector::addList(vector, state.count); - vector->setValue(rowToAdd, list_entry); - auto structVector = common::ListVector::getDataVector(vector); - - SplitStringMapOperation split{list_entry.offset, structVector}; - if (!splitCStringMap(input, len, split, csvReaderConfig)) { - throw ConversionException("Cast failed. " + std::string{input, len} + " is not in " + - LogicalTypeUtils::dataTypeToString(vector->dataType) + " range."); - } -} - -static bool parseStructFieldName(const char*& input, const char* end) { - while (input < end) { - if (*input == ':') { - return true; - } - input++; - } - return false; -} - -static bool parseStructFieldValue( - const char*& input, const char* end, const CSVReaderConfig& csvReaderConfig, bool& closeBrack) { - uint64_t lvl = 0; - while (input < end) { - if (*input == '"' || *input == '\'') { - if (!skipToCloseQuotes(input, end)) { - return false; - } - } else if (*input == '{') { - if (!skipToClose(input, end, lvl, '}', csvReaderConfig)) { - return false; - } - } else if (*input == csvReaderConfig.listBeginChar) { - if (!skipToClose(input, end, ++lvl, csvReaderConfig.listEndChar, csvReaderConfig)) { - return false; - } - } else if (*input == csvReaderConfig.delimiter || *input == '}') { - if (*input == '}') { - closeBrack = true; - } - return (lvl == 0); - } - input++; - } - return false; -} - -static bool tryCastStringToStruct(const char* input, uint64_t len, ValueVector* vector, - uint64_t rowToAdd, const CSVReaderConfig& csvReaderConfig) { - // check if start with { - auto end = input + len; - auto type = vector->dataType; - skipWhitespace(input, end); - if (input == end || *input != '{') { - return false; - } - skipWhitespace(++input, end); - - if (input == end) { // no closing bracket - return false; - } - if (*input == '}') { - skipWhitespace(++input, end); - return input == end; - } - - bool closeBracket = false; - while (input < end) { - auto keyStart = input; - if (!parseStructFieldName(input, end)) { // find key - return false; - } - auto keyEnd = input; - trimRightWhitespace(keyStart, keyEnd); - auto fieldIdx = StructType::getFieldIdx(&type, std::string{keyStart, keyEnd}); - if (fieldIdx == INVALID_STRUCT_FIELD_IDX) { - throw ParserException{"Invalid struct field name: " + std::string{keyStart, keyEnd}}; - } - - skipWhitespace(++input, end); - auto valStart = input; - if (!parseStructFieldValue(input, end, csvReaderConfig, closeBracket)) { // find value - return false; - } - auto valEnd = input; - trimRightWhitespace(valStart, valEnd); - skipWhitespace(++input, end); - - copyStringToVector(StructVector::getFieldVector(vector, fieldIdx).get(), rowToAdd, - std::string_view{valStart, (uint32_t)(valEnd - valStart)}, csvReaderConfig); - - if (closeBracket) { - return (input == end); - } - } - return false; -} - -static void castStringToStruct(const char* input, uint64_t len, ValueVector* vector, - uint64_t rowToAdd, const CSVReaderConfig& csvReaderConfig) { - if (!tryCastStringToStruct(input, len, vector, rowToAdd, csvReaderConfig)) { - throw ConversionException("Cast failed. " + std::string{input, len} + " is not in " + - LogicalTypeUtils::dataTypeToString(vector->dataType) + " range."); - } -} - -template -static inline void testAndSetValue(ValueVector* vector, uint64_t rowToAdd, T result, bool success) { - if (success) { - vector->setValue(rowToAdd, result); - } -} - -static bool tryCastUnionField( - ValueVector* vector, uint64_t rowToAdd, const char* input, uint64_t len) { - auto& targetType = vector->dataType; - bool success = false; - switch (targetType.getLogicalTypeID()) { - case LogicalTypeID::BOOL: { - bool result; - success = function::tryCastToBool(input, len, result); - testAndSetValue(vector, rowToAdd, result, success); - } break; - case LogicalTypeID::INT64: { - int64_t result; - success = function::trySimpleIntegerCast(input, len, result); - testAndSetValue(vector, rowToAdd, result, success); - } break; - case LogicalTypeID::INT32: { - int32_t result; - success = function::trySimpleIntegerCast(input, len, result); - testAndSetValue(vector, rowToAdd, result, success); - } break; - case LogicalTypeID::INT16: { - int16_t result; - success = function::trySimpleIntegerCast(input, len, result); - testAndSetValue(vector, rowToAdd, result, success); - } break; - case LogicalTypeID::INT8: { - int8_t result; - success = function::trySimpleIntegerCast(input, len, result); - testAndSetValue(vector, rowToAdd, result, success); - } break; - case LogicalTypeID::UINT64: { - uint64_t result; - success = function::trySimpleIntegerCast(input, len, result); - testAndSetValue(vector, rowToAdd, result, success); - } break; - case LogicalTypeID::UINT32: { - uint32_t result; - success = function::trySimpleIntegerCast(input, len, result); - testAndSetValue(vector, rowToAdd, result, success); - } break; - case LogicalTypeID::UINT16: { - uint16_t result; - success = function::trySimpleIntegerCast(input, len, result); - testAndSetValue(vector, rowToAdd, result, success); - } break; - case LogicalTypeID::UINT8: { - uint8_t result; - success = function::trySimpleIntegerCast(input, len, result); - testAndSetValue(vector, rowToAdd, result, success); - } break; - case LogicalTypeID::DOUBLE: { - double_t result; - success = function::tryDoubleCast(input, len, result); - testAndSetValue(vector, rowToAdd, result, success); - } break; - case LogicalTypeID::FLOAT: { - float_t result; - success = function::tryDoubleCast(input, len, result); - testAndSetValue(vector, rowToAdd, result, success); - } break; - case LogicalTypeID::DATE: { - date_t result; - uint64_t pos; - success = Date::tryConvertDate(input, len, pos, result); - testAndSetValue(vector, rowToAdd, result, success); - } break; - case LogicalTypeID::TIMESTAMP: { - timestamp_t result; - success = Timestamp::tryConvertTimestamp(input, len, result); - testAndSetValue(vector, rowToAdd, result, success); - } break; - case LogicalTypeID::STRING: { - if (!utf8proc::Utf8Proc::isValid(input, len)) { - throw common::CopyException{"Invalid UTF8-encoded string."}; - } - StringVector::addString(vector, rowToAdd, input, len); - return true; - } break; - default: { - return false; - } - } - return success; -} - -static void castStringToUnion(ValueVector* vector, std::string_view strVal, uint64_t rowToAdd) { - auto& type = vector->dataType; - union_field_idx_t selectedFieldIdx = INVALID_STRUCT_FIELD_IDX; - - for (auto i = 0u; i < UnionType::getNumFields(&type); i++) { - auto internalFieldIdx = UnionType::getInternalFieldIdx(i); - auto fieldVector = StructVector::getFieldVector(vector, internalFieldIdx).get(); - if (tryCastUnionField(fieldVector, rowToAdd, strVal.data(), strVal.length())) { - fieldVector->setNull(rowToAdd, false /* isNull */); - selectedFieldIdx = i; - break; - } else { - fieldVector->setNull(rowToAdd, true /* isNull */); - } - } - - if (selectedFieldIdx == INVALID_STRUCT_FIELD_IDX) { - throw ConversionException{stringFormat("Could not convert to union type {}: {}.", - LogicalTypeUtils::dataTypeToString(type), strVal)}; - } - StructVector::getFieldVector(vector, UnionType::TAG_FIELD_IDX) - ->setValue(rowToAdd, selectedFieldIdx); - StructVector::getFieldVector(vector, UnionType::TAG_FIELD_IDX) - ->setNull(rowToAdd, false /* isNull */); -} - -void copyStringToVector(ValueVector* vector, uint64_t rowToAdd, std::string_view strVal, - const CSVReaderConfig& csvReaderConfig) { - auto& type = vector->dataType; - - if (strVal.empty() || isNull(strVal)) { - vector->setNull(rowToAdd, true /* isNull */); - return; - } else { - vector->setNull(rowToAdd, false /* isNull */); - } - switch (type.getLogicalTypeID()) { - case LogicalTypeID::INT64: { - int64_t val; - function::simpleIntegerCast(strVal.data(), strVal.length(), val, type); - vector->setValue(rowToAdd, val); - } break; - case LogicalTypeID::INT32: { - int32_t val; - function::simpleIntegerCast(strVal.data(), strVal.length(), val, type); - vector->setValue(rowToAdd, val); - } break; - case LogicalTypeID::INT16: { - int16_t val; - function::simpleIntegerCast(strVal.data(), strVal.length(), val, type); - vector->setValue(rowToAdd, val); - } break; - case LogicalTypeID::INT8: { - int8_t val; - function::simpleIntegerCast(strVal.data(), strVal.length(), val, type); - vector->setValue(rowToAdd, val); - } break; - case LogicalTypeID::UINT64: { - uint64_t val; - function::simpleIntegerCast(strVal.data(), strVal.length(), val, type); - vector->setValue(rowToAdd, val); - } break; - case LogicalTypeID::UINT32: { - uint32_t val; - function::simpleIntegerCast(strVal.data(), strVal.length(), val, type); - vector->setValue(rowToAdd, val); - } break; - case LogicalTypeID::UINT16: { - uint16_t val; - function::simpleIntegerCast(strVal.data(), strVal.length(), val, type); - vector->setValue(rowToAdd, val); - } break; - case LogicalTypeID::UINT8: { - uint8_t val; - function::simpleIntegerCast(strVal.data(), strVal.length(), val, type); - vector->setValue(rowToAdd, val); - } break; - case LogicalTypeID::INT128: { - int128_t val{}; - function::simpleInt128Cast(strVal.data(), strVal.length(), val); - vector->setValue(rowToAdd, val); - } break; - case LogicalTypeID::FLOAT: { - float_t val; - function::doubleCast(strVal.data(), strVal.length(), val, type); - vector->setValue(rowToAdd, val); - } break; - case LogicalTypeID::DOUBLE: { - double_t val; - function::doubleCast(strVal.data(), strVal.length(), val, type); - vector->setValue(rowToAdd, val); - } break; - case LogicalTypeID::BOOL: { - bool val; - function::castStringToBool(strVal.data(), strVal.length(), val); - vector->setValue(rowToAdd, val); - } break; - case LogicalTypeID::BLOB: { - storage::TableCopyUtils::validateStrLen(strVal.length()); - auto blobBuffer = std::make_unique(strVal.length()); - auto blobLen = Blob::fromString(strVal.data(), strVal.length(), blobBuffer.get()); - StringVector::addString( - vector, rowToAdd, reinterpret_cast(blobBuffer.get()), blobLen); - } break; - case LogicalTypeID::STRING: { - storage::TableCopyUtils::validateStrLen(strVal.length()); - if (!utf8proc::Utf8Proc::isValid(strVal.data(), strVal.length())) { - throw common::CopyException{"Invalid UTF8-encoded string."}; - } - StringVector::addString(vector, rowToAdd, strVal.data(), strVal.length()); - } break; - case LogicalTypeID::DATE: { - vector->setValue(rowToAdd, Date::fromCString(strVal.data(), strVal.length())); - } break; - case LogicalTypeID::TIMESTAMP: { - vector->setValue(rowToAdd, Timestamp::fromCString(strVal.data(), strVal.length())); - } break; - case LogicalTypeID::INTERVAL: { - vector->setValue(rowToAdd, Interval::fromCString(strVal.data(), strVal.length())); - } break; - case LogicalTypeID::MAP: { - castStringToMap(strVal.data(), strVal.length(), vector, rowToAdd, csvReaderConfig); - } break; - case LogicalTypeID::VAR_LIST: { - castStringToList(strVal.data(), strVal.length(), vector, rowToAdd, csvReaderConfig); - } break; - case LogicalTypeID::FIXED_LIST: { - castStringToFixedList(strVal.data(), strVal.length(), vector, rowToAdd, csvReaderConfig); - } break; - case LogicalTypeID::STRUCT: { - castStringToStruct(strVal.data(), strVal.length(), vector, rowToAdd, csvReaderConfig); - } break; - case LogicalTypeID::UNION: { - castStringToUnion(vector, strVal, rowToAdd); - } break; - default: { // LCOV_EXCL_START - throw NotImplementedException("BaseCSVReader::copyStringToVector"); - } // LCOV_EXCL_STOP - } -} - bool ParsingDriver::done(uint64_t rowNum) { return rowNum >= DEFAULT_VECTOR_CAPACITY || doneEarly(); } @@ -722,7 +38,7 @@ void ParsingDriver::addValue( stringFormat("Error in file {}, on line {}: expected {} values per row, but got more.", reader->filePath, reader->getLineNumber(), reader->expectedNumColumns)); } - copyStringToVector( + function::CastStringToTypes::operation( chunk.getValueVector(columnIdx).get(), rowNum, value, reader->csvReaderConfig); } diff --git a/src/processor/operator/persistent/writer/parquet/basic_column_writer.cpp b/src/processor/operator/persistent/writer/parquet/basic_column_writer.cpp index 907fd182904..5f2adde93a0 100644 --- a/src/processor/operator/persistent/writer/parquet/basic_column_writer.cpp +++ b/src/processor/operator/persistent/writer/parquet/basic_column_writer.cpp @@ -1,6 +1,6 @@ #include "processor/operator/persistent/writer/parquet/basic_column_writer.h" -#include "function/cast/numeric_limits.h" +#include "function/cast/functions/numeric_limits.h" #include "processor/operator/persistent/reader/parquet/parquet_rle_bp_decoder.h" #include "processor/operator/persistent/writer//parquet/parquet_rle_bp_encoder.h" #include "processor/operator/persistent/writer/parquet/parquet_writer.h" diff --git a/src/processor/operator/persistent/writer/parquet/column_writer.cpp b/src/processor/operator/persistent/writer/parquet/column_writer.cpp index c22d18a5c4e..46ae3cc0c21 100644 --- a/src/processor/operator/persistent/writer/parquet/column_writer.cpp +++ b/src/processor/operator/persistent/writer/parquet/column_writer.cpp @@ -2,7 +2,7 @@ #include "common/exception/not_implemented.h" #include "common/string_utils.h" -#include "function/cast/numeric_limits.h" +#include "function/cast/functions/numeric_limits.h" #include "processor/operator/persistent/writer/parquet/boolean_column_writer.h" #include "processor/operator/persistent/writer/parquet/parquet_writer.h" #include "processor/operator/persistent/writer/parquet/standard_column_writer.h" diff --git a/src/storage/in_mem_storage_structure/in_mem_column_chunk.cpp b/src/storage/in_mem_storage_structure/in_mem_column_chunk.cpp index e2cd9e94a49..5aaf1582200 100644 --- a/src/storage/in_mem_storage_structure/in_mem_column_chunk.cpp +++ b/src/storage/in_mem_storage_structure/in_mem_column_chunk.cpp @@ -425,16 +425,6 @@ offset_t InMemFixedListColumnChunk::getOffsetInBuffer(offset_t pos) { return offsetInBuffer; } -// Bool -template<> -void InMemColumnChunk::setValueFromString( - const char* value, uint64_t /*length*/, uint64_t pos) { - std::istringstream boolStream{std::string(value)}; - bool booleanVal; - boolStream >> std::boolalpha >> booleanVal; - setValue(booleanVal, pos); -} - // Fixed list template<> void InMemColumnChunk::setValueFromString( @@ -446,29 +436,5 @@ void InMemColumnChunk::setValueFromString( fixedListVal.get(), storage::StorageUtils::getDataTypeSize(dataType)); } -// Interval -template<> -void InMemColumnChunk::setValueFromString( - const char* value, uint64_t length, uint64_t pos) { - auto val = Interval::fromCString(value, length); - setValue(val, pos); -} - -// Date -template<> -void InMemColumnChunk::setValueFromString( - const char* value, uint64_t length, uint64_t pos) { - auto val = Date::fromCString(value, length); - setValue(val, pos); -} - -// Timestamp -template<> -void InMemColumnChunk::setValueFromString( - const char* value, uint64_t length, uint64_t pos) { - auto val = Timestamp::fromCString(value, length); - setValue(val, pos); -} - } // namespace storage } // namespace kuzu diff --git a/src/storage/in_mem_storage_structure/in_mem_lists.cpp b/src/storage/in_mem_storage_structure/in_mem_lists.cpp index 495312514db..938e4a3c968 100644 --- a/src/storage/in_mem_storage_structure/in_mem_lists.cpp +++ b/src/storage/in_mem_storage_structure/in_mem_lists.cpp @@ -163,22 +163,6 @@ void InMemLists::setValue(offset_t nodeOffset, uint64_t pos, uint8_t* val) { numBytesForElement); } -template -void InMemLists::setValueFromString( - offset_t nodeOffset, uint64_t pos, const char* val, uint64_t length) { - auto numericVal = function::castStringToNum(val, length); - setValue(nodeOffset, pos, (uint8_t*)&numericVal); -} - -template<> -void InMemLists::setValueFromString( - offset_t nodeOffset, uint64_t pos, const char* val, uint64_t /*length*/) { - std::istringstream boolStream{std::string(val)}; - bool booleanVal; - boolStream >> std::boolalpha >> booleanVal; - setValue(nodeOffset, pos, (uint8_t*)&booleanVal); -} - // Fixed list template<> void InMemLists::setValueFromString( @@ -188,30 +172,6 @@ void InMemLists::setValueFromString( setValue(nodeOffset, pos, fixedListVal.get()); } -// Interval -template<> -void InMemLists::setValueFromString( - offset_t nodeOffset, uint64_t pos, const char* val, uint64_t length) { - auto intervalVal = Interval::fromCString(val, length); - setValue(nodeOffset, pos, (uint8_t*)&intervalVal); -} - -// Date -template<> -void InMemLists::setValueFromString( - offset_t nodeOffset, uint64_t pos, const char* val, uint64_t length) { - auto dateVal = Date::fromCString(val, length); - setValue(nodeOffset, pos, (uint8_t*)&dateVal); -} - -// Timestamp -template<> -void InMemLists::setValueFromString( - offset_t nodeOffset, uint64_t pos, const char* val, uint64_t length) { - auto timestampVal = Timestamp::fromCString(val, length); - setValue(nodeOffset, pos, (uint8_t*)×tampVal); -} - void InMemLists::initListsMetadataAndAllocatePages( uint64_t numNodes, ListHeaders* listHeaders, ListsMetadata* /*listsMetadata*/) { offset_t nodeOffset = 0u; diff --git a/src/storage/store/table_copy_utils.cpp b/src/storage/store/table_copy_utils.cpp index 8b36adad692..8cacdbd835c 100644 --- a/src/storage/store/table_copy_utils.cpp +++ b/src/storage/store/table_copy_utils.cpp @@ -5,7 +5,7 @@ #include "common/exception/message.h" #include "common/exception/parser.h" #include "common/string_format.h" -#include "function/cast/cast_utils.h" +#include "function/cast/functions/cast_string_to_functions.h" #include "storage/storage_structure/lists/lists.h" #include #include @@ -173,71 +173,32 @@ std::unique_ptr TableCopyUtils::getArrowFixedList(std::string_view l, switch (childDataType->getLogicalTypeID()) { case LogicalTypeID::INT64: { int64_t val; - function::simpleIntegerCast(element.data(), element.length(), val, dataType); + function::CastStringToTypes::operation(element.data(), element.length(), val); memcpy(listVal.get() + numElementsRead * sizeof(int64_t), &val, sizeof(int64_t)); numElementsRead++; } break; case LogicalTypeID::INT32: { int32_t val; - function::simpleIntegerCast(element.data(), element.length(), val, dataType); + function::CastStringToTypes::operation(element.data(), element.length(), val); memcpy(listVal.get() + numElementsRead * sizeof(int32_t), &val, sizeof(int32_t)); numElementsRead++; } break; case LogicalTypeID::INT16: { int16_t val; - function::simpleIntegerCast(element.data(), element.length(), val, dataType); + function::CastStringToTypes::operation(element.data(), element.length(), val); memcpy(listVal.get() + numElementsRead * sizeof(int16_t), &val, sizeof(int16_t)); numElementsRead++; } break; - case LogicalTypeID::INT8: { - int8_t val; - function::simpleIntegerCast(element.data(), element.length(), val, dataType); - memcpy(listVal.get() + numElementsRead * sizeof(int8_t), &val, sizeof(int8_t)); - numElementsRead++; - } break; - case LogicalTypeID::UINT64: { - uint64_t val; - function::simpleIntegerCast( - element.data(), element.length(), val, dataType); - memcpy(listVal.get() + numElementsRead * sizeof(uint64_t), &val, sizeof(uint64_t)); - numElementsRead++; - } - case LogicalTypeID::UINT32: { - uint32_t val; - function::simpleIntegerCast( - element.data(), element.length(), val, dataType); - memcpy(listVal.get() + numElementsRead * sizeof(uint32_t), &val, sizeof(uint32_t)); - numElementsRead++; - } break; - case LogicalTypeID::UINT16: { - uint16_t val; - function::simpleIntegerCast( - element.data(), element.length(), val, dataType); - memcpy(listVal.get() + numElementsRead * sizeof(uint16_t), &val, sizeof(uint16_t)); - numElementsRead++; - } break; - case LogicalTypeID::UINT8: { - uint8_t val; - function::simpleIntegerCast( - element.data(), element.length(), val, dataType); - memcpy(listVal.get() + numElementsRead * sizeof(uint8_t), &val, sizeof(uint8_t)); - numElementsRead++; - } break; - case LogicalTypeID::INT128: { - common::int128_t val{}; - function::simpleInt128Cast(element.data(), element.length(), val); - memcpy(listVal.get() + numElementsRead * sizeof(int128_t), &val, sizeof(int128_t)); - numElementsRead++; - } break; + // TODO: other types not supported - only support int64, int32, int16, double, float case LogicalTypeID::DOUBLE: { double_t val; - function::doubleCast(element.data(), element.length(), val, dataType); + function::CastStringToTypes::operation(element.data(), element.length(), val); memcpy(listVal.get() + numElementsRead * sizeof(double_t), &val, sizeof(double_t)); numElementsRead++; } break; case LogicalTypeID::FLOAT: { float_t val; - function::doubleCast(element.data(), element.length(), val, dataType); + function::CastStringToTypes::operation(element.data(), element.length(), val); memcpy(listVal.get() + numElementsRead * sizeof(float_t), &val, sizeof(float_t)); numElementsRead++; } break; @@ -339,75 +300,81 @@ std::unique_ptr TableCopyUtils::convertStringToValue( switch (type.getLogicalTypeID()) { case LogicalTypeID::INT64: { int64_t val; - function::simpleIntegerCast(element.data(), element.length(), val, type); + function::CastStringToTypes::operation(element.data(), element.length(), val); value = std::make_unique(val); } break; case LogicalTypeID::INT32: { int32_t val; - function::simpleIntegerCast(element.data(), element.length(), val, type); + function::CastStringToTypes::operation(element.data(), element.length(), val); value = std::make_unique(val); } break; case LogicalTypeID::INT16: { int16_t val; - function::simpleIntegerCast(element.data(), element.length(), val, type); + function::CastStringToTypes::operation(element.data(), element.length(), val); value = std::make_unique(val); } break; case LogicalTypeID::INT8: { int8_t val; - function::simpleIntegerCast(element.data(), element.length(), val, type); + function::CastStringToTypes::operation(element.data(), element.length(), val); value = std::make_unique(val); } break; case LogicalTypeID::UINT64: { uint64_t val; - function::simpleIntegerCast(element.data(), element.length(), val, type); + function::CastStringToTypes::operation(element.data(), element.length(), val); value = std::make_unique(val); } break; case LogicalTypeID::UINT32: { uint32_t val; - function::simpleIntegerCast(element.data(), element.length(), val, type); + function::CastStringToTypes::operation(element.data(), element.length(), val); value = std::make_unique(val); } break; case LogicalTypeID::UINT16: { uint16_t val; - function::simpleIntegerCast(element.data(), element.length(), val, type); + function::CastStringToTypes::operation(element.data(), element.length(), val); value = std::make_unique(val); } break; case LogicalTypeID::UINT8: { uint8_t val; - function::simpleIntegerCast(element.data(), element.length(), val, type); + function::CastStringToTypes::operation(element.data(), element.length(), val); value = std::make_unique(val); } break; case LogicalTypeID::INT128: { int128_t val{}; - function::simpleInt128Cast(element.data(), element.length(), val); + function::CastStringToTypes::operation(element.data(), element.length(), val); value = std::make_unique(val); } break; case LogicalTypeID::FLOAT: { float_t val; - function::doubleCast(element.data(), element.length(), val, type); + function::CastStringToTypes::operation(element.data(), element.length(), val); value = std::make_unique(val); } break; case LogicalTypeID::DOUBLE: { double_t val; - function::doubleCast(element.data(), element.length(), val, type); + function::CastStringToTypes::operation(element.data(), element.length(), val); value = std::make_unique(val); } break; case LogicalTypeID::BOOL: { bool val; - function::castStringToBool(element.data(), element.length(), val); + function::CastStringToTypes::operation(element.data(), element.length(), val); value = std::make_unique(val); } break; case LogicalTypeID::STRING: { value = make_unique(LogicalType{LogicalTypeID::STRING}, std::string(element)); } break; case LogicalTypeID::DATE: { - value = std::make_unique(Date::fromCString(element.data(), element.length())); + date_t val; + function::CastStringToTypes::operation(element.data(), element.length(), val); + value = std::make_unique(val); } break; case LogicalTypeID::TIMESTAMP: { - value = std::make_unique(Timestamp::fromCString(element.data(), element.length())); + timestamp_t val; + function::CastStringToTypes::operation(element.data(), element.length(), val); + value = std::make_unique(val); } break; case LogicalTypeID::INTERVAL: { - value = std::make_unique(Interval::fromCString(element.data(), element.length())); + interval_t val; + function::CastStringToTypes::operation(element.data(), element.length(), val); + value = std::make_unique(val); } break; case LogicalTypeID::VAR_LIST: { value = getVarListValue(element, 1, element.length() - 2, type, csvReaderConfig); diff --git a/test/test_files/tinysnb/cast/cast_error.test b/test/test_files/tinysnb/cast/cast_error.test index 1fdf69f9fa0..95ec66cb91c 100644 --- a/test/test_files/tinysnb/cast/cast_error.test +++ b/test/test_files/tinysnb/cast/cast_error.test @@ -8,12 +8,12 @@ -LOG CastUint64ToInt64OutOfRange -STATEMENT MATCH (:person)-[e:studyAt]->(:organisation) return to_int64(e.code); ---- error -Runtime exception: Value 9223372036854775808 is not within INT64 range +Overflow exception: Value 9223372036854775808 is not within INT64 range -LOG CastUint64ToInt32OutOfRange -STATEMENT MATCH (:person)-[e:studyAt]->(:organisation) return to_int32(e.code); ---- error -Runtime exception: Value 9223372036854775808 is not within INT32 range +Overflow exception: Value 9223372036854775808 is not within INT32 range -LOG CastInt64ToInt32OutOfRange -STATEMENT return to_int32("2147483648"); @@ -23,12 +23,12 @@ Conversion exception: Cast failed. 2147483648 is not in INT32 range. -LOG CastUint32ToInt32OutOfRange -STATEMENT return to_int32(to_uint32(4294967295)); ---- error -Runtime exception: Value 4294967295 is not within INT32 range +Overflow exception: Value 4294967295 is not within INT32 range -LOG CastUint64ToInt16OutOfRange -STATEMENT MATCH (:person)-[e:studyAt]->(:organisation) return to_int16(e.code); ---- error -Runtime exception: Value 9223372036854775808 is not within INT16 range +Overflow exception: Value 9223372036854775808 is not within INT16 range -LOG CastInt64ToInt16OutOfRange -STATEMENT RETURN to_int16("32768"); @@ -38,102 +38,102 @@ Conversion exception: Cast failed. 32768 is not in INT16 range. -LOG CastUint32ToInt16OutOfRange -STATEMENT MATCH (:person)-[e:studyAt]->(:organisation) return to_int16(e.temprature); ---- error -Runtime exception: Value 32800 is not within INT16 range +Overflow exception: Value 32800 is not within INT16 range -LOG CastInt32ToInt16OutOfRange -STATEMENT RETURN to_int16(to_int32("-32770")); ---- error -Runtime exception: Value -32770 is not within INT16 range +Overflow exception: Value -32770 is not within INT16 range -LOG CastUint16ToInt16OutOfRange -STATEMENT MATCH (:person)-[e:studyAt]->(:organisation) return to_int16(e.ulength); ---- error -Runtime exception: Value 33768 is not within INT16 range +Overflow exception: Value 33768 is not within INT16 range -LOG CastUint64ToInt8OutOfRange -STATEMENT MATCH (:person)-[e:studyAt]->(:organisation) return to_int8(e.code); ---- error -Runtime exception: Value 9223372036854775808 is not within INT8 range +Overflow exception: Value 9223372036854775808 is not within INT8 range -LOG CastInt64ToInt8OutOfRange -STATEMENT return to_int8(-1000); ---- error -Runtime exception: Value -1000 is not within INT8 range +Overflow exception: Value -1000 is not within INT8 range -LOG CastUint32ToInt8OutOfRange -STATEMENT MATCH (:person)-[e:studyAt]->(:organisation) return to_int8(e.temprature); ---- error -Runtime exception: Value 32800 is not within INT8 range +Overflow exception: Value 32800 is not within INT8 range -LOG CastInt32ToInt8OutOfRange -STATEMENT return to_int8(to_int32(1000)); ---- error -Runtime exception: Value 1000 is not within INT8 range +Overflow exception: Value 1000 is not within INT8 range -LOG CastUint16ToInt8OutOfRange -STATEMENT MATCH (:person)-[e:studyAt]->(:organisation) return to_int8(e.ulength); ---- error -Runtime exception: Value 33768 is not within INT8 range +Overflow exception: Value 33768 is not within INT8 range -LOG CastInt16ToInt8OutOfRange -STATEMENT return to_int8(to_int16(520)); ---- error -Runtime exception: Value 520 is not within INT8 range +Overflow exception: Value 520 is not within INT8 range -LOG CastUint8ToInt8OutOfRange -STATEMENT MATCH (:person)-[e:studyAt]->(:organisation) return to_int8(e.ulevel); ---- error -Runtime exception: Value 250 is not within INT8 range +Overflow exception: Value 250 is not within INT8 range -LOG CastInt64ToUint64OutOfRange -STATEMENT return to_uint64(-500); ---- error -Runtime exception: Value -500 is not within UINT64 range +Overflow exception: Value -500 is not within UINT64 range -LOG CastInt32ToUint64OutOfRange -STATEMENT return to_uint64(to_int32(-1024)); ---- error -Runtime exception: Value -1024 is not within UINT64 range +Overflow exception: Value -1024 is not within UINT64 range -LOG CastInt16ToUint64OutOfRange -STATEMENT return to_uint64(to_int16(-1)); ---- error -Runtime exception: Value -1 is not within UINT64 range +Overflow exception: Value -1 is not within UINT64 range -LOG CastInt8ToUint64OutOfRange -STATEMENT return to_uint64(to_int8(-2)); ---- error -Runtime exception: Value -2 is not within UINT64 range +Overflow exception: Value -2 is not within UINT64 range -LOG CastUint64ToUint32OutOfRange -STATEMENT return to_uint32(to_uint64(922337203685477580)); ---- error -Runtime exception: Value 922337203685477580 is not within UINT32 range +Overflow exception: Value 922337203685477580 is not within UINT32 range -LOG CastInt64ToUint32OutOfRange -STATEMENT return to_uint32(9223372036854775807); ---- error -Runtime exception: Value 9223372036854775807 is not within UINT32 range +Overflow exception: Value 9223372036854775807 is not within UINT32 range -LOG CastInt32ToUint32OutOfRange -STATEMENT return to_uint32(to_int32("-10244")); ---- error -Runtime exception: Value -10244 is not within UINT32 range +Overflow exception: Value -10244 is not within UINT32 range -LOG CastInt16ToUint32OutOfRange -STATEMENT return to_uint32(to_int16(-100)); ---- error -Runtime exception: Value -100 is not within UINT32 range +Overflow exception: Value -100 is not within UINT32 range -LOG CastInt8ToUint32OutOfRange -STATEMENT return to_uint32(to_int8(-110)); ---- error -Runtime exception: Value -110 is not within UINT32 range +Overflow exception: Value -110 is not within UINT32 range -LOG CastUint64ToUint16OutOfRange -STATEMENT return to_uint16(to_uint64(922337203685477580)); ---- error -Runtime exception: Value 922337203685477580 is not within UINT16 range +Overflow exception: Value 922337203685477580 is not within UINT16 range -LOG CastInt64ToUint16OutOfRange -STATEMENT return to_uint16("922337203685"); @@ -143,67 +143,67 @@ Conversion exception: Cast failed. 922337203685 is not in UINT16 range. -LOG CastUint32ToUint16OutOfRange -STATEMENT return to_uint16(to_uint32(65536)); ---- error -Runtime exception: Value 65536 is not within UINT16 range +Overflow exception: Value 65536 is not within UINT16 range -LOG CastInt32ToUint16OutOfRange -STATEMENT return to_uint16(to_int32("-10244")); ---- error -Runtime exception: Value -10244 is not within UINT16 range +Overflow exception: Value -10244 is not within UINT16 range -LOG CastInt16ToUint16OutOfRange -STATEMENT return to_uint16(to_int16(-100)); ---- error -Runtime exception: Value -100 is not within UINT16 range +Overflow exception: Value -100 is not within UINT16 range -LOG CastInt8ToUint16OutOfRange -STATEMENT return to_uint16(to_int8(-110)); ---- error -Runtime exception: Value -110 is not within UINT16 range +Overflow exception: Value -110 is not within UINT16 range -LOG CastUint64ToUint8OutOfRange -STATEMENT return to_uint8(to_uint64(922337203685477580)); ---- error -Runtime exception: Value 922337203685477580 is not within UINT8 range +Overflow exception: Value 922337203685477580 is not within UINT8 range -LOG CastInt64ToUint8OutOfRange -STATEMENT return to_uint8(257); ---- error -Runtime exception: Value 257 is not within UINT8 range +Overflow exception: Value 257 is not within UINT8 range -LOG CastUint32ToUint8OutOfRange -STATEMENT return to_uint8(to_uint32(300)); ---- error -Runtime exception: Value 300 is not within UINT8 range +Overflow exception: Value 300 is not within UINT8 range -LOG CastInt32ToUint8OutOfRange -STATEMENT return to_uint8(to_int32("-10244")); ---- error -Runtime exception: Value -10244 is not within UINT8 range +Overflow exception: Value -10244 is not within UINT8 range -LOG CastUint16ToUint8OutOfRange -STATEMENT return to_uint8(to_uint16(312)); ---- error -Runtime exception: Value 312 is not within UINT8 range +Overflow exception: Value 312 is not within UINT8 range -LOG CastInt16ToUint8OutOfRange -STATEMENT return to_uint8(to_int16(-100)); ---- error -Runtime exception: Value -100 is not within UINT8 range +Overflow exception: Value -100 is not within UINT8 range -LOG CastInt8ToUint8OutOfRange -STATEMENT return to_uint8(to_int8(-3)); ---- error -Runtime exception: Value -3 is not within UINT8 range +Overflow exception: Value -3 is not within UINT8 range -LOG CastDoubleToInt64OutOfRange -STATEMENT return to_int64(9223372038854775807.452313); ---- error -Runtime exception: Value 9223372038854774784.000000 is not within INT64 range +Overflow exception: Value 9223372038854774784.000000 is not within INT64 range -LOG CastFloatToUint8OutOfRange -STATEMENT return to_uint8(-728.923); ---- error -Runtime exception: Value -728.923000 is not within UINT8 range +Overflow exception: Value -728.923000 is not within UINT8 range -STATEMENT RETURN TO_INT32("2147483648"); ---- error Conversion exception: Cast failed. 2147483648 is not in INT32 range. @@ -239,7 +239,7 @@ Conversion exception: Cast failed. 256 is not in UINT8 range. Conversion exception: Cast failed. -1 is not in UINT8 range. -STATEMENT RETURN TO_INT128(170141183460469231731687303715884105728); ---- error -Overflow exception: Cast failed. 170141183460469231731687303715884105728 is not within INT128 range. +Conversion exception: Cast failed. 170141183460469231731687303715884105728 is not within INT128 range. -STATEMENT RETURN TO_INT128(170141183460469231731687303715884105727) + TO_INT128(10); ---- error Overflow exception: INT128 is out of range: cannot add. @@ -268,7 +268,7 @@ Conversion exception: Value fal is not a valid boolean -CASE SerialOutOfRange -STATEMENT RETURN TO_SERIAL(TO_UINT64(9223372036854775807) * TO_UINT64(2)); ---- error -Runtime exception: Value 18446744073709551614 is not within INT64 range +Overflow exception: Value 18446744073709551614 is not within INT64 range -CASE NonAsciiStringToBlob -STATEMENT RETURN BLOB('😀') diff --git a/test/test_files/tinysnb/exception/exception.test b/test/test_files/tinysnb/exception/exception.test index 2458f3fe793..15dc76da58e 100644 --- a/test/test_files/tinysnb/exception/exception.test +++ b/test/test_files/tinysnb/exception/exception.test @@ -1,34 +1,35 @@ -GROUP TinySnbExceptionTest --DATASET CSV tinysnb +-DATASET CSV empty -- --CASE DivideBy0Error +-CASE EXCEPTION +-LOG DivideBy0Error -STATEMENT RETURN 1 / 0 ---- error Runtime exception: Divide by zero. --CASE ModuloBy0Error +-LOG ModuloBy0Error -STATEMENT RETURN 1 % 0 ---- error Runtime exception: Modulo by zero. --CASE EmptyQuery +-LOG EmptyQuery -STATEMENT ---- error Connection Exception: Query is empty. --CASE Overflow +-LOG Overflow -STATEMENT RETURN to_int16(10000000000) ---- error -Runtime exception: Value 10000000000 is not within INT16 range +Overflow exception: Value 10000000000 is not within INT16 range --CASE Int32PrimaryKey +-LOG Int32PrimaryKey -STATEMENT CREATE NODE TABLE play(a INT32, PRIMARY KEY (a)) ---- error Binder exception: Invalid primary key type: INT32. Expected STRING or INT64. --CASE UnalignedKeyAndValueList +-LOG UnalignedKeyAndValueList -STATEMENT RETURN MAP([4],[3,2]) ---- error Runtime exception: Unaligned key list and value list.