Skip to content

Commit

Permalink
Add list_sum and list_sort functions
Browse files Browse the repository at this point in the history
  • Loading branch information
gaurav8297 committed May 12, 2023
1 parent 4f02b2d commit 31d9516
Show file tree
Hide file tree
Showing 8 changed files with 394 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/function/built_in_vector_operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,8 @@ void BuiltInVectorOperations::registerListOperations() {
vectorOperations.insert({ARRAY_HAS_FUNC_NAME, ListContainsVectorOperation::getDefinitions()});
vectorOperations.insert({LIST_SLICE_FUNC_NAME, ListSliceVectorOperation::getDefinitions()});
vectorOperations.insert({ARRAY_SLICE_FUNC_NAME, ListSliceVectorOperation::getDefinitions()});
vectorOperations.insert({LIST_SORT_FUNC_NAME, ListSortVectorOperation::getDefinitions()});
vectorOperations.insert({LIST_SUM_FUNC_NAME, ListSumVectorOperation::getDefinitions()});
}

void BuiltInVectorOperations::registerInternalIDOperation() {
Expand Down
92 changes: 92 additions & 0 deletions src/function/vector_list_operation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include "function/list/operations/list_position_operation.h"
#include "function/list/operations/list_prepend_operation.h"
#include "function/list/operations/list_slice_operation.h"
#include "function/list/operations/list_sort_operation.h"
#include "function/list/operations/list_sum_operation.h"
#include "function/list/vector_list_operations.h"

using namespace kuzu::common;
Expand Down Expand Up @@ -320,5 +322,95 @@ std::unique_ptr<FunctionBindData> ListSliceVectorOperation::bindFunc(
return std::make_unique<FunctionBindData>(arguments[0]->getDataType());
}

std::vector<std::unique_ptr<VectorOperationDefinition>> ListSortVectorOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> result;
result.push_back(std::make_unique<VectorOperationDefinition>(LIST_SORT_FUNC_NAME,
std::vector<DataTypeID>{VAR_LIST}, VAR_LIST, nullptr, nullptr, bindFunc,
false /* isVarlength*/));
result.push_back(std::make_unique<VectorOperationDefinition>(LIST_SORT_FUNC_NAME,
std::vector<DataTypeID>{VAR_LIST, STRING}, VAR_LIST, nullptr, nullptr, bindFunc,
false /* isVarlength*/));
result.push_back(std::make_unique<VectorOperationDefinition>(LIST_SORT_FUNC_NAME,
std::vector<DataTypeID>{VAR_LIST, STRING, STRING}, VAR_LIST, nullptr, nullptr, bindFunc,
false /* isVarlength*/));
return result;
}

std::unique_ptr<FunctionBindData> ListSortVectorOperation::bindFunc(
const binder::expression_vector& arguments, FunctionDefinition* definition) {
auto vectorOperationDefinition = reinterpret_cast<VectorOperationDefinition*>(definition);
switch (arguments[0]->dataType.getChildType()->getTypeID()) {
case INT64: {
vectorOperationDefinition->execFunc = getExecFunction<int64_t>(arguments);
} break;
case DOUBLE: {
vectorOperationDefinition->execFunc = getExecFunction<double_t>(arguments);
} break;
case BOOL: {
vectorOperationDefinition->execFunc = getExecFunction<uint8_t>(arguments);
} break;
case STRING: {
vectorOperationDefinition->execFunc = getExecFunction<ku_string_t>(arguments);
} break;
case DATE: {
vectorOperationDefinition->execFunc = getExecFunction<date_t>(arguments);
} break;
case TIMESTAMP: {
vectorOperationDefinition->execFunc = getExecFunction<timestamp_t>(arguments);
} break;
case INTERVAL: {
vectorOperationDefinition->execFunc = getExecFunction<interval_t>(arguments);
} break;
default: {
throw common::NotImplementedException("ListSortVectorOperation::bindFunc");
}
}
return std::make_unique<FunctionBindData>(arguments[0]->getDataType());
}

template<typename T>
scalar_exec_func ListSortVectorOperation::getExecFunction(
const binder::expression_vector& arguments) {
if (arguments.size() == 1) {
return UnaryListExecFunction<list_entry_t, list_entry_t, operation::ListSort<T>>;
} else if (arguments.size() == 2) {
return BinaryListExecFunction<list_entry_t, ku_string_t, list_entry_t,
operation::ListSort<T>>;
} else if (arguments.size() == 3) {
return TernaryListExecFunction<list_entry_t, ku_string_t, ku_string_t, list_entry_t,
operation::ListSort<T>>;
} else {
throw common::RuntimeException("Invalid number of arguments");
}
}

std::vector<std::unique_ptr<VectorOperationDefinition>> ListSumVectorOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> result;
result.push_back(std::make_unique<VectorOperationDefinition>(LIST_SUM_FUNC_NAME,
std::vector<DataTypeID>{VAR_LIST}, INT64, nullptr, nullptr, bindFunc,
false /* isVarlength*/));
return result;
}

std::unique_ptr<FunctionBindData> ListSumVectorOperation::bindFunc(
const binder::expression_vector& arguments, FunctionDefinition* definition) {
auto vectorOperationDefinition = reinterpret_cast<VectorOperationDefinition*>(definition);
auto resultType = *arguments[0]->getDataType().getChildType();
switch (resultType.getTypeID()) {
case INT64: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, int64_t, operation::ListSum>;
} break;
case DOUBLE: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, double_t, operation::ListSum>;
} break;
default: {
throw common::NotImplementedException("ListSumVectorOperation::bindFunc");
}
}
return std::make_unique<FunctionBindData>(resultType);
}

} // namespace function
} // namespace kuzu
2 changes: 2 additions & 0 deletions src/include/common/expression_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ const std::string ARRAY_CONTAINS_FUNC_NAME = "ARRAY_CONTAINS";
const std::string ARRAY_HAS_FUNC_NAME = "ARRAY_HAS";
const std::string LIST_SLICE_FUNC_NAME = "LIST_SLICE";
const std::string ARRAY_SLICE_FUNC_NAME = "ARRAY_SLICE";
const std::string LIST_SUM_FUNC_NAME = "LIST_SUM";
const std::string LIST_SORT_FUNC_NAME = "LIST_SORT";

// struct
const std::string STRUCT_PACK_FUNC_NAME = "STRUCT_PACK";
Expand Down
125 changes: 125 additions & 0 deletions src/include/function/list/operations/list_sort_operation.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
#pragma once

#include "common/vector/value_vector.h"

namespace kuzu {
namespace function {
namespace operation {

template<typename T>
struct ListSort {
static inline void operation(common::list_entry_t& input, common::list_entry_t& result,
common::ValueVector& inputVector, common::ValueVector& resultVector) {
sortValues(
input, result, inputVector, resultVector, true /* ascOrder */, true /* nullFirst */);
}

static inline void operation(common::list_entry_t& input, common::ku_string_t& sortOrder,
common::list_entry_t& result, common::ValueVector& inputVector,
common::ValueVector& valueVector, common::ValueVector& resultVector) {
sortValues(input, result, inputVector, resultVector, isAscOrder(sortOrder.getAsString()),
true /* nullFirst */);
}

static inline void operation(common::list_entry_t& input, common::ku_string_t& sortOrder,
common::ku_string_t& nullOrder, common::list_entry_t& result,
common::ValueVector& inputVector, common::ValueVector& resultVector) {
sortValues(input, result, inputVector, resultVector, isAscOrder(sortOrder.getAsString()),
isNullFirst(nullOrder.getAsString()));
}

static inline bool isAscOrder(const std::string& sortOrder) {
if (sortOrder == "ASC") {
return true;
} else if (sortOrder == "DESC") {
return false;
} else {
throw common::RuntimeException("Invalid sortOrder");
}
}

static inline bool isNullFirst(const std::string& nullOrder) {
if (nullOrder == "NULLS FIRST") {
return true;
} else if (nullOrder == "NULLS LAST") {
return false;
} else {
throw common::RuntimeException("Invalid nullOrder");
}
}

static void sortValues(common::list_entry_t& input, common::list_entry_t& result,
common::ValueVector& inputVector, common::ValueVector& resultVector, bool ascOrder,
bool nullFirst) {
// TODO(Ziyi) - Replace this sort implementation with radix_sort implementation:
// https://github.com/kuzudb/kuzu/issues/1536.
auto inputValues = common::ListVector::getListValues(&inputVector, input);
auto inputDataVector = common::ListVector::getDataVector(&inputVector);

// Calculate null count.
auto nullCount = 0;
for (auto i = 0; i < input.size; i++) {
if (inputDataVector->isNull(input.offset + i)) {
nullCount += 1;
}
}

result = common::ListVector::addList(&resultVector, input.size);
auto resultValues = common::ListVector::getListValues(&resultVector, result);
auto resultDataVector = common::ListVector::getDataVector(&resultVector);
auto numBytesPerValue = resultDataVector->getNumBytesPerValue();

// Add nulls first.
if (nullFirst) {
setVectorRangeToNull(*resultDataVector, result.offset, 0, nullCount);
resultValues += numBytesPerValue * nullCount;
}

// Add actual data.
for (auto i = 0; i < input.size; i++) {
if (inputDataVector->isNull(input.offset + i)) {
inputValues += numBytesPerValue;
continue;
}
common::ValueVectorUtils::copyValue(
resultValues, *resultDataVector, inputValues, *inputDataVector);
resultValues += numBytesPerValue;
inputValues += numBytesPerValue;
}

// Add nulls in the end.
if (!nullFirst) {
setVectorRangeToNull(
*resultDataVector, result.offset, input.size - nullCount, input.size);
resultValues += numBytesPerValue * nullCount;
}

// Determine the starting and ending position of the data to be sorted.
uint64_t sortStart = nullCount;
uint64_t sortEnd = input.size;
if (!nullFirst) {
sortStart = 0;
sortEnd = input.size - nullCount;
}

// Sort the data based on order.
auto sortingValues =
reinterpret_cast<T*>(common::ListVector::getListValues(&resultVector, result));
if (ascOrder) {
std::sort(sortingValues + sortStart, sortingValues + sortEnd, std::less{});
} else {
std::sort(sortingValues + sortStart, sortingValues + sortEnd, std::greater{});
}
}

static void setVectorRangeToNull(
common::ValueVector& vector, uint64_t offset, uint64_t startPos, uint64_t endPos) {
for (auto i = startPos; i < endPos; i++) {
vector.setNull(offset + i, true);
}
}
};

} // namespace operation
} // namespace function
} // namespace kuzu
27 changes: 27 additions & 0 deletions src/include/function/list/operations/list_sum_operation.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#pragma once

#include "common/vector/value_vector.h"

namespace kuzu {
namespace function {
namespace operation {

struct ListSum {
template<typename T>
static inline void operation(common::list_entry_t& input, T& result,
common::ValueVector& inputVector, common::ValueVector& resultVector) {
auto inputValues =
reinterpret_cast<T*>(common::ListVector::getListValues(&inputVector, input));
auto inputDataVector = common::ListVector::getDataVector(&inputVector);
for (auto i = 0; i < input.size; i++) {
if (inputDataVector->isNull(input.offset + i)) {
continue;
}
result += inputValues[i];
}
}
};

} // namespace operation
} // namespace function
} // namespace kuzu
22 changes: 22 additions & 0 deletions src/include/function/list/vector_list_operations.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,14 @@ struct VectorListOperations : public VectorOperations {
}
return result;
}

template<typename OPERAND_TYPE, typename RESULT_TYPE, typename FUNC>
static void UnaryListExecFunction(
const std::vector<std::shared_ptr<common::ValueVector>>& params,
common::ValueVector& result) {
assert(params.size() == 1);
UnaryOperationExecutor::executeList<OPERAND_TYPE, RESULT_TYPE, FUNC>(*params[0], result);
}
};

struct ListCreationVectorOperation : public VectorListOperations {
Expand Down Expand Up @@ -128,5 +136,19 @@ struct ListSliceVectorOperation : public VectorListOperations {
const binder::expression_vector& arguments, FunctionDefinition* definition);
};

struct ListSortVectorOperation : public VectorListOperations {
static std::vector<std::unique_ptr<VectorOperationDefinition>> getDefinitions();
static std::unique_ptr<FunctionBindData> bindFunc(
const binder::expression_vector& arguments, FunctionDefinition* definition);
template<typename T>
static scalar_exec_func getExecFunction(const binder::expression_vector& arguments);
};

struct ListSumVectorOperation : public VectorListOperations {
static std::vector<std::unique_ptr<VectorOperationDefinition>> getDefinitions();
static std::unique_ptr<FunctionBindData> bindFunc(
const binder::expression_vector& arguments, FunctionDefinition* definition);
};

} // namespace function
} // namespace kuzu
14 changes: 14 additions & 0 deletions src/include/function/unary_operation_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@ struct UnaryStringOperationWrapper {
}
};

struct UnaryListOperationWrapper {
template<typename OPERAND_TYPE, typename RESULT_TYPE, typename FUNC>
static inline void operation(
OPERAND_TYPE& input, RESULT_TYPE& result, void* leftValueVector, void* resultValueVector) {
FUNC::operation(input, result, *(common::ValueVector*)leftValueVector,
*(common::ValueVector*)resultValueVector);
}
};

struct UnaryCastOperationWrapper {
template<typename OPERAND_TYPE, typename RESULT_TYPE, typename FUNC>
static void operation(
Expand Down Expand Up @@ -106,6 +115,11 @@ struct UnaryOperationExecutor {
operand, result);
}

template<typename OPERAND_TYPE, typename RESULT_TYPE, typename FUNC>
static void executeList(common::ValueVector& operand, common::ValueVector& result) {
executeSwitch<OPERAND_TYPE, RESULT_TYPE, FUNC, UnaryListOperationWrapper>(operand, result);
}

template<typename OPERAND_TYPE, typename RESULT_TYPE, typename FUNC>
static void executeCast(common::ValueVector& operand, common::ValueVector& result) {
executeSwitch<OPERAND_TYPE, RESULT_TYPE, FUNC, UnaryCastOperationWrapper>(operand, result);
Expand Down
Loading

0 comments on commit 31d9516

Please sign in to comment.