Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More list functions #1543

Merged
merged 1 commit into from
May 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/function/built_in_vector_operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,14 @@ void BuiltInVectorOperations::registerListOperations() {
vectorOperations.insert({LIST_SLICE_FUNC_NAME, ListSliceVectorOperation::getDefinitions()});
vectorOperations.insert({ARRAY_SLICE_FUNC_NAME, ListSliceVectorOperation::getDefinitions()});
vectorOperations.insert({LIST_SORT_FUNC_NAME, ListSortVectorOperation::getDefinitions()});
vectorOperations.insert(
{LIST_REVERSE_SORT_FUNC_NAME, ListReverseSortVectorOperation::getDefinitions()});
vectorOperations.insert({LIST_SUM_FUNC_NAME, ListSumVectorOperation::getDefinitions()});
vectorOperations.insert(
{LIST_DISTINCT_FUNC_NAME, ListDistinctVectorOperation::getDefinitions()});
vectorOperations.insert({LIST_UNIQUE_FUNC_NAME, ListUniqueVectorOperation::getDefinitions()});
vectorOperations.insert(
{LIST_ANY_VALUE_FUNC_NAME, ListAnyValueVectorOperation::getDefinitions()});
}

void BuiltInVectorOperations::registerInternalIDOperation() {
Expand Down
264 changes: 264 additions & 0 deletions src/function/vector_list_operation.cpp
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
#include "binder/expression_binder.h"
#include "common/types/ku_list.h"
#include "common/vector/value_vector_utils.h"
#include "function/list/operations/list_any_value_operation.h"
#include "function/list/operations/list_append_operation.h"
#include "function/list/operations/list_concat_operation.h"
#include "function/list/operations/list_contains.h"
#include "function/list/operations/list_distinct_operation.h"
#include "function/list/operations/list_extract_operation.h"
#include "function/list/operations/list_len_operation.h"
#include "function/list/operations/list_position_operation.h"
#include "function/list/operations/list_prepend_operation.h"
#include "function/list/operations/list_reverse_sort_operation.h"
#include "function/list/operations/list_slice_operation.h"
#include "function/list/operations/list_sort_operation.h"
#include "function/list/operations/list_sum_operation.h"
#include "function/list/operations/list_unique_operation.h"
#include "function/list/vector_list_operations.h"

using namespace kuzu::common;
Expand Down Expand Up @@ -343,9 +347,18 @@ std::unique_ptr<FunctionBindData> ListSortVectorOperation::bindFunc(
case INT64: {
vectorOperationDefinition->execFunc = getExecFunction<int64_t>(arguments);
} break;
case INT32: {
vectorOperationDefinition->execFunc = getExecFunction<int32_t>(arguments);
} break;
case INT16: {
vectorOperationDefinition->execFunc = getExecFunction<int16_t>(arguments);
} break;
case DOUBLE: {
vectorOperationDefinition->execFunc = getExecFunction<double_t>(arguments);
} break;
case FLOAT: {
vectorOperationDefinition->execFunc = getExecFunction<float_t>(arguments);
} break;
case BOOL: {
vectorOperationDefinition->execFunc = getExecFunction<uint8_t>(arguments);
} break;
Expand Down Expand Up @@ -384,6 +397,72 @@ scalar_exec_func ListSortVectorOperation::getExecFunction(
}
}

std::vector<std::unique_ptr<VectorOperationDefinition>>
ListReverseSortVectorOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> result;
result.push_back(std::make_unique<VectorOperationDefinition>(LIST_REVERSE_SORT_FUNC_NAME,
std::vector<DataTypeID>{VAR_LIST}, VAR_LIST, nullptr, nullptr, bindFunc,
false /* isVarlength*/));
result.push_back(std::make_unique<VectorOperationDefinition>(LIST_REVERSE_SORT_FUNC_NAME,
std::vector<DataTypeID>{VAR_LIST, STRING}, VAR_LIST, nullptr, nullptr, bindFunc,
false /* isVarlength*/));
return result;
}

std::unique_ptr<FunctionBindData> ListReverseSortVectorOperation::bindFunc(
const binder::expression_vector& arguments, FunctionDefinition* definition) {
auto vectorOperationDefinition = reinterpret_cast<VectorOperationDefinition*>(definition);
switch (arguments[0]->dataType.getChildType()->getTypeID()) {
case INT64: {
vectorOperationDefinition->execFunc = getExecFunction<int64_t>(arguments);
} break;
case INT32: {
vectorOperationDefinition->execFunc = getExecFunction<int32_t>(arguments);
} break;
case INT16: {
vectorOperationDefinition->execFunc = getExecFunction<int16_t>(arguments);
} break;
case DOUBLE: {
vectorOperationDefinition->execFunc = getExecFunction<double_t>(arguments);
} break;
case FLOAT: {
vectorOperationDefinition->execFunc = getExecFunction<float_t>(arguments);
} break;
case BOOL: {
vectorOperationDefinition->execFunc = getExecFunction<uint8_t>(arguments);
} break;
case STRING: {
vectorOperationDefinition->execFunc = getExecFunction<ku_string_t>(arguments);
} break;
case DATE: {
vectorOperationDefinition->execFunc = getExecFunction<date_t>(arguments);
} break;
case TIMESTAMP: {
vectorOperationDefinition->execFunc = getExecFunction<timestamp_t>(arguments);
} break;
case INTERVAL: {
vectorOperationDefinition->execFunc = getExecFunction<interval_t>(arguments);
} break;
default: {
throw common::NotImplementedException("ListReverseSortVectorOperation::bindFunc");
}
}
return std::make_unique<FunctionBindData>(arguments[0]->getDataType());
}

template<typename T>
scalar_exec_func ListReverseSortVectorOperation::getExecFunction(
const binder::expression_vector& arguments) {
if (arguments.size() == 1) {
return UnaryListExecFunction<list_entry_t, list_entry_t, operation::ListReverseSort<T>>;
} else if (arguments.size() == 2) {
return BinaryListExecFunction<list_entry_t, ku_string_t, list_entry_t,
operation::ListReverseSort<T>>;
} else {
throw common::RuntimeException("Invalid number of arguments");
}
}

std::vector<std::unique_ptr<VectorOperationDefinition>> ListSumVectorOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> result;
result.push_back(std::make_unique<VectorOperationDefinition>(LIST_SUM_FUNC_NAME,
Expand Down Expand Up @@ -412,5 +491,190 @@ std::unique_ptr<FunctionBindData> ListSumVectorOperation::bindFunc(
return std::make_unique<FunctionBindData>(resultType);
}

std::vector<std::unique_ptr<VectorOperationDefinition>>
ListDistinctVectorOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> result;
result.push_back(std::make_unique<VectorOperationDefinition>(LIST_DISTINCT_FUNC_NAME,
std::vector<DataTypeID>{VAR_LIST}, VAR_LIST, nullptr, nullptr, bindFunc,
false /* isVarlength*/));
return result;
}

std::unique_ptr<FunctionBindData> ListDistinctVectorOperation::bindFunc(
const binder::expression_vector& arguments, FunctionDefinition* definition) {
auto vectorOperationDefinition = reinterpret_cast<VectorOperationDefinition*>(definition);
switch (arguments[0]->dataType.getChildType()->getTypeID()) {
case INT64: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, list_entry_t, operation::ListDistinct<int64_t>>;
} break;
case INT32: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, list_entry_t, operation::ListDistinct<int32_t>>;
} break;
case INT16: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, list_entry_t, operation::ListDistinct<int16_t>>;
} break;
case DOUBLE: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, list_entry_t, operation::ListDistinct<double_t>>;
} break;
case FLOAT: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, list_entry_t, operation::ListDistinct<float_t>>;
} break;
case BOOL: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, list_entry_t, operation::ListDistinct<uint8_t>>;
} break;
case STRING: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, list_entry_t, operation::ListDistinct<ku_string_t>>;
} break;
case DATE: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, list_entry_t, operation::ListDistinct<date_t>>;
} break;
case TIMESTAMP: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, list_entry_t, operation::ListDistinct<timestamp_t>>;
} break;
case INTERVAL: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, list_entry_t, operation::ListDistinct<interval_t>>;
} break;
default: {
throw common::NotImplementedException("ListDistinctVectorOperation::bindFunc");
}
}
return std::make_unique<FunctionBindData>(arguments[0]->getDataType());
}

std::vector<std::unique_ptr<VectorOperationDefinition>>
ListUniqueVectorOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> result;
result.push_back(std::make_unique<VectorOperationDefinition>(LIST_UNIQUE_FUNC_NAME,
std::vector<DataTypeID>{VAR_LIST}, INT64, nullptr, nullptr, bindFunc,
false /* isVarlength*/));
return result;
}

std::unique_ptr<FunctionBindData> ListUniqueVectorOperation::bindFunc(
const binder::expression_vector& arguments, FunctionDefinition* definition) {
auto vectorOperationDefinition = reinterpret_cast<VectorOperationDefinition*>(definition);
switch (arguments[0]->dataType.getChildType()->getTypeID()) {
case INT64: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, int64_t, operation::ListUnique<int64_t>>;
} break;
case INT32: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, int64_t, operation::ListUnique<int32_t>>;
} break;
case INT16: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, int64_t, operation::ListUnique<int16_t>>;
} break;
case DOUBLE: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, int64_t, operation::ListUnique<double_t>>;
} break;
case FLOAT: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, int64_t, operation::ListUnique<float_t>>;
} break;
case BOOL: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, int64_t, operation::ListUnique<uint8_t>>;
} break;
case STRING: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, int64_t, operation::ListUnique<ku_string_t>>;
} break;
case DATE: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, int64_t, operation::ListUnique<date_t>>;
} break;
case TIMESTAMP: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, int64_t, operation::ListUnique<timestamp_t>>;
} break;
case INTERVAL: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, int64_t, operation::ListUnique<interval_t>>;
} break;
default: {
throw common::NotImplementedException("ListUniqueVectorOperation::bindFunc");
}
}
return std::make_unique<FunctionBindData>(DataType(INT64));
}

std::vector<std::unique_ptr<VectorOperationDefinition>>
ListAnyValueVectorOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> result;
result.push_back(std::make_unique<VectorOperationDefinition>(LIST_ANY_VALUE_FUNC_NAME,
std::vector<DataTypeID>{VAR_LIST}, ANY, nullptr, nullptr, bindFunc,
false /* isVarlength*/));
return result;
}

std::unique_ptr<FunctionBindData> ListAnyValueVectorOperation::bindFunc(
const binder::expression_vector& arguments, FunctionDefinition* definition) {
auto vectorOperationDefinition = reinterpret_cast<VectorOperationDefinition*>(definition);
auto resultType = *arguments[0]->getDataType().getChildType();
switch (resultType.typeID) {
case INT64: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, int64_t, operation::ListAnyValue>;
} break;
case INT32: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, int32_t, operation::ListAnyValue>;
} break;
case INT16: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, int16_t, operation::ListAnyValue>;
} break;
case DOUBLE: {
gaurav8297 marked this conversation as resolved.
Show resolved Hide resolved
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, double_t, operation::ListAnyValue>;
} break;
case FLOAT: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, float_t, operation::ListAnyValue>;
} break;
case BOOL: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, uint8_t, operation::ListAnyValue>;
} break;
case STRING: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, ku_string_t, operation::ListAnyValue>;
} break;
case DATE: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, date_t, operation::ListAnyValue>;
} break;
case TIMESTAMP: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, timestamp_t, operation::ListAnyValue>;
} break;
case INTERVAL: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, interval_t, operation::ListAnyValue>;
} break;
case VAR_LIST: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, list_entry_t, operation::ListAnyValue>;
} break;
default: {
throw common::NotImplementedException("ListAnyValueVectorOperation::bindFunc");
}
}
return std::make_unique<FunctionBindData>(resultType);
}

} // namespace function
} // namespace kuzu
4 changes: 4 additions & 0 deletions src/include/common/expression_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ const std::string LIST_SLICE_FUNC_NAME = "LIST_SLICE";
const std::string ARRAY_SLICE_FUNC_NAME = "ARRAY_SLICE";
const std::string LIST_SUM_FUNC_NAME = "LIST_SUM";
const std::string LIST_SORT_FUNC_NAME = "LIST_SORT";
const std::string LIST_REVERSE_SORT_FUNC_NAME = "LIST_REVERSE_SORT";
const std::string LIST_DISTINCT_FUNC_NAME = "LIST_DISTINCT";
const std::string LIST_UNIQUE_FUNC_NAME = "LIST_UNIQUE";
const std::string LIST_ANY_VALUE_FUNC_NAME = "LIST_ANY_VALUE";

// struct
const std::string STRUCT_PACK_FUNC_NAME = "STRUCT_PACK";
Expand Down
Loading