diff --git a/src/binder/bind/bind_graph_pattern.cpp b/src/binder/bind/bind_graph_pattern.cpp index 0fd3ef5b37..554e307424 100644 --- a/src/binder/bind/bind_graph_pattern.cpp +++ b/src/binder/bind/bind_graph_pattern.cpp @@ -141,10 +141,31 @@ void Binder::bindQueryRel(const RelPattern& relPattern, " to relationship with same name is not supported."); } auto tableIDs = bindRelTableIDs(relPattern.getTableNames()); - // bind node to rel - auto isLeftNodeSrc = RIGHT == relPattern.getDirection(); - auto srcNode = isLeftNodeSrc ? leftNode : rightNode; - auto dstNode = isLeftNodeSrc ? rightNode : leftNode; + // bind src & dst node + RelDirectionType directionType; + std::shared_ptr srcNode; + std::shared_ptr dstNode; + switch (relPattern.getDirection()) { + case ArrowDirection::LEFT: { + srcNode = rightNode; + dstNode = leftNode; + directionType = RelDirectionType::SINGLE; + } break; + case ArrowDirection::RIGHT: { + srcNode = leftNode; + dstNode = rightNode; + directionType = RelDirectionType::SINGLE; + } break; + case ArrowDirection::BOTH: { + // For both direction, left and right will be written with the same label set. So either one + // being src will be correct. + srcNode = leftNode; + dstNode = rightNode; + directionType = RelDirectionType::BOTH; + } break; + default: + throw common::NotImplementedException("Binder::bindQueryRel"); + } if (srcNode->getUniqueName() == dstNode->getUniqueName()) { throw BinderException("Self-loop rel " + parsedName + " is not supported."); } @@ -160,8 +181,8 @@ void Binder::bindQueryRel(const RelPattern& relPattern, auto dataType = isVariableLength ? common::LogicalType(common::LogicalTypeID::RECURSIVE_REL) : common::LogicalType(common::LogicalTypeID::REL); auto queryRel = make_shared(std::move(dataType), - getUniqueExpressionName(parsedName), parsedName, tableIDs, srcNode, dstNode, - relPattern.getDirection() != BOTH, relPattern.getRelType(), lowerBound, upperBound); + getUniqueExpressionName(parsedName), parsedName, tableIDs, srcNode, dstNode, directionType, + relPattern.getRelType(), lowerBound, upperBound); queryRel->setAlias(parsedName); if (isVariableLength) { queryRel->setInternalLengthExpression( diff --git a/src/common/CMakeLists.txt b/src/common/CMakeLists.txt index 9d723d028e..3caaabfe49 100644 --- a/src/common/CMakeLists.txt +++ b/src/common/CMakeLists.txt @@ -8,6 +8,7 @@ add_subdirectory(vector) add_library(kuzu_common OBJECT assert.cpp + rel_direction.cpp expression_type.cpp file_utils.cpp in_mem_overflow_buffer.cpp diff --git a/src/common/rel_direction.cpp b/src/common/rel_direction.cpp new file mode 100644 index 0000000000..faceb7d382 --- /dev/null +++ b/src/common/rel_direction.cpp @@ -0,0 +1,22 @@ +#include "common/rel_direction.h" + +#include "common/exception.h" + +namespace kuzu { +namespace common { + +std::string RelDataDirectionUtils::relDataDirectionToString(RelDataDirection direction) { + switch (direction) { + case RelDataDirection::FWD: { + return "forward"; + } + case RelDataDirection::BWD: { + return "backward"; + } + default: + throw NotImplementedException("RelDataDirectionUtils::relDataDirectionToString"); + } +} + +} // namespace common +} // namespace kuzu diff --git a/src/common/types/types.cpp b/src/common/types/types.cpp index 05dc895c07..025115dd09 100644 --- a/src/common/types/types.cpp +++ b/src/common/types/types.cpp @@ -475,14 +475,6 @@ void LogicalType::setPhysicalType() { } } -RelDataDirection operator!(RelDataDirection& direction) { - return (FWD == direction) ? BWD : FWD; -} - -std::string getRelDataDirectionAsString(RelDataDirection direction) { - return (FWD == direction) ? "forward" : "backward"; -} - // Specialized Ser/Deser functions for logical dataTypes. template<> uint64_t SerDeser::serializeValue( diff --git a/src/include/binder/expression/rel_expression.h b/src/include/binder/expression/rel_expression.h index 0f3128c615..46fb579f98 100644 --- a/src/include/binder/expression/rel_expression.h +++ b/src/include/binder/expression/rel_expression.h @@ -7,15 +7,20 @@ namespace kuzu { namespace binder { +enum class RelDirectionType : uint8_t { + SINGLE = 0, + BOTH = 1, +}; + class RelExpression : public NodeOrRelExpression { public: RelExpression(common::LogicalType dataType, std::string uniqueName, std::string variableName, std::vector tableIDs, std::shared_ptr srcNode, - std::shared_ptr dstNode, bool directed, common::QueryRelType relType, - uint64_t lowerBound, uint64_t upperBound) + std::shared_ptr dstNode, RelDirectionType directionType, + common::QueryRelType relType, uint64_t lowerBound, uint64_t upperBound) : NodeOrRelExpression{std::move(dataType), std::move(uniqueName), std::move(variableName), std::move(tableIDs)}, - srcNode{std::move(srcNode)}, dstNode{std::move(dstNode)}, directed{directed}, + srcNode{std::move(srcNode)}, dstNode{std::move(dstNode)}, directionType{directionType}, relType{relType}, lowerBound{lowerBound}, upperBound{upperBound} {} inline bool isBoundByMultiLabeledNode() const { @@ -31,7 +36,7 @@ class RelExpression : public NodeOrRelExpression { inline uint64_t getLowerBound() const { return lowerBound; } inline uint64_t getUpperBound() const { return upperBound; } - inline bool isDirected() const { return directed; } + inline RelDirectionType getDirectionType() const { return directionType; } inline bool hasInternalIDProperty() const { return hasPropertyExpression(common::INTERNAL_ID_SUFFIX); @@ -50,7 +55,7 @@ class RelExpression : public NodeOrRelExpression { private: std::shared_ptr srcNode; std::shared_ptr dstNode; - bool directed; + RelDirectionType directionType; common::QueryRelType relType; uint64_t lowerBound; uint64_t upperBound; diff --git a/src/include/catalog/catalog_structs.h b/src/include/catalog/catalog_structs.h index d3d5d10976..739b9495bb 100644 --- a/src/include/catalog/catalog_structs.h +++ b/src/include/catalog/catalog_structs.h @@ -6,6 +6,7 @@ #include "common/constants.h" #include "common/exception.h" +#include "common/rel_direction.h" #include "common/types/types_include.h" namespace kuzu { diff --git a/src/include/common/rel_direction.h b/src/include/common/rel_direction.h new file mode 100644 index 0000000000..fb0cd74650 --- /dev/null +++ b/src/include/common/rel_direction.h @@ -0,0 +1,21 @@ +#pragma once + +#include +#include +#include + +namespace kuzu { +namespace common { + +enum RelDataDirection : uint8_t { FWD = 0, BWD = 1 }; + +struct RelDataDirectionUtils { + static inline std::vector getRelDataDirections() { + return std::vector{RelDataDirection::FWD, RelDataDirection::BWD}; + } + + static std::string relDataDirectionToString(RelDataDirection direction); +}; + +} // namespace common +} // namespace kuzu diff --git a/src/include/common/types/types.h b/src/include/common/types/types.h index 2db978ff42..b8db8323e1 100644 --- a/src/include/common/types/types.h +++ b/src/include/common/types/types.h @@ -299,14 +299,6 @@ class LogicalTypeUtils { static LogicalTypeID dataTypeIDFromString(const std::string& dataTypeIDString); }; -// RelDataDirection -enum RelDataDirection : uint8_t { FWD = 0, BWD = 1 }; -const std::vector REL_DIRECTIONS = {FWD, BWD}; -RelDataDirection operator!(RelDataDirection& direction); -std::string getRelDataDirectionAsString(RelDataDirection relDirection); - -enum class ExtendDirection : uint8_t { FWD = 0, BWD = 1, BOTH = 2 }; - enum class DBFileType : uint8_t { ORIGINAL = 0, WAL_VERSION = 1 }; } // namespace common diff --git a/src/include/parser/query/graph_pattern/rel_pattern.h b/src/include/parser/query/graph_pattern/rel_pattern.h index 701d417978..9af71b67e9 100644 --- a/src/include/parser/query/graph_pattern/rel_pattern.h +++ b/src/include/parser/query/graph_pattern/rel_pattern.h @@ -6,8 +6,7 @@ namespace kuzu { namespace parser { -enum ArrowDirection : uint8_t { LEFT = 0, RIGHT = 1, BOTH = 2 }; - +enum class ArrowDirection : uint8_t { LEFT = 0, RIGHT = 1, BOTH = 2 }; /** * RelationshipPattern represents "-[relName:RelTableName+]-" */ diff --git a/src/include/planner/join_order_enumerator.h b/src/include/planner/join_order_enumerator.h index 90da068174..86bec25aa4 100644 --- a/src/include/planner/join_order_enumerator.h +++ b/src/include/planner/join_order_enumerator.h @@ -4,6 +4,7 @@ #include "catalog/catalog.h" #include "common/join_type.h" #include "planner/join_order_enumerator_context.h" +#include "planner/logical_plan/logical_operator/extend_direction.h" #include "storage/store/nodes_statistics_and_deleted_ids.h" namespace kuzu { @@ -59,8 +60,9 @@ class JoinOrderEnumerator { void planBaseTableScan(); void planNodeScan(uint32_t nodePos); void planRelScan(uint32_t relPos); - void appendExtendAndFilter(std::shared_ptr rel, - common::ExtendDirection direction, const expression_vector& predicates, LogicalPlan& plan); + void appendExtendAndFilter(std::shared_ptr boundNode, + std::shared_ptr nbrNode, std::shared_ptr rel, + ExtendDirection direction, const expression_vector& predicates, LogicalPlan& plan); void planLevel(uint32_t level); void planLevelExactly(uint32_t level); @@ -84,11 +86,10 @@ class JoinOrderEnumerator { void appendNonRecursiveExtend(std::shared_ptr boundNode, std::shared_ptr nbrNode, std::shared_ptr rel, - common::ExtendDirection direction, const binder::expression_vector& properties, - LogicalPlan& plan); + ExtendDirection direction, const binder::expression_vector& properties, LogicalPlan& plan); void appendRecursiveExtend(std::shared_ptr boundNode, std::shared_ptr nbrNode, std::shared_ptr rel, - common::ExtendDirection direction, LogicalPlan& plan); + ExtendDirection direction, LogicalPlan& plan); void planJoin(const binder::expression_vector& joinNodeIDs, common::JoinType joinType, std::shared_ptr mark, LogicalPlan& probePlan, LogicalPlan& buildPlan); diff --git a/src/include/planner/logical_plan/logical_operator/base_logical_extend.h b/src/include/planner/logical_plan/logical_operator/base_logical_extend.h index 058fcb02e2..15017f88c0 100644 --- a/src/include/planner/logical_plan/logical_operator/base_logical_extend.h +++ b/src/include/planner/logical_plan/logical_operator/base_logical_extend.h @@ -2,6 +2,7 @@ #include "base_logical_operator.h" #include "binder/expression/rel_expression.h" +#include "extend_direction.h" namespace kuzu { namespace planner { @@ -11,14 +12,14 @@ class BaseLogicalExtend : public LogicalOperator { BaseLogicalExtend(LogicalOperatorType operatorType, std::shared_ptr boundNode, std::shared_ptr nbrNode, std::shared_ptr rel, - common::ExtendDirection direction, std::shared_ptr child) + ExtendDirection direction, std::shared_ptr child) : LogicalOperator{operatorType, std::move(child)}, boundNode{std::move(boundNode)}, nbrNode{std::move(nbrNode)}, rel{std::move(rel)}, direction{direction} {} inline std::shared_ptr getBoundNode() const { return boundNode; } inline std::shared_ptr getNbrNode() const { return nbrNode; } inline std::shared_ptr getRel() const { return rel; } - inline common::ExtendDirection getDirection() const { return direction; } + inline ExtendDirection getDirection() const { return direction; } virtual f_group_pos_set getGroupsPosToFlatten() = 0; std::string getExpressionsForPrinting() const override; @@ -27,7 +28,7 @@ class BaseLogicalExtend : public LogicalOperator { std::shared_ptr boundNode; std::shared_ptr nbrNode; std::shared_ptr rel; - common::ExtendDirection direction; + ExtendDirection direction; }; } // namespace planner diff --git a/src/include/planner/logical_plan/logical_operator/extend_direction.h b/src/include/planner/logical_plan/logical_operator/extend_direction.h new file mode 100644 index 0000000000..6c24c7bb32 --- /dev/null +++ b/src/include/planner/logical_plan/logical_operator/extend_direction.h @@ -0,0 +1,34 @@ +#pragma once + +#include + +#include "binder/expression/rel_expression.h" +#include "common/rel_direction.h" + +namespace kuzu { +namespace planner { + +enum class ExtendDirection : uint8_t { FWD = 0, BWD = 1, BOTH = 2 }; + +struct ExtendDirectionUtils { + static inline ExtendDirection getExtendDirection( + const binder::RelExpression& relExpression, const binder::NodeExpression& boundNode) { + if (relExpression.getDirectionType() == binder::RelDirectionType::BOTH) { + return ExtendDirection::BOTH; + } + if (relExpression.getSrcNodeName() == boundNode.getUniqueName()) { + return ExtendDirection::FWD; + } else { + return ExtendDirection::BWD; + } + } + + static inline common::RelDataDirection getRelDataDirection(ExtendDirection extendDirection) { + assert(extendDirection != ExtendDirection::BOTH); + return extendDirection == ExtendDirection::FWD ? common::RelDataDirection::FWD : + common::RelDataDirection::BWD; + } +}; + +} // namespace planner +} // namespace kuzu diff --git a/src/include/planner/logical_plan/logical_operator/logical_extend.h b/src/include/planner/logical_plan/logical_operator/logical_extend.h index 78d09a986c..0478aff89b 100644 --- a/src/include/planner/logical_plan/logical_operator/logical_extend.h +++ b/src/include/planner/logical_plan/logical_operator/logical_extend.h @@ -9,8 +9,8 @@ class LogicalExtend : public BaseLogicalExtend { public: LogicalExtend(std::shared_ptr boundNode, std::shared_ptr nbrNode, std::shared_ptr rel, - common::ExtendDirection direction, binder::expression_vector properties, - bool hasAtMostOneNbr, std::shared_ptr child) + ExtendDirection direction, binder::expression_vector properties, bool hasAtMostOneNbr, + std::shared_ptr child) : BaseLogicalExtend{LogicalOperatorType::EXTEND, std::move(boundNode), std::move(nbrNode), std::move(rel), direction, std::move(child)}, properties{std::move(properties)}, hasAtMostOneNbr{hasAtMostOneNbr} {} diff --git a/src/include/planner/logical_plan/logical_operator/logical_recursive_extend.h b/src/include/planner/logical_plan/logical_operator/logical_recursive_extend.h index a7ace84217..fad0d4c7bc 100644 --- a/src/include/planner/logical_plan/logical_operator/logical_recursive_extend.h +++ b/src/include/planner/logical_plan/logical_operator/logical_recursive_extend.h @@ -10,12 +10,12 @@ class LogicalRecursiveExtend : public BaseLogicalExtend { public: LogicalRecursiveExtend(std::shared_ptr boundNode, std::shared_ptr nbrNode, std::shared_ptr rel, - common::ExtendDirection direction, std::shared_ptr child) + ExtendDirection direction, std::shared_ptr child) : LogicalRecursiveExtend{std::move(boundNode), std::move(nbrNode), std::move(rel), direction, RecursiveJoinType::TRACK_PATH, std::move(child)} {} LogicalRecursiveExtend(std::shared_ptr boundNode, std::shared_ptr nbrNode, std::shared_ptr rel, - common::ExtendDirection direction, RecursiveJoinType joinType, + ExtendDirection direction, RecursiveJoinType joinType, std::shared_ptr child) : BaseLogicalExtend{LogicalOperatorType::RECURSIVE_EXTEND, std::move(boundNode), std::move(nbrNode), std::move(rel), direction, std::move(child)}, diff --git a/src/include/storage/storage_structure/lists/lists_update_store.h b/src/include/storage/storage_structure/lists/lists_update_store.h index b45c017571..0002edbe10 100644 --- a/src/include/storage/storage_structure/lists/lists_update_store.h +++ b/src/include/storage/storage_structure/lists/lists_update_store.h @@ -4,6 +4,7 @@ #include "catalog/catalog_structs.h" #include "common/data_chunk/data_chunk.h" +#include "common/rel_direction.h" #include "common/types/types.h" #include "processor/result/factorized_table.h" #include "storage/storage_structure/lists/list_handle.h" @@ -78,8 +79,9 @@ class ListsUpdatesStore { } initListsUpdatesPerTablePerDirection(); } - inline ListsUpdatesPerChunk& getListsUpdatesPerChunk(common::RelDataDirection relDirection) { - return listsUpdatesPerDirection[relDirection]; + inline ListsUpdatesPerChunk& getListsUpdatesPerChunk( + common::RelDataDirection relDataDirection) { + return listsUpdatesPerDirection[relDataDirection]; } void updateSchema(catalog::RelTableSchema& relTableSchema); diff --git a/src/include/storage/wal/wal_record.h b/src/include/storage/wal/wal_record.h index f678c9c6b0..905716f09a 100644 --- a/src/include/storage/wal/wal_record.h +++ b/src/include/storage/wal/wal_record.h @@ -1,5 +1,6 @@ #pragma once +#include "common/rel_direction.h" #include "common/types/types_include.h" #include "common/utils.h" diff --git a/src/planner/join_order_enumerator.cpp b/src/planner/join_order_enumerator.cpp index ca8e9a938e..39fb292b83 100644 --- a/src/planner/join_order_enumerator.cpp +++ b/src/planner/join_order_enumerator.cpp @@ -153,6 +153,7 @@ void JoinOrderEnumerator::planNodeScan(uint32_t nodePos) { static std::pair, std::shared_ptr> getBoundAndNbrNodes(const RelExpression& rel, ExtendDirection direction) { + assert(direction != ExtendDirection::BOTH); auto boundNode = direction == ExtendDirection::FWD ? rel.getSrcNode() : rel.getDstNode(); auto dstNode = direction == ExtendDirection::FWD ? rel.getDstNode() : rel.getSrcNode(); return make_pair(boundNode, dstNode); @@ -168,17 +169,17 @@ void JoinOrderEnumerator::planRelScan(uint32_t relPos) { // we always enumerate two plans, one from src to dst, and the other from dst to src. for (auto direction : {ExtendDirection::FWD, ExtendDirection::BWD}) { auto plan = std::make_unique(); - auto [boundNode, _] = getBoundAndNbrNodes(*rel, direction); - auto extendDirection = rel->isDirected() ? direction : ExtendDirection::BOTH; + auto [boundNode, nbrNode] = getBoundAndNbrNodes(*rel, direction); + auto extendDirection = ExtendDirectionUtils::getExtendDirection(*rel, *boundNode); appendScanNodeID(boundNode, *plan); - appendExtendAndFilter(rel, direction, predicates, *plan); + appendExtendAndFilter(boundNode, nbrNode, rel, extendDirection, predicates, *plan); context->addPlan(newSubgraph, std::move(plan)); } } -void JoinOrderEnumerator::appendExtendAndFilter(std::shared_ptr rel, - common::ExtendDirection direction, const expression_vector& predicates, LogicalPlan& plan) { - auto [boundNode, nbrNode] = getBoundAndNbrNodes(*rel, direction); +void JoinOrderEnumerator::appendExtendAndFilter(std::shared_ptr boundNode, + std::shared_ptr nbrNode, std::shared_ptr rel, + ExtendDirection direction, const expression_vector& predicates, LogicalPlan& plan) { switch (rel->getRelType()) { case common::QueryRelType::NON_RECURSIVE: { auto properties = queryPlanner->getPropertiesForRel(*rel); @@ -189,7 +190,7 @@ void JoinOrderEnumerator::appendExtendAndFilter(std::shared_ptr r appendRecursiveExtend(boundNode, nbrNode, rel, direction, plan); } break; default: - throw common::NotImplementedException("appendExtendAndFilter()"); + throw common::NotImplementedException("JoinOrderEnumerator::appendExtendAndFilter"); } queryPlanner->appendFilters(predicates, plan); } @@ -386,8 +387,9 @@ bool JoinOrderEnumerator::tryPlanINLJoin(const SubqueryGraph& subgraph, assert(relPos != UINT32_MAX); auto rel = context->queryGraph->getQueryRel(relPos); auto boundNode = joinNodes[0]; - auto direction = boundNode->getUniqueName() == rel->getSrcNodeName() ? ExtendDirection::FWD : - ExtendDirection::BWD; + auto nbrNode = + boundNode->getUniqueName() == rel->getSrcNodeName() ? rel->getDstNode() : rel->getSrcNode(); + auto extendDirection = ExtendDirectionUtils::getExtendDirection(*rel, *boundNode); auto newSubgraph = subgraph; newSubgraph.addQueryRel(relPos); auto predicates = @@ -396,7 +398,7 @@ bool JoinOrderEnumerator::tryPlanINLJoin(const SubqueryGraph& subgraph, for (auto& prevPlan : context->getPlans(subgraph)) { if (isNodeSequentialOnPlan(*prevPlan, *boundNode)) { auto plan = prevPlan->shallowCopy(); - appendExtendAndFilter(rel, direction, predicates, *plan); + appendExtendAndFilter(boundNode, nbrNode, rel, extendDirection, predicates, *plan); context->addPlan(newSubgraph, std::move(plan)); hasAppliedINLJoin = true; } @@ -461,7 +463,10 @@ static bool extendHasAtMostOneNbrGuarantee(RelExpression& rel, NodeExpression& b if (rel.isMultiLabeled()) { return false; } - auto relDirection = direction == ExtendDirection::BWD ? BWD : FWD; + if (direction == ExtendDirection::BOTH) { + return false; + } + auto relDirection = ExtendDirectionUtils::getRelDataDirection(direction); return catalog.getReadOnlyVersion()->isSingleMultiplicityInDirection( rel.getSingleTableID(), relDirection); } @@ -470,9 +475,8 @@ void JoinOrderEnumerator::appendNonRecursiveExtend(std::shared_ptr nbrNode, std::shared_ptr rel, ExtendDirection direction, const expression_vector& properties, LogicalPlan& plan) { auto hasAtMostOneNbr = extendHasAtMostOneNbrGuarantee(*rel, *boundNode, direction, catalog); - auto extend = make_shared(boundNode, nbrNode, rel, - rel->isDirected() ? direction : ExtendDirection::BOTH, properties, hasAtMostOneNbr, - plan.getLastOperator()); + auto extend = make_shared( + boundNode, nbrNode, rel, direction, properties, hasAtMostOneNbr, plan.getLastOperator()); queryPlanner->appendFlattens(extend->getGroupsPosToFlatten(), plan); extend->setChild(0, plan.getLastOperator()); extend->computeFactorizedSchema(); @@ -489,9 +493,9 @@ void JoinOrderEnumerator::appendNonRecursiveExtend(std::shared_ptr boundNode, std::shared_ptr nbrNode, std::shared_ptr rel, - common::ExtendDirection direction, LogicalPlan& plan) { + ExtendDirection direction, LogicalPlan& plan) { auto hasAtMostOneNbr = extendHasAtMostOneNbrGuarantee(*rel, *boundNode, direction, catalog); - assert(rel->isDirected()); + assert(direction != ExtendDirection::BOTH); auto extend = std::make_shared( boundNode, nbrNode, rel, direction, plan.getLastOperator()); queryPlanner->appendFlattens(extend->getGroupsPosToFlatten(), plan); diff --git a/src/planner/operator/base_logical_extend.cpp b/src/planner/operator/base_logical_extend.cpp index cb857e2a56..54b42011de 100644 --- a/src/planner/operator/base_logical_extend.cpp +++ b/src/planner/operator/base_logical_extend.cpp @@ -21,18 +21,24 @@ static std::string relToString(const binder::RelExpression& rel) { std::string BaseLogicalExtend::getExpressionsForPrinting() const { auto result = boundNode->toString(); - if (!rel->isDirected()) { - result += "<-"; - result += relToString(*rel); - result += "->"; - } else if (direction == common::ExtendDirection::FWD) { + switch (direction) { + case ExtendDirection::FWD: { result += "-"; result += relToString(*rel); result += "->"; - } else { + } break; + case ExtendDirection::BWD: { result += "<-"; result += relToString(*rel); result += "-"; + } break; + case ExtendDirection::BOTH: { + result += "<-"; + result += relToString(*rel); + result += "->"; + } break; + default: + throw common::NotImplementedException("BaseLogicalExtend::getExpressionsForPrinting"); } result += nbrNode->toString(); return result; diff --git a/src/processor/mapper/map_extend.cpp b/src/processor/mapper/map_extend.cpp index 7e3aa3572e..5f703ed629 100644 --- a/src/processor/mapper/map_extend.cpp +++ b/src/processor/mapper/map_extend.cpp @@ -15,11 +15,6 @@ using namespace kuzu::storage; namespace kuzu { namespace processor { -static RelDataDirection getRelDataDirection(ExtendDirection extendDirection) { - assert(extendDirection != ExtendDirection::BOTH); - return extendDirection == ExtendDirection::BWD ? BWD : FWD; -} - static std::vector populatePropertyIds( table_id_t relID, const expression_vector& properties) { std::vector outputColumns; @@ -69,8 +64,7 @@ static std::unique_ptr populateRelTableDataCollection( relTableDatas.push_back(relTableData); tableScanStates.push_back(std::move(scanState)); } - break; - } + } break; case ExtendDirection::BWD: { auto [relTableData, scanState] = getRelTableDataAndScanState(RelDataDirection::BWD, relTableSchema, boundNodeTableID, relsStore, relTableID, properties); @@ -78,8 +72,7 @@ static std::unique_ptr populateRelTableDataCollection( relTableDatas.push_back(relTableData); tableScanStates.push_back(std::move(scanState)); } - break; - } + } break; case ExtendDirection::BOTH: { auto [relTableDataFWD, scanStateFWD] = getRelTableDataAndScanState(RelDataDirection::FWD, relTableSchema, boundNodeTableID, @@ -95,8 +88,9 @@ static std::unique_ptr populateRelTableDataCollection( relTableDatas.push_back(relTableDataBWD); tableScanStates.push_back(std::move(scanStateBWD)); } - break; - } + } break; + default: + throw common::NotImplementedException("populateRelTableDataCollection"); } } return std::make_unique( @@ -123,8 +117,9 @@ std::unique_ptr PlanMapper::mapLogicalExtendToPhysical( outputVectorsPos.emplace_back(outSchema->getExpressionPos(*expression)); } auto& relsStore = storageManager.getRelsStore(); - if (!rel->isMultiLabeled() && !boundNode->isMultiLabeled() && rel->isDirected()) { - auto relDataDirection = getRelDataDirection(extendDirection); + if (!rel->isMultiLabeled() && !boundNode->isMultiLabeled() && + extendDirection != planner::ExtendDirection::BOTH) { + auto relDataDirection = ExtendDirectionUtils::getRelDataDirection(extendDirection); auto relTableID = rel->getSingleTableID(); if (relsStore.isSingleMultiplicityInDirection(relDataDirection, relTableID)) { auto propertyIds = populatePropertyIds(relTableID, extend->getProperties()); @@ -165,7 +160,7 @@ std::unique_ptr PlanMapper::mapLogicalRecursiveExtendToPhysica auto boundNode = extend->getBoundNode(); auto nbrNode = extend->getNbrNode(); auto rel = extend->getRel(); - auto direction = getRelDataDirection(extend->getDirection()); + auto relDataDirection = ExtendDirectionUtils::getRelDataDirection(extend->getDirection()); auto prevOperator = mapLogicalOperatorToPhysical(logicalOperator->getChild(0)); auto inNodeIDVectorPos = DataPos(inSchema->getExpressionPos(*boundNode->getInternalIDProperty())); @@ -196,18 +191,18 @@ std::unique_ptr PlanMapper::mapLogicalRecursiveExtendToPhysica std::unique_ptr scanRelTable; std::vector emptyPropertyIDs; DataPos tmpDstNodePos{0, 0}; - if (relsStore.isSingleMultiplicityInDirection(direction, relTableID)) { + if (relsStore.isSingleMultiplicityInDirection(relDataDirection, relTableID)) { tmpDstNodePos = DataPos{0, 1}; scanRelTable = make_unique( - relsStore.getRelTable(relTableID)->getDirectedTableData(direction), emptyPropertyIDs, - tmpSrcNodePos, std::vector{tmpDstNodePos}, std::move(scanFrontier), - getOperatorID(), emptyParamString); + relsStore.getRelTable(relTableID)->getDirectedTableData(relDataDirection), + emptyPropertyIDs, tmpSrcNodePos, std::vector{tmpDstNodePos}, + std::move(scanFrontier), getOperatorID(), emptyParamString); } else { tmpDstNodePos = DataPos{0, 1}; scanRelTable = make_unique( - relsStore.getRelTable(relTableID)->getDirectedTableData(direction), emptyPropertyIDs, - tmpSrcNodePos, std::vector{tmpDstNodePos}, std::move(scanFrontier), - getOperatorID(), emptyParamString); + relsStore.getRelTable(relTableID)->getDirectedTableData(relDataDirection), + emptyPropertyIDs, tmpSrcNodePos, std::vector{tmpDstNodePos}, + std::move(scanFrontier), getOperatorID(), emptyParamString); } std::unique_ptr dataInfo; switch (extend->getJoinType()) { diff --git a/src/storage/copier/rel_copy_executor.cpp b/src/storage/copier/rel_copy_executor.cpp index bf2d36dd5f..6a01b8d6be 100644 --- a/src/storage/copier/rel_copy_executor.cpp +++ b/src/storage/copier/rel_copy_executor.cpp @@ -37,7 +37,7 @@ std::string RelCopyExecutor::getTaskTypeName(PopulateTaskType populateTaskType) } void RelCopyExecutor::initializeColumnsAndLists() { - for (auto relDirection : REL_DIRECTIONS) { + for (auto relDirection : RelDataDirectionUtils::getRelDataDirections()) { listSizesPerDirection[relDirection] = std::make_unique( maxNodeOffsetsPerTable.at( reinterpret_cast(tableSchema)->getBoundTableID(relDirection)) + @@ -69,7 +69,7 @@ void RelCopyExecutor::populateColumnsAndLists(processor::ExecutionContext* execu void RelCopyExecutor::saveToFile() { logger->debug("Writing columns and Lists to disk for rel {}.", tableSchema->tableName); - for (auto relDirection : REL_DIRECTIONS) { + for (auto relDirection : RelDataDirectionUtils::getRelDataDirections()) { if (reinterpret_cast(tableSchema) ->isSingleMultiplicityInDirection(relDirection)) { adjColumnsPerDirection[relDirection]->flushChunk( @@ -139,7 +139,7 @@ void RelCopyExecutor::initializeLists(RelDataDirection relDirection) { void RelCopyExecutor::initAdjListsHeaders() { // TODO(Semih): Schedule one at a time and wait. logger->debug("Initializing AdjListHeaders for rel {}.", tableSchema->tableName); - for (auto relDirection : REL_DIRECTIONS) { + for (auto relDirection : RelDataDirectionUtils::getRelDataDirections()) { if (!reinterpret_cast(tableSchema) ->isSingleMultiplicityInDirection(relDirection)) { auto boundTableID = @@ -158,7 +158,7 @@ void RelCopyExecutor::initListsMetadata() { // TODO(Semih): Schedule one at a time and wait. logger->debug( "Initializing adjLists and propertyLists metadata for rel {}.", tableSchema->tableName); - for (auto relDirection : REL_DIRECTIONS) { + for (auto relDirection : RelDataDirectionUtils::getRelDataDirections()) { if (!reinterpret_cast(tableSchema) ->isSingleMultiplicityInDirection(relDirection)) { auto boundTableID = @@ -283,7 +283,7 @@ void RelCopyExecutor::populateLists() { } void RelCopyExecutor::sortAndCopyOverflowValues() { - for (auto relDirection : REL_DIRECTIONS) { + for (auto relDirection : RelDataDirectionUtils::getRelDataDirections()) { // Sort overflow values of property Lists. if (!reinterpret_cast(tableSchema) ->isSingleMultiplicityInDirection(relDirection)) { @@ -313,7 +313,7 @@ void RelCopyExecutor::sortAndCopyOverflowValues() { } } // Sort overflow values of property columns. - for (auto relDirection : REL_DIRECTIONS) { + for (auto relDirection : RelDataDirectionUtils::getRelDataDirections()) { if (reinterpret_cast(tableSchema) ->isSingleMultiplicityInDirection(relDirection)) { auto numNodes = @@ -354,7 +354,7 @@ void RelCopyExecutor::inferTableIDsAndOffsets(const std::vector& nodeIDs, std::vector& nodeIDTypes, const std::map& pkIndexes, Transaction* transaction, int64_t blockOffset, int64_t& colIndex) { - for (auto& relDirection : REL_DIRECTIONS) { + for (auto& relDirection : RelDataDirectionUtils::getRelDataDirections()) { if (colIndex >= batchColumns.size()) { throw CopyException("Number of columns mismatch."); } @@ -615,7 +615,7 @@ void RelCopyExecutor::populateAdjColumnsAndCountRelsInAdjListsTask(uint64_t bloc std::vector nodeIDs{2}; std::vector nodePKTypes{2}; auto relTableSchema = reinterpret_cast(copier->tableSchema); - for (auto& relDirection : REL_DIRECTIONS) { + for (auto& relDirection : RelDataDirectionUtils::getRelDataDirections()) { auto boundTableID = relTableSchema->getBoundTableID(relDirection); nodeIDs[relDirection].tableID = boundTableID; nodePKTypes[relDirection] = copier->catalog.getReadOnlyVersion() @@ -630,7 +630,7 @@ void RelCopyExecutor::populateAdjColumnsAndCountRelsInAdjListsTask(uint64_t bloc int64_t colIndex = 0; inferTableIDsAndOffsets(batchColumns, nodeIDs, nodePKTypes, copier->pkIndexes, copier->dummyReadOnlyTrx.get(), blockOffset, colIndex); - for (auto relDirection : REL_DIRECTIONS) { + for (auto relDirection : RelDataDirectionUtils::getRelDataDirections()) { auto tableID = nodeIDs[relDirection].tableID; auto nodeOffset = nodeIDs[relDirection].offset; if (relTableSchema->isSingleMultiplicityInDirection(relDirection)) { @@ -641,7 +641,7 @@ void RelCopyExecutor::populateAdjColumnsAndCountRelsInAdjListsTask(uint64_t bloc relTableSchema->tableName, getRelMultiplicityAsString(relTableSchema->relMultiplicity), nodeOffset, copier->catalog.getReadOnlyVersion()->getTableName(tableID), - getRelDataDirectionAsString(relDirection))); + RelDataDirectionUtils::relDataDirectionToString(relDirection))); } copier->adjColumnChunksPerDirection[relDirection]->setValue( (uint8_t*)&nodeIDs[!relDirection].offset, nodeOffset); @@ -671,7 +671,7 @@ void RelCopyExecutor::populateListsTask(uint64_t blockId, uint64_t blockStartRel std::vector nodePKTypes(2); std::vector reversePos(2); auto relTableSchema = reinterpret_cast(copier->tableSchema); - for (auto relDirection : REL_DIRECTIONS) { + for (auto relDirection : RelDataDirectionUtils::getRelDataDirections()) { auto boundTableID = relTableSchema->getBoundTableID(relDirection); nodeIDs[relDirection].tableID = boundTableID; nodePKTypes[relDirection] = copier->catalog.getReadOnlyVersion() @@ -686,7 +686,7 @@ void RelCopyExecutor::populateListsTask(uint64_t blockId, uint64_t blockStartRel int64_t colIndex = 0; inferTableIDsAndOffsets(batchColumns, nodeIDs, nodePKTypes, copier->pkIndexes, copier->dummyReadOnlyTrx.get(), blockOffset, colIndex); - for (auto relDirection : REL_DIRECTIONS) { + for (auto relDirection : RelDataDirectionUtils::getRelDataDirections()) { if (!copier->catalog.getReadOnlyVersion()->isSingleMultiplicityInDirection( copier->tableSchema->tableID, relDirection)) { auto nodeOffset = nodeIDs[relDirection].offset; @@ -761,7 +761,7 @@ void RelCopyExecutor::putValueIntoColumns(uint64_t propertyID, std::vector>>& directionTablePropertyColumnChunks, const std::vector& nodeIDs, uint8_t* val) { - for (auto relDirection : REL_DIRECTIONS) { + for (auto relDirection : RelDataDirectionUtils::getRelDataDirections()) { if (directionTablePropertyColumnChunks[relDirection].empty()) { continue; } @@ -777,7 +777,7 @@ void RelCopyExecutor::putValueIntoLists(uint64_t propertyID, directionTablePropertyLists, std::vector>& directionTableAdjLists, const std::vector& nodeIDs, const std::vector& reversePos, uint8_t* val) { - for (auto relDirection : REL_DIRECTIONS) { + for (auto relDirection : RelDataDirectionUtils::getRelDataDirections()) { if (directionTableAdjLists[relDirection] == nullptr) { continue; } diff --git a/src/storage/storage_structure/lists/lists_update_store.cpp b/src/storage/storage_structure/lists/lists_update_store.cpp index f78647bc51..76f3b94e1d 100644 --- a/src/storage/storage_structure/lists/lists_update_store.cpp +++ b/src/storage/storage_structure/lists/lists_update_store.cpp @@ -63,7 +63,7 @@ bool ListsUpdatesStore::isRelDeletedInPersistentStore( } bool ListsUpdatesStore::hasUpdates() const { - for (auto relDirection : REL_DIRECTIONS) { + for (auto relDirection : RelDataDirectionUtils::getRelDataDirections()) { for (auto& [_, listsUpdatesPerNode] : listsUpdatesPerDirection[relDirection]) { for (auto& [_, listsUpdatesForNodeOffset] : listsUpdatesPerNode) { if (listsUpdatesForNodeOffset->hasUpdates()) { @@ -96,7 +96,7 @@ void ListsUpdatesStore::insertRelIfNecessary(const ValueVector* srcNodeIDVector, std::vector{(ValueVector*)srcNodeIDVector, (ValueVector*)dstNodeIDVector}; vectorsToAppendToFT.insert( vectorsToAppendToFT.end(), relPropertyVectors.begin(), relPropertyVectors.end()); - for (auto direction : REL_DIRECTIONS) { + for (auto direction : RelDataDirectionUtils::getRelDataDirections()) { auto boundNodeID = direction == FWD ? srcNodeID : dstNodeID; if (!relTableSchema.isSingleMultiplicityInDirection(direction)) { if (!hasInsertedToFT) { @@ -122,7 +122,7 @@ void ListsUpdatesStore::deleteRelIfNecessary(common::ValueVector* srcNodeIDVecto // If the rel that we are going to delete is a newly inserted rel, we need to delete // its tupleIdx from the insertedRelsTupleIdxInFT of listsUpdatesStore in FWD and BWD // direction. Note: we don't reuse the space for inserted rel tuple in factorizedTable. - for (auto direction : REL_DIRECTIONS) { + for (auto direction : RelDataDirectionUtils::getRelDataDirections()) { auto boundNodeID = direction == RelDataDirection::FWD ? srcNodeID : dstNodeID; if (!relTableSchema.isSingleMultiplicityInDirection(direction)) { auto& insertedRelsTupleIdxInFT = @@ -138,7 +138,7 @@ void ListsUpdatesStore::deleteRelIfNecessary(common::ValueVector* srcNodeIDVecto } } } else { - for (auto direction : REL_DIRECTIONS) { + for (auto direction : RelDataDirectionUtils::getRelDataDirections()) { auto boundNodeID = direction == RelDataDirection::FWD ? srcNodeID : dstNodeID; if (!relTableSchema.isSingleMultiplicityInDirection(direction)) { getOrCreateListsUpdatesForNodeOffset(direction, boundNodeID) @@ -192,7 +192,7 @@ void ListsUpdatesStore::updateRelIfNecessary(ValueVector* srcNodeIDVector, auto dstNodeID = dstNodeIDVector->getValue( dstNodeIDVector->state->selVector->selectedPositions[0]); bool insertUpdatedRel = true; - for (auto direction : REL_DIRECTIONS) { + for (auto direction : RelDataDirectionUtils::getRelDataDirections()) { auto boundNodeID = direction == FWD ? srcNodeID : dstNodeID; // We should only store the property update if the property is stored as a list in the // current direction. (E.g. We update a rel property of a MANY-ONE rel table which stores @@ -267,7 +267,7 @@ void ListsUpdatesStore::readPropertyUpdateToInMemList(ListFileID& listFileID, } void ListsUpdatesStore::initNewlyAddedNodes(nodeID_t& nodeID) { - for (auto direction : REL_DIRECTIONS) { + for (auto direction : RelDataDirectionUtils::getRelDataDirections()) { if (!relTableSchema.isSingleMultiplicityInDirection(direction) && nodeID.tableID == relTableSchema.getBoundTableID(direction)) { auto& listsUpdatesPerNode = @@ -319,7 +319,7 @@ ft_col_idx_t ListsUpdatesStore::getColIdxInFT(ListFileID& listFileID) const { void ListsUpdatesStore::initListsUpdatesPerTablePerDirection() { listsUpdatesPerDirection.clear(); - for (auto direction : REL_DIRECTIONS) { + for (auto direction : RelDataDirectionUtils::getRelDataDirections()) { listsUpdatesPerDirection.emplace_back(); } } diff --git a/src/storage/storage_utils.cpp b/src/storage/storage_utils.cpp index 9244935824..83bc4eade3 100644 --- a/src/storage/storage_utils.cpp +++ b/src/storage/storage_utils.cpp @@ -200,7 +200,7 @@ void StorageUtils::createFileForNodePropertyWithDefaultVal(table_id_t tableID, void StorageUtils::createFileForRelPropertyWithDefaultVal(RelTableSchema* tableSchema, const Property& property, uint8_t* defaultVal, bool isDefaultValNull, StorageManager& storageManager) { - for (auto direction : REL_DIRECTIONS) { + for (auto direction : RelDataDirectionUtils::getRelDataDirections()) { auto createPropertyFileFunc = tableSchema->isSingleMultiplicityInDirection(direction) ? createFileForRelColumnPropertyWithDefaultVal : createFileForRelListsPropertyWithDefaultVal; diff --git a/src/storage/store/rel_table.cpp b/src/storage/store/rel_table.cpp index bbb289f93e..11e006b0bd 100644 --- a/src/storage/store/rel_table.cpp +++ b/src/storage/store/rel_table.cpp @@ -160,7 +160,7 @@ void DirectedRelTableData::insertRel(common::ValueVector* boundVector, nodeOffset, boundVector->getValue(boundVector->state->selVector->selectedPositions[0]) .tableID, - tableID, getRelDataDirectionAsString(direction))); + tableID, RelDataDirectionUtils::relDataDirectionToString(direction))); } adjColumn->write(boundVector, nbrVector); for (auto i = 0u; i < relPropertyVectors.size(); i++) { diff --git a/src/storage/wal_replayer_utils.cpp b/src/storage/wal_replayer_utils.cpp index e59a109bd0..8b1022d1f7 100644 --- a/src/storage/wal_replayer_utils.cpp +++ b/src/storage/wal_replayer_utils.cpp @@ -11,7 +11,7 @@ namespace storage { void WALReplayerUtils::removeDBFilesForRelProperty( const std::string& directory, RelTableSchema* relTableSchema, property_id_t propertyID) { - for (auto relDirection : REL_DIRECTIONS) { + for (auto relDirection : RelDataDirectionUtils::getRelDataDirections()) { auto boundTableID = relTableSchema->getBoundTableID(relDirection); if (relTableSchema->isSingleMultiplicityInDirection(relDirection)) { removeColumnFilesForPropertyIfExists(directory, relTableSchema->tableID, boundTableID, @@ -25,7 +25,7 @@ void WALReplayerUtils::removeDBFilesForRelProperty( void WALReplayerUtils::createEmptyDBFilesForNewRelTable(RelTableSchema* relTableSchema, const std::string& directory, const std::map& maxNodeOffsetsPerTable) { - for (auto relDirection : REL_DIRECTIONS) { + for (auto relDirection : RelDataDirectionUtils::getRelDataDirections()) { if (relTableSchema->isSingleMultiplicityInDirection(relDirection)) { createEmptyDBFilesForColumns( maxNodeOffsetsPerTable, relDirection, directory, relTableSchema); @@ -74,7 +74,7 @@ void WALReplayerUtils::createEmptyDBFilesForNewNodeTable( void WALReplayerUtils::renameDBFilesForRelProperty(const std::string& directory, kuzu::catalog::RelTableSchema* relTableSchema, kuzu::common::property_id_t propertyID) { - for (auto direction : REL_DIRECTIONS) { + for (auto direction : RelDataDirectionUtils::getRelDataDirections()) { auto boundTableID = relTableSchema->getBoundTableID(direction); if (relTableSchema->isSingleMultiplicityInDirection(direction)) { replaceOriginalColumnFilesWithWALVersionIfExists( @@ -91,7 +91,7 @@ void WALReplayerUtils::replaceListsHeadersFilesWithVersionFromWALIfExists( const std::unordered_set& relTableSchemas, table_id_t boundTableID, const std::string& directory) { for (auto relTableSchema : relTableSchemas) { - for (auto direction : REL_DIRECTIONS) { + for (auto direction : RelDataDirectionUtils::getRelDataDirections()) { if (!relTableSchema->isSingleMultiplicityInDirection(direction)) { auto listsHeadersFileName = StorageUtils::getListHeadersFName(StorageUtils::getAdjListsFName( @@ -214,7 +214,7 @@ void WALReplayerUtils::fileOperationOnNodeFiles(NodeTableSchema* nodeTableSchema void WALReplayerUtils::fileOperationOnRelFiles(RelTableSchema* relTableSchema, const std::string& directory, std::function columnFileOperation, std::function listFileOperation) { - for (auto relDirection : REL_DIRECTIONS) { + for (auto relDirection : RelDataDirectionUtils::getRelDataDirections()) { auto boundTableID = relTableSchema->getBoundTableID(relDirection); auto isColumnProperty = relTableSchema->isSingleMultiplicityInDirection(relDirection); if (isColumnProperty) { diff --git a/test/graph_test/graph_test.cpp b/test/graph_test/graph_test.cpp index ce2beaebb6..4c33038f00 100644 --- a/test/graph_test/graph_test.cpp +++ b/test/graph_test/graph_test.cpp @@ -54,7 +54,7 @@ void BaseGraphTest::validateNodeColumnFilesExistence( void BaseGraphTest::validateRelColumnAndListFilesExistence( RelTableSchema* relTableSchema, DBFileType dbFileType, bool existence) { - for (auto relDirection : REL_DIRECTIONS) { + for (auto relDirection : RelDataDirectionUtils::getRelDataDirections()) { if (relTableSchema->relMultiplicity) { validateColumnFilesExistence(StorageUtils::getAdjColumnFName(databasePath, relTableSchema->tableID, relDirection, dbFileType), diff --git a/test/test_files/tinysnb/match/undirected.test b/test/test_files/tinysnb/match/undirected.test index 8fa13c39ba..73bea6a244 100644 --- a/test/test_files/tinysnb/match/undirected.test +++ b/test/test_files/tinysnb/match/undirected.test @@ -6,6 +6,7 @@ -NAME UndirKnows1 -QUERY MATCH (a:person)-[:knows]-(b:person) WHERE b.fName = "Bob" RETURN a.fName; +-ENUMERATE ---- 6 Alice Carol @@ -16,6 +17,7 @@ Dan -NAME UndirKnows2 -QUERY MATCH (a:person)-[:knows]-(b:person)-[:knows]-(c:person) WHERE a.gender = 1 AND b.gender = 2 AND c.fName = "Bob" RETURN a.fName, b.fName; +-ENUMERATE ---- 8 Alice|Dan Carol|Dan @@ -28,16 +30,19 @@ Carol|Dan -NAME UndirMultiLabel1 -QUERY MATCH (a:person:organisation)-[:meets|:marries|:workAt]-(b:person:organisation) RETURN COUNT(*); +-ENUMERATE ---- 1 26 -NAME UndirMultiLabel2 -QUERY MATCH (a:person)-[:studyAt|:meets]-(b:person:organisation) RETURN COUNT(*); +-ENUMERATE ---- 1 20 -NAME UndirMultiLabel3 -QUERY MATCH (a:person)-[:meets|:marries|:knows]-(b:person)-[:knows|:meets]-(c:person) WHERE c.fName = "Farooq" AND a.fName <> "Farooq" RETURN a.fName, b.fName; +-ENUMERATE ---- 13 Carol|Elizabeth Alice|Carol @@ -55,11 +60,13 @@ Dan|Carol -NAME UndirUnlabelled -QUERY MATCH (a:person)-[]-() RETURN COUNT(*); +-ENUMERATE ---- 1 60 -NAME UndirPattern -QUERY MATCH ()-[:studyAt]-(a)-[:meets]-()-[:workAt]-() RETURN a.fName; +-ENUMERATE ---- 2 Farooq Bob