From 09f0cc7b1be80e2de282d92b39e7182da47f001c Mon Sep 17 00:00:00 2001 From: xiyang Date: Fri, 19 May 2023 11:25:26 -0400 Subject: [PATCH] Push recursive join length into operator --- src/binder/bind/bind_copy.cpp | 18 --- src/binder/bind/bind_graph_pattern.cpp | 4 + src/binder/bind/bind_projection_clause.cpp | 4 +- .../bind_function_expression.cpp | 153 ++++++++++-------- src/common/types/types.cpp | 3 +- src/function/built_in_vector_operations.cpp | 10 -- src/function/vector_list_operation.cpp | 3 - src/include/binder/binder.h | 3 - .../binder/expression/rel_expression.h | 8 + src/include/binder/expression_binder.h | 14 +- .../function/built_in_vector_operations.h | 2 - .../projection_push_down_optimizer.h | 8 +- .../logical_recursive_extend.h | 16 +- .../logical_operator/recursive_join_type.h | 14 ++ .../operator/recursive_extend/path_scanner.h | 57 ++++--- .../recursive_extend/recursive_join.h | 22 +-- .../projection_push_down_optimizer.cpp | 17 +- .../operator/logical_recursive_extend.cpp | 16 +- src/planner/projection_planner.cpp | 13 +- src/processor/mapper/map_extend.cpp | 18 ++- .../recursive_extend/path_scanner.cpp | 40 ++--- .../recursive_extend/recursive_join.cpp | 12 +- .../shortest_path_recursive_join.cpp | 9 +- .../variable_length_recursive_join.cpp | 10 +- test/optimizer/optimizer_test.cpp | 10 ++ test/test_files/shortest_path/bfs_sssp.test | 6 +- .../shortest_path/bfs_sssp_large.test | 4 +- 27 files changed, 290 insertions(+), 204 deletions(-) create mode 100644 src/include/planner/logical_plan/logical_operator/recursive_join_type.h diff --git a/src/binder/bind/bind_copy.cpp b/src/binder/bind/bind_copy.cpp index c630929752..99e6c15f1b 100644 --- a/src/binder/bind/bind_copy.cpp +++ b/src/binder/bind/bind_copy.cpp @@ -57,24 +57,6 @@ std::vector Binder::bindFilePaths(const std::vector& f return boundFilePaths; } -std::unordered_map Binder::bindPropertyToNpyMap( - common::table_id_t tableID, const std::vector& filePaths) { - auto catalogContent = catalog.getReadOnlyVersion(); - auto tableSchema = catalogContent->getTableSchema(tableID); - if (tableSchema->properties.size() != filePaths.size()) { - throw BinderException(StringUtils::string_format( - "Number of npy files is not equal to number of properties in table {}.", - tableSchema->tableName)); - } - std::unordered_map propertyIDToNpyMap; - for (int i = 0; i < filePaths.size(); i++) { - auto& filePath = filePaths[i]; - auto& propertyID = tableSchema->properties[i].propertyID; - propertyIDToNpyMap[propertyID] = filePath; - } - return propertyIDToNpyMap; -} - CSVReaderConfig Binder::bindParsingOptions( const std::unordered_map>* parsingOptions) { CSVReaderConfig csvReaderConfig; diff --git a/src/binder/bind/bind_graph_pattern.cpp b/src/binder/bind/bind_graph_pattern.cpp index 298c48b045..0fd3ef5b37 100644 --- a/src/binder/bind/bind_graph_pattern.cpp +++ b/src/binder/bind/bind_graph_pattern.cpp @@ -163,6 +163,10 @@ void Binder::bindQueryRel(const RelPattern& relPattern, getUniqueExpressionName(parsedName), parsedName, tableIDs, srcNode, dstNode, relPattern.getDirection() != BOTH, relPattern.getRelType(), lowerBound, upperBound); queryRel->setAlias(parsedName); + if (isVariableLength) { + queryRel->setInternalLengthExpression( + expressionBinder.createInternalLengthExpression(*queryRel)); + } // resolve properties associate with rel table std::vector relTableSchemas; for (auto tableID : tableIDs) { diff --git a/src/binder/bind/bind_projection_clause.cpp b/src/binder/bind/bind_projection_clause.cpp index 982431bf41..cdcc37ff3f 100644 --- a/src/binder/bind/bind_projection_clause.cpp +++ b/src/binder/bind/bind_projection_clause.cpp @@ -80,7 +80,7 @@ expression_vector Binder::rewriteNodeExpression(const kuzu::binder::Expression& expression_vector result; auto& node = (NodeExpression&)expression; result.push_back(node.getInternalIDProperty()); - result.push_back(expressionBinder.bindNodeLabelFunction(node)); + result.push_back(expressionBinder.bindLabelFunction(node)); for (auto& property : node.getPropertyExpressions()) { result.push_back(property->copy()); } @@ -92,7 +92,7 @@ expression_vector Binder::rewriteRelExpression(const Expression& expression) { auto& rel = (RelExpression&)expression; result.push_back(rel.getSrcNode()->getInternalIDProperty()); result.push_back(rel.getDstNode()->getInternalIDProperty()); - result.push_back(expressionBinder.bindRelLabelFunction(rel)); + result.push_back(expressionBinder.bindLabelFunction(rel)); for (auto& property : rel.getPropertyExpressions()) { result.push_back(property->copy()); } diff --git a/src/binder/bind_expression/bind_function_expression.cpp b/src/binder/bind_expression/bind_function_expression.cpp index 09a5ca459d..ff6bb1d108 100644 --- a/src/binder/bind_expression/bind_function_expression.cpp +++ b/src/binder/bind_expression/bind_function_expression.cpp @@ -17,11 +17,9 @@ std::shared_ptr ExpressionBinder::bindFunctionExpression( auto& parsedFunctionExpression = (ParsedFunctionExpression&)parsedExpression; auto functionName = parsedFunctionExpression.getFunctionName(); StringUtils::toUpper(functionName); - // check for special function binding - if (functionName == ID_FUNC_NAME) { - return bindInternalIDExpression(parsedExpression); - } else if (functionName == LABEL_FUNC_NAME) { - return bindLabelFunction(parsedExpression); + auto result = rewriteFunctionExpression(parsedExpression, functionName); + if (result != nullptr) { + return result; } auto functionType = binder->catalog.getFunctionType(functionName); if (functionType == FUNCTION) { @@ -123,23 +121,29 @@ std::shared_ptr ExpressionBinder::staticEvaluate( return createLiteralExpression(std::move(value)); } -std::shared_ptr ExpressionBinder::bindInternalIDExpression( - const ParsedExpression& parsedExpression) { - auto child = bindExpression(*parsedExpression.getChild(0)); - validateExpectedDataType( - *child, std::unordered_set{LogicalTypeID::NODE, LogicalTypeID::REL}); - return bindInternalIDExpression(*child); -} - -std::shared_ptr ExpressionBinder::bindInternalIDExpression( - const Expression& expression) { - if (expression.dataType.getLogicalTypeID() == LogicalTypeID::NODE) { - auto& node = (NodeExpression&)expression; - return node.getInternalIDProperty(); - } else { - assert(expression.dataType.getLogicalTypeID() == LogicalTypeID::REL); - return bindRelPropertyExpression(expression, INTERNAL_ID_SUFFIX); - } +// Function rewriting happens when we need to expose internal property access through function so +// that it becomes read-only or the function involves catalog information. Currently we write +// Before | After +// ID(a) | a._id +// LABEL(a) | LIST_EXTRACT(offset(a), [table names from catalog]) +// LENGTH(e) | e._length +std::shared_ptr ExpressionBinder::rewriteFunctionExpression( + const parser::ParsedExpression& parsedExpression, const std::string& functionName) { + if (functionName == ID_FUNC_NAME) { + auto child = bindExpression(*parsedExpression.getChild(0)); + validateExpectedDataType( + *child, std::unordered_set{LogicalTypeID::NODE, LogicalTypeID::REL}); + return bindInternalIDExpression(*child); + } else if (functionName == LABEL_FUNC_NAME) { + auto child = bindExpression(*parsedExpression.getChild(0)); + validateExpectedDataType( + *child, std::unordered_set{LogicalTypeID::NODE, LogicalTypeID::REL}); + return bindLabelFunction(*child); + } else if (functionName == LENGTH_FUNC_NAME) { + auto child = bindExpression(*parsedExpression.getChild(0)); + return bindRecursiveJoinLengthFunction(*child); + } + return nullptr; } std::unique_ptr ExpressionBinder::createInternalNodeIDExpression( @@ -149,20 +153,22 @@ std::unique_ptr ExpressionBinder::createInternalNodeIDExpression( for (auto tableID : node.getTableIDs()) { propertyIDPerTable.insert({tableID, INVALID_PROPERTY_ID}); } - auto result = std::make_unique(LogicalType(LogicalTypeID::INTERNAL_ID), + return std::make_unique(LogicalType(LogicalTypeID::INTERNAL_ID), INTERNAL_ID_SUFFIX, node, std::move(propertyIDPerTable), false /* isPrimaryKey */); - return result; } -std::shared_ptr ExpressionBinder::bindLabelFunction( - const ParsedExpression& parsedExpression) { - // bind child node - auto child = bindExpression(*parsedExpression.getChild(0)); - if (child->dataType.getLogicalTypeID() == common::LogicalTypeID::NODE) { - return bindNodeLabelFunction(*child); - } else { - assert(child->dataType.getLogicalTypeID() == common::LogicalTypeID::REL); - return bindRelLabelFunction(*child); +std::shared_ptr ExpressionBinder::bindInternalIDExpression( + const Expression& expression) { + switch (expression.getDataType().getLogicalTypeID()) { + case common::LogicalTypeID::NODE: { + auto& node = (NodeExpression&)expression; + return node.getInternalIDProperty(); + } + case common::LogicalTypeID::REL: { + return bindRelPropertyExpression(expression, INTERNAL_ID_SUFFIX); + } + default: + throw NotImplementedException("ExpressionBinder::bindInternalIDExpression"); } } @@ -183,22 +189,41 @@ static std::vector> populateLabelValues( return labels; } -std::shared_ptr ExpressionBinder::bindNodeLabelFunction(const Expression& expression) { +std::shared_ptr ExpressionBinder::bindLabelFunction(const Expression& expression) { auto catalogContent = binder->catalog.getReadOnlyVersion(); - auto& node = (NodeExpression&)expression; - if (!node.isMultiLabeled()) { - auto labelName = catalogContent->getTableName(node.getSingleTableID()); - return createLiteralExpression(std::make_unique(labelName)); - } - auto nodeTableIDs = catalogContent->getNodeTableIDs(); + auto varListTypeInfo = + std::make_unique(std::make_unique(LogicalTypeID::STRING)); + auto listType = + std::make_unique(LogicalTypeID::VAR_LIST, std::move(varListTypeInfo)); expression_vector children; - children.push_back(node.getInternalIDProperty()); - auto labelsValue = - std::make_unique(LogicalType(LogicalTypeID::VAR_LIST, - std::make_unique( - std::make_unique(LogicalTypeID::STRING))), - populateLabelValues(nodeTableIDs, *catalogContent)); - children.push_back(createLiteralExpression(std::move(labelsValue))); + switch (expression.getDataType().getLogicalTypeID()) { + case common::LogicalTypeID::NODE: { + auto& node = (NodeExpression&)expression; + if (!node.isMultiLabeled()) { + auto labelName = catalogContent->getTableName(node.getSingleTableID()); + return createLiteralExpression(std::make_unique(labelName)); + } + auto nodeTableIDs = catalogContent->getNodeTableIDs(); + children.push_back(node.getInternalIDProperty()); + auto labelsValue = + std::make_unique(*listType, populateLabelValues(nodeTableIDs, *catalogContent)); + children.push_back(createLiteralExpression(std::move(labelsValue))); + } break; + case common::LogicalTypeID::REL: { + auto& rel = (RelExpression&)expression; + if (!rel.isMultiLabeled()) { + auto labelName = catalogContent->getTableName(rel.getSingleTableID()); + return createLiteralExpression(std::make_unique(labelName)); + } + auto relTableIDs = catalogContent->getRelTableIDs(); + children.push_back(rel.getInternalIDProperty()); + auto labelsValue = + std::make_unique(*listType, populateLabelValues(relTableIDs, *catalogContent)); + children.push_back(createLiteralExpression(std::move(labelsValue))); + } break; + default: + throw NotImplementedException("ExpressionBinder::bindLabelFunction"); + } auto execFunc = function::LabelVectorOperation::execFunction; auto bindData = std::make_unique(LogicalType(LogicalTypeID::STRING)); @@ -207,28 +232,22 @@ std::shared_ptr ExpressionBinder::bindNodeLabelFunction(const Expres std::move(children), execFunc, nullptr, uniqueExpressionName); } -std::shared_ptr ExpressionBinder::bindRelLabelFunction(const Expression& expression) { - auto catalogContent = binder->catalog.getReadOnlyVersion(); +std::unique_ptr ExpressionBinder::createInternalLengthExpression( + const Expression& expression) { auto& rel = (RelExpression&)expression; - if (!rel.isMultiLabeled()) { - auto labelName = catalogContent->getTableName(rel.getSingleTableID()); - return createLiteralExpression(std::make_unique(labelName)); + std::unordered_map propertyIDPerTable; + propertyIDPerTable.insert({rel.getSingleTableID(), INVALID_PROPERTY_ID}); + return std::make_unique(LogicalType(common::LogicalTypeID::INT64), + INTERNAL_LENGTH_SUFFIX, rel, std::move(propertyIDPerTable), false /* isPrimaryKey */); +} + +std::shared_ptr ExpressionBinder::bindRecursiveJoinLengthFunction( + const Expression& expression) { + if (expression.getDataType().getLogicalTypeID() != common::LogicalTypeID::RECURSIVE_REL) { + return nullptr; } - auto relTableIDs = catalogContent->getRelTableIDs(); - expression_vector children; - children.push_back(rel.getInternalIDProperty()); - auto labelsValue = - std::make_unique(LogicalType(LogicalTypeID::VAR_LIST, - std::make_unique( - std::make_unique(LogicalTypeID::STRING))), - populateLabelValues(relTableIDs, *catalogContent)); - children.push_back(createLiteralExpression(std::move(labelsValue))); - auto execFunc = function::LabelVectorOperation::execFunction; - auto bindData = - std::make_unique(LogicalType(LogicalTypeID::STRING)); - auto uniqueExpressionName = ScalarFunctionExpression::getUniqueName(LABEL_FUNC_NAME, children); - return make_shared(LABEL_FUNC_NAME, FUNCTION, std::move(bindData), - std::move(children), execFunc, nullptr, uniqueExpressionName); + auto& rel = (RelExpression&)expression; + return rel.getInternalLengthExpression(); } } // namespace binder diff --git a/src/common/types/types.cpp b/src/common/types/types.cpp index db43c9caf3..05dc895c07 100644 --- a/src/common/types/types.cpp +++ b/src/common/types/types.cpp @@ -471,8 +471,7 @@ void LogicalType::setPhysicalType() { physicalType = PhysicalTypeID::STRUCT; } break; default: - throw NotImplementedException{ - "Unsupported LogicalType: " + LogicalTypeUtils::dataTypeToString(typeID) + "."}; + throw NotImplementedException{"LogicalType::setPhysicalType()."}; } } diff --git a/src/function/built_in_vector_operations.cpp b/src/function/built_in_vector_operations.cpp index 65b81ae018..bfaf42a3b8 100644 --- a/src/function/built_in_vector_operations.cpp +++ b/src/function/built_in_vector_operations.cpp @@ -25,7 +25,6 @@ void BuiltInVectorOperations::registerVectorOperations() { registerStringOperations(); registerCastOperations(); registerListOperations(); - registerInternalIDOperation(); registerStructOperation(); // register internal offset operation vectorOperations.insert({OFFSET_FUNC_NAME, OffsetVectorOperation::getDefinitions()}); @@ -454,15 +453,6 @@ void BuiltInVectorOperations::registerListOperations() { {LIST_ANY_VALUE_FUNC_NAME, ListAnyValueVectorOperation::getDefinitions()}); } -void BuiltInVectorOperations::registerInternalIDOperation() { - std::vector> definitions; - definitions.push_back(make_unique(ID_FUNC_NAME, - std::vector{LogicalTypeID::NODE}, LogicalTypeID::INTERNAL_ID, nullptr)); - definitions.push_back(make_unique(ID_FUNC_NAME, - std::vector{LogicalTypeID::REL}, LogicalTypeID::INTERNAL_ID, nullptr)); - vectorOperations.insert({ID_FUNC_NAME, std::move(definitions)}); -} - void BuiltInVectorOperations::registerStructOperation() { vectorOperations.insert({STRUCT_PACK_FUNC_NAME, StructPackVectorOperations::getDefinitions()}); vectorOperations.insert( diff --git a/src/function/vector_list_operation.cpp b/src/function/vector_list_operation.cpp index 34874c599a..c8e0cdad68 100644 --- a/src/function/vector_list_operation.cpp +++ b/src/function/vector_list_operation.cpp @@ -105,9 +105,6 @@ std::vector> ListLenVectorOperation:: result.push_back(std::make_unique(LIST_LEN_FUNC_NAME, std::vector{LogicalTypeID::VAR_LIST}, LogicalTypeID::INT64, execFunc, true /* isVarlength*/)); - result.push_back(std::make_unique(LIST_LEN_FUNC_NAME, - std::vector{LogicalTypeID::RECURSIVE_REL}, LogicalTypeID::INT64, execFunc, - true /* isVarlength*/)); return result; } diff --git a/src/include/binder/binder.h b/src/include/binder/binder.h index 3a0b9e52a4..3c1ac14409 100644 --- a/src/include/binder/binder.h +++ b/src/include/binder/binder.h @@ -60,9 +60,6 @@ class Binder { std::vector bindFilePaths(const std::vector& filePaths); - std::unordered_map bindPropertyToNpyMap( - common::table_id_t tableId, const std::vector& filePaths); - common::CSVReaderConfig bindParsingOptions( const std::unordered_map>* parsingOptions); diff --git a/src/include/binder/expression/rel_expression.h b/src/include/binder/expression/rel_expression.h index eef1b3c29a..0f3128c615 100644 --- a/src/include/binder/expression/rel_expression.h +++ b/src/include/binder/expression/rel_expression.h @@ -39,6 +39,13 @@ class RelExpression : public NodeOrRelExpression { inline std::shared_ptr getInternalIDProperty() const { return getPropertyExpression(common::INTERNAL_ID_SUFFIX); } + inline void setInternalLengthExpression(std::unique_ptr expression) { + internalLengthExpression = std::move(expression); + } + inline std::shared_ptr getInternalLengthExpression() const { + assert(internalLengthExpression != nullptr); + return internalLengthExpression->copy(); + } private: std::shared_ptr srcNode; @@ -47,6 +54,7 @@ class RelExpression : public NodeOrRelExpression { common::QueryRelType relType; uint64_t lowerBound; uint64_t upperBound; + std::unique_ptr internalLengthExpression; }; } // namespace binder diff --git a/src/include/binder/expression_binder.h b/src/include/binder/expression_binder.h index ebfd813a01..c746f1f793 100644 --- a/src/include/binder/expression_binder.h +++ b/src/include/binder/expression_binder.h @@ -46,6 +46,7 @@ class ExpressionBinder { std::shared_ptr bindFunctionExpression( const parser::ParsedExpression& parsedExpression); + std::shared_ptr bindScalarFunctionExpression( const parser::ParsedExpression& parsedExpression, const std::string& functionName); std::shared_ptr bindScalarFunctionExpression( @@ -53,17 +54,16 @@ class ExpressionBinder { std::shared_ptr bindAggregateFunctionExpression( const parser::ParsedExpression& parsedExpression, const std::string& functionName, bool isDistinct); - std::shared_ptr staticEvaluate( const std::string& functionName, const expression_vector& children); - std::shared_ptr bindInternalIDExpression( - const parser::ParsedExpression& parsedExpression); - std::shared_ptr bindInternalIDExpression(const Expression& expression); + std::shared_ptr rewriteFunctionExpression( + const parser::ParsedExpression& parsedExpression, const std::string& functionName); std::unique_ptr createInternalNodeIDExpression(const Expression& node); - std::shared_ptr bindLabelFunction(const parser::ParsedExpression& parsedExpression); - std::shared_ptr bindNodeLabelFunction(const Expression& expression); - std::shared_ptr bindRelLabelFunction(const Expression& expression); + std::shared_ptr bindInternalIDExpression(const Expression& expression); + std::shared_ptr bindLabelFunction(const Expression& expression); + std::unique_ptr createInternalLengthExpression(const Expression& expression); + std::shared_ptr bindRecursiveJoinLengthFunction(const Expression& expression); std::shared_ptr bindParameterExpression( const parser::ParsedExpression& parsedExpression); diff --git a/src/include/function/built_in_vector_operations.h b/src/include/function/built_in_vector_operations.h index 951e741a36..5b9056a21a 100644 --- a/src/include/function/built_in_vector_operations.h +++ b/src/include/function/built_in_vector_operations.h @@ -67,8 +67,6 @@ class BuiltInVectorOperations { void registerStringOperations(); void registerCastOperations(); void registerListOperations(); - void registerInternalIDOperation(); - void registerInternalOffsetOperation(); void registerStructOperation(); private: diff --git a/src/include/optimizer/projection_push_down_optimizer.h b/src/include/optimizer/projection_push_down_optimizer.h index 262376571d..b6b6f538bf 100644 --- a/src/include/optimizer/projection_push_down_optimizer.h +++ b/src/include/optimizer/projection_push_down_optimizer.h @@ -8,10 +8,10 @@ namespace optimizer { // ProjectionPushDownOptimizer implements the logic to avoid materializing unnecessary properties // for hash join build. -// Note the optimization is for properties only but not for general expressions. This is because -// it's hard to figure out what expression is in-use, e.g. COUNT(a.age) + 1, it could be either the -// whole expression was evaluated in a WITH clause or only COUNT(a.age) was evaluated or only a.age -// is evaluate. For simplicity, we only consider the push down for property. +// Note the optimization is for properties & variables only but not for general expressions. This is +// because it's hard to figure out what expression is in-use, e.g. COUNT(a.age) + 1, it could be +// either the whole expression was evaluated in a WITH clause or only COUNT(a.age) was evaluated or +// only a.age is evaluate. For simplicity, we only consider the push down for property. class ProjectionPushDownOptimizer : public LogicalOperatorVisitor { public: void rewrite(planner::LogicalPlan* plan); diff --git a/src/include/planner/logical_plan/logical_operator/logical_recursive_extend.h b/src/include/planner/logical_plan/logical_operator/logical_recursive_extend.h index 3d944715c7..a7ace84217 100644 --- a/src/include/planner/logical_plan/logical_operator/logical_recursive_extend.h +++ b/src/include/planner/logical_plan/logical_operator/logical_recursive_extend.h @@ -1,6 +1,7 @@ #pragma once #include "base_logical_extend.h" +#include "recursive_join_type.h" namespace kuzu { namespace planner { @@ -11,29 +12,30 @@ class LogicalRecursiveExtend : public BaseLogicalExtend { std::shared_ptr nbrNode, std::shared_ptr rel, common::ExtendDirection direction, std::shared_ptr child) : LogicalRecursiveExtend{std::move(boundNode), std::move(nbrNode), std::move(rel), - direction, true /* trackPath */, std::move(child)} {} + direction, RecursiveJoinType::TRACK_PATH, std::move(child)} {} LogicalRecursiveExtend(std::shared_ptr boundNode, std::shared_ptr nbrNode, std::shared_ptr rel, - common::ExtendDirection direction, bool trackPath_, std::shared_ptr child) + common::ExtendDirection direction, RecursiveJoinType joinType, + std::shared_ptr child) : BaseLogicalExtend{LogicalOperatorType::RECURSIVE_EXTEND, std::move(boundNode), std::move(nbrNode), std::move(rel), direction, std::move(child)}, - trackPath_{trackPath_} {} + joinType{joinType} {} f_group_pos_set getGroupsPosToFlatten() override; void computeFactorizedSchema() override; void computeFlatSchema() override; - inline void disableTrackPath() { trackPath_ = false; } - inline bool trackPath() { return trackPath_; } + inline void setJoinType(RecursiveJoinType joinType_) { joinType = joinType_; } + inline RecursiveJoinType getJoinType() const { return joinType; } inline std::unique_ptr copy() override { return std::make_unique( - boundNode, nbrNode, rel, direction, trackPath_, children[0]->copy()); + boundNode, nbrNode, rel, direction, joinType, children[0]->copy()); } private: - bool trackPath_; + RecursiveJoinType joinType; }; } // namespace planner diff --git a/src/include/planner/logical_plan/logical_operator/recursive_join_type.h b/src/include/planner/logical_plan/logical_operator/recursive_join_type.h new file mode 100644 index 0000000000..17d0414557 --- /dev/null +++ b/src/include/planner/logical_plan/logical_operator/recursive_join_type.h @@ -0,0 +1,14 @@ +#pragma once + +#include + +namespace kuzu { +namespace planner { + +enum class RecursiveJoinType : uint8_t { + TRACK_NONE = 0, + TRACK_PATH = 1, +}; + +} // namespace planner +} // namespace kuzu diff --git a/src/include/processor/operator/recursive_extend/path_scanner.h b/src/include/processor/operator/recursive_extend/path_scanner.h index 985f6458c7..1d62e5b340 100644 --- a/src/include/processor/operator/recursive_extend/path_scanner.h +++ b/src/include/processor/operator/recursive_extend/path_scanner.h @@ -27,20 +27,28 @@ struct BaseFrontierScanner { virtual ~BaseFrontierScanner() = default; size_t scan(common::table_id_t tableID, common::ValueVector* pathVector, - common::ValueVector* dstNodeIDVector, common::sel_t& offsetVectorPos, - common::sel_t& dataVectorPos); + common::ValueVector* dstNodeIDVector, common::ValueVector* pathLengthVector, + common::sel_t& offsetVectorPos, common::sel_t& dataVectorPos); void resetState(const BaseBFSMorsel& bfsMorsel); -private: +protected: virtual void initScanFromDstOffset() = 0; virtual void scanFromDstOffset(common::table_id_t tableID, common::ValueVector* pathVector, - common::ValueVector* dstNodeIDVector, common::sel_t& offsetVectorPos, - common::sel_t& dataVectorPos) = 0; + common::ValueVector* dstNodeIDVector, common::ValueVector* pathLengthVector, + common::sel_t& offsetVectorPos, common::sel_t& dataVectorPos) = 0; + + inline void writeDstNodeOffsetAndLength(common::ValueVector* dstNodeIDVector, + common::ValueVector* pathLengthVector, common::sel_t& offsetVectorPos, + common::table_id_t tableID) { + dstNodeIDVector->setValue( + offsetVectorPos, common::nodeID_t{currentDstOffset, tableID}); + pathLengthVector->setValue(offsetVectorPos, (int64_t)k); + } }; /* - * DstNodeScanner scans dst node offset only. + * DstNodeScanner scans dst node offset & length of path. */ struct DstNodeScanner : public BaseFrontierScanner { DstNodeScanner(const std::unordered_set& targetDstOffsets, size_t k) @@ -49,19 +57,20 @@ struct DstNodeScanner : public BaseFrontierScanner { private: inline void initScanFromDstOffset() final {} inline void scanFromDstOffset(common::table_id_t tableID, common::ValueVector* pathVector, - common::ValueVector* dstNodeIDVector, common::sel_t& offsetVectorPos, - common::sel_t& dataVectorPos) final { + common::ValueVector* dstNodeIDVector, common::ValueVector* pathLengthVector, + common::sel_t& offsetVectorPos, common::sel_t& dataVectorPos) final { assert(offsetVectorPos < common::DEFAULT_VECTOR_CAPACITY); - dstNodeIDVector->setValue( - offsetVectorPos++, common::nodeID_t{currentDstOffset, tableID}); + writeDstNodeOffsetAndLength(dstNodeIDVector, pathLengthVector, offsetVectorPos, tableID); + offsetVectorPos++; } }; /* - * PathScanner scans all paths of a fixed length k. This is done by starting - * a backward traversals from only the destination nodes in the k'th frontier (assuming the first - * frontier has index 0) over the backwards edges stored between the frontiers that was used to - * store the data related to the BFS that was computed in the RecursiveJoin operator. + * PathScanner scans all paths of a fixed length k (also dst node offsets & length of path). This is + * done by starting a backward traversals from only the destination nodes in the k'th frontier + * (assuming the first frontier has index 0) over the backwards edges stored between the frontiers + * that was used to store the data related to the BFS that was computed in the RecursiveJoin + * operator. */ struct PathScanner : public BaseFrontierScanner { using nbrs_t = std::vector*; @@ -79,20 +88,20 @@ struct PathScanner : public BaseFrontierScanner { inline void initScanFromDstOffset() final { initDfs(currentDstOffset, k); } // Scan current stacks until exhausted or vector is filled up. void scanFromDstOffset(common::table_id_t tableID, common::ValueVector* pathVector, - common::ValueVector* dstNodeIDVector, common::sel_t& offsetVectorPos, - common::sel_t& dataVectorPos) final; + common::ValueVector* dstNodeIDVector, common::ValueVector* pathLengthVector, + common::sel_t& offsetVectorPos, common::sel_t& dataVectorPos) final; // Initialize stacks for given offset. void initDfs(common::offset_t offset, size_t currentDepth); void writePathToVector(common::table_id_t tableID, common::ValueVector* pathVector, - common::ValueVector* dstNodeIDVector, common::sel_t& offsetVectorPos, - common::sel_t& dataVectorPos); + common::ValueVector* dstNodeIDVector, common::ValueVector* pathLengthVector, + common::sel_t& offsetVectorPos, common::sel_t& dataVectorPos); }; /* - * DstNodeWithMultiplicityScanner scans dst node offset and repeat it for multiplicity times in - * value vector. + * DstNodeWithMultiplicityScanner scans dst node offset & length of path and repeat it for + * multiplicity times in value vector. */ struct DstNodeWithMultiplicityScanner : public BaseFrontierScanner { DstNodeWithMultiplicityScanner( @@ -102,8 +111,8 @@ struct DstNodeWithMultiplicityScanner : public BaseFrontierScanner { private: inline void initScanFromDstOffset() final {} void scanFromDstOffset(common::table_id_t tableID, common::ValueVector* pathVector, - common::ValueVector* dstNodeIDVector, common::sel_t& offsetVectorPos, - common::sel_t& dataVectorPos) final; + common::ValueVector* dstNodeIDVector, common::ValueVector* pathLengthVector, + common::sel_t& offsetVectorPos, common::sel_t& dataVectorPos) final; }; /* @@ -124,8 +133,8 @@ struct FrontiersScanner { : scanners{std::move(scanners)}, cursor{0} {} void scan(common::table_id_t tableID, common::ValueVector* pathVector, - common::ValueVector* dstNodeIDVector, common::sel_t& offsetVectorPos, - common::sel_t& dataVectorPos); + common::ValueVector* dstNodeIDVector, common::ValueVector* pathLengthVector, + common::sel_t& offsetVectorPos, common::sel_t& dataVectorPos); inline void resetState(const BaseBFSMorsel& bfsMorsel) { cursor = 0; diff --git a/src/include/processor/operator/recursive_extend/recursive_join.h b/src/include/processor/operator/recursive_extend/recursive_join.h index cbd6faf532..4ee1becc7a 100644 --- a/src/include/processor/operator/recursive_extend/recursive_join.h +++ b/src/include/processor/operator/recursive_extend/recursive_join.h @@ -2,6 +2,7 @@ #include "bfs_state.h" #include "path_scanner.h" +#include "planner/logical_plan/logical_operator/recursive_join_type.h" #include "processor/operator/physical_operator.h" #include "processor/operator/result_collector.h" #include "storage/store/node_table.h" @@ -57,31 +58,31 @@ struct RecursiveJoinDataInfo { DataPos srcNodePos; // Join output info. DataPos dstNodePos; + DataPos pathLengthPos; // Recursive join info. DataPos tmpDstNodePos; // Path info - bool trackPath; + planner::RecursiveJoinType joinType; DataPos pathPos; RecursiveJoinDataInfo(std::vector vectorsToScanPos, std::vector colIndicesToScan, const DataPos& srcNodePos, - const DataPos& dstNodePos, const DataPos& tmpDstNodePos, bool trackPath) + const DataPos& dstNodePos, const DataPos& pathLengthPos, const DataPos& tmpDstNodePos, + planner::RecursiveJoinType joinType) : RecursiveJoinDataInfo{std::move(vectorsToScanPos), std::move(colIndicesToScan), - srcNodePos, dstNodePos, tmpDstNodePos, trackPath, DataPos()} { - assert(trackPath == false); - } + srcNodePos, dstNodePos, pathLengthPos, tmpDstNodePos, joinType, DataPos()} {} RecursiveJoinDataInfo(std::vector vectorsToScanPos, std::vector colIndicesToScan, const DataPos& srcNodePos, - const DataPos& dstNodePos, const DataPos& tmpDstNodePos, bool trackPath, - const DataPos& pathPos) + const DataPos& dstNodePos, const DataPos& pathLengthPos, const DataPos& tmpDstNodePos, + planner::RecursiveJoinType joinType, const DataPos& pathPos) : vectorsToScanPos{std::move(vectorsToScanPos)}, colIndicesToScan{std::move( colIndicesToScan)}, - srcNodePos{srcNodePos}, dstNodePos{dstNodePos}, - tmpDstNodePos{tmpDstNodePos}, trackPath{trackPath}, pathPos{pathPos} {} + srcNodePos{srcNodePos}, dstNodePos{dstNodePos}, pathLengthPos{pathLengthPos}, + tmpDstNodePos{tmpDstNodePos}, joinType{joinType}, pathPos{pathPos} {} inline std::unique_ptr copy() { return std::make_unique(vectorsToScanPos, colIndicesToScan, - srcNodePos, dstNodePos, tmpDstNodePos, trackPath, pathPos); + srcNodePos, dstNodePos, pathLengthPos, tmpDstNodePos, joinType, pathPos); } }; @@ -147,6 +148,7 @@ class BaseRecursiveJoin : public PhysicalOperator { std::vector vectorsToScan; common::ValueVector* srcNodeIDVector; common::ValueVector* dstNodeIDVector; + common::ValueVector* pathLengthVector; common::ValueVector* pathVector; common::ValueVector* tmpDstNodeIDVector; // temporary recursive join result. diff --git a/src/optimizer/projection_push_down_optimizer.cpp b/src/optimizer/projection_push_down_optimizer.cpp index 75763bf2e3..b75739ca47 100644 --- a/src/optimizer/projection_push_down_optimizer.cpp +++ b/src/optimizer/projection_push_down_optimizer.cpp @@ -40,7 +40,7 @@ void ProjectionPushDownOptimizer::visitRecursiveExtend(LogicalOperator* op) { auto recursiveExtend = (LogicalRecursiveExtend*)op; auto rel = recursiveExtend->getRel(); if (!variablesInUse.contains(rel)) { - recursiveExtend->disableTrackPath(); + recursiveExtend->setJoinType(planner::RecursiveJoinType::TRACK_NONE); } } @@ -195,6 +195,7 @@ void ProjectionPushDownOptimizer::visitSetRelProperty(planner::LogicalOperator* } } +// See comments above this class for how to collect expressions in use. void ProjectionPushDownOptimizer::collectExpressionsInUse( std::shared_ptr expression) { if (expression->expressionType == common::VARIABLE) { @@ -214,8 +215,18 @@ binder::expression_vector ProjectionPushDownOptimizer::pruneExpressions( const binder::expression_vector& expressions) { expression_set expressionsAfterPruning; for (auto& expression : expressions) { - if (expression->expressionType != common::PROPERTY || - propertiesInUse.contains(expression)) { + switch (expression->expressionType) { + case common::VARIABLE: { + if (variablesInUse.contains(expression)) { + expressionsAfterPruning.insert(expression); + } + } break; + case common::PROPERTY: { + if (propertiesInUse.contains(expression)) { + expressionsAfterPruning.insert(expression); + } + } break; + default: // We don't track other expression types so always assume they will be in use. expressionsAfterPruning.insert(expression); } } diff --git a/src/planner/operator/logical_recursive_extend.cpp b/src/planner/operator/logical_recursive_extend.cpp index bc9068ff4a..2991b10b09 100644 --- a/src/planner/operator/logical_recursive_extend.cpp +++ b/src/planner/operator/logical_recursive_extend.cpp @@ -18,8 +18,13 @@ f_group_pos_set LogicalRecursiveExtend::getGroupsPosToFlatten() { void LogicalRecursiveExtend::computeFlatSchema() { copyChildSchema(0); schema->insertToGroupAndScope(nbrNode->getInternalIDProperty(), 0); - if (trackPath()) { + schema->insertToGroupAndScope(rel->getInternalLengthExpression(), 0); + switch (joinType) { + case RecursiveJoinType::TRACK_PATH: { schema->insertToGroupAndScope(rel, 0); + } break; + default: + break; } } @@ -29,7 +34,14 @@ void LogicalRecursiveExtend::computeFactorizedSchema() { SinkOperatorUtil::recomputeSchema(*childSchema, childSchema->getExpressionsInScope(), *schema); auto nbrGroupPos = schema->createGroup(); schema->insertToGroupAndScope(nbrNode->getInternalIDProperty(), nbrGroupPos); - schema->insertToGroupAndScope(rel, nbrGroupPos); + schema->insertToGroupAndScope(rel->getInternalLengthExpression(), nbrGroupPos); + switch (joinType) { + case RecursiveJoinType::TRACK_PATH: { + schema->insertToGroupAndScope(rel, nbrGroupPos); + } break; + default: + break; + } } } // namespace planner diff --git a/src/planner/projection_planner.cpp b/src/planner/projection_planner.cpp index 9d6c4993a3..ae33a73f18 100644 --- a/src/planner/projection_planner.cpp +++ b/src/planner/projection_planner.cpp @@ -188,12 +188,19 @@ expression_vector ProjectionPlanner::rewriteExpressionsToProject( const expression_vector& expressionsToProject, const Schema& schema) { expression_vector result; for (auto& expression : expressionsToProject) { - if (expression->dataType.getLogicalTypeID() == LogicalTypeID::NODE || - expression->dataType.getLogicalTypeID() == LogicalTypeID::REL) { + switch (expression->getDataType().getLogicalTypeID()) { + case LogicalTypeID::NODE: + case LogicalTypeID::REL: { for (auto& property : rewriteVariableAsAllPropertiesInScope(*expression, schema)) { result.push_back(property); } - } else { + } break; + case LogicalTypeID::RECURSIVE_REL: { + auto& rel = (RelExpression&)*expression; + result.push_back(rel.getInternalLengthExpression()); + result.push_back(expression); + } break; + default: result.push_back(expression); } } diff --git a/src/processor/mapper/map_extend.cpp b/src/processor/mapper/map_extend.cpp index fe1cfd3e9d..7e3aa3572e 100644 --- a/src/processor/mapper/map_extend.cpp +++ b/src/processor/mapper/map_extend.cpp @@ -171,6 +171,8 @@ std::unique_ptr PlanMapper::mapLogicalRecursiveExtendToPhysica DataPos(inSchema->getExpressionPos(*boundNode->getInternalIDProperty())); auto outNodeIDVectorPos = DataPos(outSchema->getExpressionPos(*nbrNode->getInternalIDProperty())); + auto lengthVectorPos = + DataPos(outSchema->getExpressionPos(*rel->getInternalLengthExpression())); auto& relsStore = storageManager.getRelsStore(); auto relTableID = rel->getSingleTableID(); @@ -208,14 +210,20 @@ std::unique_ptr PlanMapper::mapLogicalRecursiveExtendToPhysica getOperatorID(), emptyParamString); } std::unique_ptr dataInfo; - if (extend->trackPath()) { + switch (extend->getJoinType()) { + case planner::RecursiveJoinType::TRACK_PATH: { auto pathVectorPos = DataPos(outSchema->getExpressionPos(*rel)); dataInfo = std::make_unique(outDataPoses, colIndicesToScan, - inNodeIDVectorPos, outNodeIDVectorPos, tmpDstNodePos, true /* trackPath */, - pathVectorPos); - } else { + inNodeIDVectorPos, outNodeIDVectorPos, lengthVectorPos, tmpDstNodePos, + extend->getJoinType(), pathVectorPos); + } break; + case planner::RecursiveJoinType::TRACK_NONE: { dataInfo = std::make_unique(outDataPoses, colIndicesToScan, - inNodeIDVectorPos, outNodeIDVectorPos, tmpDstNodePos, false /* trackPath */); + inNodeIDVectorPos, outNodeIDVectorPos, lengthVectorPos, tmpDstNodePos, + extend->getJoinType()); + } break; + default: + throw common::NotImplementedException("PlanMapper::mapLogicalRecursiveExtendToPhysical"); } auto sharedState = std::make_shared(sharedInputFTable, nodeTable); switch (rel->getRelType()) { diff --git a/src/processor/operator/recursive_extend/path_scanner.cpp b/src/processor/operator/recursive_extend/path_scanner.cpp index 35c94a4a21..80556ee188 100644 --- a/src/processor/operator/recursive_extend/path_scanner.cpp +++ b/src/processor/operator/recursive_extend/path_scanner.cpp @@ -4,8 +4,8 @@ namespace kuzu { namespace processor { size_t BaseFrontierScanner::scan(common::table_id_t tableID, common::ValueVector* pathVector, - common::ValueVector* dstNodeIDVector, common::sel_t& offsetVectorPos, - common::sel_t& dataVectorPos) { + common::ValueVector* dstNodeIDVector, common::ValueVector* pathLengthVector, + common::sel_t& offsetVectorPos, common::sel_t& dataVectorPos) { if (k >= frontiers.size()) { // BFS terminate before current depth. No need to scan. return 0; @@ -14,7 +14,8 @@ size_t BaseFrontierScanner::scan(common::table_id_t tableID, common::ValueVector auto lastFrontier = frontiers[k]; while (true) { if (currentDstOffset != common::INVALID_OFFSET) { - scanFromDstOffset(tableID, pathVector, dstNodeIDVector, offsetVectorPos, dataVectorPos); + scanFromDstOffset(tableID, pathVector, dstNodeIDVector, pathLengthVector, + offsetVectorPos, dataVectorPos); } if (offsetVectorPos == common::DEFAULT_VECTOR_CAPACITY) { break; @@ -44,8 +45,8 @@ void BaseFrontierScanner::resetState(const BaseBFSMorsel& bfsMorsel) { } void PathScanner::scanFromDstOffset(common::table_id_t tableID, common::ValueVector* pathVector, - common::ValueVector* dstNodeIDVector, common::sel_t& offsetVectorPos, - common::sel_t& dataVectorPos) { + common::ValueVector* dstNodeIDVector, common::ValueVector* pathLengthVector, + common::sel_t& offsetVectorPos, common::sel_t& dataVectorPos) { auto level = 0; while (!nbrsStack.empty()) { auto& cursor = cursorStack.top(); @@ -54,8 +55,8 @@ void PathScanner::scanFromDstOffset(common::table_id_t tableID, common::ValueVec auto offset = nbrsStack.top()->at(cursor); path[level] = offset; if (level == 0) { // Found a new nbr at level 0. Found a new path. - writePathToVector( - tableID, pathVector, dstNodeIDVector, offsetVectorPos, dataVectorPos); + writePathToVector(tableID, pathVector, dstNodeIDVector, pathLengthVector, + offsetVectorPos, dataVectorPos); if (offsetVectorPos == common::DEFAULT_VECTOR_CAPACITY) { return; } @@ -86,14 +87,14 @@ void PathScanner::initDfs(common::offset_t offset, size_t currentDepth) { } void PathScanner::writePathToVector(common::table_id_t tableID, common::ValueVector* pathVector, - common::ValueVector* dstNodeIDVector, common::sel_t& offsetVectorPos, - common::sel_t& dataVectorPos) { + common::ValueVector* dstNodeIDVector, common::ValueVector* pathLengthVector, + common::sel_t& offsetVectorPos, common::sel_t& dataVectorPos) { + assert(offsetVectorPos < common::DEFAULT_VECTOR_CAPACITY); auto pathLength = path.size(); auto listEntry = common::ListVector::addList(pathVector, pathLength); pathVector->setValue(offsetVectorPos, listEntry); - dstNodeIDVector->setValue( - offsetVectorPos, common::nodeID_t{path[pathLength - 1], tableID}); - assert(offsetVectorPos < common::DEFAULT_VECTOR_CAPACITY); + assert(path[pathLength - 1] == currentDstOffset); + writeDstNodeOffsetAndLength(dstNodeIDVector, pathLengthVector, offsetVectorPos, tableID); offsetVectorPos++; auto pathDataVector = common::ListVector::getDataVector(pathVector); for (auto i = 0u; i < pathLength; ++i) { @@ -104,21 +105,22 @@ void PathScanner::writePathToVector(common::table_id_t tableID, common::ValueVec void DstNodeWithMultiplicityScanner::scanFromDstOffset(common::table_id_t tableID, common::ValueVector* pathVector, common::ValueVector* dstNodeIDVector, - common::sel_t& offsetVectorPos, common::sel_t& dataVectorPos) { + common::ValueVector* pathLengthVector, common::sel_t& offsetVectorPos, + common::sel_t& dataVectorPos) { auto& multiplicity = frontiers[k]->offsetToMultiplicity.at(currentDstOffset); while (multiplicity > 0 && offsetVectorPos < common::DEFAULT_VECTOR_CAPACITY) { - dstNodeIDVector->setValue( - offsetVectorPos++, common::nodeID_t{currentDstOffset, tableID}); + writeDstNodeOffsetAndLength(dstNodeIDVector, pathLengthVector, offsetVectorPos, tableID); + offsetVectorPos++; multiplicity--; } } void FrontiersScanner::scan(common::table_id_t tableID, common::ValueVector* pathVector, - common::ValueVector* dstNodeIDVector, common::sel_t& offsetVectorPos, - common::sel_t& dataVectorPos) { + common::ValueVector* dstNodeIDVector, common::ValueVector* pathLengthVector, + common::sel_t& offsetVectorPos, common::sel_t& dataVectorPos) { while (offsetVectorPos < common::DEFAULT_VECTOR_CAPACITY && cursor < scanners.size()) { - if (scanners[cursor]->scan( - tableID, pathVector, dstNodeIDVector, offsetVectorPos, dataVectorPos) == 0) { + if (scanners[cursor]->scan(tableID, pathVector, dstNodeIDVector, pathLengthVector, + offsetVectorPos, dataVectorPos) == 0) { cursor++; } } diff --git a/src/processor/operator/recursive_extend/recursive_join.cpp b/src/processor/operator/recursive_extend/recursive_join.cpp index 6bf780290b..c2e6ef5b19 100644 --- a/src/processor/operator/recursive_extend/recursive_join.cpp +++ b/src/processor/operator/recursive_extend/recursive_join.cpp @@ -17,10 +17,14 @@ void BaseRecursiveJoin::initLocalStateInternal(ResultSet* resultSet_, ExecutionC } srcNodeIDVector = resultSet->getValueVector(dataInfo->srcNodePos).get(); dstNodeIDVector = resultSet->getValueVector(dataInfo->dstNodePos).get(); - if (dataInfo->trackPath) { + pathLengthVector = resultSet->getValueVector(dataInfo->pathLengthPos).get(); + switch (dataInfo->joinType) { + case planner::RecursiveJoinType::TRACK_PATH: { pathVector = resultSet->getValueVector(dataInfo->pathPos).get(); - } else { + } break; + default: { pathVector = nullptr; + } break; } initLocalRecursivePlan(context); } @@ -60,8 +64,8 @@ bool BaseRecursiveJoin::scanOutput() { if (pathVector != nullptr) { pathVector->resetAuxiliaryBuffer(); } - frontiersScanner->scan( - nodeTable->getTableID(), pathVector, dstNodeIDVector, offsetVectorSize, dataVectorSize); + frontiersScanner->scan(nodeTable->getTableID(), pathVector, dstNodeIDVector, pathLengthVector, + offsetVectorSize, dataVectorSize); if (offsetVectorSize == 0) { return false; } diff --git a/src/processor/operator/recursive_extend/shortest_path_recursive_join.cpp b/src/processor/operator/recursive_extend/shortest_path_recursive_join.cpp index 63cb495c48..64facf57be 100644 --- a/src/processor/operator/recursive_extend/shortest_path_recursive_join.cpp +++ b/src/processor/operator/recursive_extend/shortest_path_recursive_join.cpp @@ -8,19 +8,24 @@ void ShortestPathRecursiveJoin::initLocalStateInternal( BaseRecursiveJoin::initLocalStateInternal(resultSet_, context); auto maxNodeOffset = nodeTable->getMaxNodeOffset(transaction); std::vector> scanners; - if (dataInfo->trackPath) { + switch (dataInfo->joinType) { + case planner::RecursiveJoinType::TRACK_PATH: { bfsMorsel = std::make_unique>( maxNodeOffset, lowerBound, upperBound, sharedState->semiMask.get()); for (auto i = lowerBound; i <= upperBound; ++i) { scanners.push_back(std::make_unique(bfsMorsel->targetDstNodeOffsets, i)); } - } else { + } break; + case planner::RecursiveJoinType::TRACK_NONE: { bfsMorsel = std::make_unique>( maxNodeOffset, lowerBound, upperBound, sharedState->semiMask.get()); for (auto i = lowerBound; i <= upperBound; ++i) { scanners.push_back( std::make_unique(bfsMorsel->targetDstNodeOffsets, i)); } + } break; + default: + throw common::NotImplementedException("ShortestPathRecursiveJoin::initLocalStateInternal"); } frontiersScanner = std::make_unique(std::move(scanners)); } diff --git a/src/processor/operator/recursive_extend/variable_length_recursive_join.cpp b/src/processor/operator/recursive_extend/variable_length_recursive_join.cpp index d2225fd8ca..c3b050db66 100644 --- a/src/processor/operator/recursive_extend/variable_length_recursive_join.cpp +++ b/src/processor/operator/recursive_extend/variable_length_recursive_join.cpp @@ -8,19 +8,25 @@ void VariableLengthRecursiveJoin::initLocalStateInternal( BaseRecursiveJoin::initLocalStateInternal(resultSet_, context); auto maxNodeOffset = nodeTable->getMaxNodeOffset(transaction); std::vector> scanners; - if (dataInfo->trackPath) { + switch (dataInfo->joinType) { + case planner::RecursiveJoinType::TRACK_PATH: { bfsMorsel = std::make_unique>( maxNodeOffset, lowerBound, upperBound, sharedState->semiMask.get()); for (auto i = lowerBound; i <= upperBound; ++i) { scanners.push_back(std::make_unique(bfsMorsel->targetDstNodeOffsets, i)); } - } else { + } break; + case planner::RecursiveJoinType::TRACK_NONE: { bfsMorsel = std::make_unique>( maxNodeOffset, lowerBound, upperBound, sharedState->semiMask.get()); for (auto i = lowerBound; i <= upperBound; ++i) { scanners.push_back(std::make_unique( bfsMorsel->targetDstNodeOffsets, i)); } + } break; + default: + throw common::NotImplementedException( + "VariableLengthRecursiveJoin::initLocalStateInternal"); } frontiersScanner = std::make_unique(std::move(scanners)); } diff --git a/test/optimizer/optimizer_test.cpp b/test/optimizer/optimizer_test.cpp index bbfb960202..5b3261ec82 100644 --- a/test/optimizer/optimizer_test.cpp +++ b/test/optimizer/optimizer_test.cpp @@ -1,4 +1,5 @@ #include "graph_test/graph_test.h" +#include "planner/logical_plan/logical_operator/logical_recursive_extend.h" #include "planner/logical_plan/logical_plan_util.h" namespace kuzu { @@ -92,5 +93,14 @@ TEST_F(OptimizerTest, RecursiveJoinTest) { ASSERT_STREQ(encodedPlan.c_str(), "HJ(a._id){RE(a)S(b)}{S(a)}"); } +TEST_F(OptimizerTest, RecursiveJoinNoTrackPathTest) { + auto op = getRoot("MATCH (a:person)-[e:knows* SHORTEST 2..3]->(b:person) RETURN length(e);"); + while (op->getOperatorType() != planner::LogicalOperatorType::RECURSIVE_EXTEND) { + op = op->getChild(0); + } + auto recursiveExtend = (planner::LogicalRecursiveExtend*)op.get(); + ASSERT_TRUE(recursiveExtend->getJoinType() == planner::RecursiveJoinType::TRACK_NONE); +} + } // namespace testing } // namespace kuzu diff --git a/test/test_files/shortest_path/bfs_sssp.test b/test/test_files/shortest_path/bfs_sssp.test index 8c157b555a..456307dc29 100644 --- a/test/test_files/shortest_path/bfs_sssp.test +++ b/test/test_files/shortest_path/bfs_sssp.test @@ -20,7 +20,7 @@ Farooq|Alice|[0:0,0:7,0:5] Hubert Blaine Wolfeschlegelsteinhausenbergerdorff|Alice|[0:0,0:7] -NAME SingleSourceWithAllProperties --QUERY MATCH (a:person)-[r:knows* SHORTEST 1..30]->(b:person) WHERE a.fName = 'Alice' RETURN len(r) - 1, b, a +-QUERY MATCH (a:person)-[r:knows* SHORTEST 1..30]->(b:person) WHERE a.fName = 'Alice' RETURN length(r), b, a ---- 7 1|(label:person, 0:1, {ID:2, fName:Bob, gender:2, isStudent:True, isWorker:False, age:30, eyeSight:5.100000, birthdate:1900-01-01, registerTime:2008-11-03 15:25:30.000526, lastJobDuration:10 years 5 months 13:00:00.000024, workedHours:[12,8], usedNames:[Bobby], courseScoresPerTerm:[[8,9],[9,10]]})|(label:person, 0:0, {ID:0, fName:Alice, gender:1, isStudent:True, isWorker:False, age:35, eyeSight:5.000000, birthdate:1900-01-01, registerTime:2011-08-20 11:25:30, lastJobDuration:3 years 2 days 13:02:00, workedHours:[10,5], usedNames:[Aida], courseScoresPerTerm:[[10,8],[6,7,8]]}) 1|(label:person, 0:2, {ID:3, fName:Carol, gender:1, isStudent:False, isWorker:True, age:45, eyeSight:5.000000, birthdate:1940-06-22, registerTime:1911-08-20 02:32:21, lastJobDuration:48:24:11, workedHours:[4,5], usedNames:[Carmen,Fred], courseScoresPerTerm:[[8,10]]})|(label:person, 0:0, {ID:0, fName:Alice, gender:1, isStudent:True, isWorker:False, age:35, eyeSight:5.000000, birthdate:1900-01-01, registerTime:2011-08-20 11:25:30, lastJobDuration:3 years 2 days 13:02:00, workedHours:[10,5], usedNames:[Aida], courseScoresPerTerm:[[10,8],[6,7,8]]}) @@ -50,7 +50,7 @@ Elizabeth|Hubert Blaine Wolfeschlegelsteinhausenbergerdorff|[0:4,0:7] -NAME MultipleSrcMultipleDstQuery --QUERY MATCH (a:person)-[r:knows* SHORTEST 1..10]->(b:person) WHERE a.isStudent = true AND b.isWorker = true RETURN a.fName, b.fName, len(r) - 1 +-QUERY MATCH (a:person)-[r:knows* SHORTEST 1..10]->(b:person) WHERE a.isStudent = true AND b.isWorker = true RETURN a.fName, b.fName, length(r) ----- 12 Alice|Carol|1 Alice|Dan|1 @@ -71,7 +71,7 @@ Farooq|Hubert Blaine Wolfeschlegelsteinhausenbergerdorff|1 Alice|Bob|Elizabeth -NAME MultiPart --QUERY MATCH (a)-[r:knows* SHORTEST 1..30]->(b:person) WHERE b.ID > 6 AND a.fName = 'Alice' WITH a, b, r MATCH (c:person)<-[:knows]-(a:person) RETURN b.fName, len(r) - 1, COUNT(*) +-QUERY MATCH (a)-[r:knows* SHORTEST 1..30]->(b:person) WHERE b.ID > 6 AND a.fName = 'Alice' WITH a, b, r MATCH (c:person)<-[:knows]-(a:person) RETURN b.fName, length(r), COUNT(*) ---- 4 Elizabeth|2|3 Farooq|3|3 diff --git a/test/test_files/shortest_path/bfs_sssp_large.test b/test/test_files/shortest_path/bfs_sssp_large.test index 1d45b012b1..842f207b56 100644 --- a/test/test_files/shortest_path/bfs_sssp_large.test +++ b/test/test_files/shortest_path/bfs_sssp_large.test @@ -4,7 +4,7 @@ # This pattern is continued till Alice2500. Distance between Alice11 and Alice22 is 2, and so on. # Distance between Alice11 and Alice100 is 9 because 11 -> 11,20,30,40,50,60,70,80,90,100 -NAME SingleSrcAllDstQueryLarge --QUERY MATCH (a:person)-[r:knows* SHORTEST 1..30]->(b:person) WHERE a.fName = 'Alice11' RETURN a.fName, b.fName, len(r) - 1 +-QUERY MATCH (a:person)-[r:knows* SHORTEST 1..30]->(b:person) WHERE a.fName = 'Alice11' RETURN a.fName, b.fName, length(r) ---- 300 Alice11|Alice100|9 Alice11|Alice101|9 @@ -308,7 +308,7 @@ Alice11|Alice98|9 Alice11|Alice99|9 -NAME MultipleSrcAllDstQueryLarge --QUERY MATCH (a:person)-[r:knows* SHORTEST 1..30]->(b:person) WHERE a.fName = 'Alice11' OR a.fName = 'Alice' RETURN a.fName, b.fName, len(r) - 1 +-QUERY MATCH (a:person)-[r:knows* SHORTEST 1..30]->(b:person) WHERE a.fName = 'Alice11' OR a.fName = 'Alice' RETURN a.fName, b.fName, length(r) ---- 307 Alice11|Alice100|9 Alice11|Alice101|9