diff --git a/src/expression_evaluator/function_evaluator.cpp b/src/expression_evaluator/function_evaluator.cpp index 033b58ab942..886e46d70c9 100644 --- a/src/expression_evaluator/function_evaluator.cpp +++ b/src/expression_evaluator/function_evaluator.cpp @@ -22,10 +22,8 @@ void FunctionExpressionEvaluator::evaluate() { child->evaluate(); } auto expr = reinterpret_cast(expression.get()); - if (expr->getFunctionName() == CAST_FUNC_NAME && - parameters[0]->dataType.getLogicalTypeID() == LogicalTypeID::STRING) { - execFunc(parameters, *resultVector, - reinterpret_cast(expr->getBindData())); + if (expr->getFunctionName() == CAST_FUNC_NAME) { + execFunc(parameters, *resultVector, expr->getBindData()); return; } if (execFunc != nullptr) { diff --git a/src/function/vector_cast_functions.cpp b/src/function/vector_cast_functions.cpp index e4e62e260d5..0c8a4b37a7b 100644 --- a/src/function/vector_cast_functions.cpp +++ b/src/function/vector_cast_functions.cpp @@ -34,9 +34,10 @@ static void castFixedListToString( resultVector.setValue(resultPos, result); } +template static void fixedListCastExecFunction(const std::vector>& params, ValueVector& result, void* /*dataPtr*/ = nullptr) { - assert(params.size() == 1); + KU_ASSERT(params.size() == 1); auto param = params[0]; if (param->state->isFlat()) { castFixedListToString(*param, param->state->selVector->selectedPositions[0], result, @@ -53,9 +54,25 @@ static void fixedListCastExecFunction(const std::vector +void fixedListCastExecFunction( + const std::vector>& params, ValueVector& result, + void* /*dataPtr*/) { + KU_ASSERT(params.size() == 1); + + auto inputVector = params[0].get(); + auto numOfChild = ListVector::getDataVectorSize(inputVector); + auto inputChildVector = (ListVector::getDataVector(inputVector)); + auto resultChildVector = (ListVector::getDataVector(&result)); + for (auto i = 0u; i < numOfChild; i++) { + castFixedListToString(*inputChildVector, i, *resultChildVector, i); + } +} + +template static void StringtoFixedListCastExecFunction( const std::vector>& params, ValueVector& result, void* dataPtr) { - assert(params.size() == 1); + KU_ASSERT(params.size() == 1); auto param = params[0]; auto csvReaderConfig = &reinterpret_cast(dataPtr)->csvConfig; if (param->state->isFlat()) { @@ -86,6 +103,67 @@ static void StringtoFixedListCastExecFunction( } } +template<> +void StringtoFixedListCastExecFunction( + const std::vector>& params, ValueVector& result, void* dataPtr) { + KU_ASSERT(params.size() == 1 && + params[0]->dataType.getLogicalTypeID() == LogicalTypeID::VAR_LIST && + result.dataType.getLogicalTypeID() == LogicalTypeID::VAR_LIST); + auto csvReaderConfig = &reinterpret_cast(dataPtr)->csvConfig; + + auto inputVector = params[0].get(); + auto numOfChild = ListVector::getDataVectorSize(inputVector); + auto inputChildVector = (ListVector::getDataVector(inputVector)); + auto resultChildVector = (ListVector::getDataVector(&result)); + for (auto i = 0u; i < numOfChild; i++) { + resultChildVector->setNull(i, inputChildVector->isNull(i)); + if (!resultChildVector->isNull(i)) { + CastString::castToFixedList( + inputChildVector->getValue(i), resultChildVector, i, csvReaderConfig); + } + } +} + +static void varListCastExecFunction( + const std::vector>& params, ValueVector& result, void* dataPtr) { + KU_ASSERT(params.size() == 1); + result.resetAuxiliaryBuffer(); + auto inputVector = params[0]; + auto resultVector = &result; + + auto numOfChild = ListVector::getDataVectorSize(inputVector.get()); + ListVector::resizeDataVector(&result, numOfChild); + auto numOfListEntry = inputVector->state->selVector + ->selectedPositions[inputVector->state->selVector->selectedSize - 1] + + 1; + memcpy(resultVector->getData(), inputVector->getData(), + numOfListEntry * resultVector->getNumBytesPerValue()); + resultVector->setNullFromBits(inputVector->getNullMaskData(), 0, 0, numOfListEntry); + + // resolve to the lowest level dataVector + auto inputChildTypeID = VarListType::getChildType(&inputVector->dataType)->getLogicalTypeID(); + auto resultChildTypeID = VarListType::getChildType(&resultVector->dataType)->getLogicalTypeID(); + while (inputChildTypeID == LogicalTypeID::VAR_LIST && + resultChildTypeID == LogicalTypeID::VAR_LIST) { + inputVector = ListVector::getSharedDataVector(inputVector.get()); + resultVector = ListVector::getDataVector(resultVector); + inputChildTypeID = VarListType::getChildType(&inputVector->dataType)->getLogicalTypeID(); + resultChildTypeID = VarListType::getChildType(&resultVector->dataType)->getLogicalTypeID(); + + // copy NULL musk and list entry + memcpy(resultVector->getData(), inputVector->getData(), + numOfChild * resultVector->getNumBytesPerValue()); + resultVector->setNullFromBits(inputVector->getNullMaskData(), 0, 0, numOfChild); + numOfChild = ListVector::getDataVectorSize(inputVector.get()); + ListVector::resizeDataVector(resultVector, numOfChild); + } + scalar_exec_func func = CastFunction::bindCastFunction( + "CAST", inputChildTypeID, resultChildTypeID) + ->execFunc; + std::vector> childParams{inputVector}; + func(childParams, *resultVector, dataPtr); +} + bool CastFunction::hasImplicitCast(const LogicalType& srcType, const LogicalType& dstType) { // We allow cast between any numerical types if (LogicalTypeUtils::isNumerical(srcType) && LogicalTypeUtils::isNumerical(dstType)) { @@ -119,81 +197,98 @@ bool CastFunction::hasImplicitCast(const LogicalType& srcType, const LogicalType } } +template static std::unique_ptr bindCastFromStringFunction( const std::string& functionName, LogicalTypeID targetTypeID) { scalar_exec_func execFunc; switch (targetTypeID) { case LogicalTypeID::DATE: { - execFunc = ScalarFunction::UnaryCastStringExecFunction; + execFunc = + ScalarFunction::UnaryCastStringExecFunction; } break; case LogicalTypeID::TIMESTAMP: { - execFunc = - ScalarFunction::UnaryCastStringExecFunction; + execFunc = ScalarFunction::UnaryCastStringExecFunction; } break; case LogicalTypeID::INTERVAL: { - execFunc = ScalarFunction::UnaryCastStringExecFunction; + execFunc = ScalarFunction::UnaryCastStringExecFunction; } break; case LogicalTypeID::BLOB: { - execFunc = ScalarFunction::UnaryCastStringExecFunction; + execFunc = + ScalarFunction::UnaryCastStringExecFunction; } break; case LogicalTypeID::STRING: { - execFunc = ScalarFunction::UnaryCastExecFunction; + execFunc = + ScalarFunction::UnaryCastExecFunction; } break; case LogicalTypeID::BOOL: { - execFunc = ScalarFunction::UnaryCastStringExecFunction; + execFunc = + ScalarFunction::UnaryCastStringExecFunction; } break; case LogicalTypeID::DOUBLE: { - execFunc = ScalarFunction::UnaryCastStringExecFunction; + execFunc = ScalarFunction::UnaryCastStringExecFunction; } break; case LogicalTypeID::FLOAT: { - execFunc = ScalarFunction::UnaryCastStringExecFunction; + execFunc = + ScalarFunction::UnaryCastStringExecFunction; } break; case LogicalTypeID::INT128: { - execFunc = ScalarFunction::UnaryCastStringExecFunction; + execFunc = ScalarFunction::UnaryCastStringExecFunction; } break; case LogicalTypeID::SERIAL: case LogicalTypeID::INT64: { - execFunc = ScalarFunction::UnaryCastStringExecFunction; + execFunc = + ScalarFunction::UnaryCastStringExecFunction; } break; case LogicalTypeID::INT32: { - execFunc = ScalarFunction::UnaryCastStringExecFunction; + execFunc = + ScalarFunction::UnaryCastStringExecFunction; } break; case LogicalTypeID::INT16: { - execFunc = ScalarFunction::UnaryCastStringExecFunction; + execFunc = + ScalarFunction::UnaryCastStringExecFunction; } break; case LogicalTypeID::INT8: { - execFunc = ScalarFunction::UnaryCastStringExecFunction; + execFunc = + ScalarFunction::UnaryCastStringExecFunction; } break; case LogicalTypeID::UINT64: { - execFunc = ScalarFunction::UnaryCastStringExecFunction; + execFunc = ScalarFunction::UnaryCastStringExecFunction; } break; case LogicalTypeID::UINT32: { - execFunc = ScalarFunction::UnaryCastStringExecFunction; + execFunc = ScalarFunction::UnaryCastStringExecFunction; } break; case LogicalTypeID::UINT16: { - execFunc = ScalarFunction::UnaryCastStringExecFunction; + execFunc = ScalarFunction::UnaryCastStringExecFunction; } break; case LogicalTypeID::UINT8: { - execFunc = ScalarFunction::UnaryCastStringExecFunction; + execFunc = + ScalarFunction::UnaryCastStringExecFunction; } break; case LogicalTypeID::VAR_LIST: { - execFunc = - ScalarFunction::UnaryCastStringExecFunction; + execFunc = ScalarFunction::UnaryCastStringExecFunction; } break; case LogicalTypeID::FIXED_LIST: { - execFunc = StringtoFixedListCastExecFunction; + execFunc = StringtoFixedListCastExecFunction; } break; case LogicalTypeID::MAP: { - execFunc = - ScalarFunction::UnaryCastStringExecFunction; + execFunc = ScalarFunction::UnaryCastStringExecFunction; } break; case LogicalTypeID::STRUCT: { - execFunc = - ScalarFunction::UnaryCastStringExecFunction; + execFunc = ScalarFunction::UnaryCastStringExecFunction; } break; case LogicalTypeID::UNION: { - execFunc = - ScalarFunction::UnaryCastStringExecFunction; + execFunc = ScalarFunction::UnaryCastStringExecFunction; } break; default: throw ConversionException{stringFormat("Unsupported casting function from STRING to {}.", @@ -203,74 +298,75 @@ static std::unique_ptr bindCastFromStringFunction( functionName, std::vector{LogicalTypeID::STRING}, targetTypeID, execFunc); } +template static std::unique_ptr bindCastFromRdfVariantFunction( const std::string& functionName, LogicalTypeID targetTypeID) { scalar_exec_func execFunc; switch (targetTypeID) { case LogicalTypeID::DATE: { - execFunc = - ScalarFunction::UnaryTryCastExecFunction; + execFunc = ScalarFunction::UnaryTryCastExecFunction; } break; case LogicalTypeID::TIMESTAMP: { execFunc = ScalarFunction::UnaryTryCastExecFunction; + CastFromRdfVariant, EXECUTOR>; } break; case LogicalTypeID::INTERVAL: { execFunc = ScalarFunction::UnaryTryCastExecFunction; + CastFromRdfVariant, EXECUTOR>; } break; case LogicalTypeID::BLOB: { - execFunc = - ScalarFunction::UnaryTryCastExecFunction; + execFunc = ScalarFunction::UnaryTryCastExecFunction; } break; case LogicalTypeID::STRING: { execFunc = ScalarFunction::UnaryTryCastExecFunction; + CastFromRdfVariant, EXECUTOR>; } break; case LogicalTypeID::BOOL: { - execFunc = - ScalarFunction::UnaryTryCastExecFunction; + execFunc = ScalarFunction::UnaryTryCastExecFunction; } break; case LogicalTypeID::DOUBLE: { - execFunc = - ScalarFunction::UnaryTryCastExecFunction; + execFunc = ScalarFunction::UnaryTryCastExecFunction; } break; case LogicalTypeID::FLOAT: { - execFunc = - ScalarFunction::UnaryTryCastExecFunction; + execFunc = ScalarFunction::UnaryTryCastExecFunction; } break; case LogicalTypeID::SERIAL: case LogicalTypeID::INT64: { - execFunc = - ScalarFunction::UnaryTryCastExecFunction; + execFunc = ScalarFunction::UnaryTryCastExecFunction; } break; case LogicalTypeID::INT32: { - execFunc = - ScalarFunction::UnaryTryCastExecFunction; + execFunc = ScalarFunction::UnaryTryCastExecFunction; } break; case LogicalTypeID::INT16: { - execFunc = - ScalarFunction::UnaryTryCastExecFunction; + execFunc = ScalarFunction::UnaryTryCastExecFunction; } break; case LogicalTypeID::INT8: { - execFunc = - ScalarFunction::UnaryTryCastExecFunction; + execFunc = ScalarFunction::UnaryTryCastExecFunction; } break; case LogicalTypeID::UINT64: { - execFunc = - ScalarFunction::UnaryTryCastExecFunction; + execFunc = ScalarFunction::UnaryTryCastExecFunction; } break; case LogicalTypeID::UINT32: { - execFunc = - ScalarFunction::UnaryTryCastExecFunction; + execFunc = ScalarFunction::UnaryTryCastExecFunction; } break; case LogicalTypeID::UINT16: { - execFunc = - ScalarFunction::UnaryTryCastExecFunction; + execFunc = ScalarFunction::UnaryTryCastExecFunction; } break; case LogicalTypeID::UINT8: { - execFunc = - ScalarFunction::UnaryTryCastExecFunction; + execFunc = ScalarFunction::UnaryTryCastExecFunction; } break; // LCOV_EXCL_START default: @@ -283,78 +379,86 @@ static std::unique_ptr bindCastFromRdfVariantFunction( std::vector{LogicalTypeID::RDF_VARIANT}, targetTypeID, execFunc); } +template static std::unique_ptr bindCastToStringFunction( const std::string& functionName, LogicalTypeID sourceTypeID) { scalar_exec_func func; switch (sourceTypeID) { case LogicalTypeID::BOOL: { - func = ScalarFunction::UnaryCastExecFunction; + func = ScalarFunction::UnaryCastExecFunction; } break; case LogicalTypeID::SERIAL: case LogicalTypeID::INT64: { - func = ScalarFunction::UnaryCastExecFunction; + func = ScalarFunction::UnaryCastExecFunction; } break; case LogicalTypeID::INT32: { - func = ScalarFunction::UnaryCastExecFunction; + func = ScalarFunction::UnaryCastExecFunction; } break; case LogicalTypeID::INT16: { - func = ScalarFunction::UnaryCastExecFunction; + func = ScalarFunction::UnaryCastExecFunction; } break; case LogicalTypeID::INT8: { - func = ScalarFunction::UnaryCastExecFunction; + func = ScalarFunction::UnaryCastExecFunction; } break; case LogicalTypeID::UINT64: { - func = ScalarFunction::UnaryCastExecFunction; + func = ScalarFunction::UnaryCastExecFunction; } break; case LogicalTypeID::UINT32: { - func = ScalarFunction::UnaryCastExecFunction; + func = ScalarFunction::UnaryCastExecFunction; } break; case LogicalTypeID::UINT16: { - func = ScalarFunction::UnaryCastExecFunction; + func = ScalarFunction::UnaryCastExecFunction; } break; case LogicalTypeID::INT128: { - func = ScalarFunction::UnaryCastExecFunction; + func = ScalarFunction::UnaryCastExecFunction; } break; case LogicalTypeID::UINT8: { - func = ScalarFunction::UnaryCastExecFunction; + func = ScalarFunction::UnaryCastExecFunction; } break; case LogicalTypeID::DOUBLE: { - func = ScalarFunction::UnaryCastExecFunction; + func = ScalarFunction::UnaryCastExecFunction; } break; case LogicalTypeID::FLOAT: { - func = ScalarFunction::UnaryCastExecFunction; + func = ScalarFunction::UnaryCastExecFunction; } break; case LogicalTypeID::DATE: { - func = ScalarFunction::UnaryCastExecFunction; + func = ScalarFunction::UnaryCastExecFunction; } break; case LogicalTypeID::TIMESTAMP: { - func = ScalarFunction::UnaryCastExecFunction; + func = + ScalarFunction::UnaryCastExecFunction; } break; case LogicalTypeID::INTERVAL: { - func = ScalarFunction::UnaryCastExecFunction; + func = + ScalarFunction::UnaryCastExecFunction; } break; case LogicalTypeID::INTERNAL_ID: { - func = ScalarFunction::UnaryCastExecFunction; + func = ScalarFunction::UnaryCastExecFunction; } break; case LogicalTypeID::BLOB: { - func = ScalarFunction::UnaryCastExecFunction; + func = ScalarFunction::UnaryCastExecFunction; } break; case LogicalTypeID::VAR_LIST: { - func = ScalarFunction::UnaryCastExecFunction; + func = ScalarFunction::UnaryCastExecFunction; } break; case LogicalTypeID::FIXED_LIST: { - func = fixedListCastExecFunction; + func = fixedListCastExecFunction; } break; case LogicalTypeID::MAP: { - func = ScalarFunction::UnaryCastExecFunction; + func = + ScalarFunction::UnaryCastExecFunction; } break; case LogicalTypeID::NODE: case LogicalTypeID::REL: case LogicalTypeID::STRUCT: { - func = ScalarFunction::UnaryCastExecFunction; + func = ScalarFunction::UnaryCastExecFunction; } break; case LogicalTypeID::UNION: { - func = ScalarFunction::UnaryCastExecFunction; + func = ScalarFunction::UnaryCastExecFunction; } break; // ToDo(Kebing): RECURSIVE_REL to string default: @@ -364,44 +468,44 @@ static std::unique_ptr bindCastToStringFunction( functionName, std::vector{sourceTypeID}, LogicalTypeID::STRING, func); } -template +template static std::unique_ptr bindCastToNumericFunction( const std::string& functionName, LogicalTypeID sourceTypeID, LogicalTypeID targetTypeID) { scalar_exec_func func; switch (sourceTypeID) { case LogicalTypeID::INT8: { - func = ScalarFunction::UnaryExecFunction; + func = ScalarFunction::UnaryExecFunction; } break; case LogicalTypeID::INT16: { - func = ScalarFunction::UnaryExecFunction; + func = ScalarFunction::UnaryExecFunction; } break; case LogicalTypeID::INT32: { - func = ScalarFunction::UnaryExecFunction; + func = ScalarFunction::UnaryExecFunction; } break; case LogicalTypeID::SERIAL: case LogicalTypeID::INT64: { - func = ScalarFunction::UnaryExecFunction; + func = ScalarFunction::UnaryExecFunction; } break; case LogicalTypeID::UINT8: { - func = ScalarFunction::UnaryExecFunction; + func = ScalarFunction::UnaryExecFunction; } break; case LogicalTypeID::UINT16: { - func = ScalarFunction::UnaryExecFunction; + func = ScalarFunction::UnaryExecFunction; } break; case LogicalTypeID::UINT32: { - func = ScalarFunction::UnaryExecFunction; + func = ScalarFunction::UnaryExecFunction; } break; case LogicalTypeID::UINT64: { - func = ScalarFunction::UnaryExecFunction; + func = ScalarFunction::UnaryExecFunction; } break; case LogicalTypeID::INT128: { - func = ScalarFunction::UnaryExecFunction; + func = ScalarFunction::UnaryExecFunction; } break; case LogicalTypeID::FLOAT: { - func = ScalarFunction::UnaryExecFunction; + func = ScalarFunction::UnaryExecFunction; } break; case LogicalTypeID::DOUBLE: { - func = ScalarFunction::UnaryExecFunction; + func = ScalarFunction::UnaryExecFunction; } break; default: throw ConversionException{stringFormat("Unsupported casting function from {} to {}.", @@ -411,12 +515,33 @@ static std::unique_ptr bindCastToNumericFunction( functionName, std::vector{sourceTypeID}, targetTypeID, func); } +template +static std::unique_ptr bindCastBetweenNested( + const std::string& functionName, LogicalTypeID sourceTypeID, LogicalTypeID targetTypeID) { + scalar_exec_func func; + switch (sourceTypeID) { + case LogicalTypeID::VAR_LIST: { + func = varListCastExecFunction; + } break; + default: + // lcov_excl_start + // TODO(kebing): implement more + throw ConversionException{stringFormat("Unsupported casting function from {} to {}.", + LogicalTypeUtils::toString(sourceTypeID), LogicalTypeUtils::toString(targetTypeID))}; + // lcov_excl_end + } + return std::make_unique( + functionName, std::vector{sourceTypeID}, targetTypeID, func); +} + +template static std::unique_ptr bindCastToTimestampFunction( const std::string& functionName, LogicalTypeID sourceTypeID) { scalar_exec_func func; switch (sourceTypeID) { case LogicalTypeID::DATE: { - func = ScalarFunction::UnaryExecFunction; + func = + ScalarFunction::UnaryExecFunction; } break; default: throw ConversionException{stringFormat("Unsupported casting function from {} to TIMESTAMP.", @@ -426,68 +551,72 @@ static std::unique_ptr bindCastToTimestampFunction( functionName, std::vector{sourceTypeID}, LogicalTypeID::TIMESTAMP, func); } +template std::unique_ptr CastFunction::bindCastFunction( const std::string& functionName, LogicalTypeID sourceTypeID, LogicalTypeID targetTypeID) { if (sourceTypeID == LogicalTypeID::STRING) { - return bindCastFromStringFunction(functionName, targetTypeID); + return bindCastFromStringFunction(functionName, targetTypeID); } if (sourceTypeID == LogicalTypeID::RDF_VARIANT) { - return bindCastFromRdfVariantFunction(functionName, targetTypeID); + return bindCastFromRdfVariantFunction(functionName, targetTypeID); } switch (targetTypeID) { case LogicalTypeID::STRING: { - return bindCastToStringFunction(functionName, sourceTypeID); + return bindCastToStringFunction(functionName, sourceTypeID); } case LogicalTypeID::DOUBLE: { - return bindCastToNumericFunction( + return bindCastToNumericFunction( functionName, sourceTypeID, targetTypeID); } case LogicalTypeID::FLOAT: { - return bindCastToNumericFunction( + return bindCastToNumericFunction( functionName, sourceTypeID, targetTypeID); } case LogicalTypeID::INT128: { - return bindCastToNumericFunction( + return bindCastToNumericFunction( functionName, sourceTypeID, targetTypeID); } case LogicalTypeID::SERIAL: { - return bindCastToNumericFunction( + return bindCastToNumericFunction( functionName, sourceTypeID, targetTypeID); } case LogicalTypeID::INT64: { - return bindCastToNumericFunction( + return bindCastToNumericFunction( functionName, sourceTypeID, targetTypeID); } case LogicalTypeID::INT32: { - return bindCastToNumericFunction( + return bindCastToNumericFunction( functionName, sourceTypeID, targetTypeID); } case LogicalTypeID::INT16: { - return bindCastToNumericFunction( + return bindCastToNumericFunction( functionName, sourceTypeID, targetTypeID); } case LogicalTypeID::INT8: { - return bindCastToNumericFunction( + return bindCastToNumericFunction( functionName, sourceTypeID, targetTypeID); } case LogicalTypeID::UINT64: { - return bindCastToNumericFunction( + return bindCastToNumericFunction( functionName, sourceTypeID, targetTypeID); } case LogicalTypeID::UINT32: { - return bindCastToNumericFunction( + return bindCastToNumericFunction( functionName, sourceTypeID, targetTypeID); } case LogicalTypeID::UINT16: { - return bindCastToNumericFunction( + return bindCastToNumericFunction( functionName, sourceTypeID, targetTypeID); } case LogicalTypeID::UINT8: { - return bindCastToNumericFunction( + return bindCastToNumericFunction( functionName, sourceTypeID, targetTypeID); } case LogicalTypeID::TIMESTAMP: { - return bindCastToTimestampFunction(functionName, sourceTypeID); + return bindCastToTimestampFunction(functionName, sourceTypeID); + } + case LogicalTypeID::VAR_LIST: { + return bindCastBetweenNested(functionName, sourceTypeID, targetTypeID); } default: { throw ConversionException{stringFormat("Unsupported casting function from {} to {}.", @@ -726,10 +855,7 @@ std::unique_ptr CastAnyFunction::bindFunc( func->execFunc = CastFunction::bindCastFunction(func->name, inputTypeID, outputType->getLogicalTypeID()) ->execFunc; - if (inputTypeID == LogicalTypeID::STRING) { - return std::make_unique(*outputType); - } - return std::make_unique(*outputType); + return std::make_unique(*outputType); } function_set CastAnyFunction::getFunctionSet() { diff --git a/src/include/common/null_mask.h b/src/include/common/null_mask.h index 7739bafd40a..e6e6827b74e 100644 --- a/src/include/common/null_mask.h +++ b/src/include/common/null_mask.h @@ -115,7 +115,7 @@ class NullMask { // const because updates to the data must set mayContainNulls if any value // becomes non-null - // Modifying the underlying data shuld be done with setNull or copyFromNullData + // Modifying the underlying data should be done with setNull or copyFromNullData inline const uint64_t* getData() { return data; } static inline uint64_t getNumNullEntries(uint64_t numNullBits) { diff --git a/src/include/common/vector/auxiliary_buffer.h b/src/include/common/vector/auxiliary_buffer.h index 16babdcbee0..3f281ad9a6a 100644 --- a/src/include/common/vector/auxiliary_buffer.h +++ b/src/include/common/vector/auxiliary_buffer.h @@ -73,6 +73,7 @@ class ListAuxiliaryBuffer : public AuxiliaryBuffer { dataVector = std::move(vector); } inline ValueVector* getDataVector() const { return dataVector.get(); } + inline std::shared_ptr getSharedDataVector() const { return dataVector; } list_entry_t addList(uint64_t listSize); diff --git a/src/include/common/vector/value_vector.h b/src/include/common/vector/value_vector.h index 19baa4ff0f6..6286083deba 100644 --- a/src/include/common/vector/value_vector.h +++ b/src/include/common/vector/value_vector.h @@ -147,6 +147,11 @@ class ListVector { return reinterpret_cast(vector->auxiliaryBuffer.get()) ->getDataVector(); } + static inline std::shared_ptr getSharedDataVector(const ValueVector* vector) { + KU_ASSERT(vector->dataType.getPhysicalType() == PhysicalTypeID::VAR_LIST); + return reinterpret_cast(vector->auxiliaryBuffer.get()) + ->getSharedDataVector(); + } static inline uint64_t getDataVectorSize(const ValueVector* vector) { KU_ASSERT(vector->dataType.getPhysicalType() == PhysicalTypeID::VAR_LIST); return reinterpret_cast(vector->auxiliaryBuffer.get())->getSize(); diff --git a/src/include/function/cast/vector_cast_functions.h b/src/include/function/cast/vector_cast_functions.h index f969748ecb1..cc32ce9918a 100644 --- a/src/include/function/cast/vector_cast_functions.h +++ b/src/include/function/cast/vector_cast_functions.h @@ -16,6 +16,7 @@ struct CastFunction { static bool hasImplicitCast( const common::LogicalType& srcType, const common::LogicalType& dstType); + template static std::unique_ptr bindCastFunction(const std::string& functionName, common::LogicalTypeID sourceTypeID, common::LogicalTypeID targetTypeID); }; diff --git a/src/include/function/scalar_function.h b/src/include/function/scalar_function.h index cffddca1138..c9b5eddbb62 100644 --- a/src/include/function/scalar_function.h +++ b/src/include/function/scalar_function.h @@ -92,11 +92,12 @@ struct ScalarFunction : public BaseScalarFunction { *params[0], *params[1], selVector); } - template + template static void UnaryExecFunction(const std::vector>& params, common::ValueVector& result, void* /*dataPtr*/) { KU_ASSERT(params.size() == 1); - UnaryFunctionExecutor::executeSwitch( + EXECUTOR::template executeSwitch( *params[0], result, nullptr /* dataPtr */); } @@ -109,30 +110,33 @@ struct ScalarFunction : public BaseScalarFunction { UnaryStringFunctionWrapper>(*params[0], result, nullptr /* dataPtr */); } - template + template static void UnaryCastStringExecFunction( const std::vector>& params, common::ValueVector& result, void* dataPtr) { KU_ASSERT(params.size() == 1); - UnaryFunctionExecutor::executeCastString( - *params[0], result, dataPtr); + EXECUTOR::template executeSwitch(*params[0], result, dataPtr); } - template + template static void UnaryCastExecFunction( const std::vector>& params, common::ValueVector& result, void* /*dataPtr*/ = nullptr) { KU_ASSERT(params.size() == 1); - UnaryFunctionExecutor::executeSwitch(*params[0], result, nullptr /* dataPtr */); + EXECUTOR::template executeSwitch( + *params[0], result, nullptr /* dataPtr */); } - template + template static void UnaryTryCastExecFunction( const std::vector>& params, common::ValueVector& result, void* /*dataPtr*/ = nullptr) { KU_ASSERT(params.size() == 1); - UnaryFunctionExecutor::executeSwitch(*params[0], result, nullptr /* dataPtr */); } diff --git a/src/include/function/unary_function_executor.h b/src/include/function/unary_function_executor.h index 8faf583038a..a10300de033 100644 --- a/src/include/function/unary_function_executor.h +++ b/src/include/function/unary_function_executor.h @@ -89,6 +89,27 @@ struct UnaryUDFFunctionWrapper { } }; +struct CastChildFunctionExecutor { + template + static void executeSwitch( + common::ValueVector& operand, common::ValueVector& result, void* dataPtr) { + // this vector is of var list type and the child vector is of non-nested types then cast + KU_ASSERT(operand.dataType.getLogicalTypeID() == common::LogicalTypeID::VAR_LIST && + result.dataType.getLogicalTypeID() == common::LogicalTypeID::VAR_LIST); + auto childNum = common::ListVector::getDataVectorSize(&operand); + auto inputChildVector = common::ListVector::getDataVector(&operand); + auto resultChildVector = (common::ListVector::getDataVector(&result)); + for (auto i = 0u; i < childNum; i++) { + resultChildVector->setNull(i, inputChildVector->isNull(i)); + if (!resultChildVector->isNull(i)) { + // cast position i in child data vector + OP_WRAPPER::template operation( + (void*)(inputChildVector), i, (void*)(resultChildVector), i, dataPtr); + } + } + } +}; + struct UnaryFunctionExecutor { template static void executeOnValue(common::ValueVector& inputVector, uint64_t inputPos, @@ -152,19 +173,6 @@ struct UnaryFunctionExecutor { operand, result, nullptr /* dataPtr */); } - template - static void executeString(common::ValueVector& operand, common::ValueVector& result) { - executeSwitch( - operand, result, nullptr /* dataPtr */); - } - - template - static void executeCastString( - common::ValueVector& operand, common::ValueVector& result, void* dataPtr) { - executeSwitch( - operand, result, dataPtr); - } - template static void executeUDF( common::ValueVector& operand, common::ValueVector& result, void* dataPtr) { diff --git a/test/test_files/tinysnb/cast/cast_error.test b/test/test_files/tinysnb/cast/cast_error.test index ca28f9c42b4..a7d8bd4b45f 100644 --- a/test/test_files/tinysnb/cast/cast_error.test +++ b/test/test_files/tinysnb/cast/cast_error.test @@ -702,3 +702,14 @@ Conversion exception: Unsupported casting function from REL to TIMESTAMP. -STATEMENT MATCH (:person)-[e:studyAt*1..3]->(:organisation) return cast(e, "INT64"); ---- error Conversion exception: Unsupported casting function from RECURSIVE_REL to INT64. + +-LOG InvalidVarListToVarList +-STATEMENT RETURN cast([31231], "INT64[][]"); +---- error +Conversion exception: Unsupported casting function from INT64 to VAR_LIST. +-STATEMENT RETURN cast([-1], "UINT8[]"); +---- error +Overflow exception: Value -1 is not within UINT8 range +-STATEMENT RETURN cast([[1, 1]], "UINT8[]"); +---- error +Conversion exception: Unsupported casting function from VAR_LIST to UINT8. diff --git a/test/test_files/tinysnb/cast/cast_to_nested_types.test b/test/test_files/tinysnb/cast/cast_to_nested_types.test index 6527fb5b673..c83b84379eb 100644 --- a/test/test_files/tinysnb/cast/cast_to_nested_types.test +++ b/test/test_files/tinysnb/cast/cast_to_nested_types.test @@ -124,6 +124,12 @@ False|-4325||18446744073709551616.000000| dfsa ---- 2 [3324.123047,342423.437500,432.122986] [1.000000,4231.000000,432.122986] +-STATEMENT RETURN cast("[423, 321, 423]", "INT64[3]"), cast(null, "INT64[5]"), cast("[432.43214]", "FLOAT[1]"), cast("[4, -5]", "double[2]"), cast("[4234, 42312, 432, 1321]", "INT32[4]"), cast("[-32768]", "INT16[1]") +---- 1 +[423,321,423]||[432.432129]|[4.000000,-5.000000]|[4234,42312,432,1321]|[-32768] +-STATEMENT Return cast(cast(-4324324, "int128"), "int64") +---- 1 +-4324324 -LOG CastToMap -STATEMENT RETURN cast(" { c= {a = 3423 }, b = { g = 3421 } } ", "MAP(STRING, MAP(STRING, INT16))"), cast("{}", "MAP(STRING, MAP(STRING, INT16))"), cast("{d = {}}", "MAP(STRING, MAP(STRING, INT16))"); @@ -168,3 +174,31 @@ True|2019-03-19|-12.343200|32768| 1|-2147483648|1970-01-01 10:00:00.004666|-32769.000000|fsdxcv 0|0|2014-05-12 12:11:59|4324254534123134124032.000000|fsaf False|-4325|14|18446744073709551616.000000| dfsa + +-LOG CastVarListToVarList +-STATEMENT RETURN cast([321, 432], "DOUBLE[]"), cast([321, 432], "FLOAT[]"), cast([321, 432], "INT128[]"), cast([321, 432], "INT64[]"), cast([321, 432], "INT32[]"), cast([321, 432], "INT16[]"), cast([-1, -43], "INT8[]"), cast([0, 23], "UINT8[]"), cast([0, 23], "UINT16[]"), cast([0, 23], "UINT32[]"), cast([0, 23], "UINT64[]"), cast([5435234412435123, -432425341231], "STRING[]"); +---- 1 +[321.000000,432.000000]|[321.000000,432.000000]|[321,432]|[321,432]|[321,432]|[321,432]|[-1,-43]|[0,23]|[0,23]|[0,23]|[0,23]|[5435234412435123,-432425341231] +-STATEMENT RETURN cast([], "UINT64[]"), cast([NULL,], "UINT64[]"), cast(NULL, "UINT64[]"), cast([NULL, 432124, 0, NULL], "UINT64[]"); +---- 1 +[]|[,]||[,432124,0,] + +-LOG CastNestedVarListToNestedVarList +-STATEMENT RETURN cast([[4324.2312, 432.321, 43242.543], [31214.59,4132.72], NULL, [NULL,,4324.32]], "INT64[][]"); +---- 1 +[[4324,432,43243],[31215,4133],,[,,4324]] +-STATEMENT RETURN cast(["[123, 3234]", "[124, 3241]", NULL, "[0, -4324234]"], "INT64[2][]"), cast(cast(["[123, 3234]", "[124, 3241]", NULL, "[0, -4324234]"], "DOUBLE[2][]"), "STRING[]"); +---- 1 +[[123,3234],[124,3241],,[0,-4324234]]|[[123.000000,3234.000000],[124.000000,3241.000000],,[0.000000,-4324234.000000]] +-STATEMENT RETURN cast([NULL, NULL, NULL], "INT8[][][]"), cast([NULL], "STRING[]"), cast([], "UINT8[]"); +---- 1 +[,,]|[]|[] +-STATEMENT RETURN cast(cast([NULL, [NULL, 13], NULL, [14, 14], NULL], "INT32[][]"), "INT128[][]"), cast([NULL, 1], "INT16[]"), cast("[1, NULL, NULL]", "UINT32[]"), cast("[NULL, 1, NULL]", "UINT64[]"); +---- 1 +[,[,13],,[14,14],]|[,1]|[1,,]|[,1,] +-STATEMENT RETURN cast(NULL, "INT32[][]"); +---- 1 + +-STATEMENT RETURN cast(cast(cast(cast(["[NULL, [NULL, 1, 0, 2], NULL, [1, 2, 3, 4, 5], NULL]", "[[1, 2, 3], [4, 5, 6]]"], "UINT8[][][]"), "UINT16[][][]"), "INT32[][][]"), "DOUBLE[][][]"); +---- 1 +[[,[,1.000000,0.000000,2.000000],,[1.000000,2.000000,3.000000,4.000000,5.000000],],[[1.000000,2.000000,3.000000],[4.000000,5.000000,6.000000]]] diff --git a/test/test_files/tinysnb/function/cast.test b/test/test_files/tinysnb/function/cast.test index f1d3ee58da5..b7ab4adb01e 100644 --- a/test/test_files/tinysnb/function/cast.test +++ b/test/test_files/tinysnb/function/cast.test @@ -335,6 +335,26 @@ Hubert Blaine Wolfeschlegelsteinhausenbergerdorff [1] [10,11,12,3,4,5,6,7] +-LOG CastListOfIntsToList +-STATEMENT MATCH (p:person) RETURN cast(p.workedHours, "DOUBLE[]"), cast(p.workedHours, "FLOAT[]"), cast(p.workedHours, "INT128[]"), cast(p.workedHours, "INT64[]"), cast(p.workedHours, "INT32[]"), cast(p.workedHours, "INT16[]"), cast(p.workedHours, "INT8[]"), cast(p.workedHours, "UINT8[]"), cast(p.workedHours, "UINT16[]"), cast(p.workedHours, "UINT32[]"), cast(p.workedHours, "UINT64[]"), cast(p.workedHours, "STRING[]") +---- 9 +[10.000000,5.000000]|[10.000000,5.000000]|[10,5]|[10,5]|[10,5]|[10,5]|[10,5]|[10,5]|[10,5]|[10,5]|[10,5]|[10,5] +[12.000000,8.000000]|[12.000000,8.000000]|[12,8]|[12,8]|[12,8]|[12,8]|[12,8]|[12,8]|[12,8]|[12,8]|[12,8]|[12,8] +[4.000000,5.000000]|[4.000000,5.000000]|[4,5]|[4,5]|[4,5]|[4,5]|[4,5]|[4,5]|[4,5]|[4,5]|[4,5]|[4,5] +[1.000000,9.000000]|[1.000000,9.000000]|[1,9]|[1,9]|[1,9]|[1,9]|[1,9]|[1,9]|[1,9]|[1,9]|[1,9]|[1,9] +[2.000000]|[2.000000]|[2]|[2]|[2]|[2]|[2]|[2]|[2]|[2]|[2]|[2] +[3.000000,4.000000,5.000000,6.000000,7.000000]|[3.000000,4.000000,5.000000,6.000000,7.000000]|[3,4,5,6,7]|[3,4,5,6,7]|[3,4,5,6,7]|[3,4,5,6,7]|[3,4,5,6,7]|[3,4,5,6,7]|[3,4,5,6,7]|[3,4,5,6,7]|[3,4,5,6,7]|[3,4,5,6,7] +[1.000000]|[1.000000]|[1]|[1]|[1]|[1]|[1]|[1]|[1]|[1]|[1]|[1] +[10.000000,11.000000,12.000000,3.000000,4.000000,5.000000,6.000000,7.000000]|[10.000000,11.000000,12.000000,3.000000,4.000000,5.000000,6.000000,7.000000]|[10,11,12,3,4,5,6,7]|[10,11,12,3,4,5,6,7]|[10,11,12,3,4,5,6,7]|[10,11,12,3,4,5,6,7]|[10,11,12,3,4,5,6,7]|[10,11,12,3,4,5,6,7]|[10,11,12,3,4,5,6,7]|[10,11,12,3,4,5,6,7]|[10,11,12,3,4,5,6,7]|[10,11,12,3,4,5,6,7] +||||||||||| +-STATEMENT MATCH (p:person) WHERE size(p.workedHours) > 1 RETURN cast(p.workedHours, "STRING[]") +---- 6 +[1,9] +[10,5] +[12,8] +[3,4,5,6,7] +[4,5] +[10,11,12,3,4,5,6,7] -LOG CastListOfListOfIntsToString -STATEMENT MATCH (p:person) RETURN string(p.courseScoresPerTerm) @@ -359,6 +379,28 @@ Hubert Blaine Wolfeschlegelsteinhausenbergerdorff [[10]] [[7],[10],[6,7]] +-LOG CastListOfListOfIntsToListOfLists +-STATEMENT MATCH (p:person) RETURN cast(p.courseScoresPerTerm, "INT128[][]"), cast(p.courseScoresPerTerm, "INT64[][]"), cast(p.courseScoresPerTerm, "INT32[][]"), cast(p.courseScoresPerTerm, "INT16[][]"), cast(p.courseScoresPerTerm, "INT8[][]"), cast(p.courseScoresPerTerm, "UINT8[][]"), cast(p.courseScoresPerTerm, "UINT16[][]"), cast(p.courseScoresPerTerm, "UINT32[][]"), cast(p.courseScoresPerTerm, "UINT64[][]"), cast(p.courseScoresPerTerm, "DOUBLE[][]"), cast(p.courseScoresPerTerm, "FLOAT[][]") +---- 9 +[[10,8],[6,7,8]]|[[10,8],[6,7,8]]|[[10,8],[6,7,8]]|[[10,8],[6,7,8]]|[[10,8],[6,7,8]]|[[10,8],[6,7,8]]|[[10,8],[6,7,8]]|[[10,8],[6,7,8]]|[[10,8],[6,7,8]]|[[10.000000,8.000000],[6.000000,7.000000,8.000000]]|[[10.000000,8.000000],[6.000000,7.000000,8.000000]] +[[8,9],[9,10]]|[[8,9],[9,10]]|[[8,9],[9,10]]|[[8,9],[9,10]]|[[8,9],[9,10]]|[[8,9],[9,10]]|[[8,9],[9,10]]|[[8,9],[9,10]]|[[8,9],[9,10]]|[[8.000000,9.000000],[9.000000,10.000000]]|[[8.000000,9.000000],[9.000000,10.000000]] +[[8,10]]|[[8,10]]|[[8,10]]|[[8,10]]|[[8,10]]|[[8,10]]|[[8,10]]|[[8,10]]|[[8,10]]|[[8.000000,10.000000]]|[[8.000000,10.000000]] +[[7,4],[8,8],[9]]|[[7,4],[8,8],[9]]|[[7,4],[8,8],[9]]|[[7,4],[8,8],[9]]|[[7,4],[8,8],[9]]|[[7,4],[8,8],[9]]|[[7,4],[8,8],[9]]|[[7,4],[8,8],[9]]|[[7,4],[8,8],[9]]|[[7.000000,4.000000],[8.000000,8.000000],[9.000000]]|[[7.000000,4.000000],[8.000000,8.000000],[9.000000]] +[[6],[7],[8]]|[[6],[7],[8]]|[[6],[7],[8]]|[[6],[7],[8]]|[[6],[7],[8]]|[[6],[7],[8]]|[[6],[7],[8]]|[[6],[7],[8]]|[[6],[7],[8]]|[[6.000000],[7.000000],[8.000000]]|[[6.000000],[7.000000],[8.000000]] +[[8]]|[[8]]|[[8]]|[[8]]|[[8]]|[[8]]|[[8]]|[[8]]|[[8]]|[[8.000000]]|[[8.000000]] +[[10]]|[[10]]|[[10]]|[[10]]|[[10]]|[[10]]|[[10]]|[[10]]|[[10]]|[[10.000000]]|[[10.000000]] +[[7],[10],[6,7]]|[[7],[10],[6,7]]|[[7],[10],[6,7]]|[[7],[10],[6,7]]|[[7],[10],[6,7]]|[[7],[10],[6,7]]|[[7],[10],[6,7]]|[[7],[10],[6,7]]|[[7],[10],[6,7]]|[[7.000000],[10.000000],[6.000000,7.000000]]|[[7.000000],[10.000000],[6.000000,7.000000]] +|||||||||| +-STATEMENT MATCH (p:person) WHERE size(p.courseScoresPerTerm) > 2 RETURN cast(p.courseScoresPerTerm, "STRING[]"); +---- 3 +[[7,4],[8,8],[9]] +[[6],[7],[8]] +[[7],[10],[6,7]] +-STATEMENT MATCH (p:person) WHERE size(p.courseScoresPerTerm) > 2 RETURN cast(cast(p.courseScoresPerTerm, "INT32[][]"), "UINT8[][]"); +---- 3 +[[7,4],[8,8],[9]] +[[6],[7],[8]] +[[7],[10],[6,7]] -LOG CastFixedListToString -STATEMENT MATCH (p:person) where p.ID > 1 RETURN string(p.grades) @@ -1083,15 +1125,6 @@ False -STATEMENT Return cast(cast(-15, "float"), "int128"), cast(cast(-1, "double"), "int128"), cast(cast(15, "float"), "int128"), cast(cast(1, "double"), "int128") ---- 1 -15|-1|15|1 - --LOG CastStringToFixedList --STATEMENT RETURN cast("[423, 321, 423]", "INT64[3]"), cast(null, "INT64[5]"), cast("[432.43214]", "FLOAT[1]"), cast("[4, -5]", "double[2]"), cast("[4234, 42312, 432, 1321]", "INT32[4]"), cast("[-32768]", "INT16[1]") ----- 1 -[423,321,423]||[432.432129]|[4.000000,-5.000000]|[4234,42312,432,1321]|[-32768] --STATEMENT Return cast(cast(-4324324, "int128"), "int64") ----- 1 --4324324 - -STATEMENT Return to_int64(to_int128(-4324324)) ---- 1 -4324324