Skip to content

Commit

Permalink
replace cast strint to union function in driver.cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
AEsir777 committed Oct 17, 2023
1 parent fd38139 commit 1e9d5b1
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 91 deletions.
9 changes: 9 additions & 0 deletions dataset/load-from-test/union/union_correct.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"false"
" true "
" 34234 "
" -42342345 "
" T "
"null"
""
"0"
" F"
2 changes: 0 additions & 2 deletions src/include/storage/store/table_copy_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@ class TableCopyUtils {

static std::shared_ptr<arrow::DataType> toArrowDataType(const common::LogicalType& dataType);

static bool tryCast(const common::LogicalType& targetType, const char* value, uint64_t length);

static std::vector<StructFieldIdxAndValue> parseStructFieldNameAndValues(
common::LogicalType& type, std::string_view structString,
const common::CSVReaderConfig& csvReaderConfig);
Expand Down
140 changes: 116 additions & 24 deletions src/processor/operator/persistent/reader/csv/driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,121 @@ static void castStringToStruct(const char* input, uint64_t len, ValueVector* vec
}
}

static bool tryCastUnionField(std::shared_ptr<ValueVector> vector, uint64_t rowToAdd,
const char* input, uint64_t len, LogicalType& targetType) {
// auto& targetType = vector->dataType;
bool success = false;
switch (targetType.getLogicalTypeID()) {
case LogicalTypeID::BOOL: {
bool result;
success = function::tryCastToBool(input, len, result);
vector->setValue(rowToAdd, result);
} break;
case LogicalTypeID::INT64: {
int64_t result;
success = function::trySimpleIntegerCast(input, len, result);
vector->setValue(rowToAdd, result);
} break;
case LogicalTypeID::INT32: {
int32_t result;
success = function::trySimpleIntegerCast(input, len, result);
vector->setValue(rowToAdd, result);
} break;
case LogicalTypeID::INT16: {
int16_t result;
success = function::trySimpleIntegerCast(input, len, result);
vector->setValue(rowToAdd, result);
} break;
case LogicalTypeID::INT8: {
int8_t result;
success = function::trySimpleIntegerCast(input, len, result);
vector->setValue(rowToAdd, result);
} break;
case LogicalTypeID::UINT64: {
uint64_t result;
success = function::trySimpleIntegerCast<uint64_t, false>(input, len, result);
vector->setValue(rowToAdd, result);
} break;
case LogicalTypeID::UINT32: {
uint32_t result;
success = function::trySimpleIntegerCast<uint32_t, false>(input, len, result);
vector->setValue(rowToAdd, result);
} break;
case LogicalTypeID::UINT16: {
uint16_t result;
success = function::trySimpleIntegerCast<uint16_t, false>(input, len, result);
vector->setValue(rowToAdd, result);
} break;
case LogicalTypeID::UINT8: {
uint8_t result;
success = function::trySimpleIntegerCast<uint8_t, false>(input, len, result);
vector->setValue(rowToAdd, result);
} break;
case LogicalTypeID::DOUBLE: {
double_t result;
success = function::tryDoubleCast(input, len, result);
vector->setValue(rowToAdd, result);
} break;
case LogicalTypeID::FLOAT: {
float_t result;
success = function::tryDoubleCast(input, len, result);
vector->setValue(rowToAdd, result);
} break;
case LogicalTypeID::DATE: {
date_t result;
uint64_t pos;
success = Date::tryConvertDate(input, len, pos, result);
vector->setValue(rowToAdd, result);
} break;
case LogicalTypeID::TIMESTAMP: {
timestamp_t result;
success = Timestamp::tryConvertTimestamp(input, len, result);
vector->setValue(rowToAdd, result);
} break;
case LogicalTypeID::STRING: {
if (!utf8proc::Utf8Proc::isValid(input, len)) {
throw common::CopyException{"Invalid UTF8-encoded string."};
}
StringVector::addString(vector.get(), rowToAdd, input, len);
return true;
} break;
default: {
return false;
}
}
return success;
}

static void castStringToUnion(ValueVector* vector, std::string_view strVal, uint64_t rowToAdd) {
auto& type = vector->dataType;
bool null = false;
if (strVal.empty() || isNull(strVal)) {
null = true;
}
union_field_idx_t selectedFieldIdx = INVALID_STRUCT_FIELD_IDX;
for (auto i = 0u; i < UnionType::getNumFields(&type); i++) {
auto internalFieldIdx = UnionType::getInternalFieldIdx(i);
auto fieldVector = StructVector::getFieldVector(vector, internalFieldIdx);
auto fieldType = UnionType::getFieldType(&type, i);
if (!null &&
tryCastUnionField(fieldVector, rowToAdd, strVal.data(), strVal.length(), *fieldType)) {
fieldVector->setNull(rowToAdd, false /* isNull */);
selectedFieldIdx = i;
break;
} else {
fieldVector->setNull(rowToAdd, true /* isNull */);
}
}
if (selectedFieldIdx == INVALID_STRUCT_FIELD_IDX) {
throw ConversionException{stringFormat("Could not convert to union type {}: {}.",
LogicalTypeUtils::dataTypeToString(type), strVal)};
}
StructVector::getFieldVector(vector, UnionType::TAG_FIELD_IDX)
->setValue(rowToAdd, selectedFieldIdx);
StructVector::getFieldVector(vector, UnionType::TAG_FIELD_IDX)
->setNull(rowToAdd, false /* isNull */);
}

void copyStringToVector(ValueVector* vector, uint64_t rowToAdd, std::string_view strVal,
const CSVReaderConfig& csvReaderConfig) {
auto& type = vector->dataType;
Expand Down Expand Up @@ -569,30 +684,7 @@ void copyStringToVector(ValueVector* vector, uint64_t rowToAdd, std::string_view
castStringToStruct(strVal.data(), strVal.length(), vector, rowToAdd, csvReaderConfig);
} break;
case LogicalTypeID::UNION: {
union_field_idx_t selectedFieldIdx = INVALID_STRUCT_FIELD_IDX;
for (auto i = 0u; i < UnionType::getNumFields(&type); i++) {
auto internalFieldIdx = UnionType::getInternalFieldIdx(i);
if (storage::TableCopyUtils::tryCast(
*UnionType::getFieldType(&type, i), strVal.data(), strVal.length())) {
StructVector::getFieldVector(vector, internalFieldIdx)
->setNull(rowToAdd, false /* isNull */);
copyStringToVector(StructVector::getFieldVector(vector, internalFieldIdx).get(),
rowToAdd, strVal, csvReaderConfig);
selectedFieldIdx = i;
break;
} else {
StructVector::getFieldVector(vector, internalFieldIdx)
->setNull(rowToAdd, true /* isNull */);
}
}
if (selectedFieldIdx == INVALID_STRUCT_FIELD_IDX) {
throw ConversionException{stringFormat("Could not convert to union type {}: {}.",
LogicalTypeUtils::dataTypeToString(type), strVal)};
}
StructVector::getFieldVector(vector, UnionType::TAG_FIELD_IDX)
->setValue(rowToAdd, selectedFieldIdx);
StructVector::getFieldVector(vector, UnionType::TAG_FIELD_IDX)
->setNull(rowToAdd, false /* isNull */);
castStringToUnion(vector, strVal, rowToAdd);
} break;
default: { // LCOV_EXCL_START
throw NotImplementedException("BaseCSVReader::copyStringToVector");
Expand Down
65 changes: 0 additions & 65 deletions src/storage/store/table_copy_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,71 +306,6 @@ std::shared_ptr<arrow::DataType> TableCopyUtils::toArrowDataType(const LogicalTy
}
}

bool TableCopyUtils::tryCast(
const common::LogicalType& targetType, const char* value, uint64_t length) {
switch (targetType.getLogicalTypeID()) {
case LogicalTypeID::BOOL: {
bool result;
return function::tryCastToBool(value, length, result);
}
case LogicalTypeID::INT64: {
int64_t result;
return function::trySimpleIntegerCast(value, length, result);
}
case LogicalTypeID::INT32: {
int32_t result;
return function::trySimpleIntegerCast(value, length, result);
}
case LogicalTypeID::INT16: {
int16_t result;
return function::trySimpleIntegerCast(value, length, result);
}
case LogicalTypeID::INT8: {
int8_t result;
return function::trySimpleIntegerCast(value, length, result);
}
case LogicalTypeID::UINT64: {
uint64_t result;
return function::trySimpleIntegerCast<uint64_t, false>(value, length, result);
}
case LogicalTypeID::UINT32: {
uint32_t result;
return function::trySimpleIntegerCast<uint32_t, false>(value, length, result);
}
case LogicalTypeID::UINT16: {
uint16_t result;
return function::trySimpleIntegerCast<uint16_t, false>(value, length, result);
}
case LogicalTypeID::UINT8: {
uint8_t result;
return function::trySimpleIntegerCast<uint8_t, false>(value, length, result);
}
case LogicalTypeID::DOUBLE: {
double_t result;
return function::tryDoubleCast(value, length, result);
}
case LogicalTypeID::FLOAT: {
float_t result;
return function::tryDoubleCast(value, length, result);
}
case LogicalTypeID::DATE: {
date_t result;
uint64_t pos;
return Date::tryConvertDate(value, length, pos, result);
}
case LogicalTypeID::TIMESTAMP: {
timestamp_t result;
return Timestamp::tryConvertTimestamp(value, length, result);
}
case LogicalTypeID::STRING: {
return true;
}
default: {
return false;
}
}
}

std::vector<StructFieldIdxAndValue> TableCopyUtils::parseStructFieldNameAndValues(
LogicalType& type, std::string_view structString, const CSVReaderConfig& csvReaderConfig) {
std::vector<StructFieldIdxAndValue> structFieldIdxAndValueParis;
Expand Down
11 changes: 11 additions & 0 deletions test/test_files/tinysnb/cast/cast_string_to_nested_types.test
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,17 @@
-STATEMENT LOAD WITH HEADERS (fixedList DOUBLE[83]) FROM "${KUZU_ROOT_DIRECTORY}/dataset/load-from-test/fixed_list/long_fixed_list.csv" RETURN *;
---- 1
[4123.120000,0.000000,0.000000,42432.120000,435435.231000,432423.123340,43424.213400,4325245.130000,432423.123000,3242.123000,4123.120000,42432.120000,435435.231000,432423.123340,43424.213400,4325245.130000,432423.123000,3242.123000,4123.120000,42432.120000,435435.231000,432423.123340,43424.213400,4325245.130000,432423.123000,3242.123000,4123.120000,42432.120000,435435.231000,432423.123340,43424.213400,4325245.130000,432423.123000,3242.123000,4123.120000,42432.120000,435435.231000,432423.123340,43424.213400,4325245.130000,432423.123000,3242.123000,4123.120000,42432.120000,435435.231000,432423.123340,43424.213400,4325245.130000,432423.123000,3242.123000,4123.120000,42432.120000,435435.231000,432423.123340,43424.213400,4325245.130000,432423.123000,3242.123000,4123.120000,42432.120000,435435.231000,432423.123340,43424.213400,4325245.130000,432423.123000,3242.123000,4123.120000,42432.120000,435435.231000,432423.123340,43424.213400,4325245.130000,432423.123000,3242.123000,4123.120000,42432.120000,435435.231000,432423.123340,43424.213400,4325245.130000,432423.123000,3242.123000,543523442342.434326]
-STATEMENT LOAD WITH HEADERS (u UNION(v1 INT64, v2 BOOLEAN)) FROM "${KUZU_ROOT_DIRECTORY}/dataset/load-from-test/union/union_correct.csv" RETURN *;
---- 9
False
True
34234
-42342345
True


0
False

-CASE ErrorTest
-STATEMENT LOAD WITH HEADERS (list STRING[][]) FROM "${KUZU_ROOT_DIRECTORY}/dataset/load-from-test/delim_fail.csv" (DELIM="|", ESCAPE="~", QUOTE="'", LIST_BEGIN="(", LIST_END=")") RETURN * ;
Expand Down

0 comments on commit 1e9d5b1

Please sign in to comment.