Skip to content

Commit

Permalink
Clean up list function implementations
Browse files Browse the repository at this point in the history
  • Loading branch information
acquamarin committed Apr 26, 2024
1 parent 1b64173 commit ac08521
Show file tree
Hide file tree
Showing 35 changed files with 473 additions and 742 deletions.
1 change: 1 addition & 0 deletions src/common/types/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -928,6 +928,7 @@ std::vector<LogicalTypeID> LogicalTypeUtils::getNumericalLogicalTypeIDs() {
LogicalTypeID::FLOAT, LogicalTypeID::SERIAL};
}

// TODO(Ziyi): Support int128 and uint types here.
std::vector<LogicalTypeID> LogicalTypeUtils::getIntegerLogicalTypeIDs() {
return std::vector<LogicalTypeID>{LogicalTypeID::INT64, LogicalTypeID::INT32,
LogicalTypeID::INT16, LogicalTypeID::INT8, LogicalTypeID::SERIAL};
Expand Down
5 changes: 4 additions & 1 deletion src/function/list/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ add_library(kuzu_list_function
OBJECT
list_agg_function.cpp
list_any_value_function.cpp
list_binary_right_switch_function.cpp
list_append_function.cpp
list_concat_function.cpp
list_contains_function.cpp
list_creation.cpp
list_distinct_function.cpp
list_extract_function.cpp
Expand All @@ -13,6 +14,8 @@ add_library(kuzu_list_function
list_sort_function.cpp
list_to_string_function.cpp
list_unique_function.cpp
list_prepend_function.cpp
list_position_function.cpp
size_function.cpp)

set(ALL_OBJECT_FILES
Expand Down
100 changes: 44 additions & 56 deletions src/function/list/list_agg_function.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include "common/exception/binder.h"
#include "function/list/functions/list_product_function.h"
#include "function/list/functions/list_sum_function.h"
#include "common/type_utils.h"
#include "function/list/vector_list_functions.h"
#include "function/scalar_function.h"

Expand All @@ -12,75 +11,64 @@ namespace function {
template<typename OPERATION>
static std::unique_ptr<FunctionBindData> bindFuncListAggr(
const binder::expression_vector& arguments, Function* function) {
auto scalarFunction = ku_dynamic_cast<Function*, ScalarFunction*>(function);
auto scalarFunction = function->ptrCast<ScalarFunction>();
auto resultType = ListType::getChildType(&arguments[0]->dataType);
switch (resultType->getLogicalTypeID()) {
case LogicalTypeID::SERIAL:
case LogicalTypeID::INT64: {
scalarFunction->execFunc =
ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t, int64_t, OPERATION>;
} break;
case LogicalTypeID::INT32: {
scalarFunction->execFunc =
ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t, int32_t, OPERATION>;
} break;
case LogicalTypeID::INT16: {
scalarFunction->execFunc =
ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t, int16_t, OPERATION>;
} break;
case LogicalTypeID::INT8: {
scalarFunction->execFunc =
ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t, int8_t, OPERATION>;
} break;
case LogicalTypeID::UINT64: {
scalarFunction->execFunc =
ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t, uint64_t, OPERATION>;
} break;
case LogicalTypeID::UINT32: {
scalarFunction->execFunc =
ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t, uint32_t, OPERATION>;
} break;
case LogicalTypeID::UINT16: {
scalarFunction->execFunc =
ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t, uint16_t, OPERATION>;
} break;
case LogicalTypeID::UINT8: {
scalarFunction->execFunc =
ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t, uint8_t, OPERATION>;
} break;
case LogicalTypeID::INT128: {
scalarFunction->execFunc =
ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t, int128_t, OPERATION>;
} break;
case LogicalTypeID::DOUBLE: {
scalarFunction->execFunc =
ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t, double, OPERATION>;
} break;
case LogicalTypeID::FLOAT: {
scalarFunction->execFunc =
ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t, float, OPERATION>;
} break;
default: {
throw BinderException(stringFormat("Unsupported inner data type for {}: {}", function->name,
LogicalTypeUtils::toString(resultType->getLogicalTypeID())));
}
}
TypeUtils::visit(
resultType->getLogicalTypeID(),
[&scalarFunction]<NumericTypes T>(T) {
scalarFunction->execFunc =
ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t, T, OPERATION>;
},
[&function, &resultType](auto) {
throw BinderException(stringFormat("Unsupported inner data type for {}: {}",
function->name, LogicalTypeUtils::toString(resultType->getLogicalTypeID())));
});
return FunctionBindData::getSimpleBindData(arguments, *resultType);
}

struct ListSum {
template<typename T>
static void operation(common::list_entry_t& input, T& result, common::ValueVector& inputVector,
common::ValueVector& /*resultVector*/) {
auto inputDataVector = common::ListVector::getDataVector(&inputVector);
result = 0;
for (auto i = 0u; i < input.size; i++) {
if (inputDataVector->isNull(input.offset + i)) {
continue;
}
result += inputDataVector->getValue<T>(input.offset + i);
}
}
};

function_set ListSumFunction::getFunctionSet() {
function_set result;
result.push_back(
std::make_unique<ScalarFunction>(name, std::vector<LogicalTypeID>{LogicalTypeID::LIST},
LogicalTypeID::INT64, nullptr, nullptr, bindFuncListAggr<ListSum>));
LogicalTypeID::INT64, bindFuncListAggr<ListSum>));
return result;
}

struct ListProduct {
template<typename T>
static void operation(common::list_entry_t& input, T& result, common::ValueVector& inputVector,
common::ValueVector& /*resultVector*/) {
auto inputDataVector = common::ListVector::getDataVector(&inputVector);
result = 1;
for (auto i = 0u; i < input.size; i++) {
if (inputDataVector->isNull(input.offset + i)) {
continue;
}
result *= inputDataVector->getValue<T>(input.offset + i);
}
}
};

function_set ListProductFunction::getFunctionSet() {
function_set result;
result.push_back(
std::make_unique<ScalarFunction>(name, std::vector<LogicalTypeID>{LogicalTypeID::LIST},
LogicalTypeID::INT64, nullptr, nullptr, bindFuncListAggr<ListProduct>));
LogicalTypeID::INT64, bindFuncListAggr<ListProduct>));
return result;
}

Expand Down
118 changes: 23 additions & 95 deletions src/function/list/list_any_value_function.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#include "function/list/functions/list_any_value_function.h"

#include "common/type_utils.h"
#include "function/list/vector_list_functions.h"
#include "function/scalar_function.h"

Expand All @@ -8,104 +7,33 @@ using namespace kuzu::common;
namespace kuzu {
namespace function {

struct ListAnyValue {
template<typename T>
static void operation(common::list_entry_t& input, T& result, common::ValueVector& inputVector,
common::ValueVector& resultVector) {
auto inputValues = common::ListVector::getListValues(&inputVector, input);
auto inputDataVector = common::ListVector::getDataVector(&inputVector);
auto numBytesPerValue = inputDataVector->getNumBytesPerValue();

for (auto i = 0u; i < input.size; i++) {
if (!(inputDataVector->isNull(input.offset + i))) {
resultVector.copyFromVectorData(reinterpret_cast<uint8_t*>(&result),
inputDataVector, inputValues);
break;
}
inputValues += numBytesPerValue;
}
}
};

static std::unique_ptr<FunctionBindData> bindFunc(const binder::expression_vector& arguments,
Function* function) {
auto scalarFunction = ku_dynamic_cast<Function*, ScalarFunction*>(function);
auto resultType = ListType::getChildType(&arguments[0]->dataType);
switch (resultType->getLogicalTypeID()) {
case LogicalTypeID::SERIAL:
case LogicalTypeID::INT64: {
scalarFunction->execFunc =
ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t, int64_t, ListAnyValue>;
} break;
case LogicalTypeID::INT32: {
scalarFunction->execFunc =
ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t, int32_t, ListAnyValue>;
} break;
case LogicalTypeID::INT16: {
scalarFunction->execFunc =
ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t, int16_t, ListAnyValue>;
} break;
case LogicalTypeID::INT8: {
scalarFunction->execFunc =
ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t, int8_t, ListAnyValue>;
} break;
case LogicalTypeID::UINT64: {
scalarFunction->execFunc =
ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t, uint64_t, ListAnyValue>;
} break;
case LogicalTypeID::UINT32: {
scalarFunction->execFunc =
ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t, uint32_t, ListAnyValue>;
} break;
case LogicalTypeID::UINT16: {
scalarFunction->execFunc =
ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t, uint16_t, ListAnyValue>;
} break;
case LogicalTypeID::UINT8: {
scalarFunction->execFunc =
ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t, uint8_t, ListAnyValue>;
} break;
case LogicalTypeID::INT128: {
scalarFunction->execFunc =
ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t, int128_t, ListAnyValue>;
} break;
case LogicalTypeID::DOUBLE: {
TypeUtils::visit(resultType->getPhysicalType(), [&scalarFunction]<typename T>(T) {
scalarFunction->execFunc =
ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t, double, ListAnyValue>;
} break;
case LogicalTypeID::FLOAT: {
scalarFunction->execFunc =
ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t, float, ListAnyValue>;
} break;
case LogicalTypeID::BOOL: {
scalarFunction->execFunc =
ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t, uint8_t, ListAnyValue>;
} break;
case LogicalTypeID::STRING: {
scalarFunction->execFunc =
ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t, ku_string_t, ListAnyValue>;
} break;
case LogicalTypeID::DATE: {
scalarFunction->execFunc =
ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t, date_t, ListAnyValue>;
} break;
case LogicalTypeID::TIMESTAMP: {
scalarFunction->execFunc =
ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t, timestamp_t, ListAnyValue>;
} break;
case LogicalTypeID::TIMESTAMP_MS: {
scalarFunction->execFunc =
ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t, timestamp_ms_t, ListAnyValue>;
} break;
case LogicalTypeID::TIMESTAMP_NS: {
scalarFunction->execFunc =
ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t, timestamp_ns_t, ListAnyValue>;
} break;
case LogicalTypeID::TIMESTAMP_SEC: {
scalarFunction->execFunc = ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t,
timestamp_sec_t, ListAnyValue>;
} break;
case LogicalTypeID::TIMESTAMP_TZ: {
scalarFunction->execFunc =
ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t, timestamp_tz_t, ListAnyValue>;
} break;
case LogicalTypeID::INTERVAL: {
scalarFunction->execFunc =
ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t, interval_t, ListAnyValue>;
} break;
case LogicalTypeID::LIST: {
scalarFunction->execFunc =
ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t, list_entry_t, ListAnyValue>;
} break;
case LogicalTypeID::INTERNAL_ID: {
scalarFunction->execFunc =
ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t, internalID_t, ListAnyValue>;
} break;
default: {
KU_UNREACHABLE;
}
}
ScalarFunction::UnaryExecNestedTypeFunction<list_entry_t, T, ListAnyValue>;
});
return FunctionBindData::getSimpleBindData(arguments, *resultType);
}

Expand Down
60 changes: 60 additions & 0 deletions src/function/list/list_append_function.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#include "common/exception/binder.h"
#include "common/exception/message.h"
#include "common/type_utils.h"
#include "function/list/vector_list_functions.h"
#include "function/scalar_function.h"

using namespace kuzu::common;

namespace kuzu {
namespace function {

struct ListAppend {
template<typename T>
static void operation(common::list_entry_t& listEntry, T& value, common::list_entry_t& result,
common::ValueVector& listVector, common::ValueVector& valueVector,
common::ValueVector& resultVector) {
result = common::ListVector::addList(&resultVector, listEntry.size + 1);
auto listDataVector = common::ListVector::getDataVector(&listVector);
auto listPos = listEntry.offset;
auto resultDataVector = common::ListVector::getDataVector(&resultVector);
auto resultPos = result.offset;
for (auto i = 0u; i < listEntry.size; i++) {
resultDataVector->copyFromVectorData(resultPos++, listDataVector, listPos++);
}
resultDataVector->copyFromVectorData(
resultDataVector->getData() + resultPos * resultDataVector->getNumBytesPerValue(),
&valueVector, reinterpret_cast<uint8_t*>(&value));
}
};

static void validateArgumentType(const binder::expression_vector& arguments) {
if (*ListType::getChildType(&arguments[0]->dataType) != arguments[1]->getDataType()) {
throw BinderException(
ExceptionMessage::listFunctionIncompatibleChildrenType(ListAppendFunction::name,
arguments[0]->getDataType().toString(), arguments[1]->getDataType().toString()));
}
}

static std::unique_ptr<FunctionBindData> bindFunc(const binder::expression_vector& arguments,
Function* function) {
validateArgumentType(arguments);
auto scalarFunction = function->ptrCast<ScalarFunction>();
TypeUtils::visit(arguments[1]->getDataType().getPhysicalType(), [&scalarFunction]<typename T>(
T) {
scalarFunction->execFunc =
ScalarFunction::BinaryExecListStructFunction<list_entry_t, T, list_entry_t, ListAppend>;
});
return FunctionBindData::getSimpleBindData(arguments, arguments[0]->getDataType());
}

function_set ListAppendFunction::getFunctionSet() {
function_set result;
result.push_back(std::make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{LogicalTypeID::LIST, LogicalTypeID::ANY}, LogicalTypeID::LIST,
nullptr /* execFunc */, nullptr /* selectFunc */, bindFunc));
return result;
}

} // namespace function
} // namespace kuzu
Loading

0 comments on commit ac08521

Please sign in to comment.