Skip to content

Commit

Permalink
Add recursive join planning
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed Apr 18, 2023
1 parent d8856cb commit d6a18ba
Show file tree
Hide file tree
Showing 36 changed files with 3,043 additions and 2,662 deletions.
4 changes: 3 additions & 1 deletion src/antlr4/Cypher.g4
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,9 @@ oC_NodeLabel
: ':' SP? oC_LabelName ;

oC_RangeLiteral
: '*' SP? oC_IntegerLiteral SP? '..' SP? oC_IntegerLiteral ;
: '*' SP? SHORTEST? SP? oC_IntegerLiteral SP? '..' SP? oC_IntegerLiteral ;

SHORTEST : ( 'S' | 's' ) ( 'H' | 'h' ) ( 'O' | 'o' ) ( 'R' | 'r' ) ( 'T' | 't' ) ( 'E' | 'e' ) ( 'S' | 's' ) ( 'T' | 't' ) ;

oC_LabelName
: oC_SchemaName ;
Expand Down
6 changes: 3 additions & 3 deletions src/binder/bind/bind_graph_pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,15 @@ void Binder::bindQueryRel(const RelPattern& relPattern,
// bind variable length
auto [lowerBound, upperBound] = bindVariableLengthRelBound(relPattern);
auto queryRel = make_shared<RelExpression>(getUniqueExpressionName(parsedName), parsedName,
tableIDs, srcNode, dstNode, lowerBound, upperBound);
tableIDs, srcNode, dstNode, relPattern.getRelType(), lowerBound, upperBound);
queryRel->setAlias(parsedName);
// resolve properties associate with rel table
std::vector<RelTableSchema*> relTableSchemas;
for (auto tableID : tableIDs) {
relTableSchemas.push_back(catalog.getReadOnlyVersion()->getRelTableSchema(tableID));
}
// we don't support reading property for variable length rel yet.
if (!queryRel->isVariableLength()) {
// we don't support reading property for VARIABLE_LENGTH or SHORTEST rel.
if (queryRel->getRelType() == common::QueryRelType::NON_RECURSIVE) {
for (auto& [propertyName, propertySchemas] :
getRelPropertyNameAndPropertiesPairs(relTableSchemas)) {
auto propertyExpression = expressionBinder.createPropertyExpression(
Expand Down
6 changes: 5 additions & 1 deletion src/binder/bind_expression/bind_property_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,13 @@ static std::unordered_map<table_id_t, property_id_t> populatePropertyIDPerTable(
std::shared_ptr<Expression> ExpressionBinder::bindRelPropertyExpression(
const Expression& expression, const std::string& propertyName) {
auto& rel = (RelExpression&)expression;
if (rel.isVariableLength()) {
switch (rel.getRelType()) {
case common::QueryRelType::VARIABLE_LENGTH:
case common::QueryRelType::SHORTEST:
throw BinderException(
"Cannot read property of variable length rel " + rel.toString() + ".");
default:
break;
}
if (!rel.hasPropertyExpression(propertyName)) {
throw BinderException(
Expand Down
11 changes: 7 additions & 4 deletions src/include/binder/expression/rel_expression.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include "common/exception.h"
#include "common/query_rel_type.h"
#include "node_expression.h"

namespace kuzu {
Expand All @@ -10,11 +11,12 @@ class RelExpression : public NodeOrRelExpression {
public:
RelExpression(std::string uniqueName, std::string variableName,
std::vector<common::table_id_t> tableIDs, std::shared_ptr<NodeExpression> srcNode,
std::shared_ptr<NodeExpression> dstNode, uint64_t lowerBound, uint64_t upperBound)
std::shared_ptr<NodeExpression> dstNode, common::QueryRelType relType, uint64_t lowerBound,
uint64_t upperBound)
: NodeOrRelExpression{common::REL, std::move(uniqueName), std::move(variableName),
std::move(tableIDs)},
srcNode{std::move(srcNode)}, dstNode{std::move(dstNode)}, lowerBound{lowerBound},
upperBound{upperBound} {}
srcNode{std::move(srcNode)}, dstNode{std::move(dstNode)}, relType{relType},
lowerBound{lowerBound}, upperBound{upperBound} {}

inline bool isBoundByMultiLabeledNode() const {
return srcNode->isMultiLabeled() || dstNode->isMultiLabeled();
Expand All @@ -25,9 +27,9 @@ class RelExpression : public NodeOrRelExpression {
inline std::shared_ptr<NodeExpression> getDstNode() const { return dstNode; }
inline std::string getDstNodeName() const { return dstNode->getUniqueName(); }

inline common::QueryRelType getRelType() const { return relType; }
inline uint64_t getLowerBound() const { return lowerBound; }
inline uint64_t getUpperBound() const { return upperBound; }
inline bool isVariableLength() const { return !(lowerBound == 1 && upperBound == 1); }

inline bool hasInternalIDProperty() const {
return hasPropertyExpression(common::INTERNAL_ID_SUFFIX);
Expand All @@ -39,6 +41,7 @@ class RelExpression : public NodeOrRelExpression {
private:
std::shared_ptr<NodeExpression> srcNode;
std::shared_ptr<NodeExpression> dstNode;
common::QueryRelType relType;
uint64_t lowerBound;
uint64_t upperBound;
};
Expand Down
1 change: 1 addition & 0 deletions src/include/common/clause_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ enum class ClauseType : uint8_t {
SET = 0,
DELETE = 1,
CREATE = 2,
// reading clause
MATCH = 3,
UNWIND = 4
};
Expand Down
15 changes: 15 additions & 0 deletions src/include/common/query_rel_type.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#pragma once

#include <cstdint>

namespace kuzu {
namespace common {

enum class QueryRelType : uint8_t {
NON_RECURSIVE = 0,
VARIABLE_LENGTH = 1,
SHORTEST = 2,
};

} // namespace common
} // namespace kuzu
1 change: 1 addition & 0 deletions src/include/optimizer/factorization_rewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class FactorizationRewriter : public LogicalOperatorVisitor {
private:
void visitOperator(planner::LogicalOperator* op);
void visitExtend(planner::LogicalOperator* op) override;
void visitRecursiveExtend(planner::LogicalOperator* op) override;
void visitHashJoin(planner::LogicalOperator* op) override;
void visitIntersect(planner::LogicalOperator* op) override;
void visitProjection(planner::LogicalOperator* op) override;
Expand Down
6 changes: 6 additions & 0 deletions src/include/optimizer/logical_operator_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ class LogicalOperatorVisitor {
return op;
}

virtual void visitRecursiveExtend(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitRecursiveExtendReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}

virtual void visitHashJoin(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitHashJoinReplace(
std::shared_ptr<planner::LogicalOperator> op) {
Expand Down
10 changes: 7 additions & 3 deletions src/include/parser/query/graph_pattern/rel_pattern.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include "common/query_rel_type.h"
#include "node_pattern.h"

namespace kuzu {
Expand All @@ -12,22 +13,25 @@ enum ArrowDirection : uint8_t { LEFT = 0, RIGHT = 1 };
*/
class RelPattern : public NodePattern {
public:
RelPattern(std::string name, std::vector<std::string> tableNames, std::string lowerBound,
std::string upperBound, ArrowDirection arrowDirection,
RelPattern(std::string name, std::vector<std::string> tableNames, common::QueryRelType relType,
std::string lowerBound, std::string upperBound, ArrowDirection arrowDirection,
std::vector<std::pair<std::string, std::unique_ptr<ParsedExpression>>> propertyKeyValPairs)
: NodePattern{std::move(name), std::move(tableNames), std::move(propertyKeyValPairs)},
lowerBound{std::move(lowerBound)}, upperBound{std::move(upperBound)},
relType{relType}, lowerBound{std::move(lowerBound)}, upperBound{std::move(upperBound)},
arrowDirection{arrowDirection} {}

~RelPattern() override = default;

inline common::QueryRelType getRelType() const { return relType; }

inline std::string getLowerBound() const { return lowerBound; }

inline std::string getUpperBound() const { return upperBound; }

inline ArrowDirection getDirection() const { return arrowDirection; }

private:
common::QueryRelType relType;
std::string lowerBound;
std::string upperBound;
ArrowDirection arrowDirection;
Expand Down
2 changes: 2 additions & 0 deletions src/include/planner/join_order/cost_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ namespace planner {
class CostModel {
public:
static uint64_t computeExtendCost(const LogicalPlan& childPlan);
static uint64_t computeRecursiveExtendCost(
uint8_t upperBound, double extensionRate, const LogicalPlan& childPlan);
static uint64_t computeHashJoinCost(const binder::expression_vector& joinNodeIDs,
const LogicalPlan& probe, const LogicalPlan& build);
static uint64_t computeMarkJoinCost(const binder::expression_vector& joinNodeIDs,
Expand Down
10 changes: 7 additions & 3 deletions src/include/planner/join_order_enumerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,16 @@ class JoinOrderEnumerator {

void appendScanNodeID(std::shared_ptr<NodeExpression>& node, LogicalPlan& plan);

bool needExtendToNewGroup(
RelExpression& rel, NodeExpression& boundNode, common::RelDirection direction);
void appendExtend(std::shared_ptr<NodeExpression> boundNode,
void appendNonRecursiveExtend(std::shared_ptr<NodeExpression> boundNode,
std::shared_ptr<NodeExpression> nbrNode, std::shared_ptr<RelExpression> rel,
common::RelDirection direction, const binder::expression_vector& properties,
LogicalPlan& plan);
void appendVariableLengthExtend(std::shared_ptr<NodeExpression> boundNode,
std::shared_ptr<NodeExpression> nbrNode, std::shared_ptr<RelExpression> rel,
common::RelDirection direction, LogicalPlan& plan);
void appendShortestPathExtend(std::shared_ptr<NodeExpression> boundNode,
std::shared_ptr<NodeExpression> nbrNode, std::shared_ptr<RelExpression> rel,
common::RelDirection direction, LogicalPlan& plan);

void planJoin(const binder::expression_vector& joinNodeIDs, common::JoinType joinType,
std::shared_ptr<Expression> mark, LogicalPlan& probePlan, LogicalPlan& buildPlan);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#pragma once

#include "base_logical_operator.h"
#include "binder/expression/rel_expression.h"

namespace kuzu {
namespace planner {

class BaseLogicalExtend : public LogicalOperator {
public:
BaseLogicalExtend(LogicalOperatorType operatorType,
std::shared_ptr<binder::NodeExpression> boundNode,
std::shared_ptr<binder::NodeExpression> nbrNode, std::shared_ptr<binder::RelExpression> rel,
common::RelDirection direction, std::shared_ptr<LogicalOperator> child)
: LogicalOperator{operatorType, std::move(child)}, boundNode{std::move(boundNode)},
nbrNode{std::move(nbrNode)}, rel{std::move(rel)}, direction{direction} {}

inline std::shared_ptr<binder::NodeExpression> getBoundNode() const { return boundNode; }
inline std::shared_ptr<binder::NodeExpression> getNbrNode() const { return nbrNode; }
inline std::shared_ptr<binder::RelExpression> getRel() const { return rel; }
inline common::RelDirection getDirection() const { return direction; }

virtual f_group_pos_set getGroupsPosToFlatten() = 0;

std::string getExpressionsForPrinting() const override;

protected:
std::shared_ptr<binder::NodeExpression> boundNode;
std::shared_ptr<binder::NodeExpression> nbrNode;
std::shared_ptr<binder::RelExpression> rel;
common::RelDirection direction;
};

} // namespace planner
} // namespace kuzu
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ enum class LogicalOperatorType : uint8_t {
MULTIPLICITY_REDUCER,
ORDER_BY,
PROJECTION,
RECURSIVE_EXTEND,
RENAME_TABLE,
RENAME_PROPERTY,
SCAN_NODE,
Expand Down
36 changes: 10 additions & 26 deletions src/include/planner/logical_plan/logical_operator/logical_extend.h
Original file line number Diff line number Diff line change
@@ -1,51 +1,35 @@
#pragma once

#include "base_logical_operator.h"
#include "binder/expression/rel_expression.h"
#include "base_logical_extend.h"

namespace kuzu {
namespace planner {

class LogicalExtend : public LogicalOperator {
class LogicalExtend : public BaseLogicalExtend {
public:
LogicalExtend(std::shared_ptr<binder::NodeExpression> boundNode,
std::shared_ptr<binder::NodeExpression> nbrNode, std::shared_ptr<binder::RelExpression> rel,
common::RelDirection direction, binder::expression_vector properties, bool extendToNewGroup,
common::RelDirection direction, binder::expression_vector properties, bool hasAtMostOneNbr,
std::shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::EXTEND, std::move(child)}, boundNode{std::move(
boundNode)},
nbrNode{std::move(nbrNode)}, rel{std::move(rel)}, direction{direction},
properties{std::move(properties)}, extendToNewGroup{extendToNewGroup} {}
: BaseLogicalExtend{LogicalOperatorType::EXTEND, std::move(boundNode), std::move(nbrNode),
std::move(rel), direction, std::move(child)},
properties{std::move(properties)}, hasAtMostOneNbr{hasAtMostOneNbr} {}

f_group_pos_set getGroupsPosToFlatten();
f_group_pos_set getGroupsPosToFlatten() override;

void computeFactorizedSchema() override;
void computeFlatSchema() override;

inline std::string getExpressionsForPrinting() const override {
return boundNode->toString() + (direction == common::RelDirection::FWD ? "->" : "<-") +
nbrNode->toString();
}

inline std::shared_ptr<binder::NodeExpression> getBoundNode() const { return boundNode; }
inline std::shared_ptr<binder::NodeExpression> getNbrNode() const { return nbrNode; }
inline std::shared_ptr<binder::RelExpression> getRel() const { return rel; }
inline common::RelDirection getDirection() const { return direction; }
inline binder::expression_vector getProperties() const { return properties; }

inline std::unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalExtend>(
boundNode, nbrNode, rel, direction, properties, extendToNewGroup, children[0]->copy());
boundNode, nbrNode, rel, direction, properties, hasAtMostOneNbr, children[0]->copy());
}

protected:
std::shared_ptr<binder::NodeExpression> boundNode;
std::shared_ptr<binder::NodeExpression> nbrNode;
std::shared_ptr<binder::RelExpression> rel;
common::RelDirection direction;
private:
binder::expression_vector properties;
// When extend might increase cardinality (i.e. n * m), we extend to a new factorization group.
bool extendToNewGroup;
bool hasAtMostOneNbr;
};

} // namespace planner
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#pragma once

#include "base_logical_extend.h"

namespace kuzu {
namespace planner {

// TODO(Xiyang): we should have a single LogicalRecursiveExtend once we migrate VariableLengthExtend
// to use the same infrastructure as shortest path.
class LogicalRecursiveExtend : public BaseLogicalExtend {
public:
LogicalRecursiveExtend(std::shared_ptr<binder::NodeExpression> boundNode,
std::shared_ptr<binder::NodeExpression> nbrNode, std::shared_ptr<binder::RelExpression> rel,
common::RelDirection direction, std::shared_ptr<LogicalOperator> child)
: BaseLogicalExtend{LogicalOperatorType::RECURSIVE_EXTEND, std::move(boundNode),
std::move(nbrNode), std::move(rel), direction, std::move(child)} {}

f_group_pos_set getGroupsPosToFlatten() override;

void computeFlatSchema() override;
};

class LogicalVariableLengthExtend : public LogicalRecursiveExtend {
public:
LogicalVariableLengthExtend(std::shared_ptr<binder::NodeExpression> boundNode,
std::shared_ptr<binder::NodeExpression> nbrNode, std::shared_ptr<binder::RelExpression> rel,
common::RelDirection direction, bool hasAtMostOneNbr,
std::shared_ptr<LogicalOperator> child)
: LogicalRecursiveExtend{std::move(boundNode), std::move(nbrNode), std::move(rel),
direction, std::move(child)},
hasAtMostOneNbr{hasAtMostOneNbr} {}

void computeFactorizedSchema() override;

inline std::unique_ptr<LogicalOperator> copy() override {
return std::make_unique<LogicalVariableLengthExtend>(
boundNode, nbrNode, rel, direction, hasAtMostOneNbr, children[0]->copy());
}

private:
bool hasAtMostOneNbr;
};

class LogicalShortestPathExtend : public LogicalRecursiveExtend {
public:
LogicalShortestPathExtend(std::shared_ptr<binder::NodeExpression> boundNode,
std::shared_ptr<binder::NodeExpression> nbrNode, std::shared_ptr<binder::RelExpression> rel,
common::RelDirection direction, std::shared_ptr<LogicalOperator> child)
: LogicalRecursiveExtend{std::move(boundNode), std::move(nbrNode), std::move(rel),
direction, std::move(child)} {}

void computeFactorizedSchema() override;

inline std::unique_ptr<LogicalOperator> copy() override {
return std::make_unique<LogicalShortestPathExtend>(
boundNode, nbrNode, rel, direction, children[0]->copy());
}
};

} // namespace planner
} // namespace kuzu
1 change: 1 addition & 0 deletions src/include/planner/logical_plan/logical_plan_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class LogicalPlanUtil {
static void encodeIntersect(LogicalOperator* logicalOperator, std::string& encodeString);
static void encodeHashJoin(LogicalOperator* logicalOperator, std::string& encodeString);
static void encodeExtend(LogicalOperator* logicalOperator, std::string& encodeString);
static void encodeRecursiveExtend(LogicalOperator* logicalOperator, std::string& encodeString);
static void encodeScanNodeID(LogicalOperator* logicalOperator, std::string& encodeString);
};

Expand Down
2 changes: 2 additions & 0 deletions src/include/processor/mapper/plan_mapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class PlanMapper {
planner::LogicalOperator* logicalOperator);
std::unique_ptr<PhysicalOperator> mapLogicalExtendToPhysical(
planner::LogicalOperator* logicalOperator);
std::unique_ptr<PhysicalOperator> mapLogicalRecursiveExtendToPhysical(
planner::LogicalOperator* logicalOperator);
std::unique_ptr<PhysicalOperator> mapLogicalFlattenToPhysical(
planner::LogicalOperator* logicalOperator);
std::unique_ptr<PhysicalOperator> mapLogicalFilterToPhysical(
Expand Down
Loading

0 comments on commit d6a18ba

Please sign in to comment.