diff --git a/src/function/cast/CMakeLists.txt b/src/function/cast/CMakeLists.txt index e3382bd7ad..2280180a29 100644 --- a/src/function/cast/CMakeLists.txt +++ b/src/function/cast/CMakeLists.txt @@ -1,6 +1,7 @@ add_library(kuzu_function_cast OBJECT - cast_rdf_variant.cpp) + cast_rdf_variant.cpp + cast_fixed_list.cpp) set(ALL_OBJECT_FILES ${ALL_OBJECT_FILES} $ diff --git a/src/function/cast/cast_fixed_list.cpp b/src/function/cast/cast_fixed_list.cpp new file mode 100644 index 0000000000..96fc29bb37 --- /dev/null +++ b/src/function/cast/cast_fixed_list.cpp @@ -0,0 +1,401 @@ +#include "function/cast/functions/cast_fixed_list.h" + +#include "common/exception/conversion.h" +#include "common/type_utils.h" +#include "function/cast/functions/cast_from_string_functions.h" +#include "function/cast/functions/cast_functions.h" + +namespace kuzu { +namespace function { + +bool CastFixedListHelper::containsListToFixedList( + const LogicalType* srcType, const LogicalType* dstType) { + if (srcType->getLogicalTypeID() == LogicalTypeID::VAR_LIST && + dstType->getLogicalTypeID() == LogicalTypeID::FIXED_LIST) { + return true; + } + + while (srcType->getLogicalTypeID() == dstType->getLogicalTypeID()) { + switch (srcType->getPhysicalType()) { + case PhysicalTypeID::VAR_LIST: { + return containsListToFixedList( + VarListType::getChildType(srcType), VarListType::getChildType(dstType)); + } + case PhysicalTypeID::STRUCT: { + auto srcFieldTypes = StructType::getFieldTypes(srcType); + auto dstFieldTypes = StructType::getFieldTypes(dstType); + if (srcFieldTypes.size() != dstFieldTypes.size()) { + throw ConversionException{ + stringFormat("Unsupported casting function from {} to {}.", srcType->toString(), + dstType->toString())}; + } + + auto result = false; + std::vector fields; + for (auto i = 0u; i < srcFieldTypes.size(); i++) { + if (containsListToFixedList(srcFieldTypes[i], dstFieldTypes[i])) { + return true; + } + } + } + default: + return false; + } + } + return false; +} + +void CastFixedListHelper::validateListEntry( + ValueVector* inputVector, LogicalType* resultType, uint64_t pos) { + if (inputVector->isNull(pos)) { + return; + } + auto inputTypeID = inputVector->dataType.getPhysicalType(); + + switch (resultType->getPhysicalType()) { + case PhysicalTypeID::FIXED_LIST: { + if (inputTypeID == PhysicalTypeID::VAR_LIST) { + auto listEntry = inputVector->getValue(pos); + if (listEntry.size != FixedListType::getNumValuesInList(resultType)) { + throw ConversionException{stringFormat( + "Unsupported casting VAR_LIST with incorrect list entry to FIXED_LIST. " + "Expected: {}, Actual: {}.", + FixedListType::getNumValuesInList(resultType), + inputVector->getValue(pos).size)}; + } + + auto inputChildVector = ListVector::getDataVector(inputVector); + for (auto i = listEntry.offset; i < listEntry.offset + listEntry.size; i++) { + if (inputChildVector->isNull(i)) { + throw ConversionException("Cast failed. NULL is not allowed for FIXED_LIST."); + } + } + } + } break; + case PhysicalTypeID::VAR_LIST: { + if (inputTypeID == PhysicalTypeID::VAR_LIST) { + auto listEntry = inputVector->getValue(pos); + for (auto i = listEntry.offset; i < listEntry.offset + listEntry.size; i++) { + validateListEntry(ListVector::getDataVector(inputVector), + VarListType::getChildType(resultType), i); + } + } + } break; + case PhysicalTypeID::STRUCT: { + if (inputTypeID == PhysicalTypeID::STRUCT) { + auto fieldVectors = StructVector::getFieldVectors(inputVector); + auto fieldTypes = StructType::getFieldTypes(resultType); + + auto structEntry = inputVector->getValue(pos); + for (auto i = 0u; i < fieldVectors.size(); i++) { + validateListEntry(fieldVectors[i].get(), fieldTypes[i], structEntry.pos); + } + } + } break; + default: { + return; + } + } +} + +static void CastFixedListToString( + ValueVector& param, uint64_t pos, ValueVector& resultVector, uint64_t resultPos) { + resultVector.setNull(resultPos, param.isNull(pos)); + if (param.isNull(pos)) { + return; + } + std::string result = "["; + auto numValuesPerList = FixedListType::getNumValuesInList(¶m.dataType); + auto childType = FixedListType::getChildType(¶m.dataType); + auto values = param.getData() + pos * param.getNumBytesPerValue(); + for (auto i = 0u; i < numValuesPerList - 1; ++i) { + // Note: FixedList can only store numeric types and doesn't allow nulls. + result += TypeUtils::castValueToString(*childType, values, nullptr /* vector */); + result += ","; + values += PhysicalTypeUtils::getFixedTypeSize(childType->getPhysicalType()); + } + result += TypeUtils::castValueToString(*childType, values, nullptr /* vector */); + result += "]"; + resultVector.setValue(resultPos, result); +} + +template<> +void CastFixedList::fixedListToStringCastExecFunction( + const std::vector>& params, ValueVector& result, + void* /*dataPtr*/) { + KU_ASSERT(params.size() == 1); + auto param = params[0]; + if (param->state->isFlat()) { + CastFixedListToString(*param, param->state->selVector->selectedPositions[0], result, + result.state->selVector->selectedPositions[0]); + } else if (param->state->selVector->isUnfiltered()) { + for (auto i = 0u; i < param->state->selVector->selectedSize; i++) { + CastFixedListToString(*param, i, result, i); + } + } else { + for (auto i = 0u; i < param->state->selVector->selectedSize; i++) { + CastFixedListToString(*param, param->state->selVector->selectedPositions[i], result, + result.state->selVector->selectedPositions[i]); + } + } +} + +// LCOV_EXCL_START +template<> +void CastFixedList::fixedListToStringCastExecFunction( + const std::vector>& /*params*/, ValueVector& /*result*/, + void* /*dataPtr*/) { + KU_UNREACHABLE; +} +// LCOV_EXCL_STOP + +template<> +void CastFixedList::fixedListToStringCastExecFunction( + const std::vector>& params, ValueVector& result, void* dataPtr) { + KU_ASSERT(params.size() == 1); + + auto inputVector = params[0].get(); + auto numOfEntries = reinterpret_cast(dataPtr)->numOfEntries; + for (auto i = 0u; i < numOfEntries; i++) { + CastFixedListToString(*inputVector, i, result, i); + } +} + +template<> +void CastFixedList::stringtoFixedListCastExecFunction( + const std::vector>& params, ValueVector& result, void* dataPtr) { + KU_ASSERT(params.size() == 1); + auto param = params[0]; + auto csvReaderConfig = &reinterpret_cast(dataPtr)->csvConfig; + if (param->state->isFlat()) { + auto inputPos = param->state->selVector->selectedPositions[0]; + auto resultPos = result.state->selVector->selectedPositions[0]; + result.setNull(resultPos, param->isNull(inputPos)); + if (!result.isNull(inputPos)) { + CastString::castToFixedList( + param->getValue(inputPos), &result, resultPos, csvReaderConfig); + } + } else if (param->state->selVector->isUnfiltered()) { + for (auto i = 0u; i < param->state->selVector->selectedSize; i++) { + result.setNull(i, param->isNull(i)); + if (!result.isNull(i)) { + CastString::castToFixedList( + param->getValue(i), &result, i, csvReaderConfig); + } + } + } else { + for (auto i = 0u; i < param->state->selVector->selectedSize; i++) { + auto pos = param->state->selVector->selectedPositions[i]; + result.setNull(pos, param->isNull(pos)); + if (!result.isNull(pos)) { + CastString::castToFixedList( + param->getValue(pos), &result, pos, csvReaderConfig); + } + } + } +} + +// LCOV_EXCL_START +template<> +void CastFixedList::stringtoFixedListCastExecFunction( + const std::vector>& /*params*/, ValueVector& /*result*/, + void* /*dataPtr*/) { + KU_UNREACHABLE; +} +// LCOV_EXCL_STOP + +template<> +void CastFixedList::stringtoFixedListCastExecFunction( + const std::vector>& params, ValueVector& result, void* dataPtr) { + KU_ASSERT(params.size() == 1); + auto numOfEntries = reinterpret_cast(dataPtr)->numOfEntries; + auto csvReaderConfig = &reinterpret_cast(dataPtr)->csvConfig; + + auto inputVector = params[0].get(); + for (auto i = 0u; i < numOfEntries; i++) { + result.setNull(i, inputVector->isNull(i)); + if (!result.isNull(i)) { + CastString::castToFixedList( + inputVector->getValue(i), &result, i, csvReaderConfig); + } + } +} + +template<> +void CastFixedList::listToFixedListCastExecFunction( + const std::vector>& params, ValueVector& result, void* dataPtr) { + KU_ASSERT(params.size() == 1); + auto inputVector = params[0]; + + for (auto i = 0u; i < inputVector->state->selVector->selectedSize; i++) { + auto pos = inputVector->state->selVector->selectedPositions[i]; + CastFixedListHelper::validateListEntry(inputVector.get(), &result.dataType, pos); + } + + auto numOfEntries = inputVector->state->selVector + ->selectedPositions[inputVector->state->selVector->selectedSize - 1] + + 1; + reinterpret_cast(dataPtr)->numOfEntries = numOfEntries; + listToFixedListCastExecFunction(params, result, dataPtr); +} + +// LCOV_EXCL_START +template<> +void CastFixedList::listToFixedListCastExecFunction( + const std::vector>& /*params*/, ValueVector& /*result*/, + void* /*dataPtr*/) { + KU_UNREACHABLE; +} +// LCOV_EXCL_STOP + +using scalar_cast_func = std::function; + +template +static void getFixedListChildFuncHelper(scalar_cast_func& func, LogicalTypeID inputTypeID) { + switch (inputTypeID) { + case LogicalTypeID::STRING: { + func = UnaryCastStringFunctionWrapper::operation; + } break; + case LogicalTypeID::INT128: { + func = UnaryFunctionWrapper::operation; + } break; + case LogicalTypeID::INT64: { + func = UnaryFunctionWrapper::operation; + } break; + case LogicalTypeID::INT32: { + func = UnaryFunctionWrapper::operation; + } break; + case LogicalTypeID::INT16: { + func = UnaryFunctionWrapper::operation; + } break; + case LogicalTypeID::INT8: { + func = UnaryFunctionWrapper::operation; + } break; + case LogicalTypeID::UINT8: { + func = UnaryFunctionWrapper::operation; + } break; + case LogicalTypeID::UINT16: { + func = UnaryFunctionWrapper::operation; + } break; + case LogicalTypeID::UINT32: { + func = UnaryFunctionWrapper::operation; + } break; + case LogicalTypeID::UINT64: { + func = UnaryFunctionWrapper::operation; + } break; + case LogicalTypeID::FLOAT: { + func = UnaryFunctionWrapper::operation; + } break; + case LogicalTypeID::DOUBLE: { + func = UnaryFunctionWrapper::operation; + } break; + default: { + throw ConversionException{ + stringFormat("Unsupported casting function from {} to numerical type.", + LogicalTypeUtils::toString(inputTypeID))}; + } + } +} + +static void getFixedListChildCastFunc( + scalar_cast_func& func, LogicalTypeID inputType, LogicalTypeID resultType) { + // only support limited Fixed List Types + switch (resultType) { + case LogicalTypeID::INT64: { + return getFixedListChildFuncHelper(func, inputType); + } + case LogicalTypeID::INT32: { + return getFixedListChildFuncHelper(func, inputType); + } + case LogicalTypeID::INT16: { + return getFixedListChildFuncHelper(func, inputType); + } + case LogicalTypeID::DOUBLE: { + return getFixedListChildFuncHelper(func, inputType); + } + case LogicalTypeID::FLOAT: { + return getFixedListChildFuncHelper(func, inputType); + } + default: { + throw RuntimeException("Unsupported FIXED_LIST type: Function::getFixedListChildCastFunc"); + } + } +} + +template<> +void CastFixedList::listToFixedListCastExecFunction( + const std::vector>& params, ValueVector& result, void* dataPtr) { + auto inputVector = params[0]; + auto numOfEntries = reinterpret_cast(dataPtr)->numOfEntries; + + auto inputChildId = VarListType::getChildType(&inputVector->dataType)->getLogicalTypeID(); + auto outputChildId = FixedListType::getChildType(&result.dataType)->getLogicalTypeID(); + auto numValuesPerList = FixedListType::getNumValuesInList(&result.dataType); + scalar_cast_func func; + getFixedListChildCastFunc(func, inputChildId, outputChildId); + + result.setNullFromBits(inputVector->getNullMaskData(), 0, 0, numOfEntries); + auto inputChildVector = ListVector::getDataVector(inputVector.get()); + for (auto i = 0u; i < numOfEntries; i++) { + if (!result.isNull(i)) { + auto listEntry = inputVector->getValue(i); + if (listEntry.size == numValuesPerList) { + for (auto j = 0u; j < listEntry.size; j++) { + func((void*)(inputChildVector), listEntry.offset + j, (void*)(&result), + i * numValuesPerList + j, nullptr); + } + } + } + } +} + +template<> +void CastFixedList::castBetweenFixedListExecFunc( + const std::vector>& params, ValueVector& result, void* dataPtr) { + auto inputVector = params[0]; + auto numOfEntries = inputVector->state->selVector + ->selectedPositions[inputVector->state->selVector->selectedSize - 1] + + 1; + reinterpret_cast(dataPtr)->numOfEntries = numOfEntries; + castBetweenFixedListExecFunc(params, result, dataPtr); +} + +// LCOV_EXCL_START +template<> +void CastFixedList::castBetweenFixedListExecFunc( + const std::vector>& /*params*/, ValueVector& /*result*/, + void* /*dataPtr*/) { + KU_UNREACHABLE; +} +// LCOV_EXCL_STOP + +template<> +void CastFixedList::castBetweenFixedListExecFunc( + const std::vector>& params, ValueVector& result, void* dataPtr) { + auto inputVector = params[0]; + auto numOfEntries = reinterpret_cast(dataPtr)->numOfEntries; + + auto inputChildId = FixedListType::getChildType(&inputVector->dataType)->getLogicalTypeID(); + auto outputChildId = FixedListType::getChildType(&result.dataType)->getLogicalTypeID(); + auto numValuesPerList = FixedListType::getNumValuesInList(&result.dataType); + if (FixedListType::getNumValuesInList(&inputVector->dataType) != numValuesPerList) { + throw ConversionException(stringFormat("Unsupported casting function from {} to {}.", + inputVector->dataType.toString(), result.dataType.toString())); + } + + scalar_cast_func func; + getFixedListChildCastFunc(func, inputChildId, outputChildId); + + result.setNullFromBits(inputVector->getNullMaskData(), 0, 0, numOfEntries); + for (auto i = 0u; i < numOfEntries; i++) { + if (!result.isNull(i)) { + for (auto j = 0u; j < numValuesPerList; j++) { + func((void*)(inputVector.get()), i * numValuesPerList + j, (void*)(&result), + i * numValuesPerList + j, nullptr); + } + } + } +} + +} // namespace function +} // namespace kuzu diff --git a/src/function/vector_cast_functions.cpp b/src/function/vector_cast_functions.cpp index d2540e94bc..28f324ac42 100644 --- a/src/function/vector_cast_functions.cpp +++ b/src/function/vector_cast_functions.cpp @@ -4,7 +4,7 @@ #include "binder/expression/literal_expression.h" #include "common/exception/binder.h" #include "common/exception/conversion.h" -#include "common/exception/runtime.h" +#include "function/cast/functions/cast_fixed_list.h" #include "function/cast/functions/cast_from_string_functions.h" #include "function/cast/functions/cast_functions.h" #include "function/cast/functions/cast_rdf_variant.h" @@ -14,111 +14,6 @@ using namespace kuzu::common; namespace kuzu { namespace function { -static void castFixedListToString( - ValueVector& param, uint64_t pos, ValueVector& resultVector, uint64_t resultPos) { - resultVector.setNull(resultPos, param.isNull(pos)); - if (param.isNull(pos)) { - return; - } - std::string result = "["; - auto numValuesPerList = FixedListType::getNumValuesInList(¶m.dataType); - auto childType = FixedListType::getChildType(¶m.dataType); - auto values = param.getData() + pos * param.getNumBytesPerValue(); - for (auto i = 0u; i < numValuesPerList - 1; ++i) { - // Note: FixedList can only store numeric types and doesn't allow nulls. - result += TypeUtils::castValueToString(*childType, values, nullptr /* vector */); - result += ","; - values += PhysicalTypeUtils::getFixedTypeSize(childType->getPhysicalType()); - } - result += TypeUtils::castValueToString(*childType, values, nullptr /* vector */); - result += "]"; - resultVector.setValue(resultPos, result); -} - -template -static void fixedListToStringCastExecFunction( - const std::vector>& params, ValueVector& result, - void* /*dataPtr*/) { - KU_ASSERT(params.size() == 1); - auto param = params[0]; - if (param->state->isFlat()) { - castFixedListToString(*param, param->state->selVector->selectedPositions[0], result, - result.state->selVector->selectedPositions[0]); - } else if (param->state->selVector->isUnfiltered()) { - for (auto i = 0u; i < param->state->selVector->selectedSize; i++) { - castFixedListToString(*param, i, result, i); - } - } else { - for (auto i = 0u; i < param->state->selVector->selectedSize; i++) { - castFixedListToString(*param, param->state->selVector->selectedPositions[i], result, - result.state->selVector->selectedPositions[i]); - } - } -} - -template<> -void fixedListToStringCastExecFunction( - const std::vector>& params, ValueVector& result, void* dataPtr) { - KU_ASSERT(params.size() == 1); - - auto inputVector = params[0].get(); - auto numOfEntries = reinterpret_cast(dataPtr)->numOfEntries; - for (auto i = 0u; i < numOfEntries; i++) { - castFixedListToString(*inputVector, i, result, i); - } -} - -template -static void stringtoFixedListCastExecFunction( - const std::vector>& params, ValueVector& result, void* dataPtr) { - KU_ASSERT(params.size() == 1); - auto param = params[0]; - auto csvReaderConfig = &reinterpret_cast(dataPtr)->csvConfig; - if (param->state->isFlat()) { - auto inputPos = param->state->selVector->selectedPositions[0]; - auto resultPos = result.state->selVector->selectedPositions[0]; - result.setNull(resultPos, param->isNull(inputPos)); - if (!result.isNull(inputPos)) { - CastString::castToFixedList( - param->getValue(inputPos), &result, resultPos, csvReaderConfig); - } - } else if (param->state->selVector->isUnfiltered()) { - for (auto i = 0u; i < param->state->selVector->selectedSize; i++) { - result.setNull(i, param->isNull(i)); - if (!result.isNull(i)) { - CastString::castToFixedList( - param->getValue(i), &result, i, csvReaderConfig); - } - } - } else { - for (auto i = 0u; i < param->state->selVector->selectedSize; i++) { - auto pos = param->state->selVector->selectedPositions[i]; - result.setNull(pos, param->isNull(pos)); - if (!result.isNull(pos)) { - CastString::castToFixedList( - param->getValue(pos), &result, pos, csvReaderConfig); - } - } - } -} - -template<> -void stringtoFixedListCastExecFunction( - const std::vector>& params, ValueVector& result, void* dataPtr) { - KU_ASSERT(params.size() == 1); - auto numOfEntries = reinterpret_cast(dataPtr)->numOfEntries; - auto csvReaderConfig = &reinterpret_cast(dataPtr)->csvConfig; - - auto inputVector = params[0].get(); - for (auto i = 0u; i < numOfEntries; i++) { - result.setNull(i, inputVector->isNull(i)); - if (!result.isNull(i)) { - CastString::castToFixedList( - inputVector->getValue(i), &result, i, csvReaderConfig); - } - } -} - template static void fixedListToListCastExecFunction( const std::vector>& params, ValueVector& result, void* dataPtr) { @@ -175,206 +70,6 @@ void fixedListToListCastExecFunction( func(params, *resultVector, dataPtr); } -static bool containsListToFixedList(const LogicalType* srcType, const LogicalType* dstType) { - if (srcType->getLogicalTypeID() == LogicalTypeID::VAR_LIST && - dstType->getLogicalTypeID() == LogicalTypeID::FIXED_LIST) { - return true; - } - - while (srcType->getLogicalTypeID() == dstType->getLogicalTypeID()) { - switch (srcType->getPhysicalType()) { - case PhysicalTypeID::VAR_LIST: { - return containsListToFixedList( - VarListType::getChildType(srcType), VarListType::getChildType(dstType)); - } - case PhysicalTypeID::STRUCT: { - auto srcFieldTypes = StructType::getFieldTypes(srcType); - auto dstFieldTypes = StructType::getFieldTypes(dstType); - if (srcFieldTypes.size() != dstFieldTypes.size()) { - throw ConversionException{ - stringFormat("Unsupported casting function from {} to {}.", srcType->toString(), - dstType->toString())}; - } - - auto result = false; - std::vector fields; - for (auto i = 0u; i < srcFieldTypes.size(); i++) { - if (containsListToFixedList(srcFieldTypes[i], dstFieldTypes[i])) { - return true; - } - } - } - default: - return false; - } - } - return false; -} - -static void validateListEntry(ValueVector* inputVector, LogicalType* resultType, uint64_t pos) { - if (inputVector->isNull(pos)) { - return; - } - - switch (resultType->getPhysicalType()) { - case PhysicalTypeID::FIXED_LIST: { - auto listEntry = inputVector->getValue(pos); - if (listEntry.size != FixedListType::getNumValuesInList(resultType)) { - throw ConversionException{stringFormat( - "Unsupported casting VAR_LIST with incorrect list entry to FIXED_LIST. " - "Expected: {}, Actual: {}.", - FixedListType::getNumValuesInList(resultType), - inputVector->getValue(pos).size)}; - } - - auto inputChildVector = ListVector::getDataVector(inputVector); - for (auto i = listEntry.offset; i < listEntry.offset + listEntry.size; i++) { - if (inputChildVector->isNull(i)) { - throw ConversionException("Cast failed. NULL is not allowed for FIXED_LIST."); - } - } - } break; - case PhysicalTypeID::VAR_LIST: { - auto listEntry = inputVector->getValue(pos); - for (auto i = listEntry.offset; i < listEntry.offset + listEntry.size; i++) { - validateListEntry( - ListVector::getDataVector(inputVector), VarListType::getChildType(resultType), i); - } - } break; - case PhysicalTypeID::STRUCT: { - auto fieldVectors = StructVector::getFieldVectors(inputVector); - auto fieldTypes = StructType::getFieldTypes(resultType); - - auto structEntry = inputVector->getValue(pos); - for (auto i = 0u; i < fieldVectors.size(); i++) { - validateListEntry(fieldVectors[i].get(), fieldTypes[i], structEntry.pos); - } - } break; - default: { - return; - } - } -} - -template -static void listToFixedListCastExecFunction( - const std::vector>& params, ValueVector& result, void* dataPtr) { - KU_ASSERT(params.size() == 1); - auto inputVector = params[0]; - - for (auto i = 0u; i < inputVector->state->selVector->selectedSize; i++) { - auto pos = inputVector->state->selVector->selectedPositions[i]; - validateListEntry(inputVector.get(), &result.dataType, pos); - } - - auto numOfEntries = inputVector->state->selVector - ->selectedPositions[inputVector->state->selVector->selectedSize - 1] + - 1; - reinterpret_cast(dataPtr)->numOfEntries = numOfEntries; - listToFixedListCastExecFunction(params, result, dataPtr); -} - -using scalar_cast_func = std::function; - -template -static void getFixedListChildFuncHelper(scalar_cast_func& func, LogicalTypeID inputTypeID) { - switch (inputTypeID) { - case LogicalTypeID::STRING: { - func = UnaryCastStringFunctionWrapper::operation; - } break; - case LogicalTypeID::INT128: { - func = UnaryFunctionWrapper::operation; - } break; - case LogicalTypeID::INT64: { - func = UnaryFunctionWrapper::operation; - } break; - case LogicalTypeID::INT32: { - func = UnaryFunctionWrapper::operation; - } break; - case LogicalTypeID::INT16: { - func = UnaryFunctionWrapper::operation; - } break; - case LogicalTypeID::INT8: { - func = UnaryFunctionWrapper::operation; - } break; - case LogicalTypeID::UINT8: { - func = UnaryFunctionWrapper::operation; - } break; - case LogicalTypeID::UINT16: { - func = UnaryFunctionWrapper::operation; - } break; - case LogicalTypeID::UINT32: { - func = UnaryFunctionWrapper::operation; - } break; - case LogicalTypeID::UINT64: { - func = UnaryFunctionWrapper::operation; - } break; - case LogicalTypeID::FLOAT: { - func = UnaryFunctionWrapper::operation; - } break; - case LogicalTypeID::DOUBLE: { - func = UnaryFunctionWrapper::operation; - } break; - default: { - throw ConversionException{ - stringFormat("Unsupported casting function from {} to numerical type.", - LogicalTypeUtils::toString(inputTypeID))}; - } - } -} - -static void getFixedListChildCastFunc( - scalar_cast_func& func, LogicalTypeID inputType, LogicalTypeID resultType) { - // only support limited Fixed List Types - switch (resultType) { - case LogicalTypeID::INT64: { - return getFixedListChildFuncHelper(func, inputType); - } - case LogicalTypeID::INT32: { - return getFixedListChildFuncHelper(func, inputType); - } - case LogicalTypeID::INT16: { - return getFixedListChildFuncHelper(func, inputType); - } - case LogicalTypeID::DOUBLE: { - return getFixedListChildFuncHelper(func, inputType); - } - case LogicalTypeID::FLOAT: { - return getFixedListChildFuncHelper(func, inputType); - } - default: { - throw RuntimeException("Unsupported FIXED_LIST type: Function::getFixedListChildCastFunc"); - } - } -} - -template<> -void listToFixedListCastExecFunction( - const std::vector>& params, ValueVector& result, void* dataPtr) { - auto inputVector = params[0]; - auto numOfEntries = reinterpret_cast(dataPtr)->numOfEntries; - - auto inputChildId = VarListType::getChildType(&inputVector->dataType)->getLogicalTypeID(); - auto outputChildID = FixedListType::getChildType(&result.dataType)->getLogicalTypeID(); - auto numValuesPerList = FixedListType::getNumValuesInList(&result.dataType); - scalar_cast_func func; - getFixedListChildCastFunc(func, inputChildId, outputChildID); - - result.setNullFromBits(inputVector->getNullMaskData(), 0, 0, numOfEntries); - auto inputChildVector = ListVector::getDataVector(inputVector.get()); - for (auto i = 0u; i < numOfEntries; i++) { - if (!result.isNull(i)) { - auto listEntry = inputVector->getValue(i); - if (listEntry.size == numValuesPerList) { - for (auto j = 0u; j < listEntry.size; j++) { - func((void*)(inputChildVector), listEntry.offset + j, (void*)(&result), - i * numValuesPerList + j, nullptr); - } - } - } - } -} - static void resolveNestedVector(std::shared_ptr inputVector, ValueVector* resultVector, uint64_t numOfEntries, CastFunctionBindData* dataPtr) { auto inputType = &inputVector->dataType; @@ -443,10 +138,10 @@ static void nestedTypesCastExecFunction( auto inputVector = params[0]; // check if all selcted list entry have the requried fixed list size - if (containsListToFixedList(&inputVector->dataType, &result.dataType)) { + if (CastFixedListHelper::containsListToFixedList(&inputVector->dataType, &result.dataType)) { for (auto i = 0u; i < inputVector->state->selVector->selectedSize; i++) { auto pos = inputVector->state->selVector->selectedPositions[i]; - validateListEntry(inputVector.get(), &result.dataType, pos); + CastFixedListHelper::validateListEntry(inputVector.get(), &result.dataType, pos); } }; @@ -569,7 +264,7 @@ static std::unique_ptr bindCastFromStringFunction( CastString, EXECUTOR>; } break; case LogicalTypeID::FIXED_LIST: { - execFunc = stringtoFixedListCastExecFunction; + execFunc = CastFixedList::stringtoFixedListCastExecFunction; } break; case LogicalTypeID::MAP: { execFunc = ScalarFunction::UnaryCastStringExecFunction bindCastToStringFunction( EXECUTOR>; } break; case LogicalTypeID::FIXED_LIST: { - func = fixedListToStringCastExecFunction; + func = CastFixedList::fixedListToStringCastExecFunction; } break; case LogicalTypeID::MAP: { func = @@ -816,7 +511,7 @@ static std::unique_ptr bindCastBetweenNested( if (targetTypeID == LogicalTypeID::FIXED_LIST) { return std::make_unique(functionName, std::vector{sourceTypeID}, targetTypeID, - listToFixedListCastExecFunction); + CastFixedList::listToFixedListCastExecFunction); } } case LogicalTypeID::MAP: @@ -832,14 +527,15 @@ static std::unique_ptr bindCastBetweenNested( return std::make_unique(functionName, std::vector{sourceTypeID}, targetTypeID, fixedListToListCastExecFunction); + } else if (sourceTypeID == targetTypeID) { + return std::make_unique(functionName, + std::vector{sourceTypeID}, targetTypeID, + CastFixedList::castBetweenFixedListExecFunc); } } default: - // lcov_excl_start - // TODO(kebing): implement more throw ConversionException{stringFormat("Unsupported casting function from {} to {}.", LogicalTypeUtils::toString(sourceTypeID), LogicalTypeUtils::toString(targetTypeID))}; - // lcov_excl_end } } diff --git a/src/include/function/cast/functions/cast_fixed_list.h b/src/include/function/cast/functions/cast_fixed_list.h new file mode 100644 index 0000000000..b313a9209b --- /dev/null +++ b/src/include/function/cast/functions/cast_fixed_list.h @@ -0,0 +1,79 @@ +#pragma once + +#include "function/unary_function_executor.h" + +using namespace kuzu::common; + +namespace kuzu { +namespace function { + +struct CastFixedListHelper { + static bool containsListToFixedList(const LogicalType* srcType, const LogicalType* dstType); + + static void validateListEntry(ValueVector* inputVector, LogicalType* resultType, uint64_t pos); +}; + +struct CastFixedList { + template + static void fixedListToStringCastExecFunction( + const std::vector>& params, ValueVector& result, + void* dataPtr); + + template + static void stringtoFixedListCastExecFunction( + const std::vector>& params, ValueVector& result, + void* dataPtr); + + template + static void listToFixedListCastExecFunction( + const std::vector>& params, ValueVector& result, + void* dataPtr); + + template + static void castBetweenFixedListExecFunc( + const std::vector>& params, ValueVector& result, + void* dataPtr); +}; + +template<> +void CastFixedList::fixedListToStringCastExecFunction( + const std::vector>& params, ValueVector& result, void* dataPtr); +template<> +void CastFixedList::fixedListToStringCastExecFunction( + const std::vector>& params, ValueVector& result, void* dataPtr); +template<> +void CastFixedList::fixedListToStringCastExecFunction( + const std::vector>& params, ValueVector& result, void* dataPtr); + +template<> +void CastFixedList::stringtoFixedListCastExecFunction( + const std::vector>& params, ValueVector& result, void* dataPtr); +template<> +void CastFixedList::stringtoFixedListCastExecFunction( + const std::vector>& params, ValueVector& result, void* dataPtr); +template<> +void CastFixedList::stringtoFixedListCastExecFunction( + const std::vector>& params, ValueVector& result, void* dataPtr); + +template<> +void CastFixedList::listToFixedListCastExecFunction( + const std::vector>& params, ValueVector& result, void* dataPtr); +template<> +void CastFixedList::listToFixedListCastExecFunction( + const std::vector>& params, ValueVector& result, void* dataPtr); +template<> +void CastFixedList::listToFixedListCastExecFunction( + const std::vector>& params, ValueVector& result, void* dataPtr); + +template<> +void CastFixedList::castBetweenFixedListExecFunc( + const std::vector>& params, ValueVector& result, void* dataPtr); +template<> +void CastFixedList::castBetweenFixedListExecFunc( + const std::vector>& params, ValueVector& result, void* dataPtr); +template<> +void CastFixedList::castBetweenFixedListExecFunc( + const std::vector>& params, ValueVector& result, void* dataPtr); + +} // namespace function +} // namespace kuzu diff --git a/test/test_files/tinysnb/cast/cast_error.test b/test/test_files/tinysnb/cast/cast_error.test index aadf885ddc..4883a06267 100644 --- a/test/test_files/tinysnb/cast/cast_error.test +++ b/test/test_files/tinysnb/cast/cast_error.test @@ -747,6 +747,9 @@ Conversion exception: Unsupported casting function from STRUCT(a:STRING, b:STRIN -STATEMENT RETURN cast(cast("{a: 12, b: 0}", "STRUCT(a STRING, b STRING)"), "STRUCT(a STRING, c STRING)"); ---- error Conversion exception: Unsupported casting function from STRUCT(a:STRING, b:STRING) to STRUCT(a:STRING, c:STRING). +-STATEMENT RETURN cast(cast("{a: 12, b: 0}", "STRUCT(a STRING, b STRING)"), "MAP(STRING, STRING)"); +---- error +Conversion exception: Unsupported casting function from STRUCT to MAP. -LOG InvalidFixedListToList -STATEMENT RETURN cast(cast("[1, -1]", "INT64[2]"), "UINT8[]"); @@ -787,3 +790,15 @@ Conversion exception: Cast failed. NULL is not allowed for FIXED_LIST. -STATEMENT RETURN cast(cast("{3={}, 3= {a: 12}, 3={a:32, b:[1, 2, 3]}}", "MAP(STRING, STRUCT(a INT64, b INT64[]))"), "MAP(STRING, STRUCT(a INT64, b INT64[1]))"); ---- error Conversion exception: Unsupported casting VAR_LIST with incorrect list entry to FIXED_LIST. Expected: 1, Actual: 3. + +-LOG InvalidFixedListToList +-STATEMENT RETURN cast(cast([4, 1], "INT16[2]"), "FLOAT[3]"); +---- error +Conversion exception: Unsupported casting function from INT16[2] to FLOAT[3]. +-STATEMENT RETURN cast(cast([4, 1], "INT16[2]"), "FLOAT[1]"); +---- error +Conversion exception: Unsupported casting function from INT16[2] to FLOAT[1]. +-STATEMENT RETURN cast(cast([4, 1], "INT16[2]"), "UINT8[2]"); +---- error +Runtime exception: Unsupported FIXED_LIST type: Function::getFixedListChildCastFunc + diff --git a/test/test_files/tinysnb/cast/cast_to_nested_types.test b/test/test_files/tinysnb/cast/cast_to_nested_types.test index 2b913f2972..e7c0358ff9 100644 --- a/test/test_files/tinysnb/cast/cast_to_nested_types.test +++ b/test/test_files/tinysnb/cast/cast_to_nested_types.test @@ -330,3 +330,17 @@ False|-4325|14|18446744073709551616.000000| dfsa {a: 1999, b: {[]=true}} {a: 2341, b: {[,[,[1],[2],[3],,[4]],[[8],,[7]],,[,[5]],[[6],],]=true}} +-LOG CastFixedListToFixedList +-STATEMENT Return cast(Cast([1,2,3], "INT32[3]"), "INT64[3]"), cast(Cast([1,2,3], "INT32[3]"), "INT32[3]"), cast(Cast([1,2,3], "INT32[3]"), "INT16[3]"), cast(Cast([1,2,3], "INT32[3]"), "DOUBLE[3]"), cast(Cast([1,2,3], "INT32[3]"), "FLOAT[3]"); +---- 1 +[1,2,3]|[1,2,3]|[1,2,3]|[1.000000,2.000000,3.000000]|[1.000000,2.000000,3.000000] +-STATEMENT Return cast(cast(cast(cast(Cast([0,-1,1], "INT32[3]"), "INT64[3]"), "INT16[3]"), "DOUBLE[3]"), "FLOAT[3]"); +---- 1 +[0.000000,-1.000000,1.000000] +-STATEMENT LOAD WITH HEADERS (struct STRUCT(a INT64, b MAP(INT64[][][], STRING))) FROM "${KUZU_ROOT_DIRECTORY}/dataset/load-from-test/struct/struct_with_fixed_list.csv" where struct_extract(struct, 'a') > 0 RETURN cast(cast(struct, "STRUCT(a INT64,b MAP(DOUBLE[1][][], STRING))"), "STRUCT(a INT64,b MAP(INT16[1][][], STRING))"); +---- 4 +{a: 1999, b: {[[[1],[2],[3]]]=true}} +{a: 1999, b: {[[]]=true}} +{a: 1999, b: {[]=true}} +{a: 2341, b: {[,[,[1],[2],[3],,[4]],[[8],,[7]],,[,[5]],[[6],],]=true}} + diff --git a/test/test_files/tinysnb/function/cast.test b/test/test_files/tinysnb/function/cast.test index 5bd2513410..042fe38ad9 100644 --- a/test/test_files/tinysnb/function/cast.test +++ b/test/test_files/tinysnb/function/cast.test @@ -450,7 +450,7 @@ Hubert Blaine Wolfeschlegelsteinhausenbergerdorff [43,83,67,43] [77,64,100,54] --LOG CastFixedListToFixedList +-LOG CastFixedListToList -STATEMENT MATCH(p:person) where p.ID > 4 RETURN cast(p.grades, "UINT64[]"), cast(p.grades, "UINT32[]"), cast(p.grades, "UINT16[]"), cast(p.grades, "UINT8[]"), cast(p.grades, "INT8[]"), cast(p.grades, "INT16[]"), cast(p.grades, "INT32[]"), cast(p.grades, "INT64[]"), cast(p.grades, "INT128[]"), cast(p.grades, "STRING[]"), cast(p.grades, "DOUBLE[]"), cast(p.grades, "FLOAT[]"); ---- 6 [43,83,67,43]|[43,83,67,43]|[43,83,67,43]|[43,83,67,43]|[43,83,67,43]|[43,83,67,43]|[43,83,67,43]|[43,83,67,43]|[43,83,67,43]|[43,83,67,43]|[43.000000,83.000000,67.000000,43.000000]|[43.000000,83.000000,67.000000,43.000000] @@ -460,6 +460,13 @@ Hubert Blaine Wolfeschlegelsteinhausenbergerdorff [96,59,65,88]|[96,59,65,88]|[96,59,65,88]|[96,59,65,88]|[96,59,65,88]|[96,59,65,88]|[96,59,65,88]|[96,59,65,88]|[96,59,65,88]|[96,59,65,88]|[96.000000,59.000000,65.000000,88.000000]|[96.000000,59.000000,65.000000,88.000000] ||||||||||| +-LOG CastFixedListToFixedList +-STATEMENT MATCH(p:person) where p.ID > 8 RETURN cast(p.grades, "UINT64[]"), cast(p.grades, "UINT32[]"), cast(p.grades, "UINT16[]"), cast(p.grades, "UINT8[]"), cast(p.grades, "INT8[]"), cast(p.grades, "INT16[]"), cast(p.grades, "INT32[]"), cast(p.grades, "INT64[]"), cast(p.grades, "INT128[]"), cast(p.grades, "STRING[]"), cast(p.grades, "DOUBLE[]"), cast(p.grades, "FLOAT[]"); +---- 3 +[43,83,67,43]|[43,83,67,43]|[43,83,67,43]|[43,83,67,43]|[43,83,67,43]|[43,83,67,43]|[43,83,67,43]|[43,83,67,43]|[43,83,67,43]|[43,83,67,43]|[43.000000,83.000000,67.000000,43.000000]|[43.000000,83.000000,67.000000,43.000000] +[77,64,100,54]|[77,64,100,54]|[77,64,100,54]|[77,64,100,54]|[77,64,100,54]|[77,64,100,54]|[77,64,100,54]|[77,64,100,54]|[77,64,100,54]|[77,64,100,54]|[77.000000,64.000000,100.000000,54.000000]|[77.000000,64.000000,100.000000,54.000000] +||||||||||| + -LOG CastMapToString -STATEMENT MATCH (m:movies) RETURN string(m.audience) ---- 3