Skip to content

Commit

Permalink
Add list_sum and list_sort functions
Browse files Browse the repository at this point in the history
  • Loading branch information
gaurav8297 committed May 11, 2023
1 parent 4f02b2d commit 8a3bb51
Show file tree
Hide file tree
Showing 8 changed files with 360 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/function/built_in_vector_operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,8 @@ void BuiltInVectorOperations::registerListOperations() {
vectorOperations.insert({ARRAY_HAS_FUNC_NAME, ListContainsVectorOperation::getDefinitions()});
vectorOperations.insert({LIST_SLICE_FUNC_NAME, ListSliceVectorOperation::getDefinitions()});
vectorOperations.insert({ARRAY_SLICE_FUNC_NAME, ListSliceVectorOperation::getDefinitions()});
vectorOperations.insert({LIST_SORT_FUNC_NAME, ListSortVectorOperation::getDefinitions()});
vectorOperations.insert({LIST_SUM_FUNC_NAME, ListSumVectorOperation::getDefinitions()});
}

void BuiltInVectorOperations::registerInternalIDOperation() {
Expand Down
161 changes: 161 additions & 0 deletions src/function/vector_list_operation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include "function/list/operations/list_position_operation.h"
#include "function/list/operations/list_prepend_operation.h"
#include "function/list/operations/list_slice_operation.h"
#include "function/list/operations/list_sort_operation.h"
#include "function/list/operations/list_sum_operation.h"
#include "function/list/vector_list_operations.h"

using namespace kuzu::common;
Expand Down Expand Up @@ -320,5 +322,164 @@ std::unique_ptr<FunctionBindData> ListSliceVectorOperation::bindFunc(
return std::make_unique<FunctionBindData>(arguments[0]->getDataType());
}

std::vector<std::unique_ptr<VectorOperationDefinition>> ListSortVectorOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> result;
result.push_back(std::make_unique<VectorOperationDefinition>(LIST_SORT_FUNC_NAME,
std::vector<DataTypeID>{VAR_LIST}, VAR_LIST, nullptr, nullptr, unaryBindFunc,
false /* isVarlength*/));
result.push_back(std::make_unique<VectorOperationDefinition>(LIST_SORT_FUNC_NAME,
std::vector<DataTypeID>{VAR_LIST, STRING}, VAR_LIST, nullptr, nullptr, binaryBindFunc,
false /* isVarlength*/));
result.push_back(std::make_unique<VectorOperationDefinition>(LIST_SORT_FUNC_NAME,
std::vector<DataTypeID>{VAR_LIST, STRING, STRING}, VAR_LIST, nullptr, nullptr,
ternaryBindFunc, false /* isVarlength*/));
return result;
}

std::unique_ptr<FunctionBindData> ListSortVectorOperation::unaryBindFunc(
const binder::expression_vector& arguments, FunctionDefinition* definition) {
auto vectorOperationDefinition = reinterpret_cast<VectorOperationDefinition*>(definition);
switch (arguments[0]->dataType.getChildType()->getTypeID()) {
case INT64: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, list_entry_t, operation::ListSort<int64_t>>;
} break;
case DOUBLE: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, list_entry_t, operation::ListSort<double_t>>;
} break;
case BOOL: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, list_entry_t, operation::ListSort<uint8_t>>;
} break;
case STRING: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, list_entry_t, operation::ListSort<ku_string_t>>;
} break;
case DATE: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, list_entry_t, operation::ListSort<date_t>>;
} break;
case TIMESTAMP: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, list_entry_t, operation::ListSort<timestamp_t>>;
} break;
case INTERVAL: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, list_entry_t, operation::ListSort<interval_t>>;
} break;
default: {
throw common::NotImplementedException("ListSortVectorOperation::unaryBindFunc");
}
}
return std::make_unique<FunctionBindData>(arguments[0]->getDataType());
}

std::unique_ptr<FunctionBindData> ListSortVectorOperation::binaryBindFunc(
const binder::expression_vector& arguments, FunctionDefinition* definition) {
auto vectorOperationDefinition = reinterpret_cast<VectorOperationDefinition*>(definition);
switch (arguments[0]->dataType.getChildType()->getTypeID()) {
case INT64: {
vectorOperationDefinition->execFunc = BinaryListExecFunction<list_entry_t, ku_string_t,
list_entry_t, operation::ListSort<int64_t>>;
} break;
case DOUBLE: {
vectorOperationDefinition->execFunc = BinaryListExecFunction<list_entry_t, ku_string_t,
list_entry_t, operation::ListSort<double_t>>;
} break;
case BOOL: {
vectorOperationDefinition->execFunc = BinaryListExecFunction<list_entry_t, ku_string_t,
list_entry_t, operation::ListSort<uint8_t>>;
} break;
case STRING: {
vectorOperationDefinition->execFunc = BinaryListExecFunction<list_entry_t, ku_string_t,
list_entry_t, operation::ListSort<ku_string_t>>;
} break;
case DATE: {
vectorOperationDefinition->execFunc = BinaryListExecFunction<list_entry_t, ku_string_t,
list_entry_t, operation::ListSort<date_t>>;
} break;
case TIMESTAMP: {
vectorOperationDefinition->execFunc = BinaryListExecFunction<list_entry_t, ku_string_t,
list_entry_t, operation::ListSort<timestamp_t>>;
} break;
case INTERVAL: {
vectorOperationDefinition->execFunc = BinaryListExecFunction<list_entry_t, ku_string_t,
list_entry_t, operation::ListSort<interval_t>>;
} break;
default: {
throw common::NotImplementedException("ListSortVectorOperation::binaryBindFunc");
}
}
return std::make_unique<FunctionBindData>(arguments[0]->getDataType());
}

std::unique_ptr<FunctionBindData> ListSortVectorOperation::ternaryBindFunc(
const binder::expression_vector& arguments, FunctionDefinition* definition) {
auto vectorOperationDefinition = reinterpret_cast<VectorOperationDefinition*>(definition);
switch (arguments[0]->dataType.getChildType()->getTypeID()) {
case INT64: {
vectorOperationDefinition->execFunc = TernaryListExecFunction<list_entry_t, ku_string_t,
ku_string_t, list_entry_t, operation::ListSort<int64_t>>;
} break;
case DOUBLE: {
vectorOperationDefinition->execFunc = TernaryListExecFunction<list_entry_t, ku_string_t,
ku_string_t, list_entry_t, operation::ListSort<double_t>>;
} break;
case BOOL: {
vectorOperationDefinition->execFunc = TernaryListExecFunction<list_entry_t, ku_string_t,
ku_string_t, list_entry_t, operation::ListSort<uint8_t>>;
} break;
case STRING: {
vectorOperationDefinition->execFunc = TernaryListExecFunction<list_entry_t, ku_string_t,
ku_string_t, list_entry_t, operation::ListSort<ku_string_t>>;
} break;
case DATE: {
vectorOperationDefinition->execFunc = TernaryListExecFunction<list_entry_t, ku_string_t,
ku_string_t, list_entry_t, operation::ListSort<date_t>>;
} break;
case TIMESTAMP: {
vectorOperationDefinition->execFunc = TernaryListExecFunction<list_entry_t, ku_string_t,
ku_string_t, list_entry_t, operation::ListSort<timestamp_t>>;
} break;
case INTERVAL: {
vectorOperationDefinition->execFunc = TernaryListExecFunction<list_entry_t, ku_string_t,
ku_string_t, list_entry_t, operation::ListSort<interval_t>>;
} break;
default: {
throw common::NotImplementedException("ListSortVectorOperation::ternaryBindFunc");
}
}
return std::make_unique<FunctionBindData>(arguments[0]->getDataType());
}

std::vector<std::unique_ptr<VectorOperationDefinition>> ListSumVectorOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> result;
result.push_back(std::make_unique<VectorOperationDefinition>(LIST_SUM_FUNC_NAME,
std::vector<DataTypeID>{VAR_LIST}, INT64, nullptr, nullptr, bindFunc,
false /* isVarlength*/));
return result;
}

std::unique_ptr<FunctionBindData> ListSumVectorOperation::bindFunc(
const binder::expression_vector& arguments, FunctionDefinition* definition) {
auto vectorOperationDefinition = reinterpret_cast<VectorOperationDefinition*>(definition);
auto resultType = *arguments[0]->getDataType().getChildType();
switch (resultType.getTypeID()) {
case INT64: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, int64_t, operation::ListSum<int64_t>>;
} break;
case DOUBLE: {
vectorOperationDefinition->execFunc =
UnaryListExecFunction<list_entry_t, double_t, operation::ListSum<double_t>>;
} break;
default: {
throw common::NotImplementedException("ListSumVectorOperation::bindFunc");
}
}
return std::make_unique<FunctionBindData>(resultType);
}

} // namespace function
} // namespace kuzu
2 changes: 2 additions & 0 deletions src/include/common/expression_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ const std::string ARRAY_CONTAINS_FUNC_NAME = "ARRAY_CONTAINS";
const std::string ARRAY_HAS_FUNC_NAME = "ARRAY_HAS";
const std::string LIST_SLICE_FUNC_NAME = "LIST_SLICE";
const std::string ARRAY_SLICE_FUNC_NAME = "ARRAY_SLICE";
const std::string LIST_SUM_FUNC_NAME = "LIST_SUM";
const std::string LIST_SORT_FUNC_NAME = "LIST_SORT";

// struct
const std::string STRUCT_PACK_FUNC_NAME = "STRUCT_PACK";
Expand Down
120 changes: 120 additions & 0 deletions src/include/function/list/operations/list_sort_operation.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
#pragma once

#include "common/vector/value_vector.h"

namespace kuzu {
namespace function {
namespace operation {

template<typename T>
struct ListSort {
static inline void operation(common::list_entry_t& input, common::list_entry_t& result,
common::ValueVector& inputVector, common::ValueVector& resultVector) {
sortValues(input, result, inputVector, resultVector, true, true);
}

static inline void operation(common::list_entry_t& input, common::ku_string_t& sortOrder,
common::list_entry_t& result, common::ValueVector& inputVector,
common::ValueVector& valueVector, common::ValueVector& resultVector) {
sortValues(
input, result, inputVector, resultVector, isAscOrder(sortOrder.getAsString()), true);
}

static inline void operation(common::list_entry_t& input, common::ku_string_t& sortOrder,
common::ku_string_t& nullOrder, common::list_entry_t& result,
common::ValueVector& inputVector, common::ValueVector& resultVector) {
sortValues(input, result, inputVector, resultVector, isAscOrder(sortOrder.getAsString()),
isNullFirst(nullOrder.getAsString()));
}

static inline bool isAscOrder(const std::string& sortOrder) {
if (sortOrder == "ASC") {
return true;
} else if (sortOrder == "DESC") {
return false;
} else {
throw common::RuntimeException("Invalid sortOrder");
}
}

static inline bool isNullFirst(const std::string& nullOrder) {
if (nullOrder == "NULLS FIRST") {
return true;
} else if (nullOrder == "NULLS LAST") {
return false;
} else {
throw common::RuntimeException("Invalid nullOrder");
}
}

static void sortValues(common::list_entry_t& input, common::list_entry_t& result,
common::ValueVector& inputVector, common::ValueVector& resultVector, bool ascOrder,
bool nullFirst) {
auto inputValues = common::ListVector::getListValues(&inputVector, input);
auto inputDataVector = common::ListVector::getDataVector(&inputVector);

// Calculate null count
uint64_t nullCount = 0;
for (auto i = 0; i < input.size; i++) {
if (inputDataVector->isNull(input.offset + i)) {
nullCount += 1;
}
}

result = common::ListVector::addList(&resultVector, input.size);
auto resultValues = common::ListVector::getListValues(&resultVector, result);
auto resultDataVector = common::ListVector::getDataVector(&resultVector);
auto numBytesPerValue = resultDataVector->getNumBytesPerValue();

// Add nulls first
if (nullFirst) {
addNulls(resultValues, resultDataVector, result.offset, numBytesPerValue, nullCount);
}

// Add actual data
for (auto i = 0; i < input.size; i++) {
if (inputDataVector->isNull(input.offset + i)) {
inputValues += numBytesPerValue;
continue;
}
common::ValueVectorUtils::copyValue(
resultValues, *resultDataVector, inputValues, *inputDataVector);
resultValues += numBytesPerValue;
inputValues += numBytesPerValue;
}

// Determine the starting and ending position of the data to be sorted
uint64_t sortStart = nullCount;
uint64_t sortEnd = input.size;
if (!nullFirst) {
sortStart = 0;
sortEnd = input.size - nullCount;
}

// Sort the data based on order
auto newResultValues =
reinterpret_cast<T*>(common::ListVector::getListValues(&resultVector, result));
if (ascOrder) {
std::sort(newResultValues + sortStart, newResultValues + sortEnd, std::less{});
} else {
std::sort(newResultValues + sortStart, newResultValues + sortEnd, std::greater{});
}

// Add nulls in the end
if (!nullFirst) {
addNulls(resultValues, resultDataVector, result.offset, numBytesPerValue, nullCount);
}
}

static inline void addNulls(const uint8_t* values, common::ValueVector* valueVector,
uint64_t offset, uint32_t numBytesPerValue, uint64_t nullCount) {
for (auto i = 0; i < nullCount; i++) {
valueVector->setNull(offset + i, true);
values += numBytesPerValue;
}
}
};

} // namespace operation
} // namespace function
} // namespace kuzu
27 changes: 27 additions & 0 deletions src/include/function/list/operations/list_sum_operation.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#pragma once

#include "common/vector/value_vector.h"

namespace kuzu {
namespace function {
namespace operation {

template<typename T>
struct ListSum {
static inline void operation(common::list_entry_t& input, T& result,
common::ValueVector& inputVector, common::ValueVector& resultVector) {
auto inputValues =
reinterpret_cast<T*>(common::ListVector::getListValues(&inputVector, input));
auto inputDataVector = common::ListVector::getDataVector(&inputVector);
for (auto i = 0; i < input.size; i++) {
if (inputDataVector->isNull(input.offset + i)) {
continue;
}
result += inputValues[i];
}
}
};

} // namespace operation
} // namespace function
} // namespace kuzu
24 changes: 24 additions & 0 deletions src/include/function/list/vector_list_operations.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,14 @@ struct VectorListOperations : public VectorOperations {
}
return result;
}

template<typename OPERAND_TYPE, typename RESULT_TYPE, typename FUNC>
static void UnaryListExecFunction(
const std::vector<std::shared_ptr<common::ValueVector>>& params,
common::ValueVector& result) {
assert(params.size() == 1);
UnaryOperationExecutor::executeList<OPERAND_TYPE, RESULT_TYPE, FUNC>(*params[0], result);
}
};

struct ListCreationVectorOperation : public VectorListOperations {
Expand Down Expand Up @@ -128,5 +136,21 @@ struct ListSliceVectorOperation : public VectorListOperations {
const binder::expression_vector& arguments, FunctionDefinition* definition);
};

struct ListSortVectorOperation : public VectorListOperations {
static std::vector<std::unique_ptr<VectorOperationDefinition>> getDefinitions();
static std::unique_ptr<FunctionBindData> unaryBindFunc(
const binder::expression_vector& arguments, FunctionDefinition* definition);
static std::unique_ptr<FunctionBindData> binaryBindFunc(
const binder::expression_vector& arguments, FunctionDefinition* definition);
static std::unique_ptr<FunctionBindData> ternaryBindFunc(
const binder::expression_vector& arguments, FunctionDefinition* definition);
};

struct ListSumVectorOperation : public VectorListOperations {
static std::vector<std::unique_ptr<VectorOperationDefinition>> getDefinitions();
static std::unique_ptr<FunctionBindData> bindFunc(
const binder::expression_vector& arguments, FunctionDefinition* definition);
};

} // namespace function
} // namespace kuzu
Loading

0 comments on commit 8a3bb51

Please sign in to comment.