Skip to content

Commit

Permalink
finish range()
Browse files Browse the repository at this point in the history
  • Loading branch information
AEsir777 committed Sep 21, 2023
1 parent 6943574 commit cda29ea
Show file tree
Hide file tree
Showing 11 changed files with 284 additions and 24 deletions.
5 changes: 5 additions & 0 deletions src/common/types/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,11 @@ std::vector<LogicalTypeID> LogicalTypeUtils::getNumericalLogicalTypeIDs() {
LogicalTypeID::SERIAL};
}

std::vector<LogicalTypeID> LogicalTypeUtils::getIntegerLogicalTypeIDs() {
return std::vector<LogicalTypeID>{LogicalTypeID::INT64, LogicalTypeID::INT32,
LogicalTypeID::INT16, LogicalTypeID::INT8, LogicalTypeID::SERIAL};
}

std::vector<LogicalType> LogicalTypeUtils::getAllValidLogicTypes() {
// TODO(Ziyi): Add FIX_LIST,STRUCT,MAP type to allValidTypeID when we support functions on
// FIXED_LIST,STRUCT,MAP.
Expand Down
1 change: 1 addition & 0 deletions src/function/built_in_vector_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,7 @@ void BuiltInVectorFunctions::registerCastFunctions() {

void BuiltInVectorFunctions::registerListFunctions() {
vectorFunctions.insert({LIST_CREATION_FUNC_NAME, ListCreationVectorFunction::getDefinitions()});
vectorFunctions.insert({LIST_RANGE_FUNC_NAME, ListRangeVectorFunction::getDefinitions()});
vectorFunctions.insert({LIST_LEN_FUNC_NAME, ListLenVectorFunction::getDefinitions()});
vectorFunctions.insert({LIST_EXTRACT_FUNC_NAME, ListExtractVectorFunction::getDefinitions()});
vectorFunctions.insert({LIST_ELEMENT_FUNC_NAME, ListExtractVectorFunction::getDefinitions()});
Expand Down
33 changes: 30 additions & 3 deletions src/function/vector_list_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "function/list/functions/list_len_function.h"
#include "function/list/functions/list_position_function.h"
#include "function/list/functions/list_prepend_function.h"
#include "function/list/functions/list_range_function.h"
#include "function/list/functions/list_reverse_sort_function.h"
#include "function/list/functions/list_slice_function.h"
#include "function/list/functions/list_sort_function.h"
Expand Down Expand Up @@ -94,6 +95,32 @@ vector_function_definitions ListCreationVectorFunction::getDefinitions() {
return result;
}

std::unique_ptr<FunctionBindData> ListRangeVectorFunction::bindFunc(
const binder::expression_vector& arguments, kuzu::function::FunctionDefinition* definition) {
assert(arguments[0]->dataType == arguments[1]->dataType);
auto varListTypeInfo = std::make_unique<VarListTypeInfo>(
std::make_unique<LogicalType>(arguments[0]->dataType.getLogicalTypeID()));
auto resultType = LogicalType{LogicalTypeID::VAR_LIST, std::move(varListTypeInfo)};
return std::make_unique<FunctionBindData>(resultType);
}

vector_function_definitions ListRangeVectorFunction::getDefinitions() {
vector_function_definitions result;
for (auto typeID : LogicalTypeUtils::getIntegerLogicalTypeIDs()) {
// start, end
result.push_back(std::make_unique<VectorFunctionDefinition>(LIST_RANGE_FUNC_NAME,
std::vector<LogicalTypeID>{typeID, typeID}, LogicalTypeID::VAR_LIST,
getBinaryListExecFuncSwitchAll<Range, list_entry_t>(LogicalType{typeID}), nullptr,
bindFunc, false));
// start, end, step
result.push_back(std::make_unique<VectorFunctionDefinition>(LIST_RANGE_FUNC_NAME,
std::vector<LogicalTypeID>{typeID, typeID, typeID}, LogicalTypeID::VAR_LIST,
getTernaryListExecFuncSwitchAll<Range, list_entry_t>(LogicalType{typeID}), nullptr,
bindFunc, false));
}
return result;
}

Check warning on line 122 in src/function/vector_list_functions.cpp

View check run for this annotation

Codecov / codecov/patch

src/function/vector_list_functions.cpp#L122

Added line #L122 was not covered by tests

vector_function_definitions ListLenVectorFunction::getDefinitions() {
vector_function_definitions result;
auto execFunc = UnaryExecFunction<list_entry_t, int64_t, ListLen>;
Expand Down Expand Up @@ -206,7 +233,7 @@ std::unique_ptr<FunctionBindData> ListAppendVectorFunction::bindFunc(
auto resultType = arguments[0]->getDataType();
auto vectorFunctionDefinition = reinterpret_cast<VectorFunctionDefinition*>(definition);
vectorFunctionDefinition->execFunc =
getBinaryListExecFunc<ListAppend, list_entry_t>(arguments[1]->getDataType());
getBinaryListExecFuncSwitchRight<ListAppend, list_entry_t>(arguments[1]->getDataType());
return std::make_unique<FunctionBindData>(resultType);
}

Expand Down Expand Up @@ -299,7 +326,7 @@ std::unique_ptr<FunctionBindData> ListPositionVectorFunction::bindFunc(
const binder::expression_vector& arguments, FunctionDefinition* definition) {
auto vectorFunctionDefinition = reinterpret_cast<VectorFunctionDefinition*>(definition);
vectorFunctionDefinition->execFunc =
getBinaryListExecFunc<ListPosition, int64_t>(arguments[1]->getDataType());
getBinaryListExecFuncSwitchRight<ListPosition, int64_t>(arguments[1]->getDataType());
return std::make_unique<FunctionBindData>(LogicalType{LogicalTypeID::INT64});
}

Expand All @@ -315,7 +342,7 @@ std::unique_ptr<FunctionBindData> ListContainsVectorFunction::bindFunc(
const binder::expression_vector& arguments, FunctionDefinition* definition) {
auto vectorFunctionDefinition = reinterpret_cast<VectorFunctionDefinition*>(definition);
vectorFunctionDefinition->execFunc =
getBinaryListExecFunc<ListContains, uint8_t>(arguments[1]->getDataType());
getBinaryListExecFuncSwitchRight<ListContains, uint8_t>(arguments[1]->getDataType());
return std::make_unique<FunctionBindData>(LogicalType{LogicalTypeID::BOOL});
}

Expand Down
2 changes: 1 addition & 1 deletion src/function/vector_map_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ std::unique_ptr<FunctionBindData> MapExtractVectorFunctions::bindFunc(
validateKeyType(arguments[0], arguments[1]);
auto vectorFunctionDefinition = reinterpret_cast<VectorFunctionDefinition*>(definition);
vectorFunctionDefinition->execFunc =
VectorListFunction::getBinaryListExecFunc<MapExtract, list_entry_t>(
VectorListFunction::getBinaryListExecFuncSwitchRight<MapExtract, list_entry_t>(
arguments[1]->getDataType());
auto returnListInfo = std::make_unique<VarListTypeInfo>(
std::make_unique<LogicalType>(*MapType::getValueType(&arguments[0]->dataType)));
Expand Down
1 change: 1 addition & 0 deletions src/include/common/expression_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ const std::string CAST_TO_BLOB_FUNC_NAME = "BLOB";

// list
const std::string LIST_CREATION_FUNC_NAME = "LIST_CREATION";
const std::string LIST_RANGE_FUNC_NAME = "RANGE";
const std::string LIST_EXTRACT_FUNC_NAME = "LIST_EXTRACT";
const std::string LIST_ELEMENT_FUNC_NAME = "LIST_ELEMENT";
const std::string LIST_LEN_FUNC_NAME = "LEN";
Expand Down
1 change: 1 addition & 0 deletions src/include/common/types/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,7 @@ class LogicalTypeUtils {
static bool isNested(const LogicalType& dataType);
static std::vector<LogicalType> getAllValidComparableLogicalTypes();
static std::vector<LogicalTypeID> getNumericalLogicalTypeIDs();
static std::vector<LogicalTypeID> getIntegerLogicalTypeIDs();
static std::vector<LogicalType> getAllValidLogicTypes();

private:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ class VectorArithmeticFunction : public VectorFunction {
throw common::RuntimeException(
"Invalid input data types(" +
common::LogicalTypeUtils::dataTypeToString(operandTypeID) +
") for getUnaryExecFunc.");
") for getBinaryExecFunc.");
}
}
};
Expand Down
45 changes: 45 additions & 0 deletions src/include/function/list/functions/list_range_function.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#pragma once

#include "common/types/ku_list.h"

namespace kuzu {
namespace function {

struct Range {
public:
// range function:
// - include end
// - when start = end: there is only one element in result varlist
// - when end - start are of opposite sign of step, the result will be empty
// - default step = 1
template<typename T>
static inline void operation(T& start, T& end, common::list_entry_t& result,
common::ValueVector& leftVector, common::ValueVector& rightVector,
common::ValueVector& resultVector) {
T step = 1;
operation(start, end, step, result, leftVector, resultVector);
}

template<typename T>
static inline void operation(T& start, T& end, T& step, common::list_entry_t& result,
common::ValueVector& inputVector, common::ValueVector& resultVector) {
if (step == 0) {
throw common::RuntimeException("Step of range cannot be 0.");
}

// start, start + step, start + 2step, ..., end
T number = start;
auto size = (end - start) * 1.0 / step;
size < 0 ? size = 0 : size = (int64_t)(size + 1);

result = common::ListVector::addList(&resultVector, size);
auto resultDataVector = common::ListVector::getDataVector(&resultVector);
for (auto i = 0u; i < size; i++) {
resultDataVector->setValue(result.offset + i, number);
number += step;
}
}
};

} // namespace function
} // namespace kuzu
60 changes: 59 additions & 1 deletion src/include/function/list/vector_list_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace function {

struct VectorListFunction : public VectorFunction {
template<typename OPERATION, typename RESULT_TYPE>
static scalar_exec_func getBinaryListExecFunc(common::LogicalType rightType) {
static scalar_exec_func getBinaryListExecFuncSwitchRight(common::LogicalType rightType) {
scalar_exec_func execFunc;
switch (rightType.getPhysicalType()) {
case common::PhysicalTypeID::BOOL: {
Expand Down Expand Up @@ -65,6 +65,58 @@ struct VectorListFunction : public VectorFunction {
}
return execFunc;
}

template<typename OPERATION, typename RESULT_TYPE>
static scalar_exec_func getBinaryListExecFuncSwitchAll(common::LogicalType type) {
scalar_exec_func execFunc;
switch (type.getPhysicalType()) {
case common::PhysicalTypeID::INT64: {
execFunc = BinaryExecListStructFunction<int64_t, int64_t, RESULT_TYPE, OPERATION>;
} break;
case common::PhysicalTypeID::INT32: {
execFunc = BinaryExecListStructFunction<int32_t, int32_t, RESULT_TYPE, OPERATION>;
} break;

Check warning on line 78 in src/include/function/list/vector_list_functions.h

View check run for this annotation

Codecov / codecov/patch

src/include/function/list/vector_list_functions.h#L77-L78

Added lines #L77 - L78 were not covered by tests
case common::PhysicalTypeID::INT16: {
execFunc = BinaryExecListStructFunction<int16_t, int16_t, RESULT_TYPE, OPERATION>;
} break;
case common::PhysicalTypeID::INT8: {
execFunc = BinaryExecListStructFunction<int8_t, int8_t, RESULT_TYPE, OPERATION>;
} break;
default: {
throw common::NotImplementedException{
"VectorListFunctions::getBinaryListOperationDefinition"};
}
}
return execFunc;
}

template<typename OPERATION, typename RESULT_TYPE>
static scalar_exec_func getTernaryListExecFuncSwitchAll(common::LogicalType type) {
scalar_exec_func execFunc;
switch (type.getPhysicalType()) {
case common::PhysicalTypeID::INT64: {
execFunc =
TernaryExecListStructFunction<int64_t, int64_t, int64_t, RESULT_TYPE, OPERATION>;
} break;
case common::PhysicalTypeID::INT32: {
execFunc =

Check warning on line 102 in src/include/function/list/vector_list_functions.h

View check run for this annotation

Codecov / codecov/patch

src/include/function/list/vector_list_functions.h#L101-L102

Added lines #L101 - L102 were not covered by tests
TernaryExecListStructFunction<int32_t, int32_t, int32_t, RESULT_TYPE, OPERATION>;
} break;
case common::PhysicalTypeID::INT16: {
execFunc =
TernaryExecListStructFunction<int16_t, int16_t, int16_t, RESULT_TYPE, OPERATION>;
} break;
case common::PhysicalTypeID::INT8: {
execFunc =
TernaryExecListStructFunction<int8_t, int8_t, int8_t, RESULT_TYPE, OPERATION>;
} break;
default: {
throw common::NotImplementedException{
"VectorListFunctions::getTernaryListOperationDefinition"};
}
}
return execFunc;
}
};

struct ListCreationVectorFunction : public VectorListFunction {
Expand All @@ -75,6 +127,12 @@ struct ListCreationVectorFunction : public VectorListFunction {
common::ValueVector& result);
};

struct ListRangeVectorFunction : public VectorListFunction {

Check warning on line 130 in src/include/function/list/vector_list_functions.h

View check run for this annotation

Codecov / codecov/patch

src/include/function/list/vector_list_functions.h#L130

Added line #L130 was not covered by tests
static vector_function_definitions getDefinitions();
static std::unique_ptr<FunctionBindData> bindFunc(
const binder::expression_vector& arguments, FunctionDefinition* definition);
};

struct ListLenVectorFunction : public VectorListFunction {
static vector_function_definitions getDefinitions();
};
Expand Down
22 changes: 4 additions & 18 deletions test/test_files/tck/match/match4.test
Original file line number Diff line number Diff line change
Expand Up @@ -59,31 +59,17 @@
# Matching longer variable length paths
-CASE Scenario4
-SKIP
-STATEMENT CREATE NODE TABLE A(ID SERIAL, var STRING, PRIMARY KEY(ID));
-STATEMENT CREATE NODE TABLE A(ID SERIAL, var INT64, PRIMARY KEY(ID));
---- ok
-STATEMENT CREATE REL TABLE T(FROM A TO A);
---- ok
# indexing on varlist causes segmentaion fault
-STATEMENT CREATE (a {var: 'start'}), (b {var: 'end'})
WITH *
UNWIND ['0', '1', '2'] AS i
CREATE (n {var: i})
WITH a, b, collect(n) AS nodeList
UNWIND [0, 1] AS i
WITH nodeList[i] AS n1, nodeList[i+1] AS n2
CREATE (n1)-[:T]->(n2);
---- ok
-STATEMENT WITH collect(n) AS nodeList
UNWIND [0, 1] AS i
WITH nodeList[i] AS n1, nodeList[i+1] AS n2
CREATE (n1)-[:T]->(n2);
---- ok
-STATEMENT CREATE (a {var: 'start'}), (b {var: 'end'})
-STATEMENT CREATE (a {var: -2}), (b {var: 0})
WITH *
UNWIND range(1, 20) AS i
CREATE (n {var: i})
WITH a, b, [a] + collect(n) + [b] AS nodeList
UNWIND range(0, size(nodeList) - 2, 1) AS i
WITH a, b, list_cat(list_cat([a], collect(n)), [b]) AS nodeList
UNWIND range(0, len(nodeList) - 2, 1) AS i
WITH nodeList[i] AS n1, nodeList[i+1] AS n2
CREATE (n1)-[:T]->(n2);
---- ok
Expand Down
Loading

0 comments on commit cda29ea

Please sign in to comment.