Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shortest path planning #1472

Merged
merged 1 commit into from
Apr 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/antlr4/Cypher.g4
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,9 @@ oC_NodeLabel
: ':' SP? oC_LabelName ;

oC_RangeLiteral
: '*' SP? oC_IntegerLiteral SP? '..' SP? oC_IntegerLiteral ;
: '*' SP? SHORTEST? SP? oC_IntegerLiteral SP? '..' SP? oC_IntegerLiteral ;

SHORTEST : ( 'S' | 's' ) ( 'H' | 'h' ) ( 'O' | 'o' ) ( 'R' | 'r' ) ( 'T' | 't' ) ( 'E' | 'e' ) ( 'S' | 's' ) ( 'T' | 't' ) ;

oC_LabelName
: oC_SchemaName ;
Expand Down
6 changes: 3 additions & 3 deletions src/binder/bind/bind_graph_pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,15 @@ void Binder::bindQueryRel(const RelPattern& relPattern,
// bind variable length
auto [lowerBound, upperBound] = bindVariableLengthRelBound(relPattern);
auto queryRel = make_shared<RelExpression>(getUniqueExpressionName(parsedName), parsedName,
tableIDs, srcNode, dstNode, lowerBound, upperBound);
tableIDs, srcNode, dstNode, relPattern.getRelType(), lowerBound, upperBound);
queryRel->setAlias(parsedName);
// resolve properties associate with rel table
std::vector<RelTableSchema*> relTableSchemas;
for (auto tableID : tableIDs) {
relTableSchemas.push_back(catalog.getReadOnlyVersion()->getRelTableSchema(tableID));
}
// we don't support reading property for variable length rel yet.
if (!queryRel->isVariableLength()) {
// we don't support reading property for VARIABLE_LENGTH or SHORTEST rel.
if (queryRel->getRelType() == common::QueryRelType::NON_RECURSIVE) {
for (auto& [propertyName, propertySchemas] :
getRelPropertyNameAndPropertiesPairs(relTableSchemas)) {
auto propertyExpression = expressionBinder.createPropertyExpression(
Expand Down
6 changes: 5 additions & 1 deletion src/binder/bind_expression/bind_property_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,13 @@ static std::unordered_map<table_id_t, property_id_t> populatePropertyIDPerTable(
std::shared_ptr<Expression> ExpressionBinder::bindRelPropertyExpression(
const Expression& expression, const std::string& propertyName) {
auto& rel = (RelExpression&)expression;
if (rel.isVariableLength()) {
switch (rel.getRelType()) {
case common::QueryRelType::VARIABLE_LENGTH:
case common::QueryRelType::SHORTEST:
throw BinderException(
"Cannot read property of variable length rel " + rel.toString() + ".");
default:
break;
}
if (!rel.hasPropertyExpression(propertyName)) {
throw BinderException(
Expand Down
11 changes: 7 additions & 4 deletions src/include/binder/expression/rel_expression.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include "common/exception.h"
#include "common/query_rel_type.h"
#include "node_expression.h"

namespace kuzu {
Expand All @@ -10,11 +11,12 @@ class RelExpression : public NodeOrRelExpression {
public:
RelExpression(std::string uniqueName, std::string variableName,
std::vector<common::table_id_t> tableIDs, std::shared_ptr<NodeExpression> srcNode,
std::shared_ptr<NodeExpression> dstNode, uint64_t lowerBound, uint64_t upperBound)
std::shared_ptr<NodeExpression> dstNode, common::QueryRelType relType, uint64_t lowerBound,
uint64_t upperBound)
: NodeOrRelExpression{common::REL, std::move(uniqueName), std::move(variableName),
std::move(tableIDs)},
srcNode{std::move(srcNode)}, dstNode{std::move(dstNode)}, lowerBound{lowerBound},
upperBound{upperBound} {}
srcNode{std::move(srcNode)}, dstNode{std::move(dstNode)}, relType{relType},
lowerBound{lowerBound}, upperBound{upperBound} {}

inline bool isBoundByMultiLabeledNode() const {
return srcNode->isMultiLabeled() || dstNode->isMultiLabeled();
Expand All @@ -25,9 +27,9 @@ class RelExpression : public NodeOrRelExpression {
inline std::shared_ptr<NodeExpression> getDstNode() const { return dstNode; }
inline std::string getDstNodeName() const { return dstNode->getUniqueName(); }

inline common::QueryRelType getRelType() const { return relType; }
inline uint64_t getLowerBound() const { return lowerBound; }
inline uint64_t getUpperBound() const { return upperBound; }
inline bool isVariableLength() const { return !(lowerBound == 1 && upperBound == 1); }

inline bool hasInternalIDProperty() const {
return hasPropertyExpression(common::INTERNAL_ID_SUFFIX);
Expand All @@ -39,6 +41,7 @@ class RelExpression : public NodeOrRelExpression {
private:
std::shared_ptr<NodeExpression> srcNode;
std::shared_ptr<NodeExpression> dstNode;
common::QueryRelType relType;
uint64_t lowerBound;
uint64_t upperBound;
};
Expand Down
1 change: 1 addition & 0 deletions src/include/common/clause_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ enum class ClauseType : uint8_t {
SET = 0,
DELETE = 1,
CREATE = 2,
// reading clause
MATCH = 3,
UNWIND = 4
};
Expand Down
15 changes: 15 additions & 0 deletions src/include/common/query_rel_type.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#pragma once

#include <cstdint>

namespace kuzu {
namespace common {

enum class QueryRelType : uint8_t {
NON_RECURSIVE = 0,
VARIABLE_LENGTH = 1,
SHORTEST = 2,
};

} // namespace common
} // namespace kuzu
1 change: 1 addition & 0 deletions src/include/optimizer/factorization_rewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class FactorizationRewriter : public LogicalOperatorVisitor {
private:
void visitOperator(planner::LogicalOperator* op);
void visitExtend(planner::LogicalOperator* op) override;
void visitRecursiveExtend(planner::LogicalOperator* op) override;
void visitHashJoin(planner::LogicalOperator* op) override;
void visitIntersect(planner::LogicalOperator* op) override;
void visitProjection(planner::LogicalOperator* op) override;
Expand Down
6 changes: 6 additions & 0 deletions src/include/optimizer/logical_operator_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ class LogicalOperatorVisitor {
return op;
}

virtual void visitRecursiveExtend(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitRecursiveExtendReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}

virtual void visitHashJoin(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitHashJoinReplace(
std::shared_ptr<planner::LogicalOperator> op) {
Expand Down
10 changes: 7 additions & 3 deletions src/include/parser/query/graph_pattern/rel_pattern.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include "common/query_rel_type.h"
#include "node_pattern.h"

namespace kuzu {
Expand All @@ -12,22 +13,25 @@ enum ArrowDirection : uint8_t { LEFT = 0, RIGHT = 1 };
*/
class RelPattern : public NodePattern {
public:
RelPattern(std::string name, std::vector<std::string> tableNames, std::string lowerBound,
std::string upperBound, ArrowDirection arrowDirection,
RelPattern(std::string name, std::vector<std::string> tableNames, common::QueryRelType relType,
std::string lowerBound, std::string upperBound, ArrowDirection arrowDirection,
std::vector<std::pair<std::string, std::unique_ptr<ParsedExpression>>> propertyKeyValPairs)
: NodePattern{std::move(name), std::move(tableNames), std::move(propertyKeyValPairs)},
lowerBound{std::move(lowerBound)}, upperBound{std::move(upperBound)},
relType{relType}, lowerBound{std::move(lowerBound)}, upperBound{std::move(upperBound)},
arrowDirection{arrowDirection} {}

~RelPattern() override = default;

inline common::QueryRelType getRelType() const { return relType; }

inline std::string getLowerBound() const { return lowerBound; }

inline std::string getUpperBound() const { return upperBound; }

inline ArrowDirection getDirection() const { return arrowDirection; }

private:
common::QueryRelType relType;
std::string lowerBound;
std::string upperBound;
ArrowDirection arrowDirection;
Expand Down
2 changes: 2 additions & 0 deletions src/include/planner/join_order/cost_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ namespace planner {
class CostModel {
public:
static uint64_t computeExtendCost(const LogicalPlan& childPlan);
static uint64_t computeRecursiveExtendCost(
uint8_t upperBound, double extensionRate, const LogicalPlan& childPlan);
static uint64_t computeHashJoinCost(const binder::expression_vector& joinNodeIDs,
const LogicalPlan& probe, const LogicalPlan& build);
static uint64_t computeMarkJoinCost(const binder::expression_vector& joinNodeIDs,
Expand Down
10 changes: 7 additions & 3 deletions src/include/planner/join_order_enumerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,16 @@ class JoinOrderEnumerator {

void appendScanNodeID(std::shared_ptr<NodeExpression>& node, LogicalPlan& plan);

bool needExtendToNewGroup(
RelExpression& rel, NodeExpression& boundNode, common::RelDirection direction);
void appendExtend(std::shared_ptr<NodeExpression> boundNode,
void appendNonRecursiveExtend(std::shared_ptr<NodeExpression> boundNode,
std::shared_ptr<NodeExpression> nbrNode, std::shared_ptr<RelExpression> rel,
common::RelDirection direction, const binder::expression_vector& properties,
LogicalPlan& plan);
void appendVariableLengthExtend(std::shared_ptr<NodeExpression> boundNode,
std::shared_ptr<NodeExpression> nbrNode, std::shared_ptr<RelExpression> rel,
common::RelDirection direction, LogicalPlan& plan);
void appendShortestPathExtend(std::shared_ptr<NodeExpression> boundNode,
std::shared_ptr<NodeExpression> nbrNode, std::shared_ptr<RelExpression> rel,
common::RelDirection direction, LogicalPlan& plan);

void planJoin(const binder::expression_vector& joinNodeIDs, common::JoinType joinType,
std::shared_ptr<Expression> mark, LogicalPlan& probePlan, LogicalPlan& buildPlan);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#pragma once

#include "base_logical_operator.h"
#include "binder/expression/rel_expression.h"

namespace kuzu {
namespace planner {

class BaseLogicalExtend : public LogicalOperator {
public:
BaseLogicalExtend(LogicalOperatorType operatorType,
std::shared_ptr<binder::NodeExpression> boundNode,
std::shared_ptr<binder::NodeExpression> nbrNode, std::shared_ptr<binder::RelExpression> rel,
common::RelDirection direction, std::shared_ptr<LogicalOperator> child)
: LogicalOperator{operatorType, std::move(child)}, boundNode{std::move(boundNode)},
nbrNode{std::move(nbrNode)}, rel{std::move(rel)}, direction{direction} {}

inline std::shared_ptr<binder::NodeExpression> getBoundNode() const { return boundNode; }
inline std::shared_ptr<binder::NodeExpression> getNbrNode() const { return nbrNode; }
inline std::shared_ptr<binder::RelExpression> getRel() const { return rel; }
inline common::RelDirection getDirection() const { return direction; }

virtual f_group_pos_set getGroupsPosToFlatten() = 0;

std::string getExpressionsForPrinting() const override;

protected:
std::shared_ptr<binder::NodeExpression> boundNode;
std::shared_ptr<binder::NodeExpression> nbrNode;
std::shared_ptr<binder::RelExpression> rel;
common::RelDirection direction;
};

} // namespace planner
} // namespace kuzu
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ enum class LogicalOperatorType : uint8_t {
MULTIPLICITY_REDUCER,
ORDER_BY,
PROJECTION,
RECURSIVE_EXTEND,
RENAME_TABLE,
RENAME_PROPERTY,
SCAN_NODE,
Expand Down
36 changes: 10 additions & 26 deletions src/include/planner/logical_plan/logical_operator/logical_extend.h
Original file line number Diff line number Diff line change
@@ -1,51 +1,35 @@
#pragma once

#include "base_logical_operator.h"
#include "binder/expression/rel_expression.h"
#include "base_logical_extend.h"

namespace kuzu {
namespace planner {

class LogicalExtend : public LogicalOperator {
class LogicalExtend : public BaseLogicalExtend {
public:
LogicalExtend(std::shared_ptr<binder::NodeExpression> boundNode,
std::shared_ptr<binder::NodeExpression> nbrNode, std::shared_ptr<binder::RelExpression> rel,
common::RelDirection direction, binder::expression_vector properties, bool extendToNewGroup,
common::RelDirection direction, binder::expression_vector properties, bool hasAtMostOneNbr,
std::shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::EXTEND, std::move(child)}, boundNode{std::move(
boundNode)},
nbrNode{std::move(nbrNode)}, rel{std::move(rel)}, direction{direction},
properties{std::move(properties)}, extendToNewGroup{extendToNewGroup} {}
: BaseLogicalExtend{LogicalOperatorType::EXTEND, std::move(boundNode), std::move(nbrNode),
std::move(rel), direction, std::move(child)},
properties{std::move(properties)}, hasAtMostOneNbr{hasAtMostOneNbr} {}

f_group_pos_set getGroupsPosToFlatten();
f_group_pos_set getGroupsPosToFlatten() override;

void computeFactorizedSchema() override;
void computeFlatSchema() override;

inline std::string getExpressionsForPrinting() const override {
return boundNode->toString() + (direction == common::RelDirection::FWD ? "->" : "<-") +
nbrNode->toString();
}

inline std::shared_ptr<binder::NodeExpression> getBoundNode() const { return boundNode; }
inline std::shared_ptr<binder::NodeExpression> getNbrNode() const { return nbrNode; }
inline std::shared_ptr<binder::RelExpression> getRel() const { return rel; }
inline common::RelDirection getDirection() const { return direction; }
inline binder::expression_vector getProperties() const { return properties; }

inline std::unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalExtend>(
boundNode, nbrNode, rel, direction, properties, extendToNewGroup, children[0]->copy());
boundNode, nbrNode, rel, direction, properties, hasAtMostOneNbr, children[0]->copy());
}

protected:
std::shared_ptr<binder::NodeExpression> boundNode;
std::shared_ptr<binder::NodeExpression> nbrNode;
std::shared_ptr<binder::RelExpression> rel;
common::RelDirection direction;
private:
binder::expression_vector properties;
// When extend might increase cardinality (i.e. n * m), we extend to a new factorization group.
bool extendToNewGroup;
bool hasAtMostOneNbr;
};

} // namespace planner
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#pragma once

#include "base_logical_extend.h"

namespace kuzu {
namespace planner {

// TODO(Xiyang): we should have a single LogicalRecursiveExtend once we migrate VariableLengthExtend
// to use the same infrastructure as shortest path.
class LogicalRecursiveExtend : public BaseLogicalExtend {
public:
LogicalRecursiveExtend(std::shared_ptr<binder::NodeExpression> boundNode,
std::shared_ptr<binder::NodeExpression> nbrNode, std::shared_ptr<binder::RelExpression> rel,
common::RelDirection direction, std::shared_ptr<LogicalOperator> child)
: BaseLogicalExtend{LogicalOperatorType::RECURSIVE_EXTEND, std::move(boundNode),
std::move(nbrNode), std::move(rel), direction, std::move(child)} {}

f_group_pos_set getGroupsPosToFlatten() override;

void computeFlatSchema() override;
};

andyfengHKU marked this conversation as resolved.
Show resolved Hide resolved
class LogicalVariableLengthExtend : public LogicalRecursiveExtend {
public:
LogicalVariableLengthExtend(std::shared_ptr<binder::NodeExpression> boundNode,
std::shared_ptr<binder::NodeExpression> nbrNode, std::shared_ptr<binder::RelExpression> rel,
common::RelDirection direction, bool hasAtMostOneNbr,
std::shared_ptr<LogicalOperator> child)
: LogicalRecursiveExtend{std::move(boundNode), std::move(nbrNode), std::move(rel),
direction, std::move(child)},
hasAtMostOneNbr{hasAtMostOneNbr} {}

void computeFactorizedSchema() override;

inline std::unique_ptr<LogicalOperator> copy() override {
return std::make_unique<LogicalVariableLengthExtend>(
boundNode, nbrNode, rel, direction, hasAtMostOneNbr, children[0]->copy());
}

private:
bool hasAtMostOneNbr;
};

class LogicalShortestPathExtend : public LogicalRecursiveExtend {
public:
LogicalShortestPathExtend(std::shared_ptr<binder::NodeExpression> boundNode,
std::shared_ptr<binder::NodeExpression> nbrNode, std::shared_ptr<binder::RelExpression> rel,
common::RelDirection direction, std::shared_ptr<LogicalOperator> child)
: LogicalRecursiveExtend{std::move(boundNode), std::move(nbrNode), std::move(rel),
direction, std::move(child)} {}

void computeFactorizedSchema() override;

inline std::unique_ptr<LogicalOperator> copy() override {
return std::make_unique<LogicalShortestPathExtend>(
boundNode, nbrNode, rel, direction, children[0]->copy());
}
};

} // namespace planner
} // namespace kuzu
1 change: 1 addition & 0 deletions src/include/planner/logical_plan/logical_plan_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class LogicalPlanUtil {
static void encodeIntersect(LogicalOperator* logicalOperator, std::string& encodeString);
static void encodeHashJoin(LogicalOperator* logicalOperator, std::string& encodeString);
static void encodeExtend(LogicalOperator* logicalOperator, std::string& encodeString);
static void encodeRecursiveExtend(LogicalOperator* logicalOperator, std::string& encodeString);
static void encodeScanNodeID(LogicalOperator* logicalOperator, std::string& encodeString);
};

Expand Down
2 changes: 2 additions & 0 deletions src/include/processor/mapper/plan_mapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class PlanMapper {
planner::LogicalOperator* logicalOperator);
std::unique_ptr<PhysicalOperator> mapLogicalExtendToPhysical(
planner::LogicalOperator* logicalOperator);
std::unique_ptr<PhysicalOperator> mapLogicalRecursiveExtendToPhysical(
planner::LogicalOperator* logicalOperator);
std::unique_ptr<PhysicalOperator> mapLogicalFlattenToPhysical(
planner::LogicalOperator* logicalOperator);
std::unique_ptr<PhysicalOperator> mapLogicalFilterToPhysical(
Expand Down
Loading