diff --git a/src/function/vector_list_functions.cpp b/src/function/vector_list_functions.cpp index 586d8b5416..9f53527e73 100644 --- a/src/function/vector_list_functions.cpp +++ b/src/function/vector_list_functions.cpp @@ -139,78 +139,86 @@ function_set SizeFunction::getFunctionSet() { return result; } +template +static void BinaryExecListExtractFunction( + const std::vector>& params, common::ValueVector& result) { + assert(params.size() == 2); + BinaryFunctionExecutor::executeListExtract( + *params[0], *params[1], result); +} + std::unique_ptr ListExtractFunction::bindFunc( const binder::expression_vector& arguments, Function* function) { auto resultType = VarListType::getChildType(&arguments[0]->dataType); auto scalarFunction = reinterpret_cast(function); switch (resultType->getPhysicalType()) { case PhysicalTypeID::BOOL: { - scalarFunction->execFunc = ScalarFunction::BinaryExecListStructFunction; + scalarFunction->execFunc = + BinaryExecListExtractFunction; } break; case PhysicalTypeID::INT64: { - scalarFunction->execFunc = ScalarFunction::BinaryExecListStructFunction; + scalarFunction->execFunc = + BinaryExecListExtractFunction; } break; case PhysicalTypeID::INT32: { - scalarFunction->execFunc = ScalarFunction::BinaryExecListStructFunction; + scalarFunction->execFunc = + BinaryExecListExtractFunction; } break; case PhysicalTypeID::INT16: { - scalarFunction->execFunc = ScalarFunction::BinaryExecListStructFunction; + scalarFunction->execFunc = + BinaryExecListExtractFunction; } break; case PhysicalTypeID::INT8: { - scalarFunction->execFunc = ScalarFunction::BinaryExecListStructFunction; + scalarFunction->execFunc = + BinaryExecListExtractFunction; } break; case PhysicalTypeID::UINT64: { - scalarFunction->execFunc = ScalarFunction::BinaryExecListStructFunction; + scalarFunction->execFunc = + BinaryExecListExtractFunction; } break; case PhysicalTypeID::UINT32: { - scalarFunction->execFunc = ScalarFunction::BinaryExecListStructFunction; + scalarFunction->execFunc = + BinaryExecListExtractFunction; } break; case PhysicalTypeID::UINT16: { - scalarFunction->execFunc = ScalarFunction::BinaryExecListStructFunction; + scalarFunction->execFunc = + BinaryExecListExtractFunction; } break; case PhysicalTypeID::UINT8: { - scalarFunction->execFunc = ScalarFunction::BinaryExecListStructFunction; + scalarFunction->execFunc = + BinaryExecListExtractFunction; } break; case PhysicalTypeID::INT128: { - scalarFunction->execFunc = ScalarFunction::BinaryExecListStructFunction; + scalarFunction->execFunc = + BinaryExecListExtractFunction; } break; case PhysicalTypeID::DOUBLE: { - scalarFunction->execFunc = ScalarFunction::BinaryExecListStructFunction; + scalarFunction->execFunc = + BinaryExecListExtractFunction; } break; case PhysicalTypeID::FLOAT: { - scalarFunction->execFunc = ScalarFunction::BinaryExecListStructFunction; + scalarFunction->execFunc = + BinaryExecListExtractFunction; } break; case PhysicalTypeID::INTERVAL: { - scalarFunction->execFunc = ScalarFunction::BinaryExecListStructFunction; + scalarFunction->execFunc = + BinaryExecListExtractFunction; } break; case PhysicalTypeID::STRING: { - scalarFunction->execFunc = ScalarFunction::BinaryExecListStructFunction; + scalarFunction->execFunc = + BinaryExecListExtractFunction; } break; case PhysicalTypeID::VAR_LIST: { - scalarFunction->execFunc = ScalarFunction::BinaryExecListStructFunction; + scalarFunction->execFunc = + BinaryExecListExtractFunction; } break; case PhysicalTypeID::STRUCT: { - scalarFunction->execFunc = ScalarFunction::BinaryExecListStructFunction; + scalarFunction->execFunc = + BinaryExecListExtractFunction; } break; case PhysicalTypeID::INTERNAL_ID: { - scalarFunction->execFunc = ScalarFunction::BinaryExecListStructFunction; + scalarFunction->execFunc = + BinaryExecListExtractFunction; } break; default: { throw NotImplementedException("ListExtractFunction::bindFunc"); diff --git a/src/include/function/binary_function_executor.h b/src/include/function/binary_function_executor.h index 818448bea1..613fe9552b 100644 --- a/src/include/function/binary_function_executor.h +++ b/src/include/function/binary_function_executor.h @@ -18,8 +18,7 @@ struct BinaryFunctionWrapper { template static inline void operation(LEFT_TYPE& left, RIGHT_TYPE& right, RESULT_TYPE& result, common::ValueVector* /*leftValueVector*/, common::ValueVector* /*rightValueVector*/, - common::ValueVector* /*resultValueVector*/, void* /*dataPtr*/) { - + common::ValueVector* /*resultValueVector*/, uint64_t /*resultPos*/, void* /*dataPtr*/) { OP::operation(left, right, result); } }; @@ -28,16 +27,26 @@ struct BinaryListStructFunctionWrapper { template static inline void operation(LEFT_TYPE& left, RIGHT_TYPE& right, RESULT_TYPE& result, common::ValueVector* leftValueVector, common::ValueVector* rightValueVector, - common::ValueVector* resultValueVector, void* /*dataPtr*/) { + common::ValueVector* resultValueVector, uint64_t /*resultPos*/, void* /*dataPtr*/) { OP::operation(left, right, result, *leftValueVector, *rightValueVector, *resultValueVector); } }; +struct BinaryListExtractFunctionWrapper { + template + static inline void operation(LEFT_TYPE& left, RIGHT_TYPE& right, RESULT_TYPE& result, + common::ValueVector* leftValueVector, common::ValueVector* rightValueVector, + common::ValueVector* resultValueVector, uint64_t resultPos, void* /*dataPtr*/) { + OP::operation(left, right, result, *leftValueVector, *rightValueVector, *resultValueVector, + resultPos); + } +}; + struct BinaryStringFunctionWrapper { template static inline void operation(LEFT_TYPE& left, RIGHT_TYPE& right, RESULT_TYPE& result, common::ValueVector* /*leftValueVector*/, common::ValueVector* /*rightValueVector*/, - common::ValueVector* resultValueVector, void* /*dataPtr*/) { + common::ValueVector* resultValueVector, uint64_t /*resultPos*/, void* /*dataPtr*/) { OP::operation(left, right, result, *resultValueVector); } }; @@ -46,7 +55,7 @@ struct BinaryComparisonFunctionWrapper { template static inline void operation(LEFT_TYPE& left, RIGHT_TYPE& right, RESULT_TYPE& result, common::ValueVector* leftValueVector, common::ValueVector* rightValueVector, - common::ValueVector* /*resultValueVector*/, void* /*dataPtr*/) { + common::ValueVector* /*resultValueVector*/, uint64_t /*resultPos*/, void* /*dataPtr*/) { OP::operation(left, right, result, leftValueVector, rightValueVector); } }; @@ -55,7 +64,7 @@ struct BinaryUDFFunctionWrapper { template static inline void operation(LEFT_TYPE& left, RIGHT_TYPE& right, RESULT_TYPE& result, common::ValueVector* /*leftValueVector*/, common::ValueVector* /*rightValueVector*/, - common::ValueVector* /*resultValueVector*/, void* dataPtr) { + common::ValueVector* /*resultValueVector*/, uint64_t /*resultPos*/, void* dataPtr) { OP::operation(left, right, result, dataPtr); } }; @@ -69,7 +78,7 @@ struct BinaryFunctionExecutor { OP_WRAPPER::template operation( ((LEFT_TYPE*)left.getData())[lPos], ((RIGHT_TYPE*)right.getData())[rPos], ((RESULT_TYPE*)resultValueVector.getData())[resPos], &left, &right, &resultValueVector, - dataPtr); + resPos, dataPtr); } template + static void executeListExtract( + common::ValueVector& left, common::ValueVector& right, common::ValueVector& result) { + executeSwitch( + left, right, result, nullptr /* dataPtr */); + } + template static void executeComparison( common::ValueVector& left, common::ValueVector& right, common::ValueVector& result) { diff --git a/src/include/function/list/functions/list_extract_function.h b/src/include/function/list/functions/list_extract_function.h index 2d82baf134..df03e11688 100644 --- a/src/include/function/list/functions/list_extract_function.h +++ b/src/include/function/list/functions/list_extract_function.h @@ -12,17 +12,12 @@ namespace function { struct ListExtract { public: - template - static inline void setValue(T& src, T& dest, common::ValueVector& /*resultValueVector*/) { - dest = src; - } - // Note: this function takes in a 1-based position (The index of the first value in the list // is 1). template static inline void operation(common::list_entry_t& listEntry, int64_t pos, T& result, common::ValueVector& listVector, common::ValueVector& /*posVector*/, - common::ValueVector& resultVector) { + common::ValueVector& resultVector, uint64_t resPos) { if (pos == 0) { throw common::RuntimeException("List extract takes 1-based position."); } @@ -35,10 +30,13 @@ struct ListExtract { return; // TODO(Xiyang/Ziyi): we should fix when extracting last element of list. } auto listDataVector = common::ListVector::getDataVector(&listVector); - auto listValues = - common::ListVector::getListValuesWithOffset(&listVector, listEntry, upos - 1); - resultVector.copyFromVectorData( - reinterpret_cast(&result), listDataVector, listValues); + resultVector.setNull(resPos, listDataVector->isNull(listEntry.offset + upos - 1)); + if (!resultVector.isNull(resPos)) { + auto listValues = + common::ListVector::getListValuesWithOffset(&listVector, listEntry, upos - 1); + resultVector.copyFromVectorData( + reinterpret_cast(&result), listDataVector, listValues); + } } static inline void operation( @@ -51,11 +49,5 @@ struct ListExtract { } }; -template<> -inline void ListExtract::setValue( - common::ku_string_t& src, common::ku_string_t& dest, common::ValueVector& resultValueVector) { - common::StringVector::addString(&resultValueVector, dest, src); -} - } // namespace function } // namespace kuzu diff --git a/src/include/function/schema/label_functions.h b/src/include/function/schema/label_functions.h index 052503a138..43be0a67d5 100644 --- a/src/include/function/schema/label_functions.h +++ b/src/include/function/schema/label_functions.h @@ -9,10 +9,10 @@ namespace function { struct Label { static inline void operation(common::internalID_t& left, common::list_entry_t& right, common::ku_string_t& result, common::ValueVector& leftVector, - common::ValueVector& rightVector, common::ValueVector& resultVector) { + common::ValueVector& rightVector, common::ValueVector& resultVector, uint64_t resPos) { assert(left.tableID < right.size); ListExtract::operation(right, left.tableID + 1 /* listExtract requires 1-based index */, - result, rightVector, leftVector, resultVector); + result, rightVector, leftVector, resultVector, resPos); } }; diff --git a/src/include/function/schema/vector_label_functions.h b/src/include/function/schema/vector_label_functions.h index 9dbcfe1d37..a4a5da99d3 100644 --- a/src/include/function/schema/vector_label_functions.h +++ b/src/include/function/schema/vector_label_functions.h @@ -10,7 +10,7 @@ struct LabelFunction { static void execFunction(const std::vector>& params, common::ValueVector& result) { assert(params.size() == 2); - BinaryFunctionExecutor::executeListStruct(*params[0], *params[1], result); } }; diff --git a/test/test_files/tinysnb/function/list.test b/test/test_files/tinysnb/function/list.test index 288d9a8858..9ee8177522 100644 --- a/test/test_files/tinysnb/function/list.test +++ b/test/test_files/tinysnb/function/list.test @@ -307,6 +307,20 @@ n sdwe ad +-LOG ListExtractWithNull +-STATEMENT RETURN list_extract([1,3,null,null,2],3) +---- 1 + + +-LOG ExtractNullList +-STATEMENT RETURN list_extract(null,1) +---- 1 + + +-LOG ListExtractNullPos +-STATEMENT RETURN LIST_EXTRACT([3,4,5],NULL) +---- 1 + -LOG SliceUTF8String -STATEMENT Return '这是一个中文句子'[2:5]