diff --git a/src/binder/bind/bind_graph_pattern.cpp b/src/binder/bind/bind_graph_pattern.cpp index b835b44683..544f5f7d94 100644 --- a/src/binder/bind/bind_graph_pattern.cpp +++ b/src/binder/bind/bind_graph_pattern.cpp @@ -1,6 +1,7 @@ #include #include "binder/binder.h" +#include "binder/expression/property_expression.h" using namespace kuzu::common; using namespace kuzu::parser; @@ -129,6 +130,39 @@ getNodePropertyNameAndPropertiesPairs(const std::vector& nodeT return getPropertyNameAndSchemasPairs(propertyNames, propertyNamesToSchemas); } +static std::unique_ptr getRecursiveRelLogicalType( + const NodeExpression& node, const RelExpression& rel) { + std::vector> nodeFields; + nodeFields.push_back(std::make_unique( + InternalKeyword::ID, std::make_unique(LogicalTypeID::INTERNAL_ID))); + for (auto& expression : node.getPropertyExpressions()) { + auto propertyExpression = (PropertyExpression*)expression.get(); + nodeFields.push_back(std::make_unique( + propertyExpression->getPropertyName(), propertyExpression->getDataType().copy())); + } + auto nodeType = std::make_unique( + LogicalTypeID::STRUCT, std::make_unique(std::move(nodeFields))); + auto nodesType = std::make_unique( + LogicalTypeID::VAR_LIST, std::make_unique(std::move(nodeType))); + std::vector> relFields; + for (auto& expression : rel.getPropertyExpressions()) { + auto propertyExpression = (PropertyExpression*)expression.get(); + relFields.push_back(std::make_unique( + propertyExpression->getPropertyName(), propertyExpression->getDataType().copy())); + } + auto relType = std::make_unique( + LogicalTypeID::STRUCT, std::make_unique(std::move(relFields))); + auto relsType = std::make_unique( + LogicalTypeID::VAR_LIST, std::make_unique(std::move(relType))); + std::vector> recursiveRelFields; + recursiveRelFields.push_back( + std::make_unique(InternalKeyword::NODES, std::move(nodesType))); + recursiveRelFields.push_back( + std::make_unique(InternalKeyword::RELS, std::move(relsType))); + return std::make_unique(LogicalTypeID::RECURSIVE_REL, + std::make_unique(std::move(recursiveRelFields))); +} + void Binder::bindQueryRel(const RelPattern& relPattern, const std::shared_ptr& leftNode, const std::shared_ptr& rightNode, QueryGraph& queryGraph, @@ -170,61 +204,21 @@ void Binder::bindQueryRel(const RelPattern& relPattern, throw BinderException("Self-loop rel " + parsedName + " is not supported."); } // bind variable length - auto [lowerBound, upperBound] = bindVariableLengthRelBound(relPattern); - auto isVariableLength = !(lowerBound == 1 && upperBound == 1); - if (!isVariableLength) { + std::shared_ptr queryRel; + if (QueryRelTypeUtils::isRecursive(relPattern.getRelType())) { + auto [lowerBound, upperBound] = bindVariableLengthRelBound(relPattern); + queryRel = createRecursiveQueryRel(relPattern.getVariableName(), relPattern.getRelType(), + lowerBound, upperBound, tableIDs, srcNode, dstNode, directionType); + } else { tableIDs = pruneRelTableIDs(catalog, tableIDs, *srcNode, *dstNode); if (tableIDs.empty()) { throw BinderException("Nodes " + srcNode->toString() + " and " + dstNode->toString() + " are not connected through rel " + parsedName + "."); } - } - common::LogicalType dataType; - if (isVariableLength) { - std::vector> structFields; - auto varListTypeInfo = std::make_unique( - std::make_unique(LogicalTypeID::INTERNAL_ID)); - auto nodeStructField = std::make_unique(InternalKeyword::NODES, - std::make_unique(LogicalTypeID::VAR_LIST, varListTypeInfo->copy())); - auto relStructField = std::make_unique(InternalKeyword::RELS, - std::make_unique(LogicalTypeID::VAR_LIST, varListTypeInfo->copy())); - structFields.push_back(std::move(nodeStructField)); - structFields.push_back(std::move(relStructField)); - auto structTypeInfo = std::make_unique(std::move(structFields)); - dataType = - common::LogicalType(common::LogicalTypeID::RECURSIVE_REL, std::move(structTypeInfo)); - } else { - dataType = common::LogicalType(common::LogicalTypeID::REL); - } - auto queryRel = make_shared(dataType, getUniqueExpressionName(parsedName), - parsedName, tableIDs, srcNode, dstNode, directionType, relPattern.getRelType()); - if (isVariableLength) { - std::unordered_set recursiveNodeTableIDs; - for (auto relTableID : tableIDs) { - auto relTableSchema = catalog.getReadOnlyVersion()->getRelTableSchema(relTableID); - recursiveNodeTableIDs.insert(relTableSchema->srcTableID); - recursiveNodeTableIDs.insert(relTableSchema->dstTableID); - } - auto recursiveNode = - createQueryNode("", std::vector{ - recursiveNodeTableIDs.begin(), recursiveNodeTableIDs.end()}); - auto lengthExpression = expressionBinder.createInternalLengthExpression(*queryRel); - auto recursiveInfo = std::make_unique( - lowerBound, upperBound, std::move(recursiveNode), std::move(lengthExpression)); - queryRel->setRecursiveInfo(std::move(recursiveInfo)); + queryRel = createNonRecursiveQueryRel( + relPattern.getVariableName(), tableIDs, srcNode, dstNode, directionType); } queryRel->setAlias(parsedName); - // resolve properties associate with rel table - std::vector relTableSchemas; - for (auto tableID : tableIDs) { - relTableSchemas.push_back(catalog.getReadOnlyVersion()->getRelTableSchema(tableID)); - } - for (auto& [propertyName, propertySchemas] : - getRelPropertyNameAndPropertiesPairs(relTableSchemas)) { - auto propertyExpression = expressionBinder.createPropertyExpression( - *queryRel, propertySchemas, false /* isPrimaryKey */); - queryRel->addPropertyExpression(propertyName, std::move(propertyExpression)); - } if (!parsedName.empty()) { variableScope->addExpression(parsedName, queryRel); } @@ -238,6 +232,42 @@ void Binder::bindQueryRel(const RelPattern& relPattern, queryGraph.addQueryRel(queryRel); } +std::shared_ptr Binder::createNonRecursiveQueryRel(const std::string& parsedName, + const std::vector& tableIDs, std::shared_ptr srcNode, + std::shared_ptr dstNode, RelDirectionType directionType) { + auto queryRel = make_shared(LogicalType(LogicalTypeID::REL), + getUniqueExpressionName(parsedName), parsedName, tableIDs, std::move(srcNode), + std::move(dstNode), directionType, QueryRelType::NON_RECURSIVE); + bindQueryRelProperties(*queryRel); + return queryRel; +} + +std::shared_ptr Binder::createRecursiveQueryRel(const std::string& parsedName, + common::QueryRelType relType, uint32_t lowerBound, uint32_t upperBound, + const std::vector& tableIDs, std::shared_ptr srcNode, + std::shared_ptr dstNode, RelDirectionType directionType) { + std::unordered_set recursiveNodeTableIDs; + for (auto relTableID : tableIDs) { + auto relTableSchema = catalog.getReadOnlyVersion()->getRelTableSchema(relTableID); + recursiveNodeTableIDs.insert(relTableSchema->srcTableID); + recursiveNodeTableIDs.insert(relTableSchema->dstTableID); + } + auto tmpNode = createQueryNode( + InternalKeyword::ANONYMOUS, std::vector{recursiveNodeTableIDs.begin(), + recursiveNodeTableIDs.end()}); + auto tmpRel = createNonRecursiveQueryRel( + InternalKeyword::ANONYMOUS, tableIDs, nullptr, nullptr, directionType); + auto queryRel = make_shared(*getRecursiveRelLogicalType(*tmpNode, *tmpRel), + getUniqueExpressionName(parsedName), parsedName, tableIDs, std::move(srcNode), + std::move(dstNode), directionType, relType); + auto lengthExpression = expressionBinder.createInternalLengthExpression(*queryRel); + auto recursiveInfo = std::make_unique( + lowerBound, upperBound, std::move(tmpNode), std::move(tmpRel), std::move(lengthExpression)); + queryRel->setRecursiveInfo(std::move(recursiveInfo)); + bindQueryRelProperties(*queryRel); + return queryRel; +} + std::pair Binder::bindVariableLengthRelBound( const kuzu::parser::RelPattern& relPattern) { auto lowerBound = std::min(TypeUtils::convertToUint32(relPattern.getLowerBound().c_str()), @@ -254,6 +284,19 @@ std::pair Binder::bindVariableLengthRelBound( return std::make_pair(lowerBound, upperBound); } +void Binder::bindQueryRelProperties(RelExpression& rel) { + std::vector tableSchemas; + for (auto tableID : rel.getTableIDs()) { + tableSchemas.push_back(catalog.getReadOnlyVersion()->getRelTableSchema(tableID)); + } + for (auto& [propertyName, propertySchemas] : + getRelPropertyNameAndPropertiesPairs(tableSchemas)) { + auto propertyExpression = expressionBinder.createPropertyExpression( + rel, propertySchemas, false /* isPrimaryKey */); + rel.addPropertyExpression(propertyName, std::move(propertyExpression)); + } +} + std::shared_ptr Binder::bindQueryNode( const NodePattern& nodePattern, QueryGraph& queryGraph, PropertyKeyValCollection& collection) { auto parsedName = nodePattern.getVariableName(); @@ -294,26 +337,31 @@ std::shared_ptr Binder::createQueryNode( make_shared(getUniqueExpressionName(parsedName), parsedName, tableIDs); queryNode->setAlias(parsedName); queryNode->setInternalIDProperty(expressionBinder.createInternalNodeIDExpression(*queryNode)); - // resolve properties associate with node table - std::vector nodeTableSchemas; - for (auto tableID : tableIDs) { - nodeTableSchemas.push_back(catalog.getReadOnlyVersion()->getNodeTableSchema(tableID)); - } - auto isSingleTable = nodeTableSchemas.size() == 1; - for (auto& [propertyName, propertySchemas] : - getNodePropertyNameAndPropertiesPairs(nodeTableSchemas)) { - auto isPrimaryKey = isSingleTable && nodeTableSchemas[0]->getPrimaryKey().propertyID == - propertySchemas[0].propertyID; - auto propertyExpression = - expressionBinder.createPropertyExpression(*queryNode, propertySchemas, isPrimaryKey); - queryNode->addPropertyExpression(propertyName, std::move(propertyExpression)); - } + bindQueryNodeProperties(*queryNode); if (!parsedName.empty()) { variableScope->addExpression(parsedName, queryNode); } return queryNode; } +void Binder::bindQueryNodeProperties(NodeExpression& node) { + std::vector tableSchemas; + for (auto tableID : node.getTableIDs()) { + tableSchemas.push_back(catalog.getReadOnlyVersion()->getNodeTableSchema(tableID)); + } + for (auto& [propertyName, propertySchemas] : + getNodePropertyNameAndPropertiesPairs(tableSchemas)) { + bool isPrimaryKey = false; + if (!node.isMultiLabeled()) { + isPrimaryKey = + tableSchemas[0]->getPrimaryKey().propertyID == propertySchemas[0].propertyID; + } + auto propertyExpression = + expressionBinder.createPropertyExpression(node, propertySchemas, isPrimaryKey); + node.addPropertyExpression(propertyName, std::move(propertyExpression)); + } +} + std::vector Binder::bindTableIDs( const std::vector& tableNames, LogicalTypeID nodeOrRelType) { std::unordered_set tableIDs; diff --git a/src/binder/bind_expression/bind_function_expression.cpp b/src/binder/bind_expression/bind_function_expression.cpp index 43b49a3295..ce975ceb64 100644 --- a/src/binder/bind_expression/bind_function_expression.cpp +++ b/src/binder/bind_expression/bind_function_expression.cpp @@ -83,7 +83,7 @@ std::shared_ptr ExpressionBinder::bindAggregateFunctionExpression( // e.g. COUNT(a) -> COUNT(a._id) if (child->dataType.getLogicalTypeID() == LogicalTypeID::NODE || child->dataType.getLogicalTypeID() == LogicalTypeID::REL) { - child = bindInternalIDExpression(*child); + child = bindInternalIDExpression(child); } childrenTypes.push_back(child->dataType); children.push_back(std::move(child)); @@ -131,9 +131,9 @@ std::shared_ptr ExpressionBinder::rewriteFunctionExpression( const parser::ParsedExpression& parsedExpression, const std::string& functionName) { if (functionName == ID_FUNC_NAME) { auto child = bindExpression(*parsedExpression.getChild(0)); - validateExpectedDataType( - *child, std::vector{LogicalTypeID::NODE, LogicalTypeID::REL}); - return bindInternalIDExpression(*child); + validateExpectedDataType(*child, std::vector{LogicalTypeID::NODE, + LogicalTypeID::REL, LogicalTypeID::STRUCT}); + return bindInternalIDExpression(child); } else if (functionName == LABEL_FUNC_NAME) { auto child = bindExpression(*parsedExpression.getChild(0)); validateExpectedDataType( @@ -158,14 +158,21 @@ std::unique_ptr ExpressionBinder::createInternalNodeIDExpression( } std::shared_ptr ExpressionBinder::bindInternalIDExpression( - const Expression& expression) { - switch (expression.getDataType().getLogicalTypeID()) { + std::shared_ptr expression) { + switch (expression->getDataType().getLogicalTypeID()) { case common::LogicalTypeID::NODE: { - auto& node = (NodeExpression&)expression; + auto& node = (NodeExpression&)*expression; return node.getInternalIDProperty(); } case common::LogicalTypeID::REL: { - return bindRelPropertyExpression(expression, InternalKeyword::ID); + return bindRelPropertyExpression(*expression, InternalKeyword::ID); + } + case common::LogicalTypeID::STRUCT: { + auto stringValue = + std::make_unique(LogicalType{LogicalTypeID::STRING}, InternalKeyword::ID); + return bindScalarFunctionExpression( + expression_vector{expression, createLiteralExpression(std::move(stringValue))}, + STRUCT_EXTRACT_FUNC_NAME); } default: throw NotImplementedException("ExpressionBinder::bindInternalIDExpression"); diff --git a/src/binder/bind_expression/bind_variable_expression.cpp b/src/binder/bind_expression/bind_variable_expression.cpp index aa0421ab87..f7aa3a1695 100644 --- a/src/binder/bind_expression/bind_variable_expression.cpp +++ b/src/binder/bind_expression/bind_variable_expression.cpp @@ -1,4 +1,5 @@ #include "binder/binder.h" +#include "binder/expression/variable_expression.h" #include "binder/expression_binder.h" #include "parser/expression/parsed_variable_expression.h" @@ -18,5 +19,11 @@ std::shared_ptr ExpressionBinder::bindVariableExpression( "Variable " + parsedExpression.getRawName() + " is not in scope."); } +std::shared_ptr ExpressionBinder::createVariableExpression( + common::LogicalType logicalType, std::string uniqueName, std::string name) { + return std::make_shared( + std::move(logicalType), std::move(uniqueName), std::move(name)); +} + } // namespace binder } // namespace kuzu diff --git a/src/binder/binder.cpp b/src/binder/binder.cpp index d4a8223a65..2eebbf9f69 100644 --- a/src/binder/binder.cpp +++ b/src/binder/binder.cpp @@ -70,10 +70,10 @@ std::shared_ptr Binder::createVariable( throw BinderException("Variable " + name + " already exists."); } auto uniqueName = getUniqueExpressionName(name); - auto variable = make_shared(dataType, uniqueName, name); - variable->setAlias(name); - variableScope->addExpression(name, variable); - return variable; + auto expression = expressionBinder.createVariableExpression(dataType, uniqueName, name); + expression->setAlias(name); + variableScope->addExpression(name, expression); + return expression; } void Binder::validateFirstMatchIsNotOptional(const SingleQuery& singleQuery) { diff --git a/src/common/types/types.cpp b/src/common/types/types.cpp index 244d2420a8..f039f82075 100644 --- a/src/common/types/types.cpp +++ b/src/common/types/types.cpp @@ -110,6 +110,14 @@ struct_field_idx_t StructTypeInfo::getStructFieldIdx(std::string fieldName) cons return INVALID_STRUCT_FIELD_IDX; } +StructField* StructTypeInfo::getStructField(const std::string& fieldName) const { + auto idx = getStructFieldIdx(fieldName); + if (idx == INVALID_STRUCT_FIELD_IDX) { + throw BinderException("Cannot find field " + fieldName + " in STRUCT."); + } + return fields[idx].get(); +} + std::vector StructTypeInfo::getChildrenTypes() const { std::vector childrenTypesToReturn{fields.size()}; for (auto i = 0u; i < fields.size(); i++) { diff --git a/src/include/binder/binder.h b/src/include/binder/binder.h index 759e413444..f84c49911c 100644 --- a/src/include/binder/binder.h +++ b/src/include/binder/binder.h @@ -170,12 +170,22 @@ class Binder { const std::shared_ptr& leftNode, const std::shared_ptr& rightNode, QueryGraph& queryGraph, PropertyKeyValCollection& collection); + std::shared_ptr createNonRecursiveQueryRel(const std::string& parsedName, + const std::vector& tableIDs, std::shared_ptr srcNode, + std::shared_ptr dstNode, RelDirectionType directionType); + std::shared_ptr createRecursiveQueryRel(const std::string& parsedName, + common::QueryRelType relType, uint32_t lowerBound, uint32_t upperBound, + const std::vector& tableIDs, std::shared_ptr srcNode, + std::shared_ptr dstNode, RelDirectionType directionType); std::pair bindVariableLengthRelBound(const parser::RelPattern& relPattern); + void bindQueryRelProperties(RelExpression& rel); + std::shared_ptr bindQueryNode(const parser::NodePattern& nodePattern, QueryGraph& queryGraph, PropertyKeyValCollection& collection); std::shared_ptr createQueryNode(const parser::NodePattern& nodePattern); std::shared_ptr createQueryNode( const std::string& parsedName, const std::vector& tableIDs); + void bindQueryNodeProperties(NodeExpression& node); inline std::vector bindNodeTableIDs( const std::vector& tableNames) { return bindTableIDs(tableNames, common::LogicalTypeID::NODE); diff --git a/src/include/binder/expression/rel_expression.h b/src/include/binder/expression/rel_expression.h index cb193a2396..d18ebc8ca5 100644 --- a/src/include/binder/expression/rel_expression.h +++ b/src/include/binder/expression/rel_expression.h @@ -12,16 +12,19 @@ enum class RelDirectionType : uint8_t { BOTH = 1, }; +class RelExpression; + struct RecursiveInfo { uint64_t lowerBound; uint64_t upperBound; - std::shared_ptr recursiveNode; + std::shared_ptr node; + std::shared_ptr rel; std::shared_ptr lengthExpression; - RecursiveInfo(size_t lowerBound, size_t upperBound, - std::shared_ptr recursiveNode, std::shared_ptr lengthExpression) - : lowerBound{lowerBound}, upperBound{upperBound}, recursiveNode{std::move(recursiveNode)}, - lengthExpression{std::move(lengthExpression)} {} + RecursiveInfo(size_t lowerBound, size_t upperBound, std::shared_ptr node, + std::shared_ptr rel, std::shared_ptr lengthExpression) + : lowerBound{lowerBound}, upperBound{upperBound}, node{std::move(node)}, + rel{std::move(rel)}, lengthExpression{std::move(lengthExpression)} {} }; class RelExpression : public NodeOrRelExpression { @@ -58,9 +61,6 @@ class RelExpression : public NodeOrRelExpression { inline RecursiveInfo* getRecursiveInfo() const { return recursiveInfo.get(); } inline size_t getLowerBound() const { return recursiveInfo->lowerBound; } inline size_t getUpperBound() const { return recursiveInfo->upperBound; } - inline std::shared_ptr getRecursiveNode() const { - return recursiveInfo->recursiveNode; - } inline std::shared_ptr getLengthExpression() const { return recursiveInfo->lengthExpression; } diff --git a/src/include/binder/expression_binder.h b/src/include/binder/expression_binder.h index ca6076d18c..12c745209b 100644 --- a/src/include/binder/expression_binder.h +++ b/src/include/binder/expression_binder.h @@ -60,7 +60,7 @@ class ExpressionBinder { std::shared_ptr rewriteFunctionExpression( const parser::ParsedExpression& parsedExpression, const std::string& functionName); std::unique_ptr createInternalNodeIDExpression(const Expression& node); - std::shared_ptr bindInternalIDExpression(const Expression& expression); + std::shared_ptr bindInternalIDExpression(std::shared_ptr expression); std::shared_ptr bindLabelFunction(const Expression& expression); std::unique_ptr createInternalLengthExpression(const Expression& expression); std::shared_ptr bindRecursiveJoinLengthFunction(const Expression& expression); @@ -75,6 +75,8 @@ class ExpressionBinder { std::shared_ptr bindVariableExpression( const parser::ParsedExpression& parsedExpression); + std::shared_ptr createVariableExpression( + common::LogicalType logicalType, std::string uniqueName, std::string name); std::shared_ptr bindExistentialSubqueryExpression( const parser::ParsedExpression& parsedExpression); diff --git a/src/include/common/constants.h b/src/include/common/constants.h index 0717b51e09..09661dfea5 100644 --- a/src/include/common/constants.h +++ b/src/include/common/constants.h @@ -20,6 +20,7 @@ constexpr uint64_t THREAD_SLEEP_TIME_WHEN_WAITING_IN_MICROS = 500; constexpr uint64_t DEFAULT_CHECKPOINT_WAIT_TIMEOUT_FOR_TRANSACTIONS_TO_LEAVE_IN_MICROS = 5000000; struct InternalKeyword { + static constexpr char ANONYMOUS[] = ""; static constexpr char ID[] = "_id"; static constexpr char LENGTH[] = "_length"; static constexpr char NODES[] = "_nodes"; diff --git a/src/include/common/types/types.h b/src/include/common/types/types.h index b0f8ec6710..917370ae08 100644 --- a/src/include/common/types/types.h +++ b/src/include/common/types/types.h @@ -178,6 +178,7 @@ class StructTypeInfo : public ExtraTypeInfo { explicit StructTypeInfo(std::vector> fields); struct_field_idx_t getStructFieldIdx(std::string fieldName) const; + StructField* getStructField(const std::string& fieldName) const; std::vector getChildrenTypes() const; std::vector getChildrenNames() const; std::vector getStructFields() const; @@ -279,6 +280,12 @@ struct StructType { return structTypeInfo->getStructFields(); } + static inline StructField* getField(const LogicalType* type, const std::string& key) { + assert(type->getPhysicalType() == PhysicalTypeID::STRUCT); + auto structTypeInfo = reinterpret_cast(type->extraTypeInfo.get()); + return structTypeInfo->getStructField(key); + } + static inline struct_field_idx_t getFieldIdx(const LogicalType* type, const std::string& key) { assert(type->getPhysicalType() == PhysicalTypeID::STRUCT); auto structTypeInfo = reinterpret_cast(type->extraTypeInfo.get()); diff --git a/src/include/common/vector/auxiliary_buffer.h b/src/include/common/vector/auxiliary_buffer.h index 06b90e7c14..1220cf1b12 100644 --- a/src/include/common/vector/auxiliary_buffer.h +++ b/src/include/common/vector/auxiliary_buffer.h @@ -68,6 +68,8 @@ class ListAuxiliaryBuffer : public AuxiliaryBuffer { list_entry_t addList(uint64_t listSize); + inline uint64_t getSize() const { return size; } + inline void resetSize() { size = 0; } private: diff --git a/src/include/common/vector/value_vector.h b/src/include/common/vector/value_vector.h index 784772b766..0f3ec29c5b 100644 --- a/src/include/common/vector/value_vector.h +++ b/src/include/common/vector/value_vector.h @@ -115,6 +115,12 @@ class ListVector { return reinterpret_cast(vector->auxiliaryBuffer.get()) ->getDataVector(); } + + static inline uint64_t getDataVectorSize(const ValueVector* vector) { + assert(vector->dataType.getPhysicalType() == PhysicalTypeID::VAR_LIST); + return reinterpret_cast(vector->auxiliaryBuffer.get())->getSize(); + } + static inline uint8_t* getListValues(const ValueVector* vector, const list_entry_t& listEntry) { assert(vector->dataType.getPhysicalType() == PhysicalTypeID::VAR_LIST); auto dataVector = getDataVector(vector); diff --git a/src/include/optimizer/projection_push_down_optimizer.h b/src/include/optimizer/projection_push_down_optimizer.h index b6b6f538bf..9a7e93b200 100644 --- a/src/include/optimizer/projection_push_down_optimizer.h +++ b/src/include/optimizer/projection_push_down_optimizer.h @@ -20,6 +20,7 @@ class ProjectionPushDownOptimizer : public LogicalOperatorVisitor { void visitOperator(planner::LogicalOperator* op); void visitRecursiveExtend(planner::LogicalOperator* op) override; + void visitExtend(planner::LogicalOperator* op) override; void visitAccumulate(planner::LogicalOperator* op) override; void visitFilter(planner::LogicalOperator* op) override; void visitHashJoin(planner::LogicalOperator* op) override; diff --git a/src/include/planner/join_order/cardinality_estimator.h b/src/include/planner/join_order/cardinality_estimator.h index 77cdef71e3..e01fe4438f 100644 --- a/src/include/planner/join_order/cardinality_estimator.h +++ b/src/include/planner/join_order/cardinality_estimator.h @@ -31,6 +31,11 @@ class CardinalityEstimator { private: inline uint64_t atLeastOne(uint64_t x) { return x == 0 ? 1 : x; } + inline void addNodeIDDom(const binder::NodeExpression& node) { + if (!nodeIDName2dom.contains(node.getInternalIDPropertyName())) { + nodeIDName2dom.insert({node.getInternalIDPropertyName(), getNumNodes(node)}); + } + } uint64_t getNodeIDDom(const std::string& nodeIDName) { assert(nodeIDName2dom.contains(nodeIDName)); return nodeIDName2dom.at(nodeIDName); diff --git a/src/include/planner/join_order_enumerator.h b/src/include/planner/join_order_enumerator.h index d8aae1a6e2..9895c76988 100644 --- a/src/include/planner/join_order_enumerator.h +++ b/src/include/planner/join_order_enumerator.h @@ -91,7 +91,12 @@ class JoinOrderEnumerator { std::shared_ptr nbrNode, std::shared_ptr rel, ExtendDirection direction, LogicalPlan& plan); void createRecursivePlan(std::shared_ptr boundNode, - std::shared_ptr recursiveNode, std::shared_ptr rel, + std::shared_ptr recursiveNode, std::shared_ptr recursiveRel, + ExtendDirection direction, LogicalPlan& plan); + void createRecursiveNodePropertyScanPlan( + std::shared_ptr recursiveNode, LogicalPlan& plan); + void createRecursiveRelPropertyScanPlan(std::shared_ptr recursiveNode, + std::shared_ptr nbrNode, std::shared_ptr recursiveRel, ExtendDirection direction, LogicalPlan& plan); void planJoin(const binder::expression_vector& joinNodeIDs, common::JoinType joinType, diff --git a/src/include/planner/logical_plan/logical_operator/base_logical_extend.h b/src/include/planner/logical_plan/logical_operator/base_logical_extend.h index 15017f88c0..9c7e68d4d5 100644 --- a/src/include/planner/logical_plan/logical_operator/base_logical_extend.h +++ b/src/include/planner/logical_plan/logical_operator/base_logical_extend.h @@ -12,8 +12,8 @@ class BaseLogicalExtend : public LogicalOperator { BaseLogicalExtend(LogicalOperatorType operatorType, std::shared_ptr boundNode, std::shared_ptr nbrNode, std::shared_ptr rel, - ExtendDirection direction, std::shared_ptr child) - : LogicalOperator{operatorType, std::move(child)}, boundNode{std::move(boundNode)}, + ExtendDirection direction, std::vector> children) + : LogicalOperator{operatorType, std::move(children)}, boundNode{std::move(boundNode)}, nbrNode{std::move(nbrNode)}, rel{std::move(rel)}, direction{direction} {} inline std::shared_ptr getBoundNode() const { return boundNode; } diff --git a/src/include/planner/logical_plan/logical_operator/base_logical_operator.h b/src/include/planner/logical_plan/logical_operator/base_logical_operator.h index db007b9518..ae9aafea0e 100644 --- a/src/include/planner/logical_plan/logical_operator/base_logical_operator.h +++ b/src/include/planner/logical_plan/logical_operator/base_logical_operator.h @@ -68,13 +68,14 @@ class LogicalOperator { inline uint32_t getNumChildren() const { return children.size(); } - // Used for operators with more than two children e.g. Union - inline void addChild(std::shared_ptr op) { children.push_back(std::move(op)); } inline std::shared_ptr getChild(uint64_t idx) const { return children[idx]; } inline std::vector> getChildren() const { return children; } inline void setChild(uint64_t idx, std::shared_ptr child) { children[idx] = std::move(child); } + inline void setChildren(std::vector> children_) { + children = std::move(children_); + } inline LogicalOperatorType getOperatorType() const { return operatorType; } diff --git a/src/include/planner/logical_plan/logical_operator/logical_extend.h b/src/include/planner/logical_plan/logical_operator/logical_extend.h index 0478aff89b..6a72d041b5 100644 --- a/src/include/planner/logical_plan/logical_operator/logical_extend.h +++ b/src/include/planner/logical_plan/logical_operator/logical_extend.h @@ -12,7 +12,8 @@ class LogicalExtend : public BaseLogicalExtend { ExtendDirection direction, binder::expression_vector properties, bool hasAtMostOneNbr, std::shared_ptr child) : BaseLogicalExtend{LogicalOperatorType::EXTEND, std::move(boundNode), std::move(nbrNode), - std::move(rel), direction, std::move(child)}, + std::move(rel), direction, + std::vector>{std::move(child)}}, properties{std::move(properties)}, hasAtMostOneNbr{hasAtMostOneNbr} {} f_group_pos_set getGroupsPosToFlatten() override; diff --git a/src/include/planner/logical_plan/logical_operator/logical_recursive_extend.h b/src/include/planner/logical_plan/logical_operator/logical_recursive_extend.h index dafd3a1c8e..59b0b4001d 100644 --- a/src/include/planner/logical_plan/logical_operator/logical_recursive_extend.h +++ b/src/include/planner/logical_plan/logical_operator/logical_recursive_extend.h @@ -8,20 +8,14 @@ namespace planner { class LogicalRecursiveExtend : public BaseLogicalExtend { public: - LogicalRecursiveExtend(std::shared_ptr boundNode, - std::shared_ptr nbrNode, std::shared_ptr rel, - ExtendDirection direction, std::shared_ptr child, - std::shared_ptr recursivePlanRoot) - : LogicalRecursiveExtend{std::move(boundNode), std::move(nbrNode), std::move(rel), - direction, RecursiveJoinType::TRACK_PATH, std::move(child), - std::move(recursivePlanRoot)} {} LogicalRecursiveExtend(std::shared_ptr boundNode, std::shared_ptr nbrNode, std::shared_ptr rel, ExtendDirection direction, RecursiveJoinType joinType, - std::shared_ptr child, std::shared_ptr recursivePlanRoot) + std::vector> children, + std::shared_ptr recursiveChild) : BaseLogicalExtend{LogicalOperatorType::RECURSIVE_EXTEND, std::move(boundNode), - std::move(nbrNode), std::move(rel), direction, std::move(child)}, - joinType{joinType}, recursivePlanRoot{std::move(recursivePlanRoot)} {} + std::move(nbrNode), std::move(rel), direction, std::move(children)}, + joinType{joinType}, recursiveChild{std::move(recursiveChild)} {} f_group_pos_set getGroupsPosToFlatten() override; @@ -30,18 +24,21 @@ class LogicalRecursiveExtend : public BaseLogicalExtend { inline void setJoinType(RecursiveJoinType joinType_) { joinType = joinType_; } inline RecursiveJoinType getJoinType() const { return joinType; } - inline std::shared_ptr getRecursivePlanRoot() const { - return recursivePlanRoot; - } + inline std::shared_ptr getRecursiveChild() const { return recursiveChild; } inline std::unique_ptr copy() override { + std::vector> copiedChildren; + copiedChildren.reserve(children.size()); + for (auto& child : children) { + copiedChildren.push_back(child->copy()); + } return std::make_unique(boundNode, nbrNode, rel, direction, - joinType, children[0]->copy(), recursivePlanRoot->copy()); + joinType, std::move(copiedChildren), recursiveChild->copy()); } private: RecursiveJoinType joinType; - std::shared_ptr recursivePlanRoot; + std::shared_ptr recursiveChild; }; class LogicalScanFrontier : public LogicalOperator { diff --git a/src/include/processor/operator/hash_join/hash_join_build.h b/src/include/processor/operator/hash_join/hash_join_build.h index ec16627518..adf971a986 100644 --- a/src/include/processor/operator/hash_join/hash_join_build.h +++ b/src/include/processor/operator/hash_join/hash_join_build.h @@ -75,6 +75,8 @@ class HashJoinBuild : public Sink { sharedState{std::move(sharedState)}, info{std::move(info)} {} ~HashJoinBuild() override = default; + inline std::shared_ptr getSharedState() const { return sharedState; } + void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override; void executeInternal(ExecutionContext* context) override; diff --git a/src/include/processor/operator/physical_operator.h b/src/include/processor/operator/physical_operator.h index 76c00ad15d..eb6c694f16 100644 --- a/src/include/processor/operator/physical_operator.h +++ b/src/include/processor/operator/physical_operator.h @@ -37,6 +37,7 @@ enum class PhysicalOperatorType : uint8_t { INTERSECT, LIMIT, MULTIPLICITY_REDUCER, + PATH_PROPERTY_PROBE, PROJECTION, RECURSIVE_JOIN, RENAME_PROPERTY, diff --git a/src/include/processor/operator/recursive_extend/path_property_probe.h b/src/include/processor/operator/recursive_extend/path_property_probe.h new file mode 100644 index 0000000000..ace3ae732a --- /dev/null +++ b/src/include/processor/operator/recursive_extend/path_property_probe.h @@ -0,0 +1,97 @@ +#pragma once + +#include "processor/operator/hash_join/hash_join_build.h" +#include "processor/operator/physical_operator.h" + +namespace kuzu { +namespace processor { + +struct PathPropertyProbeSharedState { + std::shared_ptr nodeHashTableState; + std::shared_ptr relHashTableState; + + PathPropertyProbeSharedState(std::shared_ptr nodeHashTableState, + std::shared_ptr relHashTableState) + : nodeHashTableState{std::move(nodeHashTableState)}, relHashTableState{ + std::move(relHashTableState)} {} +}; + +struct PathPropertyProbeLocalState { + std::unique_ptr hashes; + std::unique_ptr probedTuples; + std::unique_ptr matchedTuples; + + PathPropertyProbeLocalState() { + hashes = std::make_unique(common::DEFAULT_VECTOR_CAPACITY); + probedTuples = std::make_unique(common::DEFAULT_VECTOR_CAPACITY); + matchedTuples = std::make_unique(common::DEFAULT_VECTOR_CAPACITY); + } +}; + +struct PathPropertyProbeDataInfo { + DataPos pathPos; + std::vector nodeHashTableColIndicesToScan; + std::vector relHashTableColIndicesToScan; + + PathPropertyProbeDataInfo(const DataPos& pathPos, + std::vector nodeHashTableColIndicesToScan, + std::vector relHashTableColIndicesToScan) + : pathPos{pathPos}, nodeHashTableColIndicesToScan{std::move(nodeHashTableColIndicesToScan)}, + relHashTableColIndicesToScan{std::move(relHashTableColIndicesToScan)} {} + PathPropertyProbeDataInfo(const PathPropertyProbeDataInfo& other) + : pathPos{other.pathPos}, + nodeHashTableColIndicesToScan{other.nodeHashTableColIndicesToScan}, + relHashTableColIndicesToScan{other.relHashTableColIndicesToScan} {} + + std::unique_ptr copy() const { + return std::make_unique(*this); + } +}; + +class PathPropertyProbe : public PhysicalOperator { +public: + PathPropertyProbe(std::unique_ptr info, + std::shared_ptr sharedState, + std::vector> children, uint32_t id, + const std::string& paramsString) + : PhysicalOperator{PhysicalOperatorType::PATH_PROPERTY_PROBE, std::move(children), id, + paramsString}, + info{std::move(info)}, sharedState{std::move(sharedState)} {} + PathPropertyProbe(std::unique_ptr info, + std::shared_ptr sharedState, + std::unique_ptr probeChild, uint32_t id, const std::string& paramsString) + : PhysicalOperator{PhysicalOperatorType::PATH_PROPERTY_PROBE, std::move(probeChild), id, + paramsString}, + info{std::move(info)}, sharedState{std::move(sharedState)} {} + + void initLocalStateInternal(ResultSet* resultSet_, ExecutionContext* context) final; + + bool getNextTuplesInternal(ExecutionContext* context) final; + + inline std::unique_ptr clone() final { + return std::make_unique( + info->copy(), sharedState, children[0]->clone(), id, paramsString); + } + +private: + void probe(JoinHashTable* hashTable, uint64_t sizeProbed, uint64_t sizeToProbe, + common::ValueVector* idVector, const std::vector& propertyVectors, + const std::vector& colIndicesToScan); + + struct Vectors { + common::ValueVector* pathNodesVector; + common::ValueVector* pathRelsVector; + common::ValueVector* pathNodesIDDataVector; + common::ValueVector* pathRelsIDDataVector; + std::vector pathNodesPropertyDataVectors; + std::vector pathRelsPropertyDataVectors; + }; + + std::unique_ptr info; + std::shared_ptr sharedState; + std::unique_ptr localState; + std::unique_ptr vectors; +}; + +} // namespace processor +} // namespace kuzu diff --git a/src/include/processor/operator/recursive_extend/recursive_join.h b/src/include/processor/operator/recursive_extend/recursive_join.h index ed9ddc6d2c..99f66a21fb 100644 --- a/src/include/processor/operator/recursive_extend/recursive_join.h +++ b/src/include/processor/operator/recursive_extend/recursive_join.h @@ -34,15 +34,6 @@ struct RecursiveJoinDataInfo { // Path info DataPos pathPos; - RecursiveJoinDataInfo(const DataPos& srcNodePos, const DataPos& dstNodePos, - std::unordered_set dstNodeTableIDs, const DataPos& pathLengthPos, - std::unique_ptr localResultSetDescriptor, - const DataPos& recursiveDstNodeIDPos, - std::unordered_set recursiveDstNodeTableIDs, - const DataPos& recursiveEdgeIDPos) - : RecursiveJoinDataInfo{srcNodePos, dstNodePos, std::move(dstNodeTableIDs), pathLengthPos, - std::move(localResultSetDescriptor), recursiveDstNodeIDPos, - std::move(recursiveDstNodeTableIDs), recursiveEdgeIDPos, DataPos()} {} RecursiveJoinDataInfo(const DataPos& srcNodePos, const DataPos& dstNodePos, std::unordered_set dstNodeTableIDs, const DataPos& pathLengthPos, std::unique_ptr localResultSetDescriptor, @@ -67,9 +58,11 @@ struct RecursiveJoinVectors { common::ValueVector* srcNodeIDVector = nullptr; common::ValueVector* dstNodeIDVector = nullptr; common::ValueVector* pathLengthVector = nullptr; - common::ValueVector* pathVector = nullptr; - common::ValueVector* pathNodeIDVector = nullptr; - common::ValueVector* pathRelIDVector = nullptr; + common::ValueVector* pathVector = nullptr; // STRUCT(LIST(STRUCT), LIST(INTERNAL_ID)) + common::ValueVector* pathNodesVector = nullptr; // LIST(STRUCT) + common::ValueVector* pathNodesIDDataVector = nullptr; // INTERNAL_ID + common::ValueVector* pathRelsVector = nullptr; // LIST(STRUCT) + common::ValueVector* pathRelsIDDataVector = nullptr; // INTERNAL_ID common::ValueVector* recursiveEdgeIDVector = nullptr; common::ValueVector* recursiveDstNodeIDVector = nullptr; diff --git a/src/include/processor/result/factorized_table.h b/src/include/processor/result/factorized_table.h index 5da7f63c91..52f8519109 100644 --- a/src/include/processor/result/factorized_table.h +++ b/src/include/processor/result/factorized_table.h @@ -174,6 +174,7 @@ class FactorizedTable { friend FlatTupleIterator; friend class JoinHashTable; friend class IntersectHashTable; + friend class PathPropertyProbe; public: FactorizedTable( @@ -306,16 +307,12 @@ class FactorizedTable { uint8_t** tuplesToRead, ft_col_idx_t colIdx, common::ValueVector& vector) const; void readUnflatCol(const uint8_t* tupleToRead, const common::SelectionVector* selVector, ft_col_idx_t colIdx, common::ValueVector& vector) const; - void readFlatColToFlatVector( - uint8_t** tuplesToRead, ft_col_idx_t colIdx, common::ValueVector& vector) const; + void readFlatColToFlatVector(uint8_t* tupleToRead, ft_col_idx_t colIdx, + common::ValueVector& vector, common::sel_t pos) const; void readFlatColToUnflatVector(uint8_t** tuplesToRead, ft_col_idx_t colIdx, common::ValueVector& vector, uint64_t numTuplesToRead) const; - inline void readFlatCol(uint8_t** tuplesToRead, ft_col_idx_t colIdx, - common::ValueVector& vector, uint64_t numTuplesToRead) const { - vector.state->isFlat() ? - readFlatColToFlatVector(tuplesToRead, colIdx, vector) : - readFlatColToUnflatVector(tuplesToRead, colIdx, vector, numTuplesToRead); - } + void readFlatCol(uint8_t** tuplesToRead, ft_col_idx_t colIdx, common::ValueVector& vector, + uint64_t numTuplesToRead) const; static void copyOverflowIfNecessary(uint8_t* dst, uint8_t* src, const common::LogicalType& type, storage::DiskOverflowFile* diskOverflowFile); diff --git a/src/optimizer/projection_push_down_optimizer.cpp b/src/optimizer/projection_push_down_optimizer.cpp index a2d200c990..235c98bb49 100644 --- a/src/optimizer/projection_push_down_optimizer.cpp +++ b/src/optimizer/projection_push_down_optimizer.cpp @@ -3,6 +3,7 @@ #include "planner/logical_plan/logical_operator/logical_accumulate.h" #include "planner/logical_plan/logical_operator/logical_create.h" #include "planner/logical_plan/logical_operator/logical_delete.h" +#include "planner/logical_plan/logical_operator/logical_extend.h" #include "planner/logical_plan/logical_operator/logical_filter.h" #include "planner/logical_plan/logical_operator/logical_hash_join.h" #include "planner/logical_plan/logical_operator/logical_intersect.h" @@ -38,12 +39,31 @@ void ProjectionPushDownOptimizer::visitOperator(LogicalOperator* op) { void ProjectionPushDownOptimizer::visitRecursiveExtend(LogicalOperator* op) { auto recursiveExtend = (LogicalRecursiveExtend*)op; + auto boundNodeID = recursiveExtend->getBoundNode()->getInternalIDProperty(); + collectExpressionsInUse(boundNodeID); auto rel = recursiveExtend->getRel(); + auto recursiveInfo = rel->getRecursiveInfo(); if (!variablesInUse.contains(rel)) { recursiveExtend->setJoinType(planner::RecursiveJoinType::TRACK_NONE); + // Remove build size + recursiveExtend->setChildren( + std::vector>{recursiveExtend->getChild(0)}); + } else { + // Pre-append projection to rel property build. + expression_vector properties; + for (auto& expression : recursiveInfo->rel->getPropertyExpressions()) { + properties.push_back(expression->copy()); + } + preAppendProjection(op, 2, properties); } } +void ProjectionPushDownOptimizer::visitExtend(planner::LogicalOperator* op) { + auto extend = (LogicalExtend*)op; + auto boundNodeID = extend->getBoundNode()->getInternalIDProperty(); + collectExpressionsInUse(boundNodeID); +} + void ProjectionPushDownOptimizer::visitAccumulate(planner::LogicalOperator* op) { auto accumulate = (LogicalAccumulate*)op; auto expressionsBeforePruning = accumulate->getExpressions(); diff --git a/src/planner/join_order/append_extend.cpp b/src/planner/join_order/append_extend.cpp index 433d7c0619..47d169b9c9 100644 --- a/src/planner/join_order/append_extend.cpp +++ b/src/planner/join_order/append_extend.cpp @@ -48,10 +48,25 @@ void JoinOrderEnumerator::appendNonRecursiveExtend(std::shared_ptr boundNode, std::shared_ptr nbrNode, std::shared_ptr rel, ExtendDirection direction, LogicalPlan& plan) { + auto recursiveInfo = rel->getRecursiveInfo(); + queryPlanner->appendAccumulate(plan); + std::vector> children; + children.push_back(plan.getLastOperator()); + // Recursive node property scan plan + auto recursiveNodePropertyScanPlan = std::make_unique(); + createRecursiveNodePropertyScanPlan(recursiveInfo->node, *recursiveNodePropertyScanPlan); + children.push_back(recursiveNodePropertyScanPlan->getLastOperator()); + // Recursive rel property scan plan + auto recursiveRelPropertyScanPlan = std::make_unique(); + createRecursiveRelPropertyScanPlan( + recursiveInfo->node, nbrNode, recursiveInfo->rel, direction, *recursiveRelPropertyScanPlan); + children.push_back(recursiveRelPropertyScanPlan->getLastOperator()); + // Recursive plan auto recursivePlan = std::make_unique(); - createRecursivePlan(boundNode, rel->getRecursiveNode(), rel, direction, *recursivePlan); + createRecursivePlan( + boundNode, recursiveInfo->node, recursiveInfo->rel, direction, *recursivePlan); auto extend = std::make_shared(boundNode, nbrNode, rel, direction, - plan.getLastOperator(), recursivePlan->getLastOperator()); + RecursiveJoinType::TRACK_PATH, std::move(children), recursivePlan->getLastOperator()); queryPlanner->appendFlattens(extend->getGroupsPosToFlatten(), plan); extend->setChild(0, plan.getLastOperator()); extend->computeFactorizedSchema(); @@ -66,13 +81,34 @@ void JoinOrderEnumerator::appendRecursiveExtend(std::shared_ptr } void JoinOrderEnumerator::createRecursivePlan(std::shared_ptr boundNode, - std::shared_ptr recursiveNode, std::shared_ptr rel, + std::shared_ptr recursiveNode, std::shared_ptr recursiveRel, ExtendDirection direction, LogicalPlan& plan) { auto scanFrontier = std::make_shared(boundNode); scanFrontier->computeFactorizedSchema(); plan.setLastOperator(std::move(scanFrontier)); - auto properties = expression_vector{rel->getInternalIDProperty()}; - appendNonRecursiveExtend(boundNode, recursiveNode, rel, direction, properties, plan); + auto properties = expression_vector{recursiveRel->getInternalIDProperty()}; + appendNonRecursiveExtend(boundNode, recursiveNode, recursiveRel, direction, properties, plan); +} + +void JoinOrderEnumerator::createRecursiveNodePropertyScanPlan( + std::shared_ptr recursiveNode, LogicalPlan& plan) { + appendScanNodeID(recursiveNode, plan); + expression_vector properties; + for (auto& property : recursiveNode->getPropertyExpressions()) { + properties.push_back(property->copy()); + } + queryPlanner->appendScanNodePropIfNecessary(properties, recursiveNode, plan); +} + +void JoinOrderEnumerator::createRecursiveRelPropertyScanPlan( + std::shared_ptr recursiveNode, std::shared_ptr nbrNode, + std::shared_ptr recursiveRel, ExtendDirection direction, LogicalPlan& plan) { + appendScanNodeID(recursiveNode, plan); + expression_vector properties; + for (auto& property : recursiveRel->getPropertyExpressions()) { + properties.push_back(property->copy()); + } + appendNonRecursiveExtend(recursiveNode, nbrNode, recursiveRel, direction, properties, plan); } } // namespace planner diff --git a/src/planner/join_order/cardinality_estimator.cpp b/src/planner/join_order/cardinality_estimator.cpp index 2e8d3ba97a..c203ee09ec 100644 --- a/src/planner/join_order/cardinality_estimator.cpp +++ b/src/planner/join_order/cardinality_estimator.cpp @@ -9,9 +9,12 @@ namespace planner { void CardinalityEstimator::initNodeIDDom(binder::QueryGraph* queryGraph) { for (auto i = 0u; i < queryGraph->getNumQueryNodes(); ++i) { - auto node = queryGraph->getQueryNode(i); - if (!nodeIDName2dom.contains(node->getInternalIDPropertyName())) { - nodeIDName2dom.insert({node->getInternalIDPropertyName(), getNumNodes(*node)}); + addNodeIDDom(*queryGraph->getQueryNode(i)); + } + for (auto i = 0u; i < queryGraph->getNumQueryRels(); ++i) { + auto rel = queryGraph->getQueryRel(i); + if (common::QueryRelTypeUtils::isRecursive(rel->getRelType())) { + addNodeIDDom(*rel->getRecursiveInfo()->node); } } } diff --git a/src/planner/operator/logical_recursive_extend.cpp b/src/planner/operator/logical_recursive_extend.cpp index 4d881ee9b8..760cc662c5 100644 --- a/src/planner/operator/logical_recursive_extend.cpp +++ b/src/planner/operator/logical_recursive_extend.cpp @@ -29,13 +29,11 @@ void LogicalRecursiveExtend::computeFlatSchema() { break; } auto rewriter = optimizer::RemoveFactorizationRewriter(); - rewriter.visitOperator(recursivePlanRoot); + rewriter.visitOperator(recursiveChild); } void LogicalRecursiveExtend::computeFactorizedSchema() { - createEmptySchema(); - auto childSchema = children[0]->getSchema(); - SinkOperatorUtil::recomputeSchema(*childSchema, childSchema->getExpressionsInScope(), *schema); + copyChildSchema(0); auto nbrGroupPos = schema->createGroup(); schema->insertToGroupAndScope(nbrNode->getInternalIDProperty(), nbrGroupPos); schema->insertToGroupAndScope(rel->getLengthExpression(), nbrGroupPos); @@ -47,7 +45,7 @@ void LogicalRecursiveExtend::computeFactorizedSchema() { break; } auto rewriter = optimizer::FactorizationRewriter(); - rewriter.visitOperator(recursivePlanRoot.get()); + rewriter.visitOperator(recursiveChild.get()); } void LogicalScanFrontier::computeFlatSchema() { diff --git a/src/processor/mapper/map_recursive_extend.cpp b/src/processor/mapper/map_recursive_extend.cpp index dc965ef06b..b6ba71038e 100644 --- a/src/processor/mapper/map_recursive_extend.cpp +++ b/src/processor/mapper/map_recursive_extend.cpp @@ -1,80 +1,137 @@ +#include "common/string_utils.h" #include "planner/logical_plan/logical_operator/logical_recursive_extend.h" #include "processor/mapper/plan_mapper.h" +#include "processor/operator/hash_join/hash_join_build.h" +#include "processor/operator/recursive_extend/path_property_probe.h" #include "processor/operator/recursive_extend/recursive_join.h" #include "processor/operator/table_scan/factorized_table_scan.h" +using namespace kuzu::binder; using namespace kuzu::planner; namespace kuzu { namespace processor { +static std::shared_ptr createSharedState( + const binder::NodeExpression& nbrNode, const storage::StorageManager& storageManager) { + std::vector> semiMasks; + for (auto tableID : nbrNode.getTableIDs()) { + auto nodeTable = storageManager.getNodesStore().getNodeTable(tableID); + semiMasks.push_back(std::make_unique(nodeTable)); + } + return std::make_shared(std::move(semiMasks)); +} + +static std::vector getColIdxToScan( + const expression_vector& payloads, uint32_t numKeys, const common::LogicalType& structType) { + std::unordered_map propertyNameToColumnIdx; + for (auto i = 0u; i < payloads.size(); ++i) { + assert(payloads[i]->expressionType == common::PROPERTY); + auto propertyName = ((PropertyExpression*)payloads[i].get())->getPropertyName(); + common::StringUtils::toUpper(propertyName); + propertyNameToColumnIdx.insert({propertyName, i + numKeys}); + } + auto nodeStructFields = common::StructType::getFields(&structType); + std::vector colIndicesToScan; + for (auto i = 1u; i < nodeStructFields.size(); ++i) { + auto field = nodeStructFields[i]; + colIndicesToScan.push_back(propertyNameToColumnIdx.at(field->getName())); + } + return colIndicesToScan; +} + std::unique_ptr PlanMapper::mapLogicalRecursiveExtendToPhysical( planner::LogicalOperator* logicalOperator) { auto extend = (LogicalRecursiveExtend*)logicalOperator; auto boundNode = extend->getBoundNode(); auto nbrNode = extend->getNbrNode(); auto rel = extend->getRel(); - auto recursiveNode = rel->getRecursiveNode(); + auto recursiveInfo = rel->getRecursiveInfo(); auto lengthExpression = rel->getLengthExpression(); - // map recursive plan - auto logicalRecursiveRoot = extend->getRecursivePlanRoot(); + // Map recursive plan + auto logicalRecursiveRoot = extend->getRecursiveChild(); auto recursiveRoot = mapLogicalOperatorToPhysical(logicalRecursiveRoot); auto recursivePlanSchema = logicalRecursiveRoot->getSchema(); auto recursivePlanResultSetDescriptor = std::make_unique(recursivePlanSchema); - auto recursiveDstNodeIDPos = - DataPos(recursivePlanSchema->getExpressionPos(*recursiveNode->getInternalIDProperty())); - auto recursiveEdgeIDPos = - DataPos(recursivePlanSchema->getExpressionPos(*rel->getInternalIDProperty())); - // map child plan + auto recursiveDstNodeIDPos = DataPos( + recursivePlanSchema->getExpressionPos(*recursiveInfo->node->getInternalIDProperty())); + auto recursiveEdgeIDPos = DataPos( + recursivePlanSchema->getExpressionPos(*recursiveInfo->rel->getInternalIDProperty())); + // Map child plan + auto prevOperator = mapLogicalOperatorToPhysical(logicalOperator->getChild(0)); + // Generate RecursiveJoin auto outSchema = extend->getSchema(); auto inSchema = extend->getChild(0)->getSchema(); - auto prevOperator = mapLogicalOperatorToPhysical(logicalOperator->getChild(0)); - auto expressions = inSchema->getExpressionsInScope(); - auto resultCollector = appendResultCollector(expressions, inSchema, std::move(prevOperator)); - auto sharedFTable = resultCollector->getSharedState(); - sharedFTable->setMaxMorselSize(1); - std::vector outDataPoses; - std::vector colIndicesToScan; - for (auto i = 0u; i < expressions.size(); ++i) { - outDataPoses.emplace_back(outSchema->getExpressionPos(*expressions[i])); - colIndicesToScan.push_back(i); + auto boundNodeIDPos = DataPos(inSchema->getExpressionPos(*boundNode->getInternalIDProperty())); + auto nbrNodeIDPos = DataPos(outSchema->getExpressionPos(*nbrNode->getInternalIDProperty())); + auto lengthPos = DataPos(outSchema->getExpressionPos(*lengthExpression)); + auto sharedState = createSharedState(*nbrNode, storageManager); + auto pathPos = DataPos(); + if (extend->getJoinType() == planner::RecursiveJoinType::TRACK_PATH) { + pathPos = DataPos(outSchema->getExpressionPos(*rel)); } - auto fTableScan = make_unique(std::move(outDataPoses), - std::move(colIndicesToScan), sharedFTable, std::move(resultCollector), getOperatorID(), ""); - // Generate RecursiveJoinDataInfo - auto boundNodeIDVectorPos = - DataPos(inSchema->getExpressionPos(*boundNode->getInternalIDProperty())); - auto nbrNodeIDVectorPos = - DataPos(outSchema->getExpressionPos(*nbrNode->getInternalIDProperty())); - auto lengthVectorPos = DataPos(outSchema->getExpressionPos(*lengthExpression)); - std::unique_ptr dataInfo; + auto dataInfo = std::make_unique(boundNodeIDPos, nbrNodeIDPos, + nbrNode->getTableIDsSet(), lengthPos, std::move(recursivePlanResultSetDescriptor), + recursiveDstNodeIDPos, recursiveInfo->node->getTableIDsSet(), recursiveEdgeIDPos, pathPos); + auto recursiveJoin = std::make_unique(rel->getLowerBound(), rel->getUpperBound(), + rel->getRelType(), extend->getJoinType(), sharedState, std::move(dataInfo), + std::move(prevOperator), getOperatorID(), extend->getExpressionsForPrinting(), + std::move(recursiveRoot)); switch (extend->getJoinType()) { case planner::RecursiveJoinType::TRACK_PATH: { - auto pathVectorPos = DataPos(outSchema->getExpressionPos(*rel)); - dataInfo = std::make_unique(boundNodeIDVectorPos, nbrNodeIDVectorPos, - nbrNode->getTableIDsSet(), lengthVectorPos, std::move(recursivePlanResultSetDescriptor), - recursiveDstNodeIDPos, recursiveNode->getTableIDsSet(), recursiveEdgeIDPos, - pathVectorPos); - } break; + // Map build node property + auto nodeBuildPrevOperator = mapLogicalOperatorToPhysical(extend->getChild(1)); + auto nodeBuildSchema = extend->getChild(1)->getSchema(); + auto nodeKeys = expression_vector{recursiveInfo->node->getInternalIDProperty()}; + auto nodePayloads = + ExpressionUtil::excludeExpressions(nodeBuildSchema->getExpressionsInScope(), nodeKeys); + auto nodeBuildInfo = createHashBuildInfo(*nodeBuildSchema, nodeKeys, nodePayloads); + auto nodeHashTable = std::make_unique( + *memoryManager, nodeBuildInfo->getNumKeys(), nodeBuildInfo->getTableSchema()->copy()); + auto nodeBuildSharedState = std::make_shared(std::move(nodeHashTable)); + auto nodeBuild = make_unique( + std::make_unique(nodeBuildSchema), nodeBuildSharedState, + std::move(nodeBuildInfo), std::move(nodeBuildPrevOperator), getOperatorID(), ""); + // Map build rel property + auto relBuildPrvOperator = mapLogicalOperatorToPhysical(extend->getChild(2)); + auto relBuildSchema = extend->getChild(2)->getSchema(); + auto relKeys = expression_vector{recursiveInfo->rel->getInternalIDProperty()}; + auto relPayloads = + ExpressionUtil::excludeExpressions(relBuildSchema->getExpressionsInScope(), relKeys); + auto relBuildInfo = createHashBuildInfo(*relBuildSchema, relKeys, relPayloads); + auto relHashTable = std::make_unique( + *memoryManager, relBuildInfo->getNumKeys(), relBuildInfo->getTableSchema()->copy()); + auto relBuildSharedState = std::make_shared(std::move(relHashTable)); + auto relBuild = std::make_unique( + std::make_unique(relBuildSchema), relBuildSharedState, + std::move(relBuildInfo), std::move(relBuildPrvOperator), getOperatorID(), ""); + // Map probe + auto relDataType = rel->getDataType(); + auto nodesField = + common::StructType::getField(&relDataType, common::InternalKeyword::NODES); + auto nodeStructType = common::VarListType::getChildType(nodesField->getType()); + auto nodeColIndicesToScan = getColIdxToScan(nodePayloads, nodeKeys.size(), *nodeStructType); + auto relsField = common::StructType::getField(&relDataType, common::InternalKeyword::RELS); + auto relStructType = common::VarListType::getChildType(relsField->getType()); + auto relColIndicesToScan = getColIdxToScan(relPayloads, relKeys.size(), *relStructType); + auto pathProbeInfo = std::make_unique( + pathPos, std::move(nodeColIndicesToScan), std::move(relColIndicesToScan)); + auto pathProbeSharedState = std::make_shared( + nodeBuildSharedState, relBuildSharedState); + std::vector> children; + children.push_back(std::move(recursiveJoin)); + children.push_back(std::move(nodeBuild)); + children.push_back(std::move(relBuild)); + return std::make_unique(std::move(pathProbeInfo), pathProbeSharedState, + std::move(children), getOperatorID(), ""); + } case planner::RecursiveJoinType::TRACK_NONE: { - dataInfo = std::make_unique(boundNodeIDVectorPos, nbrNodeIDVectorPos, - nbrNode->getTableIDsSet(), lengthVectorPos, std::move(recursivePlanResultSetDescriptor), - recursiveDstNodeIDPos, recursiveNode->getTableIDsSet(), recursiveEdgeIDPos); - } break; + return recursiveJoin; + } default: throw common::NotImplementedException("PlanMapper::mapLogicalRecursiveExtendToPhysical"); } - std::vector> semiMasks; - for (auto tableID : nbrNode->getTableIDs()) { - auto nodeTable = storageManager.getNodesStore().getNodeTable(tableID); - semiMasks.push_back(std::make_unique(nodeTable)); - } - auto sharedState = std::make_shared(std::move(semiMasks)); - return std::make_unique(rel->getLowerBound(), rel->getUpperBound(), - rel->getRelType(), extend->getJoinType(), sharedState, std::move(dataInfo), - std::move(fTableScan), getOperatorID(), extend->getExpressionsForPrinting(), - std::move(recursiveRoot)); } } // namespace processor diff --git a/src/processor/mapper/map_semi_masker.cpp b/src/processor/mapper/map_semi_masker.cpp index 4fa078f391..fb2c78e8e0 100644 --- a/src/processor/mapper/map_semi_masker.cpp +++ b/src/processor/mapper/map_semi_masker.cpp @@ -32,6 +32,13 @@ std::unique_ptr PlanMapper::mapLogicalSemiMaskerToPhysical( tableState->getSemiMask(), 0 /* initial mask idx */); } } break; + case PhysicalOperatorType::PATH_PROPERTY_PROBE: { + auto recursiveJoin = (RecursiveJoin*)physicalOp->getChild(0); + for (auto& semiMask : recursiveJoin->getSharedState()->semiMasks) { + auto tableID = semiMask->getNodeTable()->getTableID(); + masksPerTable.at(tableID).emplace_back(semiMask.get(), 0 /* initial mask idx */); + } + } break; case PhysicalOperatorType::RECURSIVE_JOIN: { auto recursiveJoin = (RecursiveJoin*)physicalOp; for (auto& semiMask : recursiveJoin->getSharedState()->semiMasks) { diff --git a/src/processor/operator/physical_operator.cpp b/src/processor/operator/physical_operator.cpp index 146ecb3372..e704c15afe 100644 --- a/src/processor/operator/physical_operator.cpp +++ b/src/processor/operator/physical_operator.cpp @@ -95,6 +95,9 @@ std::string PhysicalOperatorUtils::operatorTypeToString(PhysicalOperatorType ope case PhysicalOperatorType::MULTIPLICITY_REDUCER: { return "MULTIPLICITY_REDUCER"; } + case PhysicalOperatorType::PATH_PROPERTY_PROBE: { + return "PATH_PROPERTY_PROBE"; + } case PhysicalOperatorType::PROJECTION: { return "PROJECTION"; } diff --git a/src/processor/operator/recursive_extend/CMakeLists.txt b/src/processor/operator/recursive_extend/CMakeLists.txt index 7bf992a210..8cd174c29f 100644 --- a/src/processor/operator/recursive_extend/CMakeLists.txt +++ b/src/processor/operator/recursive_extend/CMakeLists.txt @@ -3,6 +3,7 @@ add_library(kuzu_processor_operator_ver_length_extend frontier.cpp frontier_scanner.cpp recursive_join.cpp + path_property_probe.cpp scan_frontier.cpp) set(ALL_OBJECT_FILES diff --git a/src/processor/operator/recursive_extend/frontier_scanner.cpp b/src/processor/operator/recursive_extend/frontier_scanner.cpp index 8368f2aa6b..b06f66ba57 100644 --- a/src/processor/operator/recursive_extend/frontier_scanner.cpp +++ b/src/processor/operator/recursive_extend/frontier_scanner.cpp @@ -97,19 +97,18 @@ void PathScanner::initDfs(const frontier::node_rel_id_t& nodeAndRelID, size_t cu void PathScanner::writePathToVector(RecursiveJoinVectors* vectors, common::sel_t& vectorPos, common::sel_t& nodeIDDataVectorPos, common::sel_t& relIDDataVectorPos) { assert(vectorPos < common::DEFAULT_VECTOR_CAPACITY); - auto nodeIDEntry = common::ListVector::addList(vectors->pathNodeIDVector, k + 1); - auto relIDEntry = common::ListVector::addList(vectors->pathRelIDVector, k); - vectors->pathNodeIDVector->setValue(vectorPos, nodeIDEntry); - vectors->pathRelIDVector->setValue(vectorPos, relIDEntry); + auto nodeEntry = common::ListVector::addList(vectors->pathNodesVector, k + 1); + auto relEntry = common::ListVector::addList(vectors->pathRelsVector, k); + vectors->pathNodesVector->setValue(vectorPos, nodeEntry); + vectors->pathRelsVector->setValue(vectorPos, relEntry); writeDstNodeOffsetAndLength(vectors->dstNodeIDVector, vectors->pathLengthVector, vectorPos); vectorPos++; - auto nodeIDDataVector = common::ListVector::getDataVector(vectors->pathNodeIDVector); - auto relIDDataVector = common::ListVector::getDataVector(vectors->pathRelIDVector); for (auto i = 0u; i < k; ++i) { - nodeIDDataVector->setValue(nodeIDDataVectorPos++, nodeIDs[i]); - relIDDataVector->setValue(relIDDataVectorPos++, relIDs[i]); + vectors->pathNodesIDDataVector->setValue( + nodeIDDataVectorPos++, nodeIDs[i]); + vectors->pathRelsIDDataVector->setValue(relIDDataVectorPos++, relIDs[i]); } - nodeIDDataVector->setValue(nodeIDDataVectorPos++, nodeIDs[k]); + vectors->pathNodesIDDataVector->setValue(nodeIDDataVectorPos++, nodeIDs[k]); } void DstNodeWithMultiplicityScanner::scanFromDstOffset(RecursiveJoinVectors* vectors, diff --git a/src/processor/operator/recursive_extend/path_property_probe.cpp b/src/processor/operator/recursive_extend/path_property_probe.cpp new file mode 100644 index 0000000000..19c82a6915 --- /dev/null +++ b/src/processor/operator/recursive_extend/path_property_probe.cpp @@ -0,0 +1,106 @@ +#include "processor/operator/recursive_extend/path_property_probe.h" + +#include "function/hash/vector_hash_operations.h" +using namespace kuzu::common; + +namespace kuzu { +namespace processor { + +void PathPropertyProbe::initLocalStateInternal(ResultSet* resultSet_, ExecutionContext* context) { + localState = std::make_unique(); + vectors = std::make_unique(); + auto pathVector = resultSet->getValueVector(info->pathPos); + auto pathNodesFieldIdx = StructType::getFieldIdx(&pathVector->dataType, InternalKeyword::NODES); + auto pathRelsFieldIdx = StructType::getFieldIdx(&pathVector->dataType, InternalKeyword::RELS); + vectors->pathNodesVector = + StructVector::getFieldVector(pathVector.get(), pathNodesFieldIdx).get(); + vectors->pathRelsVector = + StructVector::getFieldVector(pathVector.get(), pathRelsFieldIdx).get(); + auto pathNodesDataVector = ListVector::getDataVector(vectors->pathNodesVector); + auto pathRelsDataVector = ListVector::getDataVector(vectors->pathRelsVector); + auto pathNodesIDFieldIdx = + StructType::getFieldIdx(&pathNodesDataVector->dataType, InternalKeyword::ID); + auto pathRelsIDFieldIdx = + StructType::getFieldIdx(&pathRelsDataVector->dataType, InternalKeyword::ID); + assert(pathNodesFieldIdx == 0 && pathRelsIDFieldIdx == 0); + vectors->pathNodesIDDataVector = + StructVector::getFieldVector(pathNodesDataVector, pathNodesIDFieldIdx).get(); + vectors->pathRelsIDDataVector = + StructVector::getFieldVector(pathRelsDataVector, pathRelsIDFieldIdx).get(); + for (auto i = 1u; i < StructType::getNumFields(&pathNodesDataVector->dataType); ++i) { + vectors->pathNodesPropertyDataVectors.push_back( + StructVector::getFieldVector(pathNodesDataVector, i).get()); + } + for (auto i = 1u; i < StructType::getNumFields(&pathRelsDataVector->dataType); ++i) { + vectors->pathRelsPropertyDataVectors.push_back( + StructVector::getFieldVector(pathRelsDataVector, i).get()); + } +} + +bool PathPropertyProbe::getNextTuplesInternal(ExecutionContext* context) { + if (!children[0]->getNextTuple(context)) { + return false; + } + // Scan node property + auto nodeHashTable = sharedState->nodeHashTableState->getHashTable(); + auto nodeDataSize = ListVector::getDataVectorSize(vectors->pathNodesVector); + auto sizeProbed = 0u; + while (sizeProbed < nodeDataSize) { + auto sizeToProbe = std::min(DEFAULT_VECTOR_CAPACITY, nodeDataSize - sizeProbed); + probe(nodeHashTable, sizeProbed, sizeToProbe, vectors->pathNodesIDDataVector, + vectors->pathNodesPropertyDataVectors, info->nodeHashTableColIndicesToScan); + sizeProbed += sizeToProbe; + } + // Scan rel property + auto relHashTable = sharedState->relHashTableState->getHashTable(); + auto relDataSize = ListVector::getDataVectorSize(vectors->pathRelsVector); + sizeProbed = 0u; + while (sizeProbed < relDataSize) { + auto sizeToProbe = std::min(DEFAULT_VECTOR_CAPACITY, relDataSize - sizeProbed); + probe(relHashTable, sizeProbed, sizeToProbe, vectors->pathRelsIDDataVector, + vectors->pathRelsPropertyDataVectors, info->relHashTableColIndicesToScan); + sizeProbed += sizeToProbe; + } + return true; +} + +void PathPropertyProbe::probe(kuzu::processor::JoinHashTable* hashTable, uint64_t sizeProbed, + uint64_t sizeToProbe, common::ValueVector* idVector, + const std::vector& propertyVectors, + const std::vector& colIndicesToScan) { + // Hash + for (auto i = 0u; i < sizeToProbe; ++i) { + function::operation::Hash::operation( + idVector->getValue(sizeProbed + i), localState->hashes[i]); + } + // Probe hash + for (auto i = 0u; i < sizeToProbe; ++i) { + localState->probedTuples[i] = hashTable->getTupleForHash(localState->hashes[i]); + } + // Match value + for (auto i = 0u; i < sizeToProbe; ++i) { + while (localState->probedTuples[i]) { + auto currentTuple = localState->probedTuples[i]; + if (*(internalID_t*)currentTuple == idVector->getValue(sizeProbed + i)) { + localState->matchedTuples[i] = currentTuple; + break; + } + localState->probedTuples[i] = *hashTable->getPrevTuple(currentTuple); + } + assert(localState->matchedTuples[i] != nullptr); + } + // Scan table + auto factorizedTable = hashTable->getFactorizedTable(); + for (auto i = 0u; i < sizeToProbe; ++i) { + auto tuple = localState->matchedTuples[i]; + for (auto j = 0u; j < propertyVectors.size(); ++j) { + auto propertyVector = propertyVectors[j]; + auto colIdx = colIndicesToScan[j]; + factorizedTable->readFlatColToFlatVector( + tuple, colIdx, *propertyVector, sizeProbed + i); + } + } +} + +} // namespace processor +} // namespace kuzu diff --git a/src/processor/operator/recursive_extend/recursive_join.cpp b/src/processor/operator/recursive_extend/recursive_join.cpp index f6ef6f92ea..c76f28d0f7 100644 --- a/src/processor/operator/recursive_extend/recursive_join.cpp +++ b/src/processor/operator/recursive_extend/recursive_join.cpp @@ -5,6 +5,8 @@ #include "processor/operator/recursive_extend/shortest_path_state.h" #include "processor/operator/recursive_extend/variable_length_state.h" +using namespace kuzu::common; + namespace kuzu { namespace processor { @@ -85,15 +87,28 @@ void RecursiveJoin::initLocalStateInternal(ResultSet* resultSet_, ExecutionConte throw common::NotImplementedException("BaseRecursiveJoin::initLocalStateInternal"); } if (vectors->pathVector != nullptr) { - assert(vectors->pathVector->dataType.getPhysicalType() == common::PhysicalTypeID::STRUCT); - auto nodeIDFieldIdx = common::StructType::getFieldIdx( + auto pathNodesFieldIdx = common::StructType::getFieldIdx( &vectors->pathVector->dataType, common::InternalKeyword::NODES); - auto relIDFieldIdx = common::StructType::getFieldIdx( + vectors->pathNodesVector = + StructVector::getFieldVector(vectors->pathVector, pathNodesFieldIdx).get(); + auto pathNodesDataVector = ListVector::getDataVector(vectors->pathNodesVector); + auto pathNodesIDFieldIdx = + StructType::getFieldIdx(&pathNodesDataVector->dataType, InternalKeyword::ID); + vectors->pathNodesIDDataVector = + StructVector::getFieldVector(pathNodesDataVector, pathNodesIDFieldIdx).get(); + assert(vectors->pathNodesIDDataVector->dataType.getPhysicalType() == + common::PhysicalTypeID::INTERNAL_ID); + auto pathRelsFieldIdx = common::StructType::getFieldIdx( &vectors->pathVector->dataType, common::InternalKeyword::RELS); - vectors->pathNodeIDVector = - common::StructVector::getFieldVector(vectors->pathVector, nodeIDFieldIdx).get(); - vectors->pathRelIDVector = - common::StructVector::getFieldVector(vectors->pathVector, relIDFieldIdx).get(); + vectors->pathRelsVector = + StructVector::getFieldVector(vectors->pathVector, pathRelsFieldIdx).get(); + auto pathRelsDataVector = ListVector::getDataVector(vectors->pathRelsVector); + auto pathRelsIDFieldIdx = + StructType::getFieldIdx(&pathRelsDataVector->dataType, InternalKeyword::ID); + vectors->pathRelsIDDataVector = + StructVector::getFieldVector(pathRelsDataVector, pathRelsIDFieldIdx).get(); + assert(vectors->pathRelsIDDataVector->dataType.getPhysicalType() == + common::PhysicalTypeID::INTERNAL_ID); } frontiersScanner = std::make_unique(std::move(scanners)); initLocalRecursivePlan(context); diff --git a/src/processor/result/factorized_table.cpp b/src/processor/result/factorized_table.cpp index a22c8b82ec..c8be556a9f 100644 --- a/src/processor/result/factorized_table.cpp +++ b/src/processor/result/factorized_table.cpp @@ -602,14 +602,22 @@ void FactorizedTable::readUnflatCol(const uint8_t* tupleToRead, const SelectionV } void FactorizedTable::readFlatColToFlatVector( - uint8_t** tuplesToRead, ft_col_idx_t colIdx, ValueVector& vector) const { - assert(vector.state->isFlat()); - auto pos = vector.state->selVector->selectedPositions[0]; - if (isNonOverflowColNull(tuplesToRead[0] + tableSchema->getNullMapOffset(), colIdx)) { + uint8_t* tupleToRead, ft_col_idx_t colIdx, ValueVector& vector, common::sel_t pos) const { + if (isNonOverflowColNull(tupleToRead + tableSchema->getNullMapOffset(), colIdx)) { vector.setNull(pos, true); } else { vector.setNull(pos, false); - vector.copyFromRowData(pos, tuplesToRead[0] + tableSchema->getColOffset(colIdx)); + vector.copyFromRowData(pos, tupleToRead + tableSchema->getColOffset(colIdx)); + } +} + +void FactorizedTable::readFlatCol(uint8_t** tuplesToRead, ft_col_idx_t colIdx, + common::ValueVector& vector, uint64_t numTuplesToRead) const { + if (vector.state->isFlat()) { + auto pos = vector.state->selVector->selectedPositions[0]; + readFlatColToFlatVector(tuplesToRead[0], colIdx, vector, pos); + } else { + readFlatColToUnflatVector(tuplesToRead, colIdx, vector, numTuplesToRead); } } diff --git a/test/test_files/demo_db/demo_db.test b/test/test_files/demo_db/demo_db.test index 0f04911e2b..6e844b81e1 100644 --- a/test/test_files/demo_db/demo_db.test +++ b/test/test_files/demo_db/demo_db.test @@ -114,10 +114,10 @@ Kitchener|2 -NAME ReturnVarLen -QUERY MATCH (a:User)-[e:Follows*1..2]->(b:User) WHERE a.name = 'Adam' RETURN b.name, e; ---- 4 -Karissa|{_NODES: [0:0,0:1], _RELS: [2:0]} -Noura|{_NODES: [0:0,0:2,0:3], _RELS: [2:1,2:3]} -Zhang|{_NODES: [0:0,0:1,0:2], _RELS: [2:0,2:2]} -Zhang|{_NODES: [0:0,0:2], _RELS: [2:1]} +Karissa|{_NODES: [{_ID: 0:0, NAME: Adam, AGE: 30},{_ID: 0:1, NAME: Karissa, AGE: 40}], _RELS: [{_ID: 2:0, SINCE: 2020}]} +Noura|{_NODES: [{_ID: 0:0, NAME: Adam, AGE: 30},{_ID: 0:2, NAME: Zhang, AGE: 50},{_ID: 0:3, NAME: Noura, AGE: 25}], _RELS: [{_ID: 2:1, SINCE: 2020},{_ID: 2:3, SINCE: 2022}]} +Zhang|{_NODES: [{_ID: 0:0, NAME: Adam, AGE: 30},{_ID: 0:1, NAME: Karissa, AGE: 40},{_ID: 0:2, NAME: Zhang, AGE: 50}], _RELS: [{_ID: 2:0, SINCE: 2020},{_ID: 2:2, SINCE: 2021}]} +Zhang|{_NODES: [{_ID: 0:0, NAME: Adam, AGE: 30},{_ID: 0:2, NAME: Zhang, AGE: 50}], _RELS: [{_ID: 2:1, SINCE: 2020}]} -NAME ShortestPath -QUERY MATCH (a:User)-[e* SHORTEST 1..4]->(b:City) WHERE a.name = 'Adam' RETURN b.name, length(e) AS length; diff --git a/test/test_files/shortest_path/bfs_sssp.test b/test/test_files/shortest_path/bfs_sssp.test index f5ae5f3b4e..dd784486e0 100644 --- a/test/test_files/shortest_path/bfs_sssp.test +++ b/test/test_files/shortest_path/bfs_sssp.test @@ -5,27 +5,40 @@ -CASE Bfs +#Alice|Bob|{_NODES: [0:0,0:1], _RELS: [1:0]} +#Alice|Carol|{_NODES: [0:0,0:2], _RELS: [1:1]} +#Alice|Dan|{_NODES: [0:0,0:3], _RELS: [1:2]} +#Alice|Elizabeth|{_NODES: [0:0,0:1,0:4], _RELS: [1:0,1:6]} +#Alice|Farooq|{_NODES: [0:0,0:1,0:4,0:5], _RELS: [1:0,1:6,1:13]} +#Alice|Greg|{_NODES: [0:0,0:1,0:4,0:6], _RELS: [1:0,1:6,1:14]} +#Alice|Hubert Blaine Wolfeschlegelsteinhausenbergerdorff|{_NODES: [0:0,0:1,0:4,0:7], _RELS: [1:0,1:6,1:15]} -NAME SingleSourceAllDestinationsSSP --QUERY MATCH (a:person)-[r:knows* SHORTEST 1..30]->(b:person) WHERE a.fName = 'Alice' RETURN a.fName, b.fName, r +-QUERY MATCH (a:person)-[r:knows* SHORTEST 1..30]->(b:person) WHERE a.fName = 'Alice' RETURN a.fName, b.fName, rels(r), (nodes(r)[2]).fName ---- 7 -Alice|Bob|{_NODES: [0:0,0:1], _RELS: [1:0]} -Alice|Carol|{_NODES: [0:0,0:2], _RELS: [1:1]} -Alice|Dan|{_NODES: [0:0,0:3], _RELS: [1:2]} -Alice|Elizabeth|{_NODES: [0:0,0:1,0:4], _RELS: [1:0,1:6]} -Alice|Farooq|{_NODES: [0:0,0:1,0:4,0:5], _RELS: [1:0,1:6,1:13]} -Alice|Greg|{_NODES: [0:0,0:1,0:4,0:6], _RELS: [1:0,1:6,1:14]} -Alice|Hubert Blaine Wolfeschlegelsteinhausenbergerdorff|{_NODES: [0:0,0:1,0:4,0:7], _RELS: [1:0,1:6,1:15]} - +Alice|Bob|[{_ID: 1:0}]|Bob +Alice|Carol|[{_ID: 1:1}]|Carol +Alice|Dan|[{_ID: 1:2}]|Dan +Alice|Elizabeth|[{_ID: 1:0},{_ID: 1:6}]|Bob +Alice|Farooq|[{_ID: 1:0},{_ID: 1:6},{_ID: 1:13}]|Bob +Alice|Greg|[{_ID: 1:0},{_ID: 1:6},{_ID: 1:14}]|Bob +Alice|Hubert Blaine Wolfeschlegelsteinhausenbergerdorff|[{_ID: 1:0},{_ID: 1:6},{_ID: 1:15}]|Bob +#Bob|Alice|{_NODES: [0:0,0:1], _RELS: [1:3]} +#Carol|Alice|{_NODES: [0:0,0:2], _RELS: [1:7]} +#Dan|Alice|{_NODES: [0:0,0:3], _RELS: [1:10]} +#Elizabeth|Alice|{_NODES: [0:0,0:7,0:4], _RELS: [1:20,1:15]} +#Farooq|Alice|{_NODES: [0:0,0:7,0:5], _RELS: [1:20,1:17]} +#Hubert Blaine Wolfeschlegelsteinhausenbergerdorff|Alice|{_NODES: [0:0,0:7], _RELS: [1:20]} -NAME AllSourcesSingleDestinationQuery --QUERY MATCH (a:person)-[r:knows* SHORTEST 1..30]->(b:person) WHERE b.fName = 'Alice' RETURN a.fName, b.fName, r +-QUERY MATCH (a:person)-[r:knows* SHORTEST 1..30]->(b:person) WHERE b.fName = 'Alice' RETURN a.fName, b.fName, rels(r), (nodes(r)[2]).fName ---- 6 -Bob|Alice|{_NODES: [0:0,0:1], _RELS: [1:3]} -Carol|Alice|{_NODES: [0:0,0:2], _RELS: [1:7]} -Dan|Alice|{_NODES: [0:0,0:3], _RELS: [1:10]} -Elizabeth|Alice|{_NODES: [0:0,0:7,0:4], _RELS: [1:20,1:15]} -Farooq|Alice|{_NODES: [0:0,0:7,0:5], _RELS: [1:20,1:17]} -Hubert Blaine Wolfeschlegelsteinhausenbergerdorff|Alice|{_NODES: [0:0,0:7], _RELS: [1:20]} +Bob|Alice|[{_ID: 1:3}]|Bob +Carol|Alice|[{_ID: 1:7}]|Carol +Dan|Alice|[{_ID: 1:10}]|Dan +Elizabeth|Alice|[{_ID: 1:20},{_ID: 1:15}]|Hubert Blaine Wolfeschlegelsteinhausenbergerdorff +Farooq|Alice|[{_ID: 1:20},{_ID: 1:17}]|Hubert Blaine Wolfeschlegelsteinhausenbergerdorff +Hubert Blaine Wolfeschlegelsteinhausenbergerdorff|Alice|[{_ID: 1:20}]|Hubert Blaine Wolfeschlegelsteinhausenbergerdorff + -NAME SingleSourceWithAllProperties -QUERY MATCH (a:person)-[r:knows* SHORTEST 1..30]->(b:person) WHERE a.fName = 'Alice' RETURN length(r), b, a @@ -43,14 +56,20 @@ Hubert Blaine Wolfeschlegelsteinhausenbergerdorff|Alice|{_NODES: [0:0,0:7], _REL ---- 1 Alice|Bob|1 +#Elizabeth|Alice|{_NODES: [0:4,0:7,0:0], _RELS: [1:15,1:20]} +#Elizabeth|Dan|{_NODES: [0:4,0:7,0:3], _RELS: [1:15,1:21]} +#Elizabeth|Farooq|{_NODES: [0:4,0:5], _RELS: [1:13]} +#Elizabeth|Greg|{_NODES: [0:4,0:6], _RELS: [1:14]} +#Elizabeth|Hubert Blaine Wolfeschlegelsteinhausenbergerdorff|{_NODES: [0:4,0:7], _RELS: [1:15]} -NAME SingleSourceAllDestinations2 --QUERY MATCH (a:person)-[r:knows* SHORTEST 1..2]->(b:person) WHERE a.fName = 'Elizabeth' RETURN a.fName, b.fName, r +-QUERY MATCH (a:person)-[r:knows* SHORTEST 1..2]->(b:person) WHERE a.fName = 'Elizabeth' RETURN a.fName, b.fName, rels(r), (nodes(r)[2]).fName ---- 5 -Elizabeth|Alice|{_NODES: [0:4,0:7,0:0], _RELS: [1:15,1:20]} -Elizabeth|Dan|{_NODES: [0:4,0:7,0:3], _RELS: [1:15,1:21]} -Elizabeth|Farooq|{_NODES: [0:4,0:5], _RELS: [1:13]} -Elizabeth|Greg|{_NODES: [0:4,0:6], _RELS: [1:14]} -Elizabeth|Hubert Blaine Wolfeschlegelsteinhausenbergerdorff|{_NODES: [0:4,0:7], _RELS: [1:15]} +Elizabeth|Alice|[{_ID: 1:15},{_ID: 1:20}]|Hubert Blaine Wolfeschlegelsteinhausenbergerdorff +Elizabeth|Dan|[{_ID: 1:15},{_ID: 1:21}]|Hubert Blaine Wolfeschlegelsteinhausenbergerdorff +Elizabeth|Farooq|[{_ID: 1:13}]|Farooq +Elizabeth|Greg|[{_ID: 1:14}]|Greg +Elizabeth|Hubert Blaine Wolfeschlegelsteinhausenbergerdorff|[{_ID: 1:15}]|Hubert Blaine Wolfeschlegelsteinhausenbergerdorff + -NAME SingleSourceUnreachableDestination -QUERY MATCH (a:person)-[r:knows* SHORTEST 1..30]->(b:person) WHERE a.fName = 'Alice' AND b.fName = 'Alice11' RETURN a.fName, b.fName, r diff --git a/test/test_files/tinysnb/exception/relation.test b/test/test_files/tinysnb/exception/relation.test index 3ad63ed911..211c243d18 100644 --- a/test/test_files/tinysnb/exception/relation.test +++ b/test/test_files/tinysnb/exception/relation.test @@ -11,7 +11,7 @@ Binder exception: e has data type RECURSIVE_REL. (NODE,REL,STRUCT) was expected. -CASE ReadVarlengthRelPropertyTest2 -QUERY MATCH (a:person)-[e:knows*1..3]->(b:person) WHERE ID(e) = 0 RETURN COUNT(*) ---- error -Binder exception: e has data type RECURSIVE_REL. (NODE,REL) was expected. +Binder exception: e has data type RECURSIVE_REL. (NODE,REL,STRUCT) was expected. -CASE AccessRelInternalIDTest -QUERY MATCH (a:person)-[e:knows]->(b:person) WHERE e._id > 1 RETURN COUNT(*) diff --git a/test/test_files/tinysnb/function/node_rel.test b/test/test_files/tinysnb/function/node_rel.test deleted file mode 100644 index 2345822ac9..0000000000 --- a/test/test_files/tinysnb/function/node_rel.test +++ /dev/null @@ -1,14 +0,0 @@ --GROUP TinySnbReadTest --DATASET CSV tinysnb - --- - --CASE FunctionNodeRel - --NAME KnowsOneToTwoHopTest --QUERY MATCH (a:person)-[e:knows*1..2]-(b:person) WHERE a.fName='Elizabeth' RETURN nodes(e), rels(e) ----- 4 -[0:4,0:5,0:4]|[3:12,3:12] -[0:4,0:5]|[3:12] -[0:4,0:6,0:4]|[3:13,3:13] -[0:4,0:6]|[3:13] diff --git a/test/test_files/tinysnb/var_length_extend/multi_label.test b/test/test_files/tinysnb/var_length_extend/multi_label.test index d14346e7af..7c9d3ad1b4 100644 --- a/test/test_files/tinysnb/var_length_extend/multi_label.test +++ b/test/test_files/tinysnb/var_length_extend/multi_label.test @@ -15,37 +15,54 @@ ---- 1 10 +#{_NODES: [0:4,0:5,0:4], _RELS: [3:12,3:12]} +#{_NODES: [0:4,0:5,1:0], _RELS: [3:12,4:2]} +#{_NODES: [0:4,0:5], _RELS: [3:12]} +#{_NODES: [0:4,0:6,0:4], _RELS: [3:13,3:13]} +#{_NODES: [0:4,0:6], _RELS: [3:13]} +#{_NODES: [0:4,1:2,0:3], _RELS: [5:2,5:1]} +#{_NODES: [0:4,1:2,0:4], _RELS: [5:2,5:2]} +#{_NODES: [0:4,1:2], _RELS: [5:2]} -NAME NodeUndirectedTest2 --QUERY MATCH (a)-[e:knows|:studyAt|:workAt*1..2]-(b) WHERE a.ID=7 RETURN e, label(b) +-QUERY MATCH (a)-[e:knows|:studyAt|:workAt*1..2]-(b) WHERE a.ID=7 RETURN id(rels(e)[1]), (rels(e)[1]).rating, (nodes(e)[2]).ID ---- 8 -{_NODES: [0:4,0:5,0:4], _RELS: [3:12,3:12]}|person -{_NODES: [0:4,0:5,1:0], _RELS: [3:12,4:2]}|organisation -{_NODES: [0:4,0:5], _RELS: [3:12]}|person -{_NODES: [0:4,0:6,0:4], _RELS: [3:13,3:13]}|person -{_NODES: [0:4,0:6], _RELS: [3:13]}|person -{_NODES: [0:4,1:2,0:3], _RELS: [5:2,5:1]}|person -{_NODES: [0:4,1:2,0:4], _RELS: [5:2,5:2]}|person -{_NODES: [0:4,1:2], _RELS: [5:2]}|organisation +3:12||8 +3:12||8 +3:12||8 +3:13||9 +3:13||9 +5:2|9.200000|6 +5:2|9.200000|6 +5:2|9.200000|6 +#1|{_NODES: [0:0,0:1,1:0], _RELS: [3:0,4:1]} +#1|{_NODES: [0:0,0:1,1:0], _RELS: [6:0,4:1]} +#1|{_NODES: [0:0,0:1,1:0], _RELS: [7:0,4:1]} +#1|{_NODES: [0:0,1:0], _RELS: [4:0]} +#4|{_NODES: [0:0,0:2,1:1], _RELS: [3:1,5:0]} +#6|{_NODES: [0:0,0:3,1:2], _RELS: [3:2,5:1]} -NAME RelMultiLabelTest --QUERY MATCH (a:person)-[e*1..2]->(b:organisation) WHERE a.fName = 'Alice' RETURN b.ID, e +-QUERY MATCH (a:person)-[e*1..2]->(b:organisation) WHERE a.fName = 'Alice' RETURN b.ID, id(rels(e)[1]), (nodes(e)[2]).ID ---- 6 -1|{_NODES: [0:0,0:1,1:0], _RELS: [3:0,4:1]} -1|{_NODES: [0:0,0:1,1:0], _RELS: [6:0,4:1]} -1|{_NODES: [0:0,0:1,1:0], _RELS: [7:0,4:1]} -1|{_NODES: [0:0,1:0], _RELS: [4:0]} -4|{_NODES: [0:0,0:2,1:1], _RELS: [3:1,5:0]} -6|{_NODES: [0:0,0:3,1:2], _RELS: [3:2,5:1]} +1|3:0|2 +1|4:0|1 +1|6:0|2 +1|7:0|2 +4|3:1|3 +6|3:2|5 +#1|{_NODES: [0:0,0:1,1:0], _RELS: [6:0,4:1]} +#1|{_NODES: [0:0,0:1,1:0], _RELS: [7:0,4:1]} +#5|{_NODES: [0:0,0:1,0:3], _RELS: [6:0,6:1]} +#5|{_NODES: [0:0,0:1,0:3], _RELS: [7:0,6:1]} -NAME MixMultiLabelTest2 --QUERY MATCH (a:person)-[e:meets|:marries|:studyAt*2..2]->(b) WHERE a.fName = 'Alice' RETURN b.ID, e +-QUERY MATCH (a:person)-[e:meets|:marries|:studyAt*2..2]->(b) WHERE a.fName = 'Alice' RETURN b.ID, id(rels(e)[2]), (nodes(e)[3]).ID ---- 4 -1|{_NODES: [0:0,0:1,1:0], _RELS: [6:0,4:1]} -1|{_NODES: [0:0,0:1,1:0], _RELS: [7:0,4:1]} -5|{_NODES: [0:0,0:1,0:3], _RELS: [6:0,6:1]} -5|{_NODES: [0:0,0:1,0:3], _RELS: [7:0,6:1]} - +1|4:1|1 +1|4:1|1 +5|6:1|5 +5|6:1|5 -NAME MixMultiLabelTest3 -QUERY MATCH (a:person)-[e:meets|:marries|:studyAt*2..2]->(b) WHERE a.fName = 'Alice' AND b.ID < 5 RETURN COUNT(*) diff --git a/test/test_files/tinysnb/var_length_extend/n_n.test b/test/test_files/tinysnb/var_length_extend/n_n.test index d2d6284cc2..e76d492b2f 100644 --- a/test/test_files/tinysnb/var_length_extend/n_n.test +++ b/test/test_files/tinysnb/var_length_extend/n_n.test @@ -42,18 +42,30 @@ Greg ---- 1 148 +#[3:0,3:3]|Bob +#[3:0,3:4]|Bob +#[3:0,3:5]|Bob +#[3:0]|Bob +#[3:1,3:6]|Carol +#[3:1,3:7]|Carol +#[3:1,3:8]|Carol +#[3:1]|Carol +#[3:2,3:10]|Dan +#[3:2,3:11]|Dan +#[3:2,3:9]|Dan +#[3:2]|Dan -NAME KnowsOneToTwoHopTest --QUERY MATCH (a:person)-[e:knows*1..2]->(b:person) WHERE a.fName='Alice' RETURN e +-QUERY MATCH (a:person)-[e:knows*1..2]->(b:person) WHERE a.fName='Alice' RETURN id(rels(e)[1]), ((nodes(e))[2]).fName ---- 12 -{_NODES: [0:0,0:1,0:0], _RELS: [3:0,3:3]} -{_NODES: [0:0,0:1,0:2], _RELS: [3:0,3:4]} -{_NODES: [0:0,0:1,0:3], _RELS: [3:0,3:5]} -{_NODES: [0:0,0:1], _RELS: [3:0]} -{_NODES: [0:0,0:2,0:0], _RELS: [3:1,3:6]} -{_NODES: [0:0,0:2,0:1], _RELS: [3:1,3:7]} -{_NODES: [0:0,0:2,0:3], _RELS: [3:1,3:8]} -{_NODES: [0:0,0:2], _RELS: [3:1]} -{_NODES: [0:0,0:3,0:0], _RELS: [3:2,3:9]} -{_NODES: [0:0,0:3,0:1], _RELS: [3:2,3:10]} -{_NODES: [0:0,0:3,0:2], _RELS: [3:2,3:11]} -{_NODES: [0:0,0:3], _RELS: [3:2]} +3:0|Bob +3:0|Bob +3:0|Bob +3:0|Bob +3:1|Carol +3:1|Carol +3:1|Carol +3:1|Carol +3:2|Dan +3:2|Dan +3:2|Dan +3:2|Dan