Skip to content

Commit

Permalink
change RelDataDirection to ExtendDirection in join_order_enumerator; …
Browse files Browse the repository at this point in the history
…and other minor changes
  • Loading branch information
aziz-mu committed May 11, 2023
1 parent de7de30 commit 40bebb9
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 36 deletions.
4 changes: 2 additions & 2 deletions src/binder/bind/bind_graph_pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ void Binder::bindQueryRel(const RelPattern& relPattern,
// bind variable length
auto [lowerBound, upperBound] = bindVariableLengthRelBound(relPattern);
auto queryRel = make_shared<RelExpression>(getUniqueExpressionName(parsedName), parsedName,
tableIDs, srcNode, dstNode, relPattern.isDirected(), relPattern.getRelType(), lowerBound,
upperBound);
tableIDs, srcNode, dstNode, relPattern.getDirection() != BOTH, relPattern.getRelType(),
lowerBound, upperBound);
queryRel->setAlias(parsedName);
// resolve properties associate with rel table
std::vector<RelTableSchema*> relTableSchemas;
Expand Down
8 changes: 2 additions & 6 deletions src/include/parser/query/graph_pattern/rel_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
namespace kuzu {
namespace parser {

enum ArrowDirection : uint8_t { LEFT = 0, RIGHT = 1, UNDIRECTED = 2 };
enum ArrowDirection : uint8_t { LEFT = 0, RIGHT = 1, BOTH = 2 };

/**
* RelationshipPattern represents "-[relName:RelTableName+]-"
Expand All @@ -28,11 +28,7 @@ class RelPattern : public NodePattern {

inline std::string getUpperBound() const { return upperBound; }

inline ArrowDirection getDirection() const {
return arrowDirection == RIGHT ? ArrowDirection::RIGHT : ArrowDirection::LEFT;
}

inline bool isDirected() const { return arrowDirection != UNDIRECTED; }
inline ArrowDirection getDirection() const { return arrowDirection; }

private:
common::QueryRelType relType;
Expand Down
6 changes: 3 additions & 3 deletions src/include/planner/join_order_enumerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class JoinOrderEnumerator {
void planNodeScan(uint32_t nodePos);
void planRelScan(uint32_t relPos);
void appendExtendAndFilter(std::shared_ptr<RelExpression> rel,
common::RelDataDirection direction, const expression_vector& predicates, LogicalPlan& plan);
common::ExtendDirection direction, const expression_vector& predicates, LogicalPlan& plan);

void planLevel(uint32_t level);
void planLevelExactly(uint32_t level);
Expand All @@ -84,11 +84,11 @@ class JoinOrderEnumerator {

void appendNonRecursiveExtend(std::shared_ptr<NodeExpression> boundNode,
std::shared_ptr<NodeExpression> nbrNode, std::shared_ptr<RelExpression> rel,
common::RelDataDirection direction, const binder::expression_vector& properties,
common::ExtendDirection direction, const binder::expression_vector& properties,
LogicalPlan& plan);
void appendRecursiveExtend(std::shared_ptr<NodeExpression> boundNode,
std::shared_ptr<NodeExpression> nbrNode, std::shared_ptr<RelExpression> rel,
common::RelDataDirection direction, LogicalPlan& plan);
common::ExtendDirection direction, LogicalPlan& plan);

void planJoin(const binder::expression_vector& joinNodeIDs, common::JoinType joinType,
std::shared_ptr<Expression> mark, LogicalPlan& probePlan, LogicalPlan& buildPlan);
Expand Down
2 changes: 1 addition & 1 deletion src/parser/transformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ std::unique_ptr<RelPattern> Transformer::transformRelationshipPattern(
} else if (ctx.oC_RightArrowHead()) {
arrowDirection = ArrowDirection::RIGHT;
} else {
arrowDirection = ArrowDirection::UNDIRECTED;
arrowDirection = ArrowDirection::BOTH;
}
auto properties = relDetail->kU_Properties() ?
transformProperties(*relDetail->kU_Properties()) :
Expand Down
40 changes: 20 additions & 20 deletions src/planner/join_order_enumerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,9 @@ void JoinOrderEnumerator::planNodeScan(uint32_t nodePos) {
}

static std::pair<std::shared_ptr<NodeExpression>, std::shared_ptr<NodeExpression>>
getBoundAndNbrNodes(const RelExpression& rel, RelDataDirection direction) {
auto boundNode = direction == FWD ? rel.getSrcNode() : rel.getDstNode();
auto dstNode = direction == FWD ? rel.getDstNode() : rel.getSrcNode();
getBoundAndNbrNodes(const RelExpression& rel, ExtendDirection direction) {
auto boundNode = direction == ExtendDirection::FWD ? rel.getSrcNode() : rel.getDstNode();
auto dstNode = direction == ExtendDirection::FWD ? rel.getDstNode() : rel.getSrcNode();
return make_pair(boundNode, dstNode);
}

Expand All @@ -164,17 +164,20 @@ void JoinOrderEnumerator::planRelScan(uint32_t relPos) {
newSubgraph.addQueryRel(relPos);
auto predicates = getNewlyMatchedExpressions(
context->getEmptySubqueryGraph(), newSubgraph, context->getWhereExpressions());
for (auto direction : REL_DIRECTIONS) {

std::vector<ExtendDirection> EXTEND_DIRECTIONS = {ExtendDirection::FWD, ExtendDirection::BWD};
for (auto direction : EXTEND_DIRECTIONS) {
auto plan = std::make_unique<LogicalPlan>();
auto [boundNode, _] = getBoundAndNbrNodes(*rel, direction);
auto extendDirection = rel->isDirected() ? direction : ExtendDirection::BOTH;
appendScanNodeID(boundNode, *plan);
appendExtendAndFilter(rel, direction, predicates, *plan);
appendExtendAndFilter(rel, extendDirection, predicates, *plan);
context->addPlan(newSubgraph, std::move(plan));
}
}

void JoinOrderEnumerator::appendExtendAndFilter(std::shared_ptr<RelExpression> rel,
common::RelDataDirection direction, const expression_vector& predicates, LogicalPlan& plan) {
common::ExtendDirection direction, const expression_vector& predicates, LogicalPlan& plan) {
auto [boundNode, nbrNode] = getBoundAndNbrNodes(*rel, direction);
switch (rel->getRelType()) {
case common::QueryRelType::NON_RECURSIVE: {
Expand Down Expand Up @@ -383,7 +386,8 @@ bool JoinOrderEnumerator::tryPlanINLJoin(const SubqueryGraph& subgraph,
assert(relPos != UINT32_MAX);
auto rel = context->queryGraph->getQueryRel(relPos);
auto boundNode = joinNodes[0];
auto direction = boundNode->getUniqueName() == rel->getSrcNodeName() ? FWD : BWD;
auto direction = boundNode->getUniqueName() == rel->getSrcNodeName() ? ExtendDirection::FWD :
ExtendDirection::BWD;
auto newSubgraph = subgraph;
newSubgraph.addQueryRel(relPos);
auto predicates =
Expand Down Expand Up @@ -450,27 +454,24 @@ void JoinOrderEnumerator::appendScanNodeID(
}

static bool extendHasAtMostOneNbrGuarantee(RelExpression& rel, NodeExpression& boundNode,
RelDataDirection direction, const catalog::Catalog& catalog) {
ExtendDirection direction, const catalog::Catalog& catalog) {
if (boundNode.isMultiLabeled()) {
return false;
}
if (rel.isMultiLabeled()) {
return false;
}
auto relDirection = direction == ExtendDirection::BWD ? BWD : FWD;
return catalog.getReadOnlyVersion()->isSingleMultiplicityInDirection(
rel.getSingleTableID(), direction);
rel.getSingleTableID(), relDirection);
}

void JoinOrderEnumerator::appendNonRecursiveExtend(std::shared_ptr<NodeExpression> boundNode,
std::shared_ptr<NodeExpression> nbrNode, std::shared_ptr<RelExpression> rel,
RelDataDirection direction, const expression_vector& properties, LogicalPlan& plan) {
ExtendDirection direction, const expression_vector& properties, LogicalPlan& plan) {
auto hasAtMostOneNbr = extendHasAtMostOneNbrGuarantee(*rel, *boundNode, direction, catalog);
auto extendDirection = direction == FWD ? ExtendDirection::FWD : ExtendDirection::BWD;
if (!rel->isDirected()) {
extendDirection = ExtendDirection::BOTH;
}
auto extend = make_shared<LogicalExtend>(boundNode, nbrNode, rel, extendDirection, properties,
hasAtMostOneNbr, plan.getLastOperator());
auto extend = make_shared<LogicalExtend>(
boundNode, nbrNode, rel, direction, properties, hasAtMostOneNbr, plan.getLastOperator());
queryPlanner->appendFlattens(extend->getGroupsPosToFlatten(), plan);
extend->setChild(0, plan.getLastOperator());
extend->computeFactorizedSchema();
Expand All @@ -487,14 +488,13 @@ void JoinOrderEnumerator::appendNonRecursiveExtend(std::shared_ptr<NodeExpressio

void JoinOrderEnumerator::appendRecursiveExtend(std::shared_ptr<NodeExpression> boundNode,
std::shared_ptr<NodeExpression> nbrNode, std::shared_ptr<RelExpression> rel,
common::RelDataDirection direction, LogicalPlan& plan) {
common::ExtendDirection direction, LogicalPlan& plan) {
auto hasAtMostOneNbr = extendHasAtMostOneNbrGuarantee(*rel, *boundNode, direction, catalog);
auto extendDirection = direction == FWD ? ExtendDirection::FWD : ExtendDirection::BWD;
if (!rel->isDirected()) {
extendDirection = ExtendDirection::BOTH;
direction = ExtendDirection::BOTH;
}
auto extend = std::make_shared<LogicalRecursiveExtend>(
boundNode, nbrNode, rel, extendDirection, plan.getLastOperator());
boundNode, nbrNode, rel, direction, plan.getLastOperator());
queryPlanner->appendFlattens(extend->getGroupsPosToFlatten(), plan);
extend->setChild(0, plan.getLastOperator());
extend->computeFactorizedSchema();
Expand Down
11 changes: 7 additions & 4 deletions src/processor/mapper/map_extend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ using namespace kuzu::storage;
namespace kuzu {
namespace processor {

static RelDataDirection getRelDataDirection(ExtendDirection extendDirection) {
// assert(extendDirection != ExtendDirection::BOTH);
return extendDirection == ExtendDirection::BWD ? BWD : FWD;
}

static std::vector<property_id_t> populatePropertyIds(
table_id_t relID, const expression_vector& properties) {
std::vector<property_id_t> outputColumns;
Expand Down Expand Up @@ -59,8 +64,7 @@ std::unique_ptr<PhysicalOperator> PlanMapper::mapLogicalExtendToPhysical(
auto boundNode = extend->getBoundNode();
auto nbrNode = extend->getNbrNode();
auto rel = extend->getRel();
auto extendDirection = extend->getDirection();
auto direction = (extendDirection == ExtendDirection::BWD) ? BWD : FWD;
auto direction = getRelDataDirection(extend->getDirection());
auto prevOperator = mapLogicalOperatorToPhysical(logicalOperator->getChild(0));
auto inNodeIDVectorPos =
DataPos(inSchema->getExpressionPos(*boundNode->getInternalIDProperty()));
Expand Down Expand Up @@ -113,8 +117,7 @@ std::unique_ptr<PhysicalOperator> PlanMapper::mapLogicalRecursiveExtendToPhysica
auto boundNode = extend->getBoundNode();
auto nbrNode = extend->getNbrNode();
auto rel = extend->getRel();
auto extendDirection = extend->getDirection();
auto direction = (extendDirection == ExtendDirection::BWD) ? BWD : FWD;
auto direction = getRelDataDirection(extend->getDirection());
auto prevOperator = mapLogicalOperatorToPhysical(logicalOperator->getChild(0));
auto inNodeIDVectorPos =
DataPos(inSchema->getExpressionPos(*boundNode->getInternalIDProperty()));
Expand Down

0 comments on commit 40bebb9

Please sign in to comment.