diff --git a/src/function/built_in_vector_operations.cpp b/src/function/built_in_vector_operations.cpp index 9f6b0f17e08..c93628bb639 100644 --- a/src/function/built_in_vector_operations.cpp +++ b/src/function/built_in_vector_operations.cpp @@ -441,6 +441,8 @@ void BuiltInVectorOperations::registerListOperations() { vectorOperations.insert({ARRAY_HAS_FUNC_NAME, ListContainsVectorOperation::getDefinitions()}); vectorOperations.insert({LIST_SLICE_FUNC_NAME, ListSliceVectorOperation::getDefinitions()}); vectorOperations.insert({ARRAY_SLICE_FUNC_NAME, ListSliceVectorOperation::getDefinitions()}); + vectorOperations.insert({LIST_SORT_FUNC_NAME, ListSortVectorOperation::getDefinitions()}); + vectorOperations.insert({LIST_SUM_FUNC_NAME, ListSumVectorOperation::getDefinitions()}); } void BuiltInVectorOperations::registerInternalIDOperation() { diff --git a/src/function/vector_list_operation.cpp b/src/function/vector_list_operation.cpp index 1411b833c47..c82fa5b5110 100644 --- a/src/function/vector_list_operation.cpp +++ b/src/function/vector_list_operation.cpp @@ -9,6 +9,8 @@ #include "function/list/operations/list_position_operation.h" #include "function/list/operations/list_prepend_operation.h" #include "function/list/operations/list_slice_operation.h" +#include "function/list/operations/list_sort_operation.h" +#include "function/list/operations/list_sum_operation.h" #include "function/list/vector_list_operations.h" using namespace kuzu::common; @@ -320,5 +322,95 @@ std::unique_ptr ListSliceVectorOperation::bindFunc( return std::make_unique(arguments[0]->getDataType()); } +std::vector> ListSortVectorOperation::getDefinitions() { + std::vector> result; + result.push_back(std::make_unique(LIST_SORT_FUNC_NAME, + std::vector{VAR_LIST}, VAR_LIST, nullptr, nullptr, bindFunc, + false /* isVarlength*/)); + result.push_back(std::make_unique(LIST_SORT_FUNC_NAME, + std::vector{VAR_LIST, STRING}, VAR_LIST, nullptr, nullptr, bindFunc, + false /* isVarlength*/)); + result.push_back(std::make_unique(LIST_SORT_FUNC_NAME, + std::vector{VAR_LIST, STRING, STRING}, VAR_LIST, nullptr, nullptr, bindFunc, + false /* isVarlength*/)); + return result; +} + +std::unique_ptr ListSortVectorOperation::bindFunc( + const binder::expression_vector& arguments, FunctionDefinition* definition) { + auto vectorOperationDefinition = reinterpret_cast(definition); + switch (arguments[0]->dataType.getChildType()->getTypeID()) { + case INT64: { + vectorOperationDefinition->execFunc = getExecFunction(arguments); + } break; + case DOUBLE: { + vectorOperationDefinition->execFunc = getExecFunction(arguments); + } break; + case BOOL: { + vectorOperationDefinition->execFunc = getExecFunction(arguments); + } break; + case STRING: { + vectorOperationDefinition->execFunc = getExecFunction(arguments); + } break; + case DATE: { + vectorOperationDefinition->execFunc = getExecFunction(arguments); + } break; + case TIMESTAMP: { + vectorOperationDefinition->execFunc = getExecFunction(arguments); + } break; + case INTERVAL: { + vectorOperationDefinition->execFunc = getExecFunction(arguments); + } break; + default: { + throw common::NotImplementedException("ListSortVectorOperation::bindFunc"); + } + } + return std::make_unique(arguments[0]->getDataType()); +} + +template +scalar_exec_func ListSortVectorOperation::getExecFunction( + const binder::expression_vector& arguments) { + if (arguments.size() == 1) { + return UnaryListExecFunction>; + } else if (arguments.size() == 2) { + return BinaryListExecFunction>; + } else if (arguments.size() == 3) { + return TernaryListExecFunction>; + } else { + throw common::RuntimeException("Invalid number of arguments"); + } +} + +std::vector> ListSumVectorOperation::getDefinitions() { + std::vector> result; + result.push_back(std::make_unique(LIST_SUM_FUNC_NAME, + std::vector{VAR_LIST}, INT64, nullptr, nullptr, bindFunc, + false /* isVarlength*/)); + return result; +} + +std::unique_ptr ListSumVectorOperation::bindFunc( + const binder::expression_vector& arguments, FunctionDefinition* definition) { + auto vectorOperationDefinition = reinterpret_cast(definition); + auto resultType = *arguments[0]->getDataType().getChildType(); + switch (resultType.getTypeID()) { + case INT64: { + vectorOperationDefinition->execFunc = + UnaryListExecFunction; + } break; + case DOUBLE: { + vectorOperationDefinition->execFunc = + UnaryListExecFunction; + } break; + default: { + throw common::NotImplementedException("ListSumVectorOperation::bindFunc"); + } + } + return std::make_unique(resultType); +} + } // namespace function } // namespace kuzu diff --git a/src/include/common/expression_type.h b/src/include/common/expression_type.h index 53dae3e606c..23237f13d38 100644 --- a/src/include/common/expression_type.h +++ b/src/include/common/expression_type.h @@ -55,6 +55,8 @@ const std::string ARRAY_CONTAINS_FUNC_NAME = "ARRAY_CONTAINS"; const std::string ARRAY_HAS_FUNC_NAME = "ARRAY_HAS"; const std::string LIST_SLICE_FUNC_NAME = "LIST_SLICE"; const std::string ARRAY_SLICE_FUNC_NAME = "ARRAY_SLICE"; +const std::string LIST_SUM_FUNC_NAME = "LIST_SUM"; +const std::string LIST_SORT_FUNC_NAME = "LIST_SORT"; // struct const std::string STRUCT_PACK_FUNC_NAME = "STRUCT_PACK"; diff --git a/src/include/function/list/operations/list_sort_operation.h b/src/include/function/list/operations/list_sort_operation.h new file mode 100644 index 00000000000..28a0ba524b1 --- /dev/null +++ b/src/include/function/list/operations/list_sort_operation.h @@ -0,0 +1,122 @@ +#pragma once + +#include "common/vector/value_vector.h" + +namespace kuzu { +namespace function { +namespace operation { + +template +struct ListSort { + static inline void operation(common::list_entry_t& input, common::list_entry_t& result, + common::ValueVector& inputVector, common::ValueVector& resultVector) { + sortValues( + input, result, inputVector, resultVector, true /* ascOrder */, true /* nullFirst */); + } + + static inline void operation(common::list_entry_t& input, common::ku_string_t& sortOrder, + common::list_entry_t& result, common::ValueVector& inputVector, + common::ValueVector& valueVector, common::ValueVector& resultVector) { + sortValues( + input, result, inputVector, resultVector, isAscOrder(sortOrder.getAsString()), + true /* nullFirst */); + } + + static inline void operation(common::list_entry_t& input, common::ku_string_t& sortOrder, + common::ku_string_t& nullOrder, common::list_entry_t& result, + common::ValueVector& inputVector, common::ValueVector& resultVector) { + sortValues(input, result, inputVector, resultVector, isAscOrder(sortOrder.getAsString()), + isNullFirst(nullOrder.getAsString())); + } + + static inline bool isAscOrder(const std::string& sortOrder) { + if (sortOrder == "ASC") { + return true; + } else if (sortOrder == "DESC") { + return false; + } else { + throw common::RuntimeException("Invalid sortOrder"); + } + } + + static inline bool isNullFirst(const std::string& nullOrder) { + if (nullOrder == "NULLS FIRST") { + return true; + } else if (nullOrder == "NULLS LAST") { + return false; + } else { + throw common::RuntimeException("Invalid nullOrder"); + } + } + + static void sortValues(common::list_entry_t& input, common::list_entry_t& result, + common::ValueVector& inputVector, common::ValueVector& resultVector, bool ascOrder, + bool nullFirst) { + // Todo - Replace this sort implementation with radix_sort implementation: + // https://github.com/kuzudb/kuzu/issues/1536 + auto inputValues = common::ListVector::getListValues(&inputVector, input); + auto inputDataVector = common::ListVector::getDataVector(&inputVector); + + // Calculate null count + uint64_t nullCount = 0; + for (auto i = 0; i < input.size; i++) { + if (inputDataVector->isNull(input.offset + i)) { + nullCount += 1; + } + } + + result = common::ListVector::addList(&resultVector, input.size); + auto resultValues = common::ListVector::getListValues(&resultVector, result); + auto resultDataVector = common::ListVector::getDataVector(&resultVector); + auto numBytesPerValue = resultDataVector->getNumBytesPerValue(); + + // Add nulls first + if (nullFirst) { + for (auto i = 0; i < nullCount; i++) { + resultDataVector->setNull(result.offset + i, true); + resultValues += numBytesPerValue; + } + } + + // Add actual data + for (auto i = 0; i < input.size; i++) { + if (inputDataVector->isNull(input.offset + i)) { + inputValues += numBytesPerValue; + continue; + } + common::ValueVectorUtils::copyValue( + resultValues, *resultDataVector, inputValues, *inputDataVector); + resultValues += numBytesPerValue; + inputValues += numBytesPerValue; + } + + // Add nulls in the end + if (!nullFirst) { + for (auto i = input.size - nullCount; i < input.size; i++) { + resultDataVector->setNull(result.offset + i, true); + resultValues += numBytesPerValue; + } + } + + // Determine the starting and ending position of the data to be sorted + uint64_t sortStart = nullCount; + uint64_t sortEnd = input.size; + if (!nullFirst) { + sortStart = 0; + sortEnd = input.size - nullCount; + } + + // Sort the data based on order + auto newResultValues = + reinterpret_cast(common::ListVector::getListValues(&resultVector, result)); + if (ascOrder) { + std::sort(newResultValues + sortStart, newResultValues + sortEnd, std::less{}); + } else { + std::sort(newResultValues + sortStart, newResultValues + sortEnd, std::greater{}); + } + } +}; + +} // namespace operation +} // namespace function +} // namespace kuzu diff --git a/src/include/function/list/operations/list_sum_operation.h b/src/include/function/list/operations/list_sum_operation.h new file mode 100644 index 00000000000..abb1cac9106 --- /dev/null +++ b/src/include/function/list/operations/list_sum_operation.h @@ -0,0 +1,27 @@ +#pragma once + +#include "common/vector/value_vector.h" + +namespace kuzu { +namespace function { +namespace operation { + +struct ListSum { + template + static inline void operation(common::list_entry_t& input, T& result, + common::ValueVector& inputVector, common::ValueVector& resultVector) { + auto inputValues = + reinterpret_cast(common::ListVector::getListValues(&inputVector, input)); + auto inputDataVector = common::ListVector::getDataVector(&inputVector); + for (auto i = 0; i < input.size; i++) { + if (inputDataVector->isNull(input.offset + i)) { + continue; + } + result += inputValues[i]; + } + } +}; + +} // namespace operation +} // namespace function +} // namespace kuzu diff --git a/src/include/function/list/vector_list_operations.h b/src/include/function/list/vector_list_operations.h index 1f07b9f69bb..362bc337e57 100644 --- a/src/include/function/list/vector_list_operations.h +++ b/src/include/function/list/vector_list_operations.h @@ -76,6 +76,14 @@ struct VectorListOperations : public VectorOperations { } return result; } + + template + static void UnaryListExecFunction( + const std::vector>& params, + common::ValueVector& result) { + assert(params.size() == 1); + UnaryOperationExecutor::executeList(*params[0], result); + } }; struct ListCreationVectorOperation : public VectorListOperations { @@ -128,5 +136,19 @@ struct ListSliceVectorOperation : public VectorListOperations { const binder::expression_vector& arguments, FunctionDefinition* definition); }; +struct ListSortVectorOperation : public VectorListOperations { + static std::vector> getDefinitions(); + static std::unique_ptr bindFunc( + const binder::expression_vector& arguments, FunctionDefinition* definition); + template + static scalar_exec_func getExecFunction(const binder::expression_vector& arguments); +}; + +struct ListSumVectorOperation : public VectorListOperations { + static std::vector> getDefinitions(); + static std::unique_ptr bindFunc( + const binder::expression_vector& arguments, FunctionDefinition* definition); +}; + } // namespace function } // namespace kuzu diff --git a/src/include/function/unary_operation_executor.h b/src/include/function/unary_operation_executor.h index f8701e8d443..11518de0ffa 100644 --- a/src/include/function/unary_operation_executor.h +++ b/src/include/function/unary_operation_executor.h @@ -28,6 +28,15 @@ struct UnaryStringOperationWrapper { } }; +struct UnaryListOperationWrapper { + template + static inline void operation( + OPERAND_TYPE& input, RESULT_TYPE& result, void* leftValueVector, void* resultValueVector) { + FUNC::operation(input, result, *(common::ValueVector*)leftValueVector, + *(common::ValueVector*)resultValueVector); + } +}; + struct UnaryCastOperationWrapper { template static void operation( @@ -106,6 +115,11 @@ struct UnaryOperationExecutor { operand, result); } + template + static void executeList(common::ValueVector& operand, common::ValueVector& result) { + executeSwitch(operand, result); + } + template static void executeCast(common::ValueVector& operand, common::ValueVector& result) { executeSwitch(operand, result); diff --git a/test/test_files/tinysnb/function/list.test b/test/test_files/tinysnb/function/list.test index c89ecaebe23..eecb6ca5474 100644 --- a/test/test_files/tinysnb/function/list.test +++ b/test/test_files/tinysnb/function/list.test @@ -717,3 +717,53 @@ ert -QUERY RETURN array_contains([[100,200],[200,300],[300,400]], [100,200]) ---- 1 True + +-NAME ListSumSeq1 +-QUERY Return list_sum([1, 2, 3, NULL]); +---- 1 +6 + +-NAME ListSumSeq2 +-QUERY Return list_sum([1.1, 2.2, 3.3, NULL]); +---- 1 +6.600000 + +-NAME ListSortIntAsc +-QUERY Return list_sort([2, 3, 1, NULL, NULL]); +---- 1 +[,,1,2,3] + +-NAME ListSortIntDesc +-QUERY Return list_sort([2, 3, 1, 5, NULL], 'DESC'); +---- 1 +[,5,3,2,1] + +-NAME ListSortIntDescWithNullsLast +-QUERY Return list_sort([2, 3, 1, NULL], 'DESC', 'NULLS LAST'); +---- 1 +[3,2,1,] + +-NAME ListSortStringDesc +-QUERY Return list_sort(['sss', 'sssss', 'abs', NULL], 'DESC'); +---- 1 +[,sssss,sss,abs] + +-NAME ListSortStringAscWithNullsLast +-QUERY Return list_sort(['sss', 'sssss', 'abs', NULL], 'ASC', 'NULLS LAST'); +---- 1 +[abs,sss,sssss,] + +-NAME ListSortDoubleAscWithNullsLast +-QUERY Return list_sort([1.1, 2.3, 4.5, NULL], 'ASC', 'NULLS LAST'); +---- 1 +[1.100000,2.300000,4.500000,] + +-NAME ListSumInt +-QUERY Return list_sum([1, 2, 3, NULL]); +---- 1 +6 + +-NAME ListSumDouble +-QUERY Return list_sum([1.1, 2.2, 3.3, NULL]); +---- 1 +6.600000