Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

range() function #2058

Merged
merged 1 commit into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/common/types/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,11 @@ std::vector<LogicalTypeID> LogicalTypeUtils::getNumericalLogicalTypeIDs() {
LogicalTypeID::SERIAL};
}

std::vector<LogicalTypeID> LogicalTypeUtils::getIntegerLogicalTypeIDs() {
return std::vector<LogicalTypeID>{LogicalTypeID::INT64, LogicalTypeID::INT32,
LogicalTypeID::INT16, LogicalTypeID::INT8, LogicalTypeID::SERIAL};
}

std::vector<LogicalType> LogicalTypeUtils::getAllValidLogicTypes() {
// TODO(Ziyi): Add FIX_LIST,STRUCT,MAP type to allValidTypeID when we support functions on
// FIXED_LIST,STRUCT,MAP.
Expand Down
1 change: 1 addition & 0 deletions src/function/built_in_vector_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,7 @@ void BuiltInVectorFunctions::registerCastFunctions() {

void BuiltInVectorFunctions::registerListFunctions() {
vectorFunctions.insert({LIST_CREATION_FUNC_NAME, ListCreationVectorFunction::getDefinitions()});
vectorFunctions.insert({LIST_RANGE_FUNC_NAME, ListRangeVectorFunction::getDefinitions()});
vectorFunctions.insert({LIST_LEN_FUNC_NAME, ListLenVectorFunction::getDefinitions()});
vectorFunctions.insert({LIST_EXTRACT_FUNC_NAME, ListExtractVectorFunction::getDefinitions()});
vectorFunctions.insert({LIST_ELEMENT_FUNC_NAME, ListExtractVectorFunction::getDefinitions()});
Expand Down
33 changes: 30 additions & 3 deletions src/function/vector_list_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "function/list/functions/list_len_function.h"
#include "function/list/functions/list_position_function.h"
#include "function/list/functions/list_prepend_function.h"
#include "function/list/functions/list_range_function.h"
#include "function/list/functions/list_reverse_sort_function.h"
#include "function/list/functions/list_slice_function.h"
#include "function/list/functions/list_sort_function.h"
Expand Down Expand Up @@ -94,6 +95,32 @@
return result;
}

std::unique_ptr<FunctionBindData> ListRangeVectorFunction::bindFunc(
const binder::expression_vector& arguments, kuzu::function::FunctionDefinition* definition) {
assert(arguments[0]->dataType == arguments[1]->dataType);
auto varListTypeInfo = std::make_unique<VarListTypeInfo>(
std::make_unique<LogicalType>(arguments[0]->dataType.getLogicalTypeID()));
auto resultType = LogicalType{LogicalTypeID::VAR_LIST, std::move(varListTypeInfo)};
return std::make_unique<FunctionBindData>(resultType);
}

vector_function_definitions ListRangeVectorFunction::getDefinitions() {
vector_function_definitions result;
for (auto typeID : LogicalTypeUtils::getIntegerLogicalTypeIDs()) {
// start, end
result.push_back(std::make_unique<VectorFunctionDefinition>(LIST_RANGE_FUNC_NAME,
std::vector<LogicalTypeID>{typeID, typeID}, LogicalTypeID::VAR_LIST,
getBinaryListExecFuncSwitchAll<Range, list_entry_t>(LogicalType{typeID}), nullptr,
bindFunc, false));
// start, end, step
result.push_back(std::make_unique<VectorFunctionDefinition>(LIST_RANGE_FUNC_NAME,
std::vector<LogicalTypeID>{typeID, typeID, typeID}, LogicalTypeID::VAR_LIST,
getTernaryListExecFuncSwitchAll<Range, list_entry_t>(LogicalType{typeID}), nullptr,
bindFunc, false));
}
return result;
}

Check warning on line 122 in src/function/vector_list_functions.cpp

View check run for this annotation

Codecov / codecov/patch

src/function/vector_list_functions.cpp#L122

Added line #L122 was not covered by tests

vector_function_definitions ListLenVectorFunction::getDefinitions() {
vector_function_definitions result;
auto execFunc = UnaryExecFunction<list_entry_t, int64_t, ListLen>;
Expand Down Expand Up @@ -206,7 +233,7 @@
auto resultType = arguments[0]->getDataType();
auto vectorFunctionDefinition = reinterpret_cast<VectorFunctionDefinition*>(definition);
vectorFunctionDefinition->execFunc =
getBinaryListExecFunc<ListAppend, list_entry_t>(arguments[1]->getDataType());
getBinaryListExecFuncSwitchRight<ListAppend, list_entry_t>(arguments[1]->getDataType());
return std::make_unique<FunctionBindData>(resultType);
}

Expand Down Expand Up @@ -299,7 +326,7 @@
const binder::expression_vector& arguments, FunctionDefinition* definition) {
auto vectorFunctionDefinition = reinterpret_cast<VectorFunctionDefinition*>(definition);
vectorFunctionDefinition->execFunc =
getBinaryListExecFunc<ListPosition, int64_t>(arguments[1]->getDataType());
getBinaryListExecFuncSwitchRight<ListPosition, int64_t>(arguments[1]->getDataType());
return std::make_unique<FunctionBindData>(LogicalType{LogicalTypeID::INT64});
}

Expand All @@ -315,7 +342,7 @@
const binder::expression_vector& arguments, FunctionDefinition* definition) {
auto vectorFunctionDefinition = reinterpret_cast<VectorFunctionDefinition*>(definition);
vectorFunctionDefinition->execFunc =
getBinaryListExecFunc<ListContains, uint8_t>(arguments[1]->getDataType());
getBinaryListExecFuncSwitchRight<ListContains, uint8_t>(arguments[1]->getDataType());
return std::make_unique<FunctionBindData>(LogicalType{LogicalTypeID::BOOL});
}

Expand Down
2 changes: 1 addition & 1 deletion src/function/vector_map_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ std::unique_ptr<FunctionBindData> MapExtractVectorFunctions::bindFunc(
validateKeyType(arguments[0], arguments[1]);
auto vectorFunctionDefinition = reinterpret_cast<VectorFunctionDefinition*>(definition);
vectorFunctionDefinition->execFunc =
VectorListFunction::getBinaryListExecFunc<MapExtract, list_entry_t>(
VectorListFunction::getBinaryListExecFuncSwitchRight<MapExtract, list_entry_t>(
arguments[1]->getDataType());
auto returnListInfo = std::make_unique<VarListTypeInfo>(
std::make_unique<LogicalType>(*MapType::getValueType(&arguments[0]->dataType)));
Expand Down
1 change: 1 addition & 0 deletions src/include/common/expression_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ const std::string CAST_TO_BLOB_FUNC_NAME = "BLOB";

// list
const std::string LIST_CREATION_FUNC_NAME = "LIST_CREATION";
const std::string LIST_RANGE_FUNC_NAME = "RANGE";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not related to this PR but I guess we can start sorting these function names in alphabetical order. Also sort code when you register.

const std::string LIST_EXTRACT_FUNC_NAME = "LIST_EXTRACT";
const std::string LIST_ELEMENT_FUNC_NAME = "LIST_ELEMENT";
const std::string LIST_LEN_FUNC_NAME = "LEN";
Expand Down
1 change: 1 addition & 0 deletions src/include/common/types/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,7 @@ class LogicalTypeUtils {
static bool isNested(const LogicalType& dataType);
static std::vector<LogicalType> getAllValidComparableLogicalTypes();
static std::vector<LogicalTypeID> getNumericalLogicalTypeIDs();
static std::vector<LogicalTypeID> getIntegerLogicalTypeIDs();
static std::vector<LogicalType> getAllValidLogicTypes();

private:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ class VectorArithmeticFunction : public VectorFunction {
throw common::RuntimeException(
"Invalid input data types(" +
common::LogicalTypeUtils::dataTypeToString(operandTypeID) +
") for getUnaryExecFunc.");
") for getBinaryExecFunc.");
}
}
};
Expand Down
45 changes: 45 additions & 0 deletions src/include/function/list/functions/list_range_function.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#pragma once

#include "common/types/ku_list.h"

namespace kuzu {
namespace function {

struct Range {
public:
// range function:
// - include end
// - when start = end: there is only one element in result varlist
// - when end - start are of opposite sign of step, the result will be empty
// - default step = 1
template<typename T>
static inline void operation(T& start, T& end, common::list_entry_t& result,
common::ValueVector& leftVector, common::ValueVector& rightVector,
common::ValueVector& resultVector) {
T step = 1;
operation(start, end, step, result, leftVector, resultVector);
}

template<typename T>
static inline void operation(T& start, T& end, T& step, common::list_entry_t& result,
common::ValueVector& inputVector, common::ValueVector& resultVector) {
if (step == 0) {
throw common::RuntimeException("Step of range cannot be 0.");
}

// start, start + step, start + 2step, ..., end
T number = start;
auto size = (end - start) * 1.0 / step;
AEsir777 marked this conversation as resolved.
Show resolved Hide resolved
size < 0 ? size = 0 : size = (int64_t)(size + 1);

result = common::ListVector::addList(&resultVector, size);
auto resultDataVector = common::ListVector::getDataVector(&resultVector);
for (auto i = 0u; i < size; i++) {
resultDataVector->setValue(result.offset + i, number);
number += step;
}
}
};

} // namespace function
} // namespace kuzu
60 changes: 59 additions & 1 deletion src/include/function/list/vector_list_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

struct VectorListFunction : public VectorFunction {
template<typename OPERATION, typename RESULT_TYPE>
static scalar_exec_func getBinaryListExecFunc(common::LogicalType rightType) {
static scalar_exec_func getBinaryListExecFuncSwitchRight(common::LogicalType rightType) {
scalar_exec_func execFunc;
switch (rightType.getPhysicalType()) {
case common::PhysicalTypeID::BOOL: {
Expand Down Expand Up @@ -65,6 +65,58 @@
}
return execFunc;
}

template<typename OPERATION, typename RESULT_TYPE>
static scalar_exec_func getBinaryListExecFuncSwitchAll(common::LogicalType type) {
scalar_exec_func execFunc;
switch (type.getPhysicalType()) {
case common::PhysicalTypeID::INT64: {
execFunc = BinaryExecListStructFunction<int64_t, int64_t, RESULT_TYPE, OPERATION>;
} break;
case common::PhysicalTypeID::INT32: {
execFunc = BinaryExecListStructFunction<int32_t, int32_t, RESULT_TYPE, OPERATION>;
} break;

Check warning on line 78 in src/include/function/list/vector_list_functions.h

View check run for this annotation

Codecov / codecov/patch

src/include/function/list/vector_list_functions.h#L77-L78

Added lines #L77 - L78 were not covered by tests
AEsir777 marked this conversation as resolved.
Show resolved Hide resolved
case common::PhysicalTypeID::INT16: {
execFunc = BinaryExecListStructFunction<int16_t, int16_t, RESULT_TYPE, OPERATION>;
} break;
case common::PhysicalTypeID::INT8: {
execFunc = BinaryExecListStructFunction<int8_t, int8_t, RESULT_TYPE, OPERATION>;
} break;
default: {
throw common::NotImplementedException{
"VectorListFunctions::getBinaryListOperationDefinition"};
}
}
return execFunc;
}

template<typename OPERATION, typename RESULT_TYPE>
static scalar_exec_func getTernaryListExecFuncSwitchAll(common::LogicalType type) {
scalar_exec_func execFunc;
switch (type.getPhysicalType()) {
case common::PhysicalTypeID::INT64: {
execFunc =
TernaryExecListStructFunction<int64_t, int64_t, int64_t, RESULT_TYPE, OPERATION>;
} break;
case common::PhysicalTypeID::INT32: {
AEsir777 marked this conversation as resolved.
Show resolved Hide resolved
execFunc =

Check warning on line 102 in src/include/function/list/vector_list_functions.h

View check run for this annotation

Codecov / codecov/patch

src/include/function/list/vector_list_functions.h#L101-L102

Added lines #L101 - L102 were not covered by tests
TernaryExecListStructFunction<int32_t, int32_t, int32_t, RESULT_TYPE, OPERATION>;
} break;
case common::PhysicalTypeID::INT16: {
execFunc =
TernaryExecListStructFunction<int16_t, int16_t, int16_t, RESULT_TYPE, OPERATION>;
} break;
case common::PhysicalTypeID::INT8: {
execFunc =
TernaryExecListStructFunction<int8_t, int8_t, int8_t, RESULT_TYPE, OPERATION>;
} break;
default: {
throw common::NotImplementedException{
"VectorListFunctions::getTernaryListOperationDefinition"};
}
}
return execFunc;
}
};

struct ListCreationVectorFunction : public VectorListFunction {
Expand All @@ -75,6 +127,12 @@
common::ValueVector& result);
};

struct ListRangeVectorFunction : public VectorListFunction {

Check warning on line 130 in src/include/function/list/vector_list_functions.h

View check run for this annotation

Codecov / codecov/patch

src/include/function/list/vector_list_functions.h#L130

Added line #L130 was not covered by tests
static vector_function_definitions getDefinitions();
static std::unique_ptr<FunctionBindData> bindFunc(
const binder::expression_vector& arguments, FunctionDefinition* definition);
};

struct ListLenVectorFunction : public VectorListFunction {
static vector_function_definitions getDefinitions();
};
Expand Down
22 changes: 4 additions & 18 deletions test/test_files/tck/match/match4.test
Original file line number Diff line number Diff line change
Expand Up @@ -59,31 +59,17 @@
# Matching longer variable length paths
-CASE Scenario4
-SKIP
-STATEMENT CREATE NODE TABLE A(ID SERIAL, var STRING, PRIMARY KEY(ID));
-STATEMENT CREATE NODE TABLE A(ID SERIAL, var INT64, PRIMARY KEY(ID));
---- ok
-STATEMENT CREATE REL TABLE T(FROM A TO A);
---- ok
# indexing on varlist causes segmentaion fault
-STATEMENT CREATE (a {var: 'start'}), (b {var: 'end'})
WITH *
UNWIND ['0', '1', '2'] AS i
CREATE (n {var: i})
WITH a, b, collect(n) AS nodeList
UNWIND [0, 1] AS i
WITH nodeList[i] AS n1, nodeList[i+1] AS n2
CREATE (n1)-[:T]->(n2);
---- ok
-STATEMENT WITH collect(n) AS nodeList
UNWIND [0, 1] AS i
WITH nodeList[i] AS n1, nodeList[i+1] AS n2
CREATE (n1)-[:T]->(n2);
---- ok
-STATEMENT CREATE (a {var: 'start'}), (b {var: 'end'})
-STATEMENT CREATE (a {var: -2}), (b {var: 0})
WITH *
UNWIND range(1, 20) AS i
CREATE (n {var: i})
WITH a, b, [a] + collect(n) + [b] AS nodeList
UNWIND range(0, size(nodeList) - 2, 1) AS i
WITH a, b, list_cat(list_cat([a], collect(n)), [b]) AS nodeList
UNWIND range(0, len(nodeList) - 2, 1) AS i
WITH nodeList[i] AS n1, nodeList[i+1] AS n2
CREATE (n1)-[:T]->(n2);
---- ok
Expand Down
Loading