From ac08521d706159ab2a425a9f9fc5692647219c5b Mon Sep 17 00:00:00 2001 From: ziyi chen Date: Fri, 26 Apr 2024 14:36:02 +0800 Subject: [PATCH] Clean up list function implementations --- src/common/types/types.cpp | 1 + src/function/list/CMakeLists.txt | 5 +- src/function/list/list_agg_function.cpp | 100 +++++----- src/function/list/list_any_value_function.cpp | 118 +++--------- src/function/list/list_append_function.cpp | 60 ++++++ .../list_binary_right_switch_function.cpp | 173 ------------------ src/function/list/list_concat_function.cpp | 18 ++ src/function/list/list_contains_function.cpp | 42 +++++ src/function/list/list_distinct_function.cpp | 21 ++- src/function/list/list_extract_function.cpp | 4 +- src/function/list/list_position_function.cpp | 32 ++++ src/function/list/list_prepend_function.cpp | 58 ++++++ src/function/list/list_range_function.cpp | 106 +++++------ src/function/list/list_reverse_function.cpp | 21 ++- src/function/list/list_slice_function.cpp | 49 ++++- src/function/list/list_unique_function.cpp | 32 ++++ src/function/vector_blob_functions.cpp | 4 +- src/include/common/types/types.h | 4 + .../list/functions/list_any_value_function.h | 28 --- .../list/functions/list_append_function.h | 28 --- .../list/functions/list_concat_function.h | 18 +- .../list/functions/list_contains_function.h | 21 --- .../list/functions/list_distinct_function.h | 28 --- .../list/functions/list_len_function.h | 2 +- .../list/functions/list_position_function.h | 2 +- .../list/functions/list_prepend_function.h | 28 --- .../list/functions/list_product_function.h | 24 --- .../list/functions/list_range_function.h | 46 ----- .../list/functions/list_reverse_function.h | 23 --- .../list/functions/list_slice_function.h | 52 ------ .../list/functions/list_sum_function.h | 24 --- .../list/functions/list_unique_function.h | 29 +-- src/include/function/scalar_function.h | 10 + .../string/functions/ltrim_function.h | 2 - src/storage/store/node_table.cpp | 2 +- 35 files changed, 473 insertions(+), 742 deletions(-) create mode 100644 src/function/list/list_append_function.cpp delete mode 100644 src/function/list/list_binary_right_switch_function.cpp create mode 100644 src/function/list/list_contains_function.cpp create mode 100644 src/function/list/list_position_function.cpp create mode 100644 src/function/list/list_prepend_function.cpp delete mode 100644 src/include/function/list/functions/list_any_value_function.h delete mode 100644 src/include/function/list/functions/list_append_function.h delete mode 100644 src/include/function/list/functions/list_contains_function.h delete mode 100644 src/include/function/list/functions/list_distinct_function.h delete mode 100644 src/include/function/list/functions/list_prepend_function.h delete mode 100644 src/include/function/list/functions/list_product_function.h delete mode 100644 src/include/function/list/functions/list_range_function.h delete mode 100644 src/include/function/list/functions/list_reverse_function.h delete mode 100644 src/include/function/list/functions/list_slice_function.h delete mode 100644 src/include/function/list/functions/list_sum_function.h diff --git a/src/common/types/types.cpp b/src/common/types/types.cpp index a902c6a7f7..537e5a3852 100644 --- a/src/common/types/types.cpp +++ b/src/common/types/types.cpp @@ -928,6 +928,7 @@ std::vector LogicalTypeUtils::getNumericalLogicalTypeIDs() { LogicalTypeID::FLOAT, LogicalTypeID::SERIAL}; } +// TODO(Ziyi): Support int128 and uint types here. std::vector LogicalTypeUtils::getIntegerLogicalTypeIDs() { return std::vector{LogicalTypeID::INT64, LogicalTypeID::INT32, LogicalTypeID::INT16, LogicalTypeID::INT8, LogicalTypeID::SERIAL}; diff --git a/src/function/list/CMakeLists.txt b/src/function/list/CMakeLists.txt index e163fa2cbf..a1a54c48a3 100644 --- a/src/function/list/CMakeLists.txt +++ b/src/function/list/CMakeLists.txt @@ -2,8 +2,9 @@ add_library(kuzu_list_function OBJECT list_agg_function.cpp list_any_value_function.cpp - list_binary_right_switch_function.cpp + list_append_function.cpp list_concat_function.cpp + list_contains_function.cpp list_creation.cpp list_distinct_function.cpp list_extract_function.cpp @@ -13,6 +14,8 @@ add_library(kuzu_list_function list_sort_function.cpp list_to_string_function.cpp list_unique_function.cpp + list_prepend_function.cpp + list_position_function.cpp size_function.cpp) set(ALL_OBJECT_FILES diff --git a/src/function/list/list_agg_function.cpp b/src/function/list/list_agg_function.cpp index 5528205fd4..e4192fe0b2 100644 --- a/src/function/list/list_agg_function.cpp +++ b/src/function/list/list_agg_function.cpp @@ -1,6 +1,5 @@ #include "common/exception/binder.h" -#include "function/list/functions/list_product_function.h" -#include "function/list/functions/list_sum_function.h" +#include "common/type_utils.h" #include "function/list/vector_list_functions.h" #include "function/scalar_function.h" @@ -12,75 +11,64 @@ namespace function { template static std::unique_ptr bindFuncListAggr( const binder::expression_vector& arguments, Function* function) { - auto scalarFunction = ku_dynamic_cast(function); + auto scalarFunction = function->ptrCast(); auto resultType = ListType::getChildType(&arguments[0]->dataType); - switch (resultType->getLogicalTypeID()) { - case LogicalTypeID::SERIAL: - case LogicalTypeID::INT64: { - scalarFunction->execFunc = - ScalarFunction::UnaryExecNestedTypeFunction; - } break; - case LogicalTypeID::INT32: { - scalarFunction->execFunc = - ScalarFunction::UnaryExecNestedTypeFunction; - } break; - case LogicalTypeID::INT16: { - scalarFunction->execFunc = - ScalarFunction::UnaryExecNestedTypeFunction; - } break; - case LogicalTypeID::INT8: { - scalarFunction->execFunc = - ScalarFunction::UnaryExecNestedTypeFunction; - } break; - case LogicalTypeID::UINT64: { - scalarFunction->execFunc = - ScalarFunction::UnaryExecNestedTypeFunction; - } break; - case LogicalTypeID::UINT32: { - scalarFunction->execFunc = - ScalarFunction::UnaryExecNestedTypeFunction; - } break; - case LogicalTypeID::UINT16: { - scalarFunction->execFunc = - ScalarFunction::UnaryExecNestedTypeFunction; - } break; - case LogicalTypeID::UINT8: { - scalarFunction->execFunc = - ScalarFunction::UnaryExecNestedTypeFunction; - } break; - case LogicalTypeID::INT128: { - scalarFunction->execFunc = - ScalarFunction::UnaryExecNestedTypeFunction; - } break; - case LogicalTypeID::DOUBLE: { - scalarFunction->execFunc = - ScalarFunction::UnaryExecNestedTypeFunction; - } break; - case LogicalTypeID::FLOAT: { - scalarFunction->execFunc = - ScalarFunction::UnaryExecNestedTypeFunction; - } break; - default: { - throw BinderException(stringFormat("Unsupported inner data type for {}: {}", function->name, - LogicalTypeUtils::toString(resultType->getLogicalTypeID()))); - } - } + TypeUtils::visit( + resultType->getLogicalTypeID(), + [&scalarFunction](T) { + scalarFunction->execFunc = + ScalarFunction::UnaryExecNestedTypeFunction; + }, + [&function, &resultType](auto) { + throw BinderException(stringFormat("Unsupported inner data type for {}: {}", + function->name, LogicalTypeUtils::toString(resultType->getLogicalTypeID()))); + }); return FunctionBindData::getSimpleBindData(arguments, *resultType); } +struct ListSum { + template + static void operation(common::list_entry_t& input, T& result, common::ValueVector& inputVector, + common::ValueVector& /*resultVector*/) { + auto inputDataVector = common::ListVector::getDataVector(&inputVector); + result = 0; + for (auto i = 0u; i < input.size; i++) { + if (inputDataVector->isNull(input.offset + i)) { + continue; + } + result += inputDataVector->getValue(input.offset + i); + } + } +}; + function_set ListSumFunction::getFunctionSet() { function_set result; result.push_back( std::make_unique(name, std::vector{LogicalTypeID::LIST}, - LogicalTypeID::INT64, nullptr, nullptr, bindFuncListAggr)); + LogicalTypeID::INT64, bindFuncListAggr)); return result; } +struct ListProduct { + template + static void operation(common::list_entry_t& input, T& result, common::ValueVector& inputVector, + common::ValueVector& /*resultVector*/) { + auto inputDataVector = common::ListVector::getDataVector(&inputVector); + result = 1; + for (auto i = 0u; i < input.size; i++) { + if (inputDataVector->isNull(input.offset + i)) { + continue; + } + result *= inputDataVector->getValue(input.offset + i); + } + } +}; + function_set ListProductFunction::getFunctionSet() { function_set result; result.push_back( std::make_unique(name, std::vector{LogicalTypeID::LIST}, - LogicalTypeID::INT64, nullptr, nullptr, bindFuncListAggr)); + LogicalTypeID::INT64, bindFuncListAggr)); return result; } diff --git a/src/function/list/list_any_value_function.cpp b/src/function/list/list_any_value_function.cpp index 4b2eee5dfa..b1622408fa 100644 --- a/src/function/list/list_any_value_function.cpp +++ b/src/function/list/list_any_value_function.cpp @@ -1,5 +1,4 @@ -#include "function/list/functions/list_any_value_function.h" - +#include "common/type_utils.h" #include "function/list/vector_list_functions.h" #include "function/scalar_function.h" @@ -8,104 +7,33 @@ using namespace kuzu::common; namespace kuzu { namespace function { +struct ListAnyValue { + template + static void operation(common::list_entry_t& input, T& result, common::ValueVector& inputVector, + common::ValueVector& resultVector) { + auto inputValues = common::ListVector::getListValues(&inputVector, input); + auto inputDataVector = common::ListVector::getDataVector(&inputVector); + auto numBytesPerValue = inputDataVector->getNumBytesPerValue(); + + for (auto i = 0u; i < input.size; i++) { + if (!(inputDataVector->isNull(input.offset + i))) { + resultVector.copyFromVectorData(reinterpret_cast(&result), + inputDataVector, inputValues); + break; + } + inputValues += numBytesPerValue; + } + } +}; + static std::unique_ptr bindFunc(const binder::expression_vector& arguments, Function* function) { auto scalarFunction = ku_dynamic_cast(function); auto resultType = ListType::getChildType(&arguments[0]->dataType); - switch (resultType->getLogicalTypeID()) { - case LogicalTypeID::SERIAL: - case LogicalTypeID::INT64: { - scalarFunction->execFunc = - ScalarFunction::UnaryExecNestedTypeFunction; - } break; - case LogicalTypeID::INT32: { - scalarFunction->execFunc = - ScalarFunction::UnaryExecNestedTypeFunction; - } break; - case LogicalTypeID::INT16: { - scalarFunction->execFunc = - ScalarFunction::UnaryExecNestedTypeFunction; - } break; - case LogicalTypeID::INT8: { - scalarFunction->execFunc = - ScalarFunction::UnaryExecNestedTypeFunction; - } break; - case LogicalTypeID::UINT64: { - scalarFunction->execFunc = - ScalarFunction::UnaryExecNestedTypeFunction; - } break; - case LogicalTypeID::UINT32: { - scalarFunction->execFunc = - ScalarFunction::UnaryExecNestedTypeFunction; - } break; - case LogicalTypeID::UINT16: { - scalarFunction->execFunc = - ScalarFunction::UnaryExecNestedTypeFunction; - } break; - case LogicalTypeID::UINT8: { - scalarFunction->execFunc = - ScalarFunction::UnaryExecNestedTypeFunction; - } break; - case LogicalTypeID::INT128: { - scalarFunction->execFunc = - ScalarFunction::UnaryExecNestedTypeFunction; - } break; - case LogicalTypeID::DOUBLE: { + TypeUtils::visit(resultType->getPhysicalType(), [&scalarFunction](T) { scalarFunction->execFunc = - ScalarFunction::UnaryExecNestedTypeFunction; - } break; - case LogicalTypeID::FLOAT: { - scalarFunction->execFunc = - ScalarFunction::UnaryExecNestedTypeFunction; - } break; - case LogicalTypeID::BOOL: { - scalarFunction->execFunc = - ScalarFunction::UnaryExecNestedTypeFunction; - } break; - case LogicalTypeID::STRING: { - scalarFunction->execFunc = - ScalarFunction::UnaryExecNestedTypeFunction; - } break; - case LogicalTypeID::DATE: { - scalarFunction->execFunc = - ScalarFunction::UnaryExecNestedTypeFunction; - } break; - case LogicalTypeID::TIMESTAMP: { - scalarFunction->execFunc = - ScalarFunction::UnaryExecNestedTypeFunction; - } break; - case LogicalTypeID::TIMESTAMP_MS: { - scalarFunction->execFunc = - ScalarFunction::UnaryExecNestedTypeFunction; - } break; - case LogicalTypeID::TIMESTAMP_NS: { - scalarFunction->execFunc = - ScalarFunction::UnaryExecNestedTypeFunction; - } break; - case LogicalTypeID::TIMESTAMP_SEC: { - scalarFunction->execFunc = ScalarFunction::UnaryExecNestedTypeFunction; - } break; - case LogicalTypeID::TIMESTAMP_TZ: { - scalarFunction->execFunc = - ScalarFunction::UnaryExecNestedTypeFunction; - } break; - case LogicalTypeID::INTERVAL: { - scalarFunction->execFunc = - ScalarFunction::UnaryExecNestedTypeFunction; - } break; - case LogicalTypeID::LIST: { - scalarFunction->execFunc = - ScalarFunction::UnaryExecNestedTypeFunction; - } break; - case LogicalTypeID::INTERNAL_ID: { - scalarFunction->execFunc = - ScalarFunction::UnaryExecNestedTypeFunction; - } break; - default: { - KU_UNREACHABLE; - } - } + ScalarFunction::UnaryExecNestedTypeFunction; + }); return FunctionBindData::getSimpleBindData(arguments, *resultType); } diff --git a/src/function/list/list_append_function.cpp b/src/function/list/list_append_function.cpp new file mode 100644 index 0000000000..0fd233d435 --- /dev/null +++ b/src/function/list/list_append_function.cpp @@ -0,0 +1,60 @@ +#include "common/exception/binder.h" +#include "common/exception/message.h" +#include "common/type_utils.h" +#include "function/list/vector_list_functions.h" +#include "function/scalar_function.h" + +using namespace kuzu::common; + +namespace kuzu { +namespace function { + +struct ListAppend { + template + static void operation(common::list_entry_t& listEntry, T& value, common::list_entry_t& result, + common::ValueVector& listVector, common::ValueVector& valueVector, + common::ValueVector& resultVector) { + result = common::ListVector::addList(&resultVector, listEntry.size + 1); + auto listDataVector = common::ListVector::getDataVector(&listVector); + auto listPos = listEntry.offset; + auto resultDataVector = common::ListVector::getDataVector(&resultVector); + auto resultPos = result.offset; + for (auto i = 0u; i < listEntry.size; i++) { + resultDataVector->copyFromVectorData(resultPos++, listDataVector, listPos++); + } + resultDataVector->copyFromVectorData( + resultDataVector->getData() + resultPos * resultDataVector->getNumBytesPerValue(), + &valueVector, reinterpret_cast(&value)); + } +}; + +static void validateArgumentType(const binder::expression_vector& arguments) { + if (*ListType::getChildType(&arguments[0]->dataType) != arguments[1]->getDataType()) { + throw BinderException( + ExceptionMessage::listFunctionIncompatibleChildrenType(ListAppendFunction::name, + arguments[0]->getDataType().toString(), arguments[1]->getDataType().toString())); + } +} + +static std::unique_ptr bindFunc(const binder::expression_vector& arguments, + Function* function) { + validateArgumentType(arguments); + auto scalarFunction = function->ptrCast(); + TypeUtils::visit(arguments[1]->getDataType().getPhysicalType(), [&scalarFunction]( + T) { + scalarFunction->execFunc = + ScalarFunction::BinaryExecListStructFunction; + }); + return FunctionBindData::getSimpleBindData(arguments, arguments[0]->getDataType()); +} + +function_set ListAppendFunction::getFunctionSet() { + function_set result; + result.push_back(std::make_unique(name, + std::vector{LogicalTypeID::LIST, LogicalTypeID::ANY}, LogicalTypeID::LIST, + nullptr /* execFunc */, nullptr /* selectFunc */, bindFunc)); + return result; +} + +} // namespace function +} // namespace kuzu diff --git a/src/function/list/list_binary_right_switch_function.cpp b/src/function/list/list_binary_right_switch_function.cpp deleted file mode 100644 index ada70d1734..0000000000 --- a/src/function/list/list_binary_right_switch_function.cpp +++ /dev/null @@ -1,173 +0,0 @@ -#include "common/exception/binder.h" -#include "common/exception/message.h" -#include "function/list/functions/list_append_function.h" -#include "function/list/functions/list_contains_function.h" -#include "function/list/functions/list_position_function.h" -#include "function/list/functions/list_prepend_function.h" -#include "function/list/vector_list_functions.h" -#include "function/scalar_function.h" - -using namespace kuzu::common; - -namespace kuzu { -namespace function { - -template -static scalar_func_exec_t getBinaryListExecFuncSwitchRight(const LogicalType& rightType) { - scalar_func_exec_t execFunc; - switch (rightType.getPhysicalType()) { - case PhysicalTypeID::BOOL: { - execFunc = ScalarFunction::BinaryExecListStructFunction; - } break; - case PhysicalTypeID::INT64: { - execFunc = ScalarFunction::BinaryExecListStructFunction; - } break; - case PhysicalTypeID::INT32: { - execFunc = ScalarFunction::BinaryExecListStructFunction; - } break; - case PhysicalTypeID::INT16: { - execFunc = ScalarFunction::BinaryExecListStructFunction; - } break; - case PhysicalTypeID::INT8: { - execFunc = ScalarFunction::BinaryExecListStructFunction; - } break; - case PhysicalTypeID::UINT64: { - execFunc = ScalarFunction::BinaryExecListStructFunction; - } break; - case PhysicalTypeID::UINT32: { - execFunc = ScalarFunction::BinaryExecListStructFunction; - } break; - case PhysicalTypeID::UINT16: { - execFunc = ScalarFunction::BinaryExecListStructFunction; - } break; - case PhysicalTypeID::UINT8: { - execFunc = ScalarFunction::BinaryExecListStructFunction; - } break; - case PhysicalTypeID::INT128: { - execFunc = ScalarFunction::BinaryExecListStructFunction; - } break; - case PhysicalTypeID::DOUBLE: { - execFunc = ScalarFunction::BinaryExecListStructFunction; - } break; - case PhysicalTypeID::FLOAT: { - execFunc = ScalarFunction::BinaryExecListStructFunction; - } break; - case PhysicalTypeID::STRING: { - execFunc = ScalarFunction::BinaryExecListStructFunction; - } break; - case PhysicalTypeID::INTERVAL: { - execFunc = ScalarFunction::BinaryExecListStructFunction; - } break; - case PhysicalTypeID::INTERNAL_ID: { - execFunc = ScalarFunction::BinaryExecListStructFunction; - } break; - case PhysicalTypeID::ARRAY: - case PhysicalTypeID::LIST: { - execFunc = ScalarFunction::BinaryExecListStructFunction; - } break; - case PhysicalTypeID::STRUCT: { - execFunc = ScalarFunction::BinaryExecListStructFunction; - } break; - default: { - KU_UNREACHABLE; - } - } - return execFunc; -} - -static std::unique_ptr ListAppendBindFunc( - const binder::expression_vector& arguments, Function* function) { - if (*ListType::getChildType(&arguments[0]->dataType) != arguments[1]->getDataType()) { - throw BinderException( - ExceptionMessage::listFunctionIncompatibleChildrenType(ListAppendFunction::name, - arguments[0]->getDataType().toString(), arguments[1]->getDataType().toString())); - } - auto resultType = arguments[0]->getDataType(); - auto scalarFunction = ku_dynamic_cast(function); - scalarFunction->execFunc = - getBinaryListExecFuncSwitchRight(arguments[1]->getDataType()); - return FunctionBindData::getSimpleBindData(arguments, resultType); -} - -function_set ListAppendFunction::getFunctionSet() { - function_set result; - result.push_back(std::make_unique(name, - std::vector{LogicalTypeID::LIST, LogicalTypeID::ANY}, LogicalTypeID::LIST, - nullptr, nullptr, ListAppendBindFunc)); - return result; -} - -static std::unique_ptr ListPrependBindFunc( - const binder::expression_vector& arguments, Function* function) { - if (arguments[0]->getDataType().getLogicalTypeID() != LogicalTypeID::ANY && - arguments[1]->dataType != *ListType::getChildType(&arguments[0]->dataType)) { - throw BinderException( - ExceptionMessage::listFunctionIncompatibleChildrenType(ListPrependFunction::name, - arguments[0]->getDataType().toString(), arguments[1]->getDataType().toString())); - } - auto resultType = arguments[0]->getDataType(); - auto scalarFunction = ku_dynamic_cast(function); - scalarFunction->execFunc = - getBinaryListExecFuncSwitchRight(arguments[1]->getDataType()); - return FunctionBindData::getSimpleBindData(arguments, resultType); -} - -function_set ListPrependFunction::getFunctionSet() { - function_set result; - result.push_back(std::make_unique(name, - std::vector{LogicalTypeID::LIST, LogicalTypeID::ANY}, LogicalTypeID::LIST, - nullptr, nullptr, ListPrependBindFunc)); - return result; -} - -static std::unique_ptr ListPositionBindFunc( - const binder::expression_vector& arguments, Function* function) { - auto scalarFunction = ku_dynamic_cast(function); - scalarFunction->execFunc = - getBinaryListExecFuncSwitchRight(arguments[1]->getDataType()); - return FunctionBindData::getSimpleBindData(arguments, *LogicalType::INT64()); -} - -function_set ListPositionFunction::getFunctionSet() { - function_set result; - result.push_back(std::make_unique(name, - std::vector{LogicalTypeID::LIST, LogicalTypeID::ANY}, LogicalTypeID::INT64, - nullptr, nullptr, ListPositionBindFunc)); - return result; -} - -static std::unique_ptr ListContainsBindFunc( - const binder::expression_vector& arguments, Function* function) { - auto scalarFunction = function->ptrCast(); - scalarFunction->execFunc = - getBinaryListExecFuncSwitchRight(arguments[1]->getDataType()); - return FunctionBindData::getSimpleBindData(arguments, *LogicalType::BOOL()); -} - -function_set ListContainsFunction::getFunctionSet() { - function_set result; - result.push_back(std::make_unique(name, - std::vector{LogicalTypeID::LIST, LogicalTypeID::ANY}, LogicalTypeID::BOOL, - nullptr, nullptr, ListContainsBindFunc)); - return result; -} - -} // namespace function -} // namespace kuzu diff --git a/src/function/list/list_concat_function.cpp b/src/function/list/list_concat_function.cpp index 9cedc21d37..27ac7976d4 100644 --- a/src/function/list/list_concat_function.cpp +++ b/src/function/list/list_concat_function.cpp @@ -10,6 +10,24 @@ using namespace kuzu::common; namespace kuzu { namespace function { +void ListConcat::operation(common::list_entry_t& left, common::list_entry_t& right, + common::list_entry_t& result, common::ValueVector& leftVector, common::ValueVector& rightVector, + common::ValueVector& resultVector) { + result = common::ListVector::addList(&resultVector, left.size + right.size); + auto resultDataVector = common::ListVector::getDataVector(&resultVector); + auto resultPos = result.offset; + auto leftDataVector = common::ListVector::getDataVector(&leftVector); + auto leftPos = left.offset; + for (auto i = 0u; i < left.size; i++) { + resultDataVector->copyFromVectorData(resultPos++, leftDataVector, leftPos++); + } + auto rightDataVector = common::ListVector::getDataVector(&rightVector); + auto rightPos = right.offset; + for (auto i = 0u; i < right.size; i++) { + resultDataVector->copyFromVectorData(resultPos++, rightDataVector, rightPos++); + } +} + std::unique_ptr ListConcatFunction::bindFunc( const binder::expression_vector& arguments, Function* /*function*/) { if (arguments[0]->getDataType() != arguments[1]->getDataType()) { diff --git a/src/function/list/list_contains_function.cpp b/src/function/list/list_contains_function.cpp new file mode 100644 index 0000000000..21c409c091 --- /dev/null +++ b/src/function/list/list_contains_function.cpp @@ -0,0 +1,42 @@ +#include "common/type_utils.h" +#include "function/list/functions/list_position_function.h" +#include "function/list/vector_list_functions.h" +#include "function/scalar_function.h" + +using namespace kuzu::common; + +namespace kuzu { +namespace function { + +struct ListContains { + template + static void operation(common::list_entry_t& list, T& element, uint8_t& result, + common::ValueVector& listVector, common::ValueVector& elementVector, + common::ValueVector& resultVector) { + int64_t pos; + ListPosition::operation(list, element, pos, listVector, elementVector, resultVector); + result = (pos != 0); + } +}; + +static std::unique_ptr bindFunc(const binder::expression_vector& arguments, + Function* function) { + auto scalarFunction = function->ptrCast(); + TypeUtils::visit(arguments[1]->getDataType().getPhysicalType(), [&scalarFunction]( + T) { + scalarFunction->execFunc = + ScalarFunction::BinaryExecListStructFunction; + }); + return FunctionBindData::getSimpleBindData(arguments, *LogicalType::BOOL()); +} + +function_set ListContainsFunction::getFunctionSet() { + function_set result; + result.push_back(std::make_unique(name, + std::vector{LogicalTypeID::LIST, LogicalTypeID::ANY}, LogicalTypeID::BOOL, + nullptr, nullptr, bindFunc)); + return result; +} + +} // namespace function +} // namespace kuzu diff --git a/src/function/list/list_distinct_function.cpp b/src/function/list/list_distinct_function.cpp index 90bbdae3ce..bd41ad6418 100644 --- a/src/function/list/list_distinct_function.cpp +++ b/src/function/list/list_distinct_function.cpp @@ -1,5 +1,4 @@ -#include "function/list/functions/list_distinct_function.h" - +#include "function/list/functions/list_unique_function.h" #include "function/list/vector_list_functions.h" #include "function/scalar_function.h" @@ -8,6 +7,24 @@ using namespace kuzu::common; namespace kuzu { namespace function { +struct ListDistinct { + static void operation(common::list_entry_t& input, common::list_entry_t& result, + common::ValueVector& inputVector, common::ValueVector& resultVector) { + auto numUniqueValues = ListUnique::appendListElementsToValueSet(input, inputVector); + result = common::ListVector::addList(&resultVector, numUniqueValues); + auto resultDataVector = common::ListVector::getDataVector(&resultVector); + auto resultDataVectorBuffer = + common::ListVector::getListValuesWithOffset(&resultVector, result, 0 /* offset */); + ListUnique::appendListElementsToValueSet(input, inputVector, nullptr, + [&resultDataVector, &resultDataVectorBuffer](common::ValueVector& dataVector, + uint64_t pos) -> void { + resultDataVector->copyFromVectorData(resultDataVectorBuffer, &dataVector, + dataVector.getData() + pos * dataVector.getNumBytesPerValue()); + resultDataVectorBuffer += dataVector.getNumBytesPerValue(); + }); + } +}; + static std::unique_ptr bindFunc(const binder::expression_vector& arguments, Function*) { return FunctionBindData::getSimpleBindData(arguments, arguments[0]->getDataType()); diff --git a/src/function/list/list_extract_function.cpp b/src/function/list/list_extract_function.cpp index fba5fe29fb..996092c46a 100644 --- a/src/function/list/list_extract_function.cpp +++ b/src/function/list/list_extract_function.cpp @@ -19,8 +19,8 @@ static void BinaryExecListExtractFunction(const std::vector ListExtractBindFunc( const binder::expression_vector& arguments, Function* function) { auto resultType = ListType::getChildType(&arguments[0]->dataType); - auto scalarFunction = ku_dynamic_cast(function); - TypeUtils::visit(resultType->getPhysicalType(), [&](T) { + auto scalarFunction = function->ptrCast(); + TypeUtils::visit(resultType->getPhysicalType(), [&scalarFunction](T) { scalarFunction->execFunc = BinaryExecListExtractFunction; }); diff --git a/src/function/list/list_position_function.cpp b/src/function/list/list_position_function.cpp new file mode 100644 index 0000000000..27fed50f83 --- /dev/null +++ b/src/function/list/list_position_function.cpp @@ -0,0 +1,32 @@ +#include "function/list/functions/list_position_function.h" + +#include "common/type_utils.h" +#include "function/list/vector_list_functions.h" +#include "function/scalar_function.h" + +using namespace kuzu::common; + +namespace kuzu { +namespace function { + +static std::unique_ptr ListPositionBindFunc( + const binder::expression_vector& arguments, Function* function) { + auto scalarFunction = function->ptrCast(); + TypeUtils::visit(arguments[1]->getDataType().getPhysicalType(), [&scalarFunction]( + T) { + scalarFunction->execFunc = + ScalarFunction::BinaryExecListStructFunction; + }); + return FunctionBindData::getSimpleBindData(arguments, *LogicalType::INT64()); +} + +function_set ListPositionFunction::getFunctionSet() { + function_set result; + result.push_back(std::make_unique(name, + std::vector{LogicalTypeID::LIST, LogicalTypeID::ANY}, LogicalTypeID::INT64, + ListPositionBindFunc)); + return result; +} + +} // namespace function +} // namespace kuzu diff --git a/src/function/list/list_prepend_function.cpp b/src/function/list/list_prepend_function.cpp new file mode 100644 index 0000000000..c40e8950bf --- /dev/null +++ b/src/function/list/list_prepend_function.cpp @@ -0,0 +1,58 @@ +#include "common/exception/binder.h" +#include "common/exception/message.h" +#include "common/type_utils.h" +#include "function/list/vector_list_functions.h" +#include "function/scalar_function.h" + +using namespace kuzu::common; + +namespace kuzu { +namespace function { + +struct ListPrepend { + template + static void operation(common::list_entry_t& listEntry, T& value, common::list_entry_t& result, + common::ValueVector& listVector, common::ValueVector& valueVector, + common::ValueVector& resultVector) { + result = common::ListVector::addList(&resultVector, listEntry.size + 1); + auto resultDataVector = common::ListVector::getDataVector(&resultVector); + resultDataVector->copyFromVectorData( + common::ListVector::getListValues(&resultVector, result), &valueVector, + reinterpret_cast(&value)); + auto resultPos = result.offset + 1; + auto listDataVector = common::ListVector::getDataVector(&listVector); + auto listPos = listEntry.offset; + for (auto i = 0u; i < listEntry.size; i++) { + resultDataVector->copyFromVectorData(resultPos++, listDataVector, listPos++); + } + } +}; + +static std::unique_ptr bindFunc(const binder::expression_vector& arguments, + Function* function) { + if (arguments[0]->getDataType().getLogicalTypeID() != LogicalTypeID::ANY && + arguments[1]->dataType != *ListType::getChildType(&arguments[0]->dataType)) { + throw BinderException( + ExceptionMessage::listFunctionIncompatibleChildrenType(ListPrependFunction::name, + arguments[0]->getDataType().toString(), arguments[1]->getDataType().toString())); + } + auto resultType = arguments[0]->getDataType(); + auto scalarFunction = function->ptrCast(); + TypeUtils::visit(arguments[1]->getDataType().getPhysicalType(), + [&scalarFunction](T) { + scalarFunction->execFunc = ScalarFunction::BinaryExecListStructFunction; + }); + return FunctionBindData::getSimpleBindData(arguments, resultType); +} + +function_set ListPrependFunction::getFunctionSet() { + function_set result; + result.push_back(std::make_unique(name, + std::vector{LogicalTypeID::LIST, LogicalTypeID::ANY}, LogicalTypeID::LIST, + bindFunc)); + return result; +} + +} // namespace function +} // namespace kuzu diff --git a/src/function/list/list_range_function.cpp b/src/function/list/list_range_function.cpp index 191f9bf4d1..448217cdb1 100644 --- a/src/function/list/list_range_function.cpp +++ b/src/function/list/list_range_function.cpp @@ -1,5 +1,5 @@ -#include "function/list/functions/list_range_function.h" - +#include "common/exception/runtime.h" +#include "common/type_utils.h" #include "function/list/vector_list_functions.h" #include "function/scalar_function.h" @@ -8,57 +8,59 @@ using namespace kuzu::common; namespace kuzu { namespace function { -template -static scalar_func_exec_t getBinaryListExecFuncSwitchAll(const LogicalType& type) { - scalar_func_exec_t execFunc; - switch (type.getPhysicalType()) { - case PhysicalTypeID::INT64: { - execFunc = - ScalarFunction::BinaryExecListStructFunction; - } break; - case PhysicalTypeID::INT32: { - execFunc = - ScalarFunction::BinaryExecListStructFunction; - } break; - case PhysicalTypeID::INT16: { - execFunc = - ScalarFunction::BinaryExecListStructFunction; - } break; - case PhysicalTypeID::INT8: { - execFunc = - ScalarFunction::BinaryExecListStructFunction; - } break; - default: { - KU_UNREACHABLE; +struct Range { + // range function: + // - include end + // - when start = end: there is only one element in result list + // - when end - start are of opposite sign of step, the result will be empty + // - default step = 1 + template + static void operation(T& start, T& end, list_entry_t& result, ValueVector& leftVector, + ValueVector& /*rightVector*/, ValueVector& resultVector) { + T step = 1; + operation(start, end, step, result, leftVector, resultVector); } + + template + static void operation(T& start, T& end, T& step, list_entry_t& result, + ValueVector& /*inputVector*/, ValueVector& resultVector) { + if (step == 0) { + throw RuntimeException("Step of range cannot be 0."); + } + + // start, start + step, start + 2step, ..., end + T number = start; + auto size = ((end - start) * 1.0 / step); + size < 0 ? size = 0 : size = (int64_t)(size + 1); + + result = ListVector::addList(&resultVector, (int64_t)size); + auto resultDataVector = ListVector::getDataVector(&resultVector); + for (auto i = 0u; i < (int64_t)size; i++) { + resultDataVector->setValue(result.offset + i, number); + number += step; + } } +}; + +static scalar_func_exec_t getBinaryExecFunc(const LogicalType& type) { + scalar_func_exec_t execFunc; + TypeUtils::visit( + type.getLogicalTypeID(), + [&execFunc](T) { + execFunc = ScalarFunction::BinaryExecListStructFunction; + }, + [](auto) { KU_UNREACHABLE; }); return execFunc; } -template -static scalar_func_exec_t getTernaryListExecFuncSwitchAll(const LogicalType& type) { +static scalar_func_exec_t getTernaryExecFunc(const LogicalType& type) { scalar_func_exec_t execFunc; - switch (type.getPhysicalType()) { - case PhysicalTypeID::INT64: { - execFunc = ScalarFunction::TernaryExecListStructFunction; - } break; - case PhysicalTypeID::INT32: { - execFunc = ScalarFunction::TernaryExecListStructFunction; - } break; - case PhysicalTypeID::INT16: { - execFunc = ScalarFunction::TernaryExecListStructFunction; - } break; - case PhysicalTypeID::INT8: { - execFunc = ScalarFunction::TernaryExecListStructFunction; - } break; - default: { - KU_UNREACHABLE; - } - } + TypeUtils::visit( + type.getLogicalTypeID(), + [&execFunc](T) { + execFunc = ScalarFunction::TernaryExecListStructFunction; + }, + [](auto) { KU_UNREACHABLE; }); return execFunc; } @@ -78,15 +80,13 @@ function_set ListRangeFunction::getFunctionSet() { function_set result; for (auto typeID : LogicalTypeUtils::getIntegerLogicalTypeIDs()) { // start, end - result.push_back(std::make_unique(name, - std::vector{typeID, typeID}, LogicalTypeID::LIST, - getBinaryListExecFuncSwitchAll(LogicalType{typeID}), nullptr, - bindFunc)); + result.push_back( + std::make_unique(name, std::vector{typeID, typeID}, + LogicalTypeID::LIST, getBinaryExecFunc(LogicalType{typeID}), bindFunc)); // start, end, step result.push_back(std::make_unique(name, std::vector{typeID, typeID, typeID}, LogicalTypeID::LIST, - getTernaryListExecFuncSwitchAll(LogicalType{typeID}), nullptr, - bindFunc)); + getTernaryExecFunc(LogicalType{typeID}), bindFunc)); } return result; } diff --git a/src/function/list/list_reverse_function.cpp b/src/function/list/list_reverse_function.cpp index 9502be12ba..99c67b2763 100644 --- a/src/function/list/list_reverse_function.cpp +++ b/src/function/list/list_reverse_function.cpp @@ -1,5 +1,3 @@ -#include "function/list/functions/list_reverse_function.h" - #include "function/list/vector_list_functions.h" #include "function/scalar_function.h" @@ -8,6 +6,20 @@ using namespace kuzu::common; namespace kuzu { namespace function { +struct ListReverse { + static inline void operation(common::list_entry_t& input, common::list_entry_t& result, + common::ValueVector& inputVector, common::ValueVector& resultVector) { + auto inputDataVector = common::ListVector::getDataVector(&inputVector); + auto resultDataVector = common::ListVector::getDataVector(&resultVector); + result = input; // reverse does not change + for (auto i = 0u; i < input.size; i++) { + auto pos = input.offset + i; + auto reversePos = input.offset + input.size - 1 - i; + resultDataVector->copyFromVectorData(reversePos, inputDataVector, pos); + } + } +}; + static std::unique_ptr bindFunc(const binder::expression_vector& arguments, Function* function) { auto scalarFunction = ku_dynamic_cast(function); @@ -19,9 +31,8 @@ static std::unique_ptr bindFunc(const binder::expression_vecto function_set ListReverseFunction::getFunctionSet() { function_set result; - result.push_back( - std::make_unique(name, std::vector{LogicalTypeID::LIST}, - LogicalTypeID::ANY, nullptr, nullptr, bindFunc)); + result.push_back(std::make_unique(name, + std::vector{LogicalTypeID::LIST}, LogicalTypeID::ANY, bindFunc)); return result; } diff --git a/src/function/list/list_slice_function.cpp b/src/function/list/list_slice_function.cpp index f2c1cc9a70..2156265d31 100644 --- a/src/function/list/list_slice_function.cpp +++ b/src/function/list/list_slice_function.cpp @@ -1,13 +1,56 @@ -#include "function/list/functions/list_slice_function.h" #include "function/list/vector_list_functions.h" #include "function/scalar_function.h" +#include "function/string/functions/substr_function.h" using namespace kuzu::common; namespace kuzu { namespace function { +struct ListSlice { + // Note: this function takes in a 1-based begin/end index (The index of the first value in the + // listEntry is 1). + static void operation(common::list_entry_t& listEntry, int64_t& begin, int64_t& end, + common::list_entry_t& result, common::ValueVector& listVector, + common::ValueVector& resultVector) { + auto startIdx = begin; + auto endIdx = end; + normalizeIndices(startIdx, endIdx, listEntry.size); + result = common::ListVector::addList(&resultVector, endIdx - startIdx); + auto srcDataVector = common::ListVector::getDataVector(&listVector); + auto srcPos = listEntry.offset + startIdx - 1; + auto dstDataVector = common::ListVector::getDataVector(&resultVector); + auto dstPos = result.offset; + for (auto i = 0u; i < endIdx - startIdx; i++) { + dstDataVector->copyFromVectorData(dstPos++, srcDataVector, srcPos++); + } + } + + static void operation(common::ku_string_t& str, int64_t& begin, int64_t& end, + common::ku_string_t& result, common::ValueVector& /*listValueVector*/, + common::ValueVector& resultValueVector) { + auto startIdx = begin; + auto endIdx = end; + normalizeIndices(startIdx, endIdx, str.len); + SubStr::operation(str, startIdx, std::min(endIdx - startIdx + 1, str.len - startIdx + 1), + result, resultValueVector); + } + +private: + static void normalizeIndices(int64_t& startIdx, int64_t& endIdx, uint64_t size) { + if (startIdx <= 0) { + startIdx = 1; + } + if (endIdx <= 0 || (uint64_t)endIdx > size) { + endIdx = size + 1; + } + if (startIdx > endIdx) { + endIdx = startIdx; + } + } +}; + static std::unique_ptr bindFunc(const binder::expression_vector& arguments, Function* function) { KU_ASSERT(arguments.size() == 3); @@ -26,14 +69,14 @@ function_set ListSliceFunction::getFunctionSet() { LogicalTypeID::LIST, ScalarFunction::TernaryExecListStructFunction, - nullptr, bindFunc)); + nullptr /* selectFunc */, bindFunc)); result.push_back(std::make_unique(name, std::vector{LogicalTypeID::STRING, LogicalTypeID::INT64, LogicalTypeID::INT64}, LogicalTypeID::STRING, ScalarFunction::TernaryExecListStructFunction, - nullptr, bindFunc)); + nullptr /* selectFunc */, bindFunc)); return result; } diff --git a/src/function/list/list_unique_function.cpp b/src/function/list/list_unique_function.cpp index 44c4fa3ce2..e01ef28e9b 100644 --- a/src/function/list/list_unique_function.cpp +++ b/src/function/list/list_unique_function.cpp @@ -8,6 +8,38 @@ using namespace kuzu::common; namespace kuzu { namespace function { +uint64_t ListUnique::appendListElementsToValueSet(common::list_entry_t& input, + common::ValueVector& inputVector, duplicate_value_handler duplicateValHandler, + unique_value_handler uniqueValueHandler, null_value_handler nullValueHandler) { + ValueSet uniqueKeys; + auto dataVector = common::ListVector::getDataVector(&inputVector); + auto val = common::Value::createDefaultValue(dataVector->dataType); + for (auto i = 0u; i < input.size; i++) { + if (dataVector->isNull(input.offset + i)) { + if (nullValueHandler != nullptr) { + nullValueHandler(); + } + continue; + } + auto entryVal = common::ListVector::getListValuesWithOffset(&inputVector, input, i); + val.copyFromColLayout(entryVal, dataVector); + auto uniqueKey = uniqueKeys.insert(val).second; + if (duplicateValHandler != nullptr && !uniqueKey) { + duplicateValHandler( + common::TypeUtils::entryToString(dataVector->dataType, entryVal, dataVector)); + } + if (uniqueValueHandler != nullptr && uniqueKey) { + uniqueValueHandler(*dataVector, input.offset + i); + } + } + return uniqueKeys.size(); +} + +void ListUnique::operation(common::list_entry_t& input, int64_t& result, + common::ValueVector& inputVector, common::ValueVector& /*resultVector*/) { + result = appendListElementsToValueSet(input, inputVector); +} + static std::unique_ptr bindFunc(const binder::expression_vector& arguments, Function*) { return FunctionBindData::getSimpleBindData(arguments, *LogicalType::INT64()); diff --git a/src/function/vector_blob_functions.cpp b/src/function/vector_blob_functions.cpp index 9d39e3fcfd..740ad18440 100644 --- a/src/function/vector_blob_functions.cpp +++ b/src/function/vector_blob_functions.cpp @@ -23,7 +23,7 @@ function_set EncodeFunctions::getFunctionSet() { function_set definitions; definitions.push_back(make_unique(name, std::vector{LogicalTypeID::STRING}, LogicalTypeID::BLOB, - ScalarFunction::UnaryStringExecFunction, nullptr)); + ScalarFunction::UnaryStringExecFunction)); return definitions; } @@ -31,7 +31,7 @@ function_set DecodeFunctions::getFunctionSet() { function_set definitions; definitions.push_back(make_unique(name, std::vector{LogicalTypeID::BLOB}, LogicalTypeID::STRING, - ScalarFunction::UnaryStringExecFunction, nullptr)); + ScalarFunction::UnaryStringExecFunction)); return definitions; } diff --git a/src/include/common/types/types.h b/src/include/common/types/types.h index 8f9eb5fa4f..4e61641f5b 100644 --- a/src/include/common/types/types.h +++ b/src/include/common/types/types.h @@ -81,6 +81,10 @@ struct union_entry_t { struct int128_t; struct ku_string_t; +template +concept NumericTypes = + std::integral || std::floating_point || std::is_same_v; + template concept HashablePrimitive = ((std::integral && !std::is_same_v) || std::floating_point || std::is_same_v); diff --git a/src/include/function/list/functions/list_any_value_function.h b/src/include/function/list/functions/list_any_value_function.h deleted file mode 100644 index 8adcbb076f..0000000000 --- a/src/include/function/list/functions/list_any_value_function.h +++ /dev/null @@ -1,28 +0,0 @@ -#pragma once - -#include "common/vector/value_vector.h" - -namespace kuzu { -namespace function { - -struct ListAnyValue { - template - static inline void operation(common::list_entry_t& input, T& result, - common::ValueVector& inputVector, common::ValueVector& resultVector) { - auto inputValues = common::ListVector::getListValues(&inputVector, input); - auto inputDataVector = common::ListVector::getDataVector(&inputVector); - auto numBytesPerValue = inputDataVector->getNumBytesPerValue(); - - for (auto i = 0u; i < input.size; i++) { - if (!(inputDataVector->isNull(input.offset + i))) { - resultVector.copyFromVectorData(reinterpret_cast(&result), - inputDataVector, inputValues); - break; - } - inputValues += numBytesPerValue; - } - } -}; - -} // namespace function -} // namespace kuzu diff --git a/src/include/function/list/functions/list_append_function.h b/src/include/function/list/functions/list_append_function.h deleted file mode 100644 index 408ac77c79..0000000000 --- a/src/include/function/list/functions/list_append_function.h +++ /dev/null @@ -1,28 +0,0 @@ -#pragma once - -#include "common/vector/value_vector.h" - -namespace kuzu { -namespace function { - -struct ListAppend { - template - static inline void operation(common::list_entry_t& listEntry, T& value, - common::list_entry_t& result, common::ValueVector& listVector, - common::ValueVector& valueVector, common::ValueVector& resultVector) { - result = common::ListVector::addList(&resultVector, listEntry.size + 1); - auto listDataVector = common::ListVector::getDataVector(&listVector); - auto listPos = listEntry.offset; - auto resultDataVector = common::ListVector::getDataVector(&resultVector); - auto resultPos = result.offset; - for (auto i = 0u; i < listEntry.size; i++) { - resultDataVector->copyFromVectorData(resultPos++, listDataVector, listPos++); - } - resultDataVector->copyFromVectorData( - resultDataVector->getData() + resultPos * resultDataVector->getNumBytesPerValue(), - &valueVector, reinterpret_cast(&value)); - } -}; - -} // namespace function -} // namespace kuzu diff --git a/src/include/function/list/functions/list_concat_function.h b/src/include/function/list/functions/list_concat_function.h index 1f14f0836c..cf9a2e1d2f 100644 --- a/src/include/function/list/functions/list_concat_function.h +++ b/src/include/function/list/functions/list_concat_function.h @@ -8,23 +8,9 @@ namespace function { struct ListConcat { public: - static inline void operation(common::list_entry_t& left, common::list_entry_t& right, + static void operation(common::list_entry_t& left, common::list_entry_t& right, common::list_entry_t& result, common::ValueVector& leftVector, - common::ValueVector& rightVector, common::ValueVector& resultVector) { - result = common::ListVector::addList(&resultVector, left.size + right.size); - auto resultDataVector = common::ListVector::getDataVector(&resultVector); - auto resultPos = result.offset; - auto leftDataVector = common::ListVector::getDataVector(&leftVector); - auto leftPos = left.offset; - for (auto i = 0u; i < left.size; i++) { - resultDataVector->copyFromVectorData(resultPos++, leftDataVector, leftPos++); - } - auto rightDataVector = common::ListVector::getDataVector(&rightVector); - auto rightPos = right.offset; - for (auto i = 0u; i < right.size; i++) { - resultDataVector->copyFromVectorData(resultPos++, rightDataVector, rightPos++); - } - } + common::ValueVector& rightVector, common::ValueVector& resultVector); }; } // namespace function diff --git a/src/include/function/list/functions/list_contains_function.h b/src/include/function/list/functions/list_contains_function.h deleted file mode 100644 index b43522872d..0000000000 --- a/src/include/function/list/functions/list_contains_function.h +++ /dev/null @@ -1,21 +0,0 @@ -#pragma once - -#include "common/vector/value_vector.h" -#include "list_position_function.h" - -namespace kuzu { -namespace function { - -struct ListContains { - template - static inline void operation(common::list_entry_t& list, T& element, uint8_t& result, - common::ValueVector& listVector, common::ValueVector& elementVector, - common::ValueVector& resultVector) { - int64_t pos; - ListPosition::operation(list, element, pos, listVector, elementVector, resultVector); - result = (pos != 0); - } -}; - -} // namespace function -} // namespace kuzu diff --git a/src/include/function/list/functions/list_distinct_function.h b/src/include/function/list/functions/list_distinct_function.h deleted file mode 100644 index 8ac0665f8c..0000000000 --- a/src/include/function/list/functions/list_distinct_function.h +++ /dev/null @@ -1,28 +0,0 @@ -#pragma once - -#include "common/vector/value_vector.h" -#include "list_unique_function.h" - -namespace kuzu { -namespace function { - -struct ListDistinct { - static inline void operation(common::list_entry_t& input, common::list_entry_t& result, - common::ValueVector& inputVector, common::ValueVector& resultVector) { - auto numUniqueValues = ListUnique::appendListElementsToValueSet(input, inputVector); - result = common::ListVector::addList(&resultVector, numUniqueValues); - auto resultDataVector = common::ListVector::getDataVector(&resultVector); - auto resultDataVectorBuffer = - common::ListVector::getListValuesWithOffset(&resultVector, result, 0 /* offset */); - ListUnique::appendListElementsToValueSet(input, inputVector, nullptr, - [&resultDataVector, &resultDataVectorBuffer](common::ValueVector& dataVector, - uint64_t pos) -> void { - resultDataVector->copyFromVectorData(resultDataVectorBuffer, &dataVector, - dataVector.getData() + pos * dataVector.getNumBytesPerValue()); - resultDataVectorBuffer += dataVector.getNumBytesPerValue(); - }); - } -}; - -} // namespace function -} // namespace kuzu diff --git a/src/include/function/list/functions/list_len_function.h b/src/include/function/list/functions/list_len_function.h index 25af99da23..17bc404f73 100644 --- a/src/include/function/list/functions/list_len_function.h +++ b/src/include/function/list/functions/list_len_function.h @@ -11,7 +11,7 @@ namespace function { struct ListLen { public: template - static inline void operation(T& input, int64_t& result) { + static void operation(T& input, int64_t& result) { result = input.size; } }; diff --git a/src/include/function/list/functions/list_position_function.h b/src/include/function/list/functions/list_position_function.h index d47f0558c7..b5ad82b505 100644 --- a/src/include/function/list/functions/list_position_function.h +++ b/src/include/function/list/functions/list_position_function.h @@ -10,7 +10,7 @@ struct ListPosition { // Note: this function takes in a 1-based element (The index of the first element in the list // is 1). template - static inline void operation(common::list_entry_t& list, T& element, int64_t& result, + static void operation(common::list_entry_t& list, T& element, int64_t& result, common::ValueVector& listVector, common::ValueVector& elementVector, common::ValueVector& /*resultVector*/) { if (*common::ListType::getChildType(&listVector.dataType) != elementVector.dataType) { diff --git a/src/include/function/list/functions/list_prepend_function.h b/src/include/function/list/functions/list_prepend_function.h deleted file mode 100644 index fb961afae7..0000000000 --- a/src/include/function/list/functions/list_prepend_function.h +++ /dev/null @@ -1,28 +0,0 @@ -#pragma once - -#include "common/vector/value_vector.h" - -namespace kuzu { -namespace function { - -struct ListPrepend { - template - static inline void operation(common::list_entry_t& listEntry, T& value, - common::list_entry_t& result, common::ValueVector& listVector, - common::ValueVector& valueVector, common::ValueVector& resultVector) { - result = common::ListVector::addList(&resultVector, listEntry.size + 1); - auto resultDataVector = common::ListVector::getDataVector(&resultVector); - resultDataVector->copyFromVectorData( - common::ListVector::getListValues(&resultVector, result), &valueVector, - reinterpret_cast(&value)); - auto resultPos = result.offset + 1; - auto listDataVector = common::ListVector::getDataVector(&listVector); - auto listPos = listEntry.offset; - for (auto i = 0u; i < listEntry.size; i++) { - resultDataVector->copyFromVectorData(resultPos++, listDataVector, listPos++); - } - } -}; - -} // namespace function -} // namespace kuzu diff --git a/src/include/function/list/functions/list_product_function.h b/src/include/function/list/functions/list_product_function.h deleted file mode 100644 index d7aac72ffa..0000000000 --- a/src/include/function/list/functions/list_product_function.h +++ /dev/null @@ -1,24 +0,0 @@ -#pragma once - -#include "common/vector/value_vector.h" - -namespace kuzu { -namespace function { - -struct ListProduct { - template - static inline void operation(common::list_entry_t& input, T& result, - common::ValueVector& inputVector, common::ValueVector& /*resultVector*/) { - auto inputDataVector = common::ListVector::getDataVector(&inputVector); - result = 1; - for (auto i = 0u; i < input.size; i++) { - if (inputDataVector->isNull(input.offset + i)) { - continue; - } - result *= inputDataVector->getValue(input.offset + i); - } - } -}; - -} // namespace function -} // namespace kuzu diff --git a/src/include/function/list/functions/list_range_function.h b/src/include/function/list/functions/list_range_function.h deleted file mode 100644 index 9b810d81bf..0000000000 --- a/src/include/function/list/functions/list_range_function.h +++ /dev/null @@ -1,46 +0,0 @@ -#pragma once - -#include "common/exception/runtime.h" -#include "common/vector/value_vector.h" - -namespace kuzu { -namespace function { - -struct Range { -public: - // range function: - // - include end - // - when start = end: there is only one element in result list - // - when end - start are of opposite sign of step, the result will be empty - // - default step = 1 - template - static inline void operation(T& start, T& end, common::list_entry_t& result, - common::ValueVector& leftVector, common::ValueVector& /*rightVector*/, - common::ValueVector& resultVector) { - T step = 1; - operation(start, end, step, result, leftVector, resultVector); - } - - template - static inline void operation(T& start, T& end, T& step, common::list_entry_t& result, - common::ValueVector& /*inputVector*/, common::ValueVector& resultVector) { - if (step == 0) { - throw common::RuntimeException("Step of range cannot be 0."); - } - - // start, start + step, start + 2step, ..., end - T number = start; - auto size = (end - start) * 1.0 / step; - size < 0 ? size = 0 : size = (int64_t)(size + 1); - - result = common::ListVector::addList(&resultVector, size); - auto resultDataVector = common::ListVector::getDataVector(&resultVector); - for (auto i = 0u; i < size; i++) { - resultDataVector->setValue(result.offset + i, number); - number += step; - } - } -}; - -} // namespace function -} // namespace kuzu diff --git a/src/include/function/list/functions/list_reverse_function.h b/src/include/function/list/functions/list_reverse_function.h deleted file mode 100644 index 3b36615b02..0000000000 --- a/src/include/function/list/functions/list_reverse_function.h +++ /dev/null @@ -1,23 +0,0 @@ -#pragma once - -#include "common/vector/value_vector.h" - -namespace kuzu { -namespace function { - -struct ListReverse { - static inline void operation(common::list_entry_t& input, common::list_entry_t& result, - common::ValueVector& inputVector, common::ValueVector& resultVector) { - auto inputDataVector = common::ListVector::getDataVector(&inputVector); - auto resultDataVector = common::ListVector::getDataVector(&resultVector); - result = input; // reverse does not change - for (auto i = 0u; i < input.size; i++) { - auto pos = input.offset + i; - auto reversePos = input.offset + input.size - 1 - i; - resultDataVector->copyFromVectorData(reversePos, inputDataVector, pos); - } - } -}; - -} // namespace function -} // namespace kuzu diff --git a/src/include/function/list/functions/list_slice_function.h b/src/include/function/list/functions/list_slice_function.h deleted file mode 100644 index d775aa684b..0000000000 --- a/src/include/function/list/functions/list_slice_function.h +++ /dev/null @@ -1,52 +0,0 @@ -#pragma once - -#include "function/string/functions/substr_function.h" - -namespace kuzu { -namespace function { - -struct ListSlice { - // Note: this function takes in a 1-based begin/end index (The index of the first value in the - // listEntry is 1). - static inline void operation(common::list_entry_t& listEntry, int64_t& begin, int64_t& end, - common::list_entry_t& result, common::ValueVector& listVector, - common::ValueVector& resultVector) { - auto startIdx = begin; - auto endIdx = end; - normalizeIndices(startIdx, endIdx, listEntry.size); - result = common::ListVector::addList(&resultVector, endIdx - startIdx); - auto srcDataVector = common::ListVector::getDataVector(&listVector); - auto srcPos = listEntry.offset + startIdx - 1; - auto dstDataVector = common::ListVector::getDataVector(&resultVector); - auto dstPos = result.offset; - for (auto i = 0u; i < endIdx - startIdx; i++) { - dstDataVector->copyFromVectorData(dstPos++, srcDataVector, srcPos++); - } - } - - static inline void operation(common::ku_string_t& str, int64_t& begin, int64_t& end, - common::ku_string_t& result, common::ValueVector& /*listValueVector*/, - common::ValueVector& resultValueVector) { - auto startIdx = begin; - auto endIdx = end; - normalizeIndices(startIdx, endIdx, str.len); - SubStr::operation(str, startIdx, std::min(endIdx - startIdx + 1, str.len - startIdx + 1), - result, resultValueVector); - } - -private: - static inline void normalizeIndices(int64_t& startIdx, int64_t& endIdx, uint64_t size) { - if (startIdx <= 0) { - startIdx = 1; - } - if (endIdx <= 0 || (uint64_t)endIdx > size) { - endIdx = size + 1; - } - if (startIdx > endIdx) { - endIdx = startIdx; - } - } -}; - -} // namespace function -} // namespace kuzu diff --git a/src/include/function/list/functions/list_sum_function.h b/src/include/function/list/functions/list_sum_function.h deleted file mode 100644 index 8c9f23b5bf..0000000000 --- a/src/include/function/list/functions/list_sum_function.h +++ /dev/null @@ -1,24 +0,0 @@ -#pragma once - -#include "common/vector/value_vector.h" - -namespace kuzu { -namespace function { - -struct ListSum { - template - static inline void operation(common::list_entry_t& input, T& result, - common::ValueVector& inputVector, common::ValueVector& /*resultVector*/) { - auto inputDataVector = common::ListVector::getDataVector(&inputVector); - result = 0; - for (auto i = 0u; i < input.size; i++) { - if (inputDataVector->isNull(input.offset + i)) { - continue; - } - result += inputDataVector->getValue(input.offset + i); - } - } -}; - -} // namespace function -} // namespace kuzu diff --git a/src/include/function/list/functions/list_unique_function.h b/src/include/function/list/functions/list_unique_function.h index 9bc61d7a19..2695f8d6bd 100644 --- a/src/include/function/list/functions/list_unique_function.h +++ b/src/include/function/list/functions/list_unique_function.h @@ -25,35 +25,10 @@ struct ListUnique { static uint64_t appendListElementsToValueSet(common::list_entry_t& input, common::ValueVector& inputVector, duplicate_value_handler duplicateValHandler = nullptr, unique_value_handler uniqueValueHandler = nullptr, - null_value_handler nullValueHandler = nullptr) { - ValueSet uniqueKeys; - auto dataVector = common::ListVector::getDataVector(&inputVector); - auto val = common::Value::createDefaultValue(dataVector->dataType); - for (auto i = 0u; i < input.size; i++) { - if (dataVector->isNull(input.offset + i)) { - if (nullValueHandler != nullptr) { - nullValueHandler(); - } - continue; - } - auto entryVal = common::ListVector::getListValuesWithOffset(&inputVector, input, i); - val.copyFromColLayout(entryVal, dataVector); - auto uniqueKey = uniqueKeys.insert(val).second; - if (duplicateValHandler != nullptr && !uniqueKey) { - duplicateValHandler( - common::TypeUtils::entryToString(dataVector->dataType, entryVal, dataVector)); - } - if (uniqueValueHandler != nullptr && uniqueKey) { - uniqueValueHandler(*dataVector, input.offset + i); - } - } - return uniqueKeys.size(); - } + null_value_handler nullValueHandler = nullptr); static void operation(common::list_entry_t& input, int64_t& result, - common::ValueVector& inputVector, common::ValueVector& /*resultVector*/) { - result = appendListElementsToValueSet(input, inputVector); - } + common::ValueVector& inputVector, common::ValueVector& resultVector); }; } // namespace function diff --git a/src/include/function/scalar_function.h b/src/include/function/scalar_function.h index 0815bf137b..b978c454c6 100644 --- a/src/include/function/scalar_function.h +++ b/src/include/function/scalar_function.h @@ -53,6 +53,16 @@ struct ScalarFunction final : public BaseScalarFunction { execFunc{std::move(execFunc)}, selectFunc(std::move(selectFunc)), compileFunc{std::move(compileFunc)} {} + ScalarFunction(std::string name, std::vector parameterTypeIDs, + common::LogicalTypeID returnTypeID, scalar_bind_func bindFunc) + : ScalarFunction{std::move(name), std::move(parameterTypeIDs), returnTypeID, + nullptr /* execFunc */, nullptr /* selectFunc */, bindFunc} {} + + ScalarFunction(std::string name, std::vector parameterTypeIDs, + common::LogicalTypeID returnTypeID, scalar_func_exec_t execFunc, scalar_bind_func bindFunc) + : ScalarFunction{std::move(name), std::move(parameterTypeIDs), returnTypeID, execFunc, + nullptr /* selectFunc */, bindFunc} {} + template static void TernaryExecFunction(const std::vector>& params, common::ValueVector& result, void* /*dataPtr*/ = nullptr) { diff --git a/src/include/function/string/functions/ltrim_function.h b/src/include/function/string/functions/ltrim_function.h index 5cd121613b..bb41008b77 100644 --- a/src/include/function/string/functions/ltrim_function.h +++ b/src/include/function/string/functions/ltrim_function.h @@ -1,7 +1,5 @@ #pragma once -#include - #include "base_str_function.h" #include "common/types/ku_string.h" diff --git a/src/storage/store/node_table.cpp b/src/storage/store/node_table.cpp index 469e3fcca7..d9cfc3d931 100644 --- a/src/storage/store/node_table.cpp +++ b/src/storage/store/node_table.cpp @@ -208,7 +208,7 @@ void NodeTable::insertPK(const ValueVector& nodeIDVector, const ValueVector& pri [&](ku_string_t) { pkStr = primaryKeyVector.getValue(pkPos).getAsString(); }, - [&]( + [&pkStr, &primaryKeyVector, &pkPos]( T) { pkStr = TypeUtils::toString(primaryKeyVector.getValue(pkPos)); }); throw RuntimeException(ExceptionMessage::duplicatePKException(pkStr)); }