Skip to content

Commit

Permalink
Merge pull request #2069 from kuzudb/list_product
Browse files Browse the repository at this point in the history
finish list_product
  • Loading branch information
andyfengHKU committed Sep 27, 2023
2 parents eb22aa0 + 11700f8 commit 17723ab
Show file tree
Hide file tree
Showing 10 changed files with 196 additions and 147 deletions.
1 change: 1 addition & 0 deletions src/function/built_in_vector_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,7 @@ void BuiltInVectorFunctions::registerListFunctions() {
vectorFunctions.insert(
{LIST_REVERSE_SORT_FUNC_NAME, ListReverseSortVectorFunction::getDefinitions()});
vectorFunctions.insert({LIST_SUM_FUNC_NAME, ListSumVectorFunction::getDefinitions()});
vectorFunctions.insert({LIST_PRODUCT_FUNC_NAME, ListProductVectorFunction::getDefinitions()});
vectorFunctions.insert({LIST_DISTINCT_FUNC_NAME, ListDistinctVectorFunction::getDefinitions()});
vectorFunctions.insert({LIST_UNIQUE_FUNC_NAME, ListUniqueVectorFunction::getDefinitions()});
vectorFunctions.insert(
Expand Down
135 changes: 14 additions & 121 deletions src/function/vector_list_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "function/list/functions/list_len_function.h"
#include "function/list/functions/list_position_function.h"
#include "function/list/functions/list_prepend_function.h"
#include "function/list/functions/list_product_function.h"
#include "function/list/functions/list_range_function.h"
#include "function/list/functions/list_reverse_sort_function.h"
#include "function/list/functions/list_slice_function.h"
Expand Down Expand Up @@ -264,84 +265,21 @@ vector_function_definitions ListAppendVectorFunction::getDefinitions() {
std::unique_ptr<FunctionBindData> ListPrependVectorFunction::bindFunc(
const binder::expression_vector& arguments, FunctionDefinition* definition) {
if (arguments[0]->getDataType().getLogicalTypeID() != LogicalTypeID::ANY &&
arguments[0]->dataType != *VarListType::getChildType(&arguments[1]->dataType)) {
arguments[1]->dataType != *VarListType::getChildType(&arguments[0]->dataType)) {
throw BinderException(getListFunctionIncompatibleChildrenTypeErrorMsg(
LIST_APPEND_FUNC_NAME, arguments[0]->getDataType(), arguments[1]->getDataType()));
LIST_PREPEND_FUNC_NAME, arguments[0]->getDataType(), arguments[1]->getDataType()));
}
auto resultType = arguments[1]->getDataType();
auto resultType = arguments[0]->getDataType();
auto vectorFunctionDefinition = reinterpret_cast<VectorFunctionDefinition*>(definition);
switch (arguments[0]->getDataType().getPhysicalType()) {
case PhysicalTypeID::INT64: {
vectorFunctionDefinition->execFunc =
BinaryExecListStructFunction<int64_t, list_entry_t, list_entry_t, ListPrepend>;
} break;
case PhysicalTypeID::INT32: {
vectorFunctionDefinition->execFunc =
BinaryExecListStructFunction<int32_t, list_entry_t, list_entry_t, ListPrepend>;
} break;
case PhysicalTypeID::INT16: {
vectorFunctionDefinition->execFunc =
BinaryExecListStructFunction<int16_t, list_entry_t, list_entry_t, ListPrepend>;
} break;
case PhysicalTypeID::INT8: {
vectorFunctionDefinition->execFunc =
BinaryExecListStructFunction<int8_t, list_entry_t, list_entry_t, ListPrepend>;
} break;
case PhysicalTypeID::UINT64: {
vectorFunctionDefinition->execFunc =
BinaryExecListStructFunction<uint64_t, list_entry_t, list_entry_t, ListPrepend>;
} break;
case PhysicalTypeID::UINT32: {
vectorFunctionDefinition->execFunc =
BinaryExecListStructFunction<uint32_t, list_entry_t, list_entry_t, ListPrepend>;
} break;
case PhysicalTypeID::UINT16: {
vectorFunctionDefinition->execFunc =
BinaryExecListStructFunction<uint16_t, list_entry_t, list_entry_t, ListPrepend>;
} break;
case PhysicalTypeID::UINT8: {
vectorFunctionDefinition->execFunc =
BinaryExecListStructFunction<uint8_t, list_entry_t, list_entry_t, ListPrepend>;
} break;
case PhysicalTypeID::DOUBLE: {
vectorFunctionDefinition->execFunc =
BinaryExecListStructFunction<double_t, list_entry_t, list_entry_t, ListPrepend>;
} break;
case PhysicalTypeID::FLOAT: {
vectorFunctionDefinition->execFunc =
BinaryExecListStructFunction<float_t, list_entry_t, list_entry_t, ListPrepend>;
} break;
case PhysicalTypeID::BOOL: {
vectorFunctionDefinition->execFunc =
BinaryExecListStructFunction<uint8_t, list_entry_t, list_entry_t, ListPrepend>;
} break;
case PhysicalTypeID::STRING: {
vectorFunctionDefinition->execFunc =
BinaryExecListStructFunction<ku_string_t, list_entry_t, list_entry_t, ListPrepend>;
} break;
case PhysicalTypeID::INTERVAL: {
vectorFunctionDefinition->execFunc =
BinaryExecListStructFunction<interval_t, list_entry_t, list_entry_t, ListPrepend>;
} break;
case PhysicalTypeID::VAR_LIST: {
vectorFunctionDefinition->execFunc =
BinaryExecListStructFunction<list_entry_t, list_entry_t, list_entry_t, ListPrepend>;
} break;
case PhysicalTypeID::INTERNAL_ID: {
vectorFunctionDefinition->execFunc =
BinaryExecListStructFunction<internalID_t, list_entry_t, list_entry_t, ListPrepend>;
} break;
default: {
throw NotImplementedException("ListPrependVectorFunction::bindFunc");
}
}
vectorFunctionDefinition->execFunc =
getBinaryListExecFuncSwitchRight<ListPrepend, list_entry_t>(arguments[1]->getDataType());
return std::make_unique<FunctionBindData>(resultType);
}

vector_function_definitions ListPrependVectorFunction::getDefinitions() {
vector_function_definitions result;
result.push_back(std::make_unique<VectorFunctionDefinition>(LIST_PREPEND_FUNC_NAME,
std::vector<LogicalTypeID>{LogicalTypeID::ANY, LogicalTypeID::VAR_LIST},
std::vector<LogicalTypeID>{LogicalTypeID::VAR_LIST, LogicalTypeID::ANY},
LogicalTypeID::VAR_LIST, nullptr, nullptr, bindFunc, false /* isVarlength */));
return result;
}
Expand Down Expand Up @@ -577,61 +515,16 @@ vector_function_definitions ListSumVectorFunction::getDefinitions() {
vector_function_definitions result;
result.push_back(std::make_unique<VectorFunctionDefinition>(LIST_SUM_FUNC_NAME,
std::vector<LogicalTypeID>{LogicalTypeID::VAR_LIST}, LogicalTypeID::INT64, nullptr, nullptr,
bindFunc, false /* isVarlength*/));
bindFuncListAggre<ListSum>, false /* isVarlength*/));
return result;
}

std::unique_ptr<FunctionBindData> ListSumVectorFunction::bindFunc(
const binder::expression_vector& arguments, FunctionDefinition* definition) {
auto vectorFunctionDefinition = reinterpret_cast<VectorFunctionDefinition*>(definition);
auto resultType = VarListType::getChildType(&arguments[0]->dataType);
switch (resultType->getLogicalTypeID()) {
case LogicalTypeID::SERIAL:
case LogicalTypeID::INT64: {
vectorFunctionDefinition->execFunc =
UnaryExecListStructFunction<list_entry_t, int64_t, ListSum>;
} break;
case LogicalTypeID::INT32: {
vectorFunctionDefinition->execFunc =
UnaryExecListStructFunction<list_entry_t, int32_t, ListSum>;
} break;
case LogicalTypeID::INT16: {
vectorFunctionDefinition->execFunc =
UnaryExecListStructFunction<list_entry_t, int16_t, ListSum>;
} break;
case LogicalTypeID::INT8: {
vectorFunctionDefinition->execFunc =
UnaryExecListStructFunction<list_entry_t, int8_t, ListSum>;
} break;
case LogicalTypeID::UINT64: {
vectorFunctionDefinition->execFunc =
UnaryExecListStructFunction<list_entry_t, uint64_t, ListSum>;
} break;
case LogicalTypeID::UINT32: {
vectorFunctionDefinition->execFunc =
UnaryExecListStructFunction<list_entry_t, uint32_t, ListSum>;
} break;
case LogicalTypeID::UINT16: {
vectorFunctionDefinition->execFunc =
UnaryExecListStructFunction<list_entry_t, uint16_t, ListSum>;
} break;
case LogicalTypeID::UINT8: {
vectorFunctionDefinition->execFunc =
UnaryExecListStructFunction<list_entry_t, uint8_t, ListSum>;
} break;
case LogicalTypeID::DOUBLE: {
vectorFunctionDefinition->execFunc =
UnaryExecListStructFunction<list_entry_t, double_t, ListSum>;
} break;
case LogicalTypeID::FLOAT: {
vectorFunctionDefinition->execFunc =
UnaryExecListStructFunction<list_entry_t, float_t, ListSum>;
} break;
default: {
throw NotImplementedException("ListSumVectorFunction::bindFunc");
}
}
return std::make_unique<FunctionBindData>(*resultType);
vector_function_definitions ListProductVectorFunction::getDefinitions() {
vector_function_definitions result;
result.push_back(std::make_unique<VectorFunctionDefinition>(LIST_PRODUCT_FUNC_NAME,
std::vector<LogicalTypeID>{LogicalTypeID::VAR_LIST}, LogicalTypeID::INT64, nullptr, nullptr,
bindFuncListAggre<ListProduct>, false /* isVarlength*/));
return result;
}

vector_function_definitions ListDistinctVectorFunction::getDefinitions() {
Expand Down
1 change: 1 addition & 0 deletions src/include/common/expression_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ const std::string ARRAY_HAS_FUNC_NAME = "ARRAY_HAS";
const std::string LIST_SLICE_FUNC_NAME = "LIST_SLICE";
const std::string ARRAY_SLICE_FUNC_NAME = "ARRAY_SLICE";
const std::string LIST_SUM_FUNC_NAME = "LIST_SUM";
const std::string LIST_PRODUCT_FUNC_NAME = "LIST_PRODUCT";
const std::string LIST_SORT_FUNC_NAME = "LIST_SORT";
const std::string LIST_REVERSE_SORT_FUNC_NAME = "LIST_REVERSE_SORT";
const std::string LIST_DISTINCT_FUNC_NAME = "LIST_DISTINCT";
Expand Down
6 changes: 3 additions & 3 deletions src/include/function/list/functions/list_prepend_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ namespace function {

struct ListPrepend {
template<typename T>
static inline void operation(T& value, common::list_entry_t& listEntry,
common::list_entry_t& result, common::ValueVector& valueVector,
common::ValueVector& listVector, common::ValueVector& resultVector) {
static inline void operation(common::list_entry_t& listEntry, T& value,
common::list_entry_t& result, common::ValueVector& listVector,
common::ValueVector& valueVector, common::ValueVector& resultVector) {
result = common::ListVector::addList(&resultVector, listEntry.size + 1);
auto resultDataVector = common::ListVector::getDataVector(&resultVector);
resultDataVector->copyFromVectorData(
Expand Down
24 changes: 24 additions & 0 deletions src/include/function/list/functions/list_product_function.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#pragma once

#include "common/vector/value_vector.h"

namespace kuzu {
namespace function {

struct ListProduct {
template<typename T>
static inline void operation(common::list_entry_t& input, T& result,
common::ValueVector& inputVector, common::ValueVector& resultVector) {
auto inputDataVector = common::ListVector::getDataVector(&inputVector);
result = 1;
for (auto i = 0; i < input.size; i++) {
if (inputDataVector->isNull(input.offset + i)) {
continue;
}
result *= inputDataVector->getValue<T>(input.offset + i);
}
}
};

} // namespace function
} // namespace kuzu
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ struct ListReverseSort : BaseListSortOperation {
sortValues<T>(input, result, inputVector, resultVector, false /* ascOrder */,
isNullFirst(nullOrder.getAsString()) /* nullFirst */);
}

static inline void operation(common::list_entry_t& input, common::ku_string_t& sortOrder,
common::ku_string_t& nullOrder, common::list_entry_t& result,
common::ValueVector& inputVector, common::ValueVector& resultVector) {
throw common::RuntimeException("Invalid number of arguments");
}
};

} // namespace function
Expand Down
4 changes: 1 addition & 3 deletions src/include/function/list/functions/list_sum_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,13 @@ struct ListSum {
template<typename T>
static inline void operation(common::list_entry_t& input, T& result,
common::ValueVector& inputVector, common::ValueVector& resultVector) {
auto inputValues =
reinterpret_cast<T*>(common::ListVector::getListValues(&inputVector, input));
auto inputDataVector = common::ListVector::getDataVector(&inputVector);
result = 0;
for (auto i = 0; i < input.size; i++) {
if (inputDataVector->isNull(input.offset + i)) {
continue;
}
result += inputValues[i];
result += inputDataVector->getValue<T>(input.offset + i);
}
}
};
Expand Down
60 changes: 58 additions & 2 deletions src/include/function/list/vector_list_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,60 @@ struct VectorListFunction : public VectorFunction {
}
return execFunc;
}

template<typename OPERATION>
static std::unique_ptr<FunctionBindData> bindFuncListAggre(
const binder::expression_vector& arguments, FunctionDefinition* definition) {
auto vectorFunctionDefinition = reinterpret_cast<VectorFunctionDefinition*>(definition);
auto resultType = common::VarListType::getChildType(&arguments[0]->dataType);
switch (resultType->getLogicalTypeID()) {
case common::LogicalTypeID::SERIAL:
case common::LogicalTypeID::INT64: {
vectorFunctionDefinition->execFunc =
UnaryExecListStructFunction<common::list_entry_t, int64_t, OPERATION>;
} break;
case common::LogicalTypeID::INT32: {
vectorFunctionDefinition->execFunc =
UnaryExecListStructFunction<common::list_entry_t, int32_t, OPERATION>;
} break;
case common::LogicalTypeID::INT16: {
vectorFunctionDefinition->execFunc =
UnaryExecListStructFunction<common::list_entry_t, int16_t, OPERATION>;
} break;
case common::LogicalTypeID::INT8: {
vectorFunctionDefinition->execFunc =
UnaryExecListStructFunction<common::list_entry_t, int8_t, OPERATION>;
} break;
case common::LogicalTypeID::UINT64: {
vectorFunctionDefinition->execFunc =
UnaryExecListStructFunction<common::list_entry_t, uint64_t, OPERATION>;
} break;
case common::LogicalTypeID::UINT32: {
vectorFunctionDefinition->execFunc =
UnaryExecListStructFunction<common::list_entry_t, uint32_t, OPERATION>;
} break;
case common::LogicalTypeID::UINT16: {
vectorFunctionDefinition->execFunc =
UnaryExecListStructFunction<common::list_entry_t, uint16_t, OPERATION>;
} break;
case common::LogicalTypeID::UINT8: {
vectorFunctionDefinition->execFunc =
UnaryExecListStructFunction<common::list_entry_t, uint8_t, OPERATION>;
} break;
case common::LogicalTypeID::DOUBLE: {
vectorFunctionDefinition->execFunc =
UnaryExecListStructFunction<common::list_entry_t, double_t, OPERATION>;
} break;
case common::LogicalTypeID::FLOAT: {
vectorFunctionDefinition->execFunc =
UnaryExecListStructFunction<common::list_entry_t, float_t, OPERATION>;
} break;
default: {
throw common::NotImplementedException(definition->name + "::bindFunc");
}
}
return std::make_unique<FunctionBindData>(*resultType);
}
};

struct ListCreationVectorFunction : public VectorListFunction {
Expand Down Expand Up @@ -213,8 +267,10 @@ struct ListReverseSortVectorFunction : public VectorListFunction {

struct ListSumVectorFunction : public VectorListFunction {
static vector_function_definitions getDefinitions();
static std::unique_ptr<FunctionBindData> bindFunc(
const binder::expression_vector& arguments, FunctionDefinition* definition);
};

struct ListProductVectorFunction : public VectorListFunction {
static vector_function_definitions getDefinitions();
};

struct ListDistinctVectorFunction : public VectorListFunction {
Expand Down
Loading

0 comments on commit 17723ab

Please sign in to comment.