diff --git a/src/common/types/types.cpp b/src/common/types/types.cpp index 8c313cda70..e99d3d91ec 100644 --- a/src/common/types/types.cpp +++ b/src/common/types/types.cpp @@ -705,6 +705,11 @@ std::vector LogicalTypeUtils::getNumericalLogicalTypeIDs() { LogicalTypeID::SERIAL}; } +std::vector LogicalTypeUtils::getIntegerLogicalTypeIDs() { + return std::vector{LogicalTypeID::INT64, LogicalTypeID::INT32, + LogicalTypeID::INT16, LogicalTypeID::INT8, LogicalTypeID::SERIAL}; +} + std::vector LogicalTypeUtils::getAllValidLogicTypes() { // TODO(Ziyi): Add FIX_LIST,STRUCT,MAP type to allValidTypeID when we support functions on // FIXED_LIST,STRUCT,MAP. diff --git a/src/function/built_in_vector_functions.cpp b/src/function/built_in_vector_functions.cpp index 58396734d6..747f08f62f 100644 --- a/src/function/built_in_vector_functions.cpp +++ b/src/function/built_in_vector_functions.cpp @@ -505,6 +505,7 @@ void BuiltInVectorFunctions::registerCastFunctions() { void BuiltInVectorFunctions::registerListFunctions() { vectorFunctions.insert({LIST_CREATION_FUNC_NAME, ListCreationVectorFunction::getDefinitions()}); + vectorFunctions.insert({LIST_RANGE_FUNC_NAME, ListRangeVectorFunction::getDefinitions()}); vectorFunctions.insert({LIST_LEN_FUNC_NAME, ListLenVectorFunction::getDefinitions()}); vectorFunctions.insert({LIST_EXTRACT_FUNC_NAME, ListExtractVectorFunction::getDefinitions()}); vectorFunctions.insert({LIST_ELEMENT_FUNC_NAME, ListExtractVectorFunction::getDefinitions()}); diff --git a/src/function/vector_list_functions.cpp b/src/function/vector_list_functions.cpp index 0f1d9782c5..ae6506d87d 100644 --- a/src/function/vector_list_functions.cpp +++ b/src/function/vector_list_functions.cpp @@ -15,6 +15,7 @@ #include "function/list/functions/list_len_function.h" #include "function/list/functions/list_position_function.h" #include "function/list/functions/list_prepend_function.h" +#include "function/list/functions/list_range_function.h" #include "function/list/functions/list_reverse_sort_function.h" #include "function/list/functions/list_slice_function.h" #include "function/list/functions/list_sort_function.h" @@ -94,6 +95,32 @@ vector_function_definitions ListCreationVectorFunction::getDefinitions() { return result; } +std::unique_ptr ListRangeVectorFunction::bindFunc( + const binder::expression_vector& arguments, kuzu::function::FunctionDefinition* definition) { + assert(arguments[0]->dataType == arguments[1]->dataType); + auto varListTypeInfo = std::make_unique( + std::make_unique(arguments[0]->dataType.getLogicalTypeID())); + auto resultType = LogicalType{LogicalTypeID::VAR_LIST, std::move(varListTypeInfo)}; + return std::make_unique(resultType); +} + +vector_function_definitions ListRangeVectorFunction::getDefinitions() { + vector_function_definitions result; + for (auto typeID : LogicalTypeUtils::getIntegerLogicalTypeIDs()) { + // start, end + result.push_back(std::make_unique(LIST_RANGE_FUNC_NAME, + std::vector{typeID, typeID}, LogicalTypeID::VAR_LIST, + getBinaryListExecFuncSwitchAll(LogicalType{typeID}), nullptr, + bindFunc, false)); + // start, end, step + result.push_back(std::make_unique(LIST_RANGE_FUNC_NAME, + std::vector{typeID, typeID, typeID}, LogicalTypeID::VAR_LIST, + getTernaryListExecFuncSwitchAll(LogicalType{typeID}), nullptr, + bindFunc, false)); + } + return result; +} + vector_function_definitions ListLenVectorFunction::getDefinitions() { vector_function_definitions result; auto execFunc = UnaryExecFunction; @@ -222,7 +249,7 @@ std::unique_ptr ListAppendVectorFunction::bindFunc( auto resultType = arguments[0]->getDataType(); auto vectorFunctionDefinition = reinterpret_cast(definition); vectorFunctionDefinition->execFunc = - getBinaryListExecFunc(arguments[1]->getDataType()); + getBinaryListExecFuncSwitchRight(arguments[1]->getDataType()); return std::make_unique(resultType); } @@ -331,7 +358,7 @@ std::unique_ptr ListPositionVectorFunction::bindFunc( const binder::expression_vector& arguments, FunctionDefinition* definition) { auto vectorFunctionDefinition = reinterpret_cast(definition); vectorFunctionDefinition->execFunc = - getBinaryListExecFunc(arguments[1]->getDataType()); + getBinaryListExecFuncSwitchRight(arguments[1]->getDataType()); return std::make_unique(LogicalType{LogicalTypeID::INT64}); } @@ -347,7 +374,7 @@ std::unique_ptr ListContainsVectorFunction::bindFunc( const binder::expression_vector& arguments, FunctionDefinition* definition) { auto vectorFunctionDefinition = reinterpret_cast(definition); vectorFunctionDefinition->execFunc = - getBinaryListExecFunc(arguments[1]->getDataType()); + getBinaryListExecFuncSwitchRight(arguments[1]->getDataType()); return std::make_unique(LogicalType{LogicalTypeID::BOOL}); } diff --git a/src/function/vector_map_functions.cpp b/src/function/vector_map_functions.cpp index a468c11c79..ada45fe705 100644 --- a/src/function/vector_map_functions.cpp +++ b/src/function/vector_map_functions.cpp @@ -58,7 +58,7 @@ std::unique_ptr MapExtractVectorFunctions::bindFunc( validateKeyType(arguments[0], arguments[1]); auto vectorFunctionDefinition = reinterpret_cast(definition); vectorFunctionDefinition->execFunc = - VectorListFunction::getBinaryListExecFunc( + VectorListFunction::getBinaryListExecFuncSwitchRight( arguments[1]->getDataType()); auto returnListInfo = std::make_unique( std::make_unique(*MapType::getValueType(&arguments[0]->dataType))); diff --git a/src/include/common/expression_type.h b/src/include/common/expression_type.h index cd81ff1bd2..8490ec3c9f 100644 --- a/src/include/common/expression_type.h +++ b/src/include/common/expression_type.h @@ -39,6 +39,7 @@ const std::string CAST_TO_BLOB_FUNC_NAME = "BLOB"; // list const std::string LIST_CREATION_FUNC_NAME = "LIST_CREATION"; +const std::string LIST_RANGE_FUNC_NAME = "RANGE"; const std::string LIST_EXTRACT_FUNC_NAME = "LIST_EXTRACT"; const std::string LIST_ELEMENT_FUNC_NAME = "LIST_ELEMENT"; const std::string LIST_LEN_FUNC_NAME = "LEN"; diff --git a/src/include/common/types/types.h b/src/include/common/types/types.h index a6544f1723..6ad412b4f6 100644 --- a/src/include/common/types/types.h +++ b/src/include/common/types/types.h @@ -434,6 +434,7 @@ class LogicalTypeUtils { static bool isNested(const LogicalType& dataType); static std::vector getAllValidComparableLogicalTypes(); static std::vector getNumericalLogicalTypeIDs(); + static std::vector getIntegerLogicalTypeIDs(); static std::vector getAllValidLogicTypes(); private: diff --git a/src/include/function/arithmetic/vector_arithmetic_functions.h b/src/include/function/arithmetic/vector_arithmetic_functions.h index 79ec0bdb9b..49e86a4f1e 100644 --- a/src/include/function/arithmetic/vector_arithmetic_functions.h +++ b/src/include/function/arithmetic/vector_arithmetic_functions.h @@ -144,7 +144,7 @@ class VectorArithmeticFunction : public VectorFunction { throw common::RuntimeException( "Invalid input data types(" + common::LogicalTypeUtils::dataTypeToString(operandTypeID) + - ") for getUnaryExecFunc."); + ") for getBinaryExecFunc."); } } }; diff --git a/src/include/function/list/functions/list_range_function.h b/src/include/function/list/functions/list_range_function.h new file mode 100644 index 0000000000..6b8cb0a82e --- /dev/null +++ b/src/include/function/list/functions/list_range_function.h @@ -0,0 +1,45 @@ +#pragma once + +#include "common/types/ku_list.h" + +namespace kuzu { +namespace function { + +struct Range { +public: + // range function: + // - include end + // - when start = end: there is only one element in result varlist + // - when end - start are of opposite sign of step, the result will be empty + // - default step = 1 + template + static inline void operation(T& start, T& end, common::list_entry_t& result, + common::ValueVector& leftVector, common::ValueVector& rightVector, + common::ValueVector& resultVector) { + T step = 1; + operation(start, end, step, result, leftVector, resultVector); + } + + template + static inline void operation(T& start, T& end, T& step, common::list_entry_t& result, + common::ValueVector& inputVector, common::ValueVector& resultVector) { + if (step == 0) { + throw common::RuntimeException("Step of range cannot be 0."); + } + + // start, start + step, start + 2step, ..., end + T number = start; + auto size = (end - start) * 1.0 / step; + size < 0 ? size = 0 : size = (int64_t)(size + 1); + + result = common::ListVector::addList(&resultVector, size); + auto resultDataVector = common::ListVector::getDataVector(&resultVector); + for (auto i = 0u; i < size; i++) { + resultDataVector->setValue(result.offset + i, number); + number += step; + } + } +}; + +} // namespace function +} // namespace kuzu diff --git a/src/include/function/list/vector_list_functions.h b/src/include/function/list/vector_list_functions.h index dbad442410..7126d2d081 100644 --- a/src/include/function/list/vector_list_functions.h +++ b/src/include/function/list/vector_list_functions.h @@ -7,7 +7,7 @@ namespace function { struct VectorListFunction : public VectorFunction { template - static scalar_exec_func getBinaryListExecFunc(common::LogicalType rightType) { + static scalar_exec_func getBinaryListExecFuncSwitchRight(common::LogicalType rightType) { scalar_exec_func execFunc; switch (rightType.getPhysicalType()) { case common::PhysicalTypeID::BOOL: { @@ -81,6 +81,58 @@ struct VectorListFunction : public VectorFunction { } return execFunc; } + + template + static scalar_exec_func getBinaryListExecFuncSwitchAll(common::LogicalType type) { + scalar_exec_func execFunc; + switch (type.getPhysicalType()) { + case common::PhysicalTypeID::INT64: { + execFunc = BinaryExecListStructFunction; + } break; + case common::PhysicalTypeID::INT32: { + execFunc = BinaryExecListStructFunction; + } break; + case common::PhysicalTypeID::INT16: { + execFunc = BinaryExecListStructFunction; + } break; + case common::PhysicalTypeID::INT8: { + execFunc = BinaryExecListStructFunction; + } break; + default: { + throw common::NotImplementedException{ + "VectorListFunctions::getBinaryListOperationDefinition"}; + } + } + return execFunc; + } + + template + static scalar_exec_func getTernaryListExecFuncSwitchAll(common::LogicalType type) { + scalar_exec_func execFunc; + switch (type.getPhysicalType()) { + case common::PhysicalTypeID::INT64: { + execFunc = + TernaryExecListStructFunction; + } break; + case common::PhysicalTypeID::INT32: { + execFunc = + TernaryExecListStructFunction; + } break; + case common::PhysicalTypeID::INT16: { + execFunc = + TernaryExecListStructFunction; + } break; + case common::PhysicalTypeID::INT8: { + execFunc = + TernaryExecListStructFunction; + } break; + default: { + throw common::NotImplementedException{ + "VectorListFunctions::getTernaryListOperationDefinition"}; + } + } + return execFunc; + } }; struct ListCreationVectorFunction : public VectorListFunction { @@ -91,6 +143,12 @@ struct ListCreationVectorFunction : public VectorListFunction { common::ValueVector& result); }; +struct ListRangeVectorFunction : public VectorListFunction { + static vector_function_definitions getDefinitions(); + static std::unique_ptr bindFunc( + const binder::expression_vector& arguments, FunctionDefinition* definition); +}; + struct ListLenVectorFunction : public VectorListFunction { static vector_function_definitions getDefinitions(); }; diff --git a/test/test_files/tck/match/match4.test b/test/test_files/tck/match/match4.test index 889bc0525d..0fa7038e92 100644 --- a/test/test_files/tck/match/match4.test +++ b/test/test_files/tck/match/match4.test @@ -59,31 +59,17 @@ # Matching longer variable length paths -CASE Scenario4 -SKIP --STATEMENT CREATE NODE TABLE A(ID SERIAL, var STRING, PRIMARY KEY(ID)); +-STATEMENT CREATE NODE TABLE A(ID SERIAL, var INT64, PRIMARY KEY(ID)); ---- ok -STATEMENT CREATE REL TABLE T(FROM A TO A); ---- ok # indexing on varlist causes segmentaion fault --STATEMENT CREATE (a {var: 'start'}), (b {var: 'end'}) - WITH * - UNWIND ['0', '1', '2'] AS i - CREATE (n {var: i}) - WITH a, b, collect(n) AS nodeList - UNWIND [0, 1] AS i - WITH nodeList[i] AS n1, nodeList[i+1] AS n2 - CREATE (n1)-[:T]->(n2); ----- ok --STATEMENT WITH collect(n) AS nodeList - UNWIND [0, 1] AS i - WITH nodeList[i] AS n1, nodeList[i+1] AS n2 - CREATE (n1)-[:T]->(n2); ----- ok --STATEMENT CREATE (a {var: 'start'}), (b {var: 'end'}) +-STATEMENT CREATE (a {var: -2}), (b {var: 0}) WITH * UNWIND range(1, 20) AS i CREATE (n {var: i}) - WITH a, b, [a] + collect(n) + [b] AS nodeList - UNWIND range(0, size(nodeList) - 2, 1) AS i + WITH a, b, list_cat(list_cat([a], collect(n)), [b]) AS nodeList + UNWIND range(0, len(nodeList) - 2, 1) AS i WITH nodeList[i] AS n1, nodeList[i+1] AS n2 CREATE (n1)-[:T]->(n2); ---- ok diff --git a/test/test_files/tinysnb/function/range.test b/test/test_files/tinysnb/function/range.test new file mode 100644 index 0000000000..befac8dc22 --- /dev/null +++ b/test/test_files/tinysnb/function/range.test @@ -0,0 +1,136 @@ +-GROUP TinySnbReadTest +-DATASET CSV tck + +-- + +-CASE Range +-STATEMENT CREATE NODE TABLE A(ID SERIAL, AGE INT64, NAME STRING, NUMBER INT32, PRIMARY KEY(ID)); +---- ok +-STATEMENT CREATE REL TABLE T(FROM A TO A); +---- ok +-STATEMENT CREATE ({AGE: 1, NAME: "Alice", NUMBER: 1})-[]-({AGE: 2, NAME: "Alice", NUMBER: 2}); +---- ok +-STATEMENT CREATE ({AGE: 0, NAME:"Alice", NUMBER:3}); +---- ok +-STATEMENT MATCH (a) RETURN range(a.AGE, 3); +---- 3 +[1,2,3] +[2,3] +[0,1,2,3] +-STATEMENT MATCH(a) RETURN range(1, a.AGE, 1); +---- 3 +[1] +[1,2] +[] +-STATEMENT MATCH(a) RETURN range(a.ID, a.ID, 1); +---- 3 +[0] +[1] +[2] +-STATEMENT MATCH (a)--() RETURN range(a.AGE, a.AGE, a.AGE); +---- 2 +[1] +[2] +-STATEMENT MATCH (a)--() RETURN range(1, 5, a.AGE); +---- 2 +[1,2,3,4,5] +[1,3,5] +-STATEMENT MATCH (a) RETURN range(a.NUMBER, a.AGE, -1); +---- 3 +[1] +[2] +[3,2,1,0] +-STATEMENT MATCH (a) RETURN range(-4, a.AGE, a.NUMBER); +---- 3 +[-4,-3,-2,-1,0,1] +[-4,-2,0,2] +[-4,-1] +-STATEMENT MATCH (b)--() RETURN range(b.AGE, 4, b.NUMBER); +---- 2 +[1,2,3,4] +[2,4] +-STATEMENT MATCH (a) RETURN range(a.NUMBER, 5, a.AGE); +---- error +Runtime exception: Step of range cannot be 0. +-STATEMENT MATCH (a) RETURN range(a.NAME, 4, 1); +---- error +Binder exception: Cannot match a built-in function for given function RANGE(STRING,INT64,INT64). Supported inputs are +(INT64,INT64) -> VAR_LIST +(INT64,INT64,INT64) -> VAR_LIST +(INT32,INT32) -> VAR_LIST +(INT32,INT32,INT32) -> VAR_LIST +(INT16,INT16) -> VAR_LIST +(INT16,INT16,INT16) -> VAR_LIST +(INT8,INT8) -> VAR_LIST +(INT8,INT8,INT8) -> VAR_LIST +(SERIAL,SERIAL) -> VAR_LIST +(SERIAL,SERIAL,SERIAL) -> VAR_LIST +-STATEMENT RETURN range(4, 3); +---- 1 +[] +-STATEMENT RETURN range(0, 0); +---- 1 +[0] +-STATEMENT RETURN range(1, 5); +---- 1 +[1,2,3,4,5] +-STATEMENT RETURN range(To_INT32(-3), TO_INT32(0)); +---- 1 +[-3,-2,-1,0] +-STATEMENT RETURN range(To_INT16(-3), TO_INT16(0)); +---- 1 +[-3,-2,-1,0] +-STATEMENT RETURN range(To_INT8(-3), TO_INT16(0)); +---- 1 +[-3,-2,-1,0] +-STATEMENT RETURN range(To_INT8(-3), TO_INT8(0)); +---- 1 +[-3,-2,-1,0] +-STATEMENT RETURN range(4, 4, 2); +---- 1 +[4] +-STATEMENT RETURN range(4, 4, -2); +---- 1 +[4] +-STATEMENT RETURN range(5, 1, -1); +---- 1 +[5,4,3,2,1] +-STATEMENT RETURN range(5, 1, -2); +---- 1 +[5,3,1] +-STATEMENT RETURN range(5, 1, -3); +---- 1 +[5,2] +-STATEMENT RETURN range(6, 2, 2); +---- 1 +[] +-STATEMENT RETURN range(3, 6, -5); +---- 1 +[] +-STATEMENT RETURN range(3, 6, -1); +---- 1 +[] +-STATEMENT RETURN range(0, -1, 9223372036854775807); +---- 1 +[] +-STATEMENT RETURN range(TO_INT32(3), 8, 2); +---- 1 +[3,5,7] +-STATEMENT RETURN range(TO_INT8(0), TO_INT16(-9), TO_INT8(-2)); +---- 1 +[0,-2,-4,-6,-8] +-STATEMENT RETURN range(TO_INT16(5), TO_INT16(6), TO_INT16(1)); +---- 1 +[5,6] +-STATEMENT RETURN range(TO_INT32(5), 25, TO_INT32(10)); +---- 1 +[5,15,25] +-STATEMENT RETURN range(TO_INT32(5), TO_INT32(26), TO_INT32(10)); +---- 1 +[5,15,25] +-STATEMENT RETURN range(TO_INT8(-128), TO_INT8(127), TO_INT8(127)); +---- 1 +[-128,-1,126] +-STATEMENT RETURN range(3, 4, 0); +---- error +Runtime exception: Step of range cannot be 0.