Skip to content

Commit

Permalink
Add list functions
Browse files Browse the repository at this point in the history
1. list_reverse_sort(list)
Sorts the elements of the list in reverse order.

2. list_distinct(list)
Removes all duplicates and NULLs from a list.
Does not preserve the original order.

3. list_unique(list)
Counts the unique elements of a list.

4. list_any_value(list)
Returns the first non-null value in the list.
  • Loading branch information
gaurav8297 committed May 15, 2023
1 parent 4f75391 commit c9ddb83
Show file tree
Hide file tree
Showing 11 changed files with 653 additions and 95 deletions.
7 changes: 7 additions & 0 deletions src/function/built_in_vector_operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,14 @@ void BuiltInVectorOperations::registerListOperations() {
vectorOperations.insert({LIST_SLICE_FUNC_NAME, ListSliceVectorOperation::getDefinitions()});
vectorOperations.insert({ARRAY_SLICE_FUNC_NAME, ListSliceVectorOperation::getDefinitions()});
vectorOperations.insert({LIST_SORT_FUNC_NAME, ListSortVectorOperation::getDefinitions()});
vectorOperations.insert(
{LIST_REVERSE_SORT_FUNC_NAME, ListReverseSortVectorOperation::getDefinitions()});
vectorOperations.insert({LIST_SUM_FUNC_NAME, ListSumVectorOperation::getDefinitions()});
vectorOperations.insert(
{LIST_DISTINCT_FUNC_NAME, ListDistinctVectorOperation::getDefinitions()});
vectorOperations.insert({LIST_UNIQUE_FUNC_NAME, ListUniqueVectorOperation::getDefinitions()});
vectorOperations.insert(
{LIST_ANY_VALUE_FUNC_NAME, ListAnyValueVectorOperation::getDefinitions()});
}

void BuiltInVectorOperations::registerInternalIDOperation() {
Expand Down
210 changes: 210 additions & 0 deletions src/function/vector_list_operation.cpp
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
#include "binder/expression_binder.h"
#include "common/types/ku_list.h"
#include "common/vector/value_vector_utils.h"
#include "function/list/operations/list_any_value_operation.h"
#include "function/list/operations/list_append_operation.h"
#include "function/list/operations/list_concat_operation.h"
#include "function/list/operations/list_contains.h"
#include "function/list/operations/list_distinct_operation.h"
#include "function/list/operations/list_extract_operation.h"
#include "function/list/operations/list_len_operation.h"
#include "function/list/operations/list_position_operation.h"
#include "function/list/operations/list_prepend_operation.h"
#include "function/list/operations/list_reverse_sort_operation.h"
#include "function/list/operations/list_slice_operation.h"
#include "function/list/operations/list_sort_operation.h"
#include "function/list/operations/list_sum_operation.h"
#include "function/list/operations/list_unique_operation.h"
#include "function/list/vector_list_operations.h"

using namespace kuzu::common;
Expand Down Expand Up @@ -384,6 +388,63 @@ scalar_exec_func ListSortVectorOperation::getExecFunction(
}
}

std::vector<std::unique_ptr<VectorOperationDefinition>>
ListReverseSortVectorOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> result;
result.push_back(std::make_unique<VectorOperationDefinition>(LIST_REVERSE_SORT_FUNC_NAME,
std::vector<DataTypeID>{VAR_LIST}, VAR_LIST, nullptr, nullptr, bindFunc,
false /* isVarlength*/));
result.push_back(std::make_unique<VectorOperationDefinition>(LIST_REVERSE_SORT_FUNC_NAME,
std::vector<DataTypeID>{VAR_LIST, STRING}, VAR_LIST, nullptr, nullptr, bindFunc,
false /* isVarlength*/));
return result;
}

std::unique_ptr<FunctionBindData> ListReverseSortVectorOperation::bindFunc(
const binder::expression_vector& arguments, FunctionDefinition* definition) {
auto vectorOperationDefinition = reinterpret_cast<VectorOperationDefinition*>(definition);
switch (arguments[0]->dataType.getChildType()->getTypeID()) {
case INT64: {
vectorOperationDefinition->execFunc = getExecFunction<int64_t>(arguments);
} break;
case DOUBLE: {
vectorOperationDefinition->execFunc = getExecFunction<double_t>(arguments);
} break;
case BOOL: {
vectorOperationDefinition->execFunc = getExecFunction<uint8_t>(arguments);
} break;
case STRING: {
vectorOperationDefinition->execFunc = getExecFunction<ku_string_t>(arguments);
} break;
case DATE: {
vectorOperationDefinition->execFunc = getExecFunction<date_t>(arguments);
} break;
case TIMESTAMP: {
vectorOperationDefinition->execFunc = getExecFunction<timestamp_t>(arguments);
} break;
case INTERVAL: {
vectorOperationDefinition->execFunc = getExecFunction<interval_t>(arguments);
} break;
default: {
throw common::NotImplementedException("ListReverseSortVectorOperation::bindFunc");
}
}
return std::make_unique<FunctionBindData>(arguments[0]->getDataType());
}

template<typename T>
scalar_exec_func ListReverseSortVectorOperation::getExecFunction(
const binder::expression_vector& arguments) {
if (arguments.size() == 1) {
return UnaryListExecFunction<list_entry_t, list_entry_t, operation::ListReverseSort<T>>;
} else if (arguments.size() == 2) {
return BinaryListExecFunction<list_entry_t, ku_string_t, list_entry_t,
operation::ListReverseSort<T>>;
} else {
throw common::RuntimeException("Invalid number of arguments");
}
}

std::vector<std::unique_ptr<VectorOperationDefinition>> ListSumVectorOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> result;
result.push_back(std::make_unique<VectorOperationDefinition>(LIST_SUM_FUNC_NAME,
Expand Down Expand Up @@ -412,5 +473,154 @@ std::unique_ptr<FunctionBindData> ListSumVectorOperation::bindFunc(
return std::make_unique<FunctionBindData>(resultType);
}

std::vector<std::unique_ptr<VectorOperationDefinition>>
ListDistinctVectorOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> result;
result.push_back(std::make_unique<VectorOperationDefinition>(LIST_DISTINCT_FUNC_NAME,
std::vector<DataTypeID>{VAR_LIST}, VAR_LIST, nullptr, nullptr, bindFunc,
false /* isVarlength*/));
return result;
}

std::unique_ptr<FunctionBindData> ListDistinctVectorOperation::bindFunc(
const binder::expression_vector& arguments, FunctionDefinition* definition) {
auto vectorOperationDefinition = reinterpret_cast<VectorOperationDefinition*>(definition);
switch (arguments[0]->dataType.getChildType()->getTypeID()) {
case INT64: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, list_entry_t, operation::ListDistinct<int64_t>>;
} break;
case DOUBLE: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, list_entry_t, operation::ListDistinct<double_t>>;
} break;
case BOOL: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, list_entry_t, operation::ListDistinct<uint8_t>>;
} break;
case STRING: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, list_entry_t, operation::ListDistinct<ku_string_t>>;
} break;
case DATE: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, list_entry_t, operation::ListDistinct<date_t>>;
} break;
case TIMESTAMP: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, list_entry_t, operation::ListDistinct<timestamp_t>>;
} break;
case INTERVAL: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, list_entry_t, operation::ListDistinct<interval_t>>;
} break;
default: {
throw common::NotImplementedException("ListDistinctVectorOperation::bindFunc");
}
}
return std::make_unique<FunctionBindData>(arguments[0]->getDataType());
}

std::vector<std::unique_ptr<VectorOperationDefinition>>
ListUniqueVectorOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> result;
result.push_back(std::make_unique<VectorOperationDefinition>(LIST_UNIQUE_FUNC_NAME,
std::vector<DataTypeID>{VAR_LIST}, INT64, nullptr, nullptr, bindFunc,
false /* isVarlength*/));
return result;
}

std::unique_ptr<FunctionBindData> ListUniqueVectorOperation::bindFunc(
const binder::expression_vector& arguments, FunctionDefinition* definition) {
auto vectorOperationDefinition = reinterpret_cast<VectorOperationDefinition*>(definition);
switch (arguments[0]->dataType.getChildType()->getTypeID()) {
case INT64: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, int64_t, operation::ListUnique<int64_t>>;
} break;
case DOUBLE: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, int64_t, operation::ListUnique<double_t>>;
} break;
case BOOL: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, int64_t, operation::ListUnique<uint8_t>>;
} break;
case STRING: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, int64_t, operation::ListUnique<ku_string_t>>;
} break;
case DATE: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, int64_t, operation::ListUnique<date_t>>;
} break;
case TIMESTAMP: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, int64_t, operation::ListUnique<timestamp_t>>;
} break;
case INTERVAL: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, int64_t, operation::ListUnique<interval_t>>;
} break;
default: {
throw common::NotImplementedException("ListUniqueVectorOperation::bindFunc");
}
}
return std::make_unique<FunctionBindData>(DataType(INT64));
}

std::vector<std::unique_ptr<VectorOperationDefinition>>
ListAnyValueVectorOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> result;
result.push_back(std::make_unique<VectorOperationDefinition>(LIST_ANY_VALUE_FUNC_NAME,
std::vector<DataTypeID>{VAR_LIST}, ANY, nullptr, nullptr, bindFunc,
false /* isVarlength*/));
return result;
}

std::unique_ptr<FunctionBindData> ListAnyValueVectorOperation::bindFunc(
const binder::expression_vector& arguments, FunctionDefinition* definition) {
auto vectorOperationDefinition = reinterpret_cast<VectorOperationDefinition*>(definition);
auto resultType = *arguments[0]->getDataType().getChildType();
switch (resultType.typeID) {
case INT64: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, int64_t, operation::ListAnyValue>;
} break;
case DOUBLE: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, double_t, operation::ListAnyValue>;
} break;
case BOOL: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, uint8_t, operation::ListAnyValue>;
} break;
case STRING: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, ku_string_t, operation::ListAnyValue>;
} break;
case DATE: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, date_t, operation::ListAnyValue>;
} break;
case TIMESTAMP: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, timestamp_t, operation::ListAnyValue>;
} break;
case INTERVAL: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, interval_t, operation::ListAnyValue>;
} break;
case VAR_LIST: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, list_entry_t, operation::ListAnyValue>;
} break;
default: {
throw common::NotImplementedException("ListAnyValueVectorOperation::bindFunc");
}
}
return std::make_unique<FunctionBindData>(resultType);
}

} // namespace function
} // namespace kuzu
4 changes: 4 additions & 0 deletions src/include/common/expression_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ const std::string LIST_SLICE_FUNC_NAME = "LIST_SLICE";
const std::string ARRAY_SLICE_FUNC_NAME = "ARRAY_SLICE";
const std::string LIST_SUM_FUNC_NAME = "LIST_SUM";
const std::string LIST_SORT_FUNC_NAME = "LIST_SORT";
const std::string LIST_REVERSE_SORT_FUNC_NAME = "LIST_REVERSE_SORT";
const std::string LIST_DISTINCT_FUNC_NAME = "LIST_DISTINCT";
const std::string LIST_UNIQUE_FUNC_NAME = "LIST_UNIQUE";
const std::string LIST_ANY_VALUE_FUNC_NAME = "LIST_ANY_VALUE";

// struct
const std::string STRUCT_PACK_FUNC_NAME = "STRUCT_PACK";
Expand Down
106 changes: 106 additions & 0 deletions src/include/function/list/operations/base_list_sort_operation.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
#pragma once

#include "common/vector/value_vector.h"

namespace kuzu {
namespace function {
namespace operation {

struct BaseListSortOperation {
public:
static inline bool isAscOrder(const std::string& sortOrder) {
if (sortOrder == "ASC") {
return true;
} else if (sortOrder == "DESC") {
return false;
} else {
throw common::RuntimeException("Invalid sortOrder");
}
}

static inline bool isNullFirst(const std::string& nullOrder) {
if (nullOrder == "NULLS FIRST") {
return true;
} else if (nullOrder == "NULLS LAST") {
return false;
} else {
throw common::RuntimeException("Invalid nullOrder");
}
}

template<typename T>
static void sortValues(common::list_entry_t& input, common::list_entry_t& result,
common::ValueVector& inputVector, common::ValueVector& resultVector, bool ascOrder,
bool nullFirst) {
// TODO(Ziyi) - Replace this sort implementation with radix_sort implementation:
// https://github.com/kuzudb/kuzu/issues/1536.
auto inputValues = common::ListVector::getListValues(&inputVector, input);
auto inputDataVector = common::ListVector::getDataVector(&inputVector);

// Calculate null count.
auto nullCount = 0;
for (auto i = 0; i < input.size; i++) {
if (inputDataVector->isNull(input.offset + i)) {
nullCount += 1;
}
}

result = common::ListVector::addList(&resultVector, input.size);
auto resultValues = common::ListVector::getListValues(&resultVector, result);
auto resultDataVector = common::ListVector::getDataVector(&resultVector);
auto numBytesPerValue = resultDataVector->getNumBytesPerValue();

// Add nulls first.
if (nullFirst) {
setVectorRangeToNull(*resultDataVector, result.offset, 0, nullCount);
resultValues += numBytesPerValue * nullCount;
}

// Add actual data.
for (auto i = 0; i < input.size; i++) {
if (inputDataVector->isNull(input.offset + i)) {
inputValues += numBytesPerValue;
continue;
}
common::ValueVectorUtils::copyValue(
resultValues, *resultDataVector, inputValues, *inputDataVector);
resultValues += numBytesPerValue;
inputValues += numBytesPerValue;
}

// Add nulls in the end.
if (!nullFirst) {
setVectorRangeToNull(
*resultDataVector, result.offset, input.size - nullCount, input.size);
resultValues += numBytesPerValue * nullCount;
}

// Determine the starting and ending position of the data to be sorted.
auto sortStart = nullCount;
auto sortEnd = input.size;
if (!nullFirst) {
sortStart = 0;
sortEnd = input.size - nullCount;
}

// Sort the data based on order.
auto sortingValues =
reinterpret_cast<T*>(common::ListVector::getListValues(&resultVector, result));
if (ascOrder) {
std::sort(sortingValues + sortStart, sortingValues + sortEnd, std::less{});
} else {
std::sort(sortingValues + sortStart, sortingValues + sortEnd, std::greater{});
}
}

static void setVectorRangeToNull(
common::ValueVector& vector, uint64_t offset, uint64_t startPos, uint64_t endPos) {
for (auto i = startPos; i < endPos; i++) {
vector.setNull(offset + i, true);
}
}
};

} // namespace operation
} // namespace function
} // namespace kuzu
Loading

0 comments on commit c9ddb83

Please sign in to comment.