Skip to content

Commit

Permalink
Add list functions
Browse files Browse the repository at this point in the history
1. list_reverse_sort(list)
Sorts the elements of the list in reverse order.

2. list_distinct(list)
Removes all duplicates and NULLs from a list.
Does not preserve the original order.

3. list_unique(list)
Counts the unique elements of a list.

4. list_any_value(list)
Returns the first non-null value in the list.
  • Loading branch information
gaurav8297 committed May 16, 2023
1 parent 4f75391 commit b4397a4
Show file tree
Hide file tree
Showing 11 changed files with 751 additions and 95 deletions.
7 changes: 7 additions & 0 deletions src/function/built_in_vector_operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,14 @@ void BuiltInVectorOperations::registerListOperations() {
vectorOperations.insert({LIST_SLICE_FUNC_NAME, ListSliceVectorOperation::getDefinitions()});
vectorOperations.insert({ARRAY_SLICE_FUNC_NAME, ListSliceVectorOperation::getDefinitions()});
vectorOperations.insert({LIST_SORT_FUNC_NAME, ListSortVectorOperation::getDefinitions()});
vectorOperations.insert(
{LIST_REVERSE_SORT_FUNC_NAME, ListReverseSortVectorOperation::getDefinitions()});
vectorOperations.insert({LIST_SUM_FUNC_NAME, ListSumVectorOperation::getDefinitions()});
vectorOperations.insert(
{LIST_DISTINCT_FUNC_NAME, ListDistinctVectorOperation::getDefinitions()});
vectorOperations.insert({LIST_UNIQUE_FUNC_NAME, ListUniqueVectorOperation::getDefinitions()});
vectorOperations.insert(
{LIST_ANY_VALUE_FUNC_NAME, ListAnyValueVectorOperation::getDefinitions()});
}

void BuiltInVectorOperations::registerInternalIDOperation() {
Expand Down
264 changes: 264 additions & 0 deletions src/function/vector_list_operation.cpp
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
#include "binder/expression_binder.h"
#include "common/types/ku_list.h"
#include "common/vector/value_vector_utils.h"
#include "function/list/operations/list_any_value_operation.h"
#include "function/list/operations/list_append_operation.h"
#include "function/list/operations/list_concat_operation.h"
#include "function/list/operations/list_contains.h"
#include "function/list/operations/list_distinct_operation.h"
#include "function/list/operations/list_extract_operation.h"
#include "function/list/operations/list_len_operation.h"
#include "function/list/operations/list_position_operation.h"
#include "function/list/operations/list_prepend_operation.h"
#include "function/list/operations/list_reverse_sort_operation.h"
#include "function/list/operations/list_slice_operation.h"
#include "function/list/operations/list_sort_operation.h"
#include "function/list/operations/list_sum_operation.h"
#include "function/list/operations/list_unique_operation.h"
#include "function/list/vector_list_operations.h"

using namespace kuzu::common;
Expand Down Expand Up @@ -340,9 +344,18 @@ std::unique_ptr<FunctionBindData> ListSortVectorOperation::bindFunc(
const binder::expression_vector& arguments, FunctionDefinition* definition) {
auto vectorOperationDefinition = reinterpret_cast<VectorOperationDefinition*>(definition);
switch (arguments[0]->dataType.getChildType()->getTypeID()) {
case INT16: {
vectorOperationDefinition->execFunc = getExecFunction<int16_t>(arguments);
} break;
case INT32: {
vectorOperationDefinition->execFunc = getExecFunction<int32_t>(arguments);
} break;
case INT64: {
vectorOperationDefinition->execFunc = getExecFunction<int64_t>(arguments);
} break;
case FLOAT: {
vectorOperationDefinition->execFunc = getExecFunction<float_t>(arguments);
} break;
case DOUBLE: {
vectorOperationDefinition->execFunc = getExecFunction<double_t>(arguments);
} break;
Expand Down Expand Up @@ -384,6 +397,72 @@ scalar_exec_func ListSortVectorOperation::getExecFunction(
}
}

std::vector<std::unique_ptr<VectorOperationDefinition>>
ListReverseSortVectorOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> result;
result.push_back(std::make_unique<VectorOperationDefinition>(LIST_REVERSE_SORT_FUNC_NAME,
std::vector<DataTypeID>{VAR_LIST}, VAR_LIST, nullptr, nullptr, bindFunc,
false /* isVarlength*/));
result.push_back(std::make_unique<VectorOperationDefinition>(LIST_REVERSE_SORT_FUNC_NAME,
std::vector<DataTypeID>{VAR_LIST, STRING}, VAR_LIST, nullptr, nullptr, bindFunc,
false /* isVarlength*/));
return result;
}

std::unique_ptr<FunctionBindData> ListReverseSortVectorOperation::bindFunc(
const binder::expression_vector& arguments, FunctionDefinition* definition) {
auto vectorOperationDefinition = reinterpret_cast<VectorOperationDefinition*>(definition);
switch (arguments[0]->dataType.getChildType()->getTypeID()) {
case INT16: {
vectorOperationDefinition->execFunc = getExecFunction<int16_t>(arguments);
} break;
case INT32: {
vectorOperationDefinition->execFunc = getExecFunction<int32_t>(arguments);
} break;
case INT64: {
vectorOperationDefinition->execFunc = getExecFunction<int64_t>(arguments);
} break;
case FLOAT: {
vectorOperationDefinition->execFunc = getExecFunction<float_t>(arguments);
} break;
case DOUBLE: {
vectorOperationDefinition->execFunc = getExecFunction<double_t>(arguments);
} break;
case BOOL: {
vectorOperationDefinition->execFunc = getExecFunction<uint8_t>(arguments);
} break;
case STRING: {
vectorOperationDefinition->execFunc = getExecFunction<ku_string_t>(arguments);
} break;
case DATE: {
vectorOperationDefinition->execFunc = getExecFunction<date_t>(arguments);
} break;
case TIMESTAMP: {
vectorOperationDefinition->execFunc = getExecFunction<timestamp_t>(arguments);
} break;
case INTERVAL: {
vectorOperationDefinition->execFunc = getExecFunction<interval_t>(arguments);
} break;
default: {
throw common::NotImplementedException("ListReverseSortVectorOperation::bindFunc");
}
}
return std::make_unique<FunctionBindData>(arguments[0]->getDataType());
}

template<typename T>
scalar_exec_func ListReverseSortVectorOperation::getExecFunction(
const binder::expression_vector& arguments) {
if (arguments.size() == 1) {
return UnaryListExecFunction<list_entry_t, list_entry_t, operation::ListReverseSort<T>>;
} else if (arguments.size() == 2) {
return BinaryListExecFunction<list_entry_t, ku_string_t, list_entry_t,
operation::ListReverseSort<T>>;
} else {
throw common::RuntimeException("Invalid number of arguments");
}
}

std::vector<std::unique_ptr<VectorOperationDefinition>> ListSumVectorOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> result;
result.push_back(std::make_unique<VectorOperationDefinition>(LIST_SUM_FUNC_NAME,
Expand Down Expand Up @@ -412,5 +491,190 @@ std::unique_ptr<FunctionBindData> ListSumVectorOperation::bindFunc(
return std::make_unique<FunctionBindData>(resultType);
}

std::vector<std::unique_ptr<VectorOperationDefinition>>
ListDistinctVectorOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> result;
result.push_back(std::make_unique<VectorOperationDefinition>(LIST_DISTINCT_FUNC_NAME,
std::vector<DataTypeID>{VAR_LIST}, VAR_LIST, nullptr, nullptr, bindFunc,
false /* isVarlength*/));
return result;
}

std::unique_ptr<FunctionBindData> ListDistinctVectorOperation::bindFunc(
const binder::expression_vector& arguments, FunctionDefinition* definition) {
auto vectorOperationDefinition = reinterpret_cast<VectorOperationDefinition*>(definition);
switch (arguments[0]->dataType.getChildType()->getTypeID()) {
case INT16: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, list_entry_t, operation::ListDistinct<int16_t>>;
} break;
case INT32: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, list_entry_t, operation::ListDistinct<int32_t>>;
} break;
case INT64: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, list_entry_t, operation::ListDistinct<int64_t>>;
} break;
case FLOAT: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, list_entry_t, operation::ListDistinct<float_t>>;
} break;
case DOUBLE: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, list_entry_t, operation::ListDistinct<double_t>>;
} break;
case BOOL: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, list_entry_t, operation::ListDistinct<uint8_t>>;
} break;
case STRING: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, list_entry_t, operation::ListDistinct<ku_string_t>>;
} break;
case DATE: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, list_entry_t, operation::ListDistinct<date_t>>;
} break;
case TIMESTAMP: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, list_entry_t, operation::ListDistinct<timestamp_t>>;
} break;
case INTERVAL: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, list_entry_t, operation::ListDistinct<interval_t>>;
} break;
default: {
throw common::NotImplementedException("ListDistinctVectorOperation::bindFunc");
}
}
return std::make_unique<FunctionBindData>(arguments[0]->getDataType());
}

std::vector<std::unique_ptr<VectorOperationDefinition>>
ListUniqueVectorOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> result;
result.push_back(std::make_unique<VectorOperationDefinition>(LIST_UNIQUE_FUNC_NAME,
std::vector<DataTypeID>{VAR_LIST}, INT64, nullptr, nullptr, bindFunc,
false /* isVarlength*/));
return result;
}

std::unique_ptr<FunctionBindData> ListUniqueVectorOperation::bindFunc(
const binder::expression_vector& arguments, FunctionDefinition* definition) {
auto vectorOperationDefinition = reinterpret_cast<VectorOperationDefinition*>(definition);
switch (arguments[0]->dataType.getChildType()->getTypeID()) {
case INT16: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, int64_t, operation::ListUnique<int16_t>>;
} break;
case INT32: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, int64_t, operation::ListUnique<int32_t>>;
} break;
case INT64: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, int64_t, operation::ListUnique<int64_t>>;
} break;
case FLOAT: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, int64_t, operation::ListUnique<float_t>>;
} break;
case DOUBLE: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, int64_t, operation::ListUnique<double_t>>;
} break;
case BOOL: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, int64_t, operation::ListUnique<uint8_t>>;
} break;
case STRING: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, int64_t, operation::ListUnique<ku_string_t>>;
} break;
case DATE: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, int64_t, operation::ListUnique<date_t>>;
} break;
case TIMESTAMP: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, int64_t, operation::ListUnique<timestamp_t>>;
} break;
case INTERVAL: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, int64_t, operation::ListUnique<interval_t>>;
} break;
default: {
throw common::NotImplementedException("ListUniqueVectorOperation::bindFunc");
}
}
return std::make_unique<FunctionBindData>(DataType(INT64));
}

std::vector<std::unique_ptr<VectorOperationDefinition>>
ListAnyValueVectorOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> result;
result.push_back(std::make_unique<VectorOperationDefinition>(LIST_ANY_VALUE_FUNC_NAME,
std::vector<DataTypeID>{VAR_LIST}, ANY, nullptr, nullptr, bindFunc,
false /* isVarlength*/));
return result;
}

std::unique_ptr<FunctionBindData> ListAnyValueVectorOperation::bindFunc(
const binder::expression_vector& arguments, FunctionDefinition* definition) {
auto vectorOperationDefinition = reinterpret_cast<VectorOperationDefinition*>(definition);
auto resultType = *arguments[0]->getDataType().getChildType();
switch (resultType.typeID) {
case INT16: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, int16_t, operation::ListAnyValue>;
} break;
case INT32: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, int32_t, operation::ListAnyValue>;
} break;
case INT64: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, int64_t, operation::ListAnyValue>;
} break;
case FLOAT: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, float_t, operation::ListAnyValue>;
} break;
case DOUBLE: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, double_t, operation::ListAnyValue>;
} break;
case BOOL: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, uint8_t, operation::ListAnyValue>;
} break;
case STRING: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, ku_string_t, operation::ListAnyValue>;
} break;
case DATE: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, date_t, operation::ListAnyValue>;
} break;
case TIMESTAMP: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, timestamp_t, operation::ListAnyValue>;
} break;
case INTERVAL: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, interval_t, operation::ListAnyValue>;
} break;
case VAR_LIST: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, list_entry_t, operation::ListAnyValue>;
} break;
default: {
throw common::NotImplementedException("ListAnyValueVectorOperation::bindFunc");
}
}
return std::make_unique<FunctionBindData>(resultType);
}

} // namespace function
} // namespace kuzu
4 changes: 4 additions & 0 deletions src/include/common/expression_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ const std::string LIST_SLICE_FUNC_NAME = "LIST_SLICE";
const std::string ARRAY_SLICE_FUNC_NAME = "ARRAY_SLICE";
const std::string LIST_SUM_FUNC_NAME = "LIST_SUM";
const std::string LIST_SORT_FUNC_NAME = "LIST_SORT";
const std::string LIST_REVERSE_SORT_FUNC_NAME = "LIST_REVERSE_SORT";
const std::string LIST_DISTINCT_FUNC_NAME = "LIST_DISTINCT";
const std::string LIST_UNIQUE_FUNC_NAME = "LIST_UNIQUE";
const std::string LIST_ANY_VALUE_FUNC_NAME = "LIST_ANY_VALUE";

// struct
const std::string STRUCT_PACK_FUNC_NAME = "STRUCT_PACK";
Expand Down
Loading

0 comments on commit b4397a4

Please sign in to comment.