Skip to content

Commit

Permalink
Merge pull request #1538 from kuzudb/recursive-path
Browse files Browse the repository at this point in the history
Recursive path
  • Loading branch information
andyfengHKU committed May 16, 2023
2 parents ab47844 + 875a0d8 commit ea6499a
Show file tree
Hide file tree
Showing 28 changed files with 403 additions and 336 deletions.
12 changes: 6 additions & 6 deletions src/binder/bind/bind_graph_pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,12 @@ void Binder::bindQueryRel(const RelPattern& relPattern,
}
// bind variable length
auto [lowerBound, upperBound] = bindVariableLengthRelBound(relPattern);
auto queryRel = make_shared<RelExpression>(getUniqueExpressionName(parsedName), parsedName,
tableIDs, srcNode, dstNode, relPattern.getDirection() != BOTH, relPattern.getRelType(),
lowerBound, upperBound);
auto isVariableLength = !(lowerBound == 1 && upperBound == 1);
auto dataType = isVariableLength ? common::DataType(std::make_unique<DataType>(INTERNAL_ID)) :
common::DataType(common::REL);
auto queryRel = make_shared<RelExpression>(std::move(dataType),
getUniqueExpressionName(parsedName), parsedName, tableIDs, srcNode, dstNode,
relPattern.getDirection() != BOTH, relPattern.getRelType(), lowerBound, upperBound);
queryRel->setAlias(parsedName);
// resolve properties associate with rel table
std::vector<RelTableSchema*> relTableSchemas;
Expand All @@ -147,9 +150,6 @@ void Binder::bindQueryRel(const RelPattern& relPattern,
*queryRel, propertySchemas, false /* isPrimaryKey */);
queryRel->addPropertyExpression(propertyName, std::move(propertyExpression));
}
} else if (queryRel->getRelType() == common::QueryRelType::SHORTEST) {
queryRel->setInternalLengthProperty(
expressionBinder.createInternalLengthExpression(*queryRel));
}
if (!parsedName.empty()) {
variablesInScope.insert({parsedName, queryRel});
Expand Down
12 changes: 0 additions & 12 deletions src/binder/bind_expression/bind_function_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,18 +150,6 @@ std::unique_ptr<Expression> ExpressionBinder::createInternalNodeIDExpression(
return result;
}

std::shared_ptr<Expression> ExpressionBinder::createInternalLengthExpression(
const Expression& expression) {
auto& rel = (RelExpression&)expression;
std::unordered_map<table_id_t, property_id_t> propertyIDPerTable;
for (auto tableID : rel.getTableIDs()) {
propertyIDPerTable.insert({tableID, INVALID_PROPERTY_ID});
}
auto result = std::make_unique<PropertyExpression>(DataType(common::INT64),
INTERNAL_LENGTH_SUFFIX, rel, std::move(propertyIDPerTable), false /* isPrimaryKey */);
return result;
}

std::shared_ptr<Expression> ExpressionBinder::bindLabelFunction(
const ParsedExpression& parsedExpression) {
// bind child node
Expand Down
3 changes: 0 additions & 3 deletions src/binder/bind_expression/bind_property_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,6 @@ static std::unordered_map<table_id_t, property_id_t> populatePropertyIDPerTable(
std::shared_ptr<Expression> ExpressionBinder::bindRelPropertyExpression(
const Expression& expression, const std::string& propertyName) {
auto& rel = (RelExpression&)expression;
if (propertyName == INTERNAL_LENGTH_SUFFIX) {
return rel.getInternalLengthProperty();
}
switch (rel.getRelType()) {
case common::QueryRelType::VARIABLE_LENGTH:
case common::QueryRelType::SHORTEST:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ class ExistentialSubqueryExpression : public Expression {
public:
ExistentialSubqueryExpression(std::unique_ptr<QueryGraphCollection> queryGraphCollection,
std::string uniqueName, std::string rawName)
: Expression{common::EXISTENTIAL_SUBQUERY, common::BOOL, std::move(uniqueName)},
: Expression{common::EXISTENTIAL_SUBQUERY, common::DataType(common::BOOL),
std::move(uniqueName)},
queryGraphCollection{std::move(queryGraphCollection)}, rawName{std::move(rawName)} {}

inline QueryGraphCollection* getQueryGraphCollection() const {
Expand Down
9 changes: 1 addition & 8 deletions src/include/binder/expression/expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,6 @@ class Expression : public std::enable_shared_from_this<Expression> {

virtual ~Expression() = default;

protected:
Expression(common::ExpressionType expressionType, common::DataTypeID dataTypeID,
std::string uniqueName)
: Expression{expressionType, common::DataType(dataTypeID), std::move(uniqueName)} {
assert(dataTypeID != common::VAR_LIST);
}

public:
inline void setAlias(const std::string& name) { alias = name; }

Expand All @@ -80,7 +73,7 @@ class Expression : public std::enable_shared_from_this<Expression> {
return children[idx];
}
inline void setChild(common::vector_idx_t idx, std::shared_ptr<Expression> child) {
children[idx] = child;
children[idx] = std::move(child);
}

inline virtual expression_vector getChildren() const { return children; }
Expand Down
4 changes: 2 additions & 2 deletions src/include/binder/expression/node_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ class NodeExpression : public NodeOrRelExpression {
public:
NodeExpression(
std::string uniqueName, std::string variableName, std::vector<common::table_id_t> tableIDs)
: NodeOrRelExpression{
common::NODE, std::move(uniqueName), std::move(variableName), std::move(tableIDs)} {}
: NodeOrRelExpression{common::DataType(common::NODE), std::move(uniqueName),
std::move(variableName), std::move(tableIDs)} {}

inline void setInternalIDProperty(std::unique_ptr<Expression> expression) {
internalIDExpression = std::move(expression);
Expand Down
6 changes: 3 additions & 3 deletions src/include/binder/expression/node_rel_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ namespace binder {

class NodeOrRelExpression : public Expression {
public:
NodeOrRelExpression(common::DataTypeID dataTypeID, std::string uniqueName,
std::string variableName, std::vector<common::table_id_t> tableIDs)
: Expression{common::VARIABLE, dataTypeID, std::move(uniqueName)},
NodeOrRelExpression(common::DataType dataType, std::string uniqueName, std::string variableName,
std::vector<common::table_id_t> tableIDs)
: Expression{common::VARIABLE, std::move(dataType), std::move(uniqueName)},
variableName(std::move(variableName)), tableIDs{std::move(tableIDs)} {}
virtual ~NodeOrRelExpression() override = default;

Expand Down
3 changes: 2 additions & 1 deletion src/include/binder/expression/parameter_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ class ParameterExpression : public Expression {
public:
explicit ParameterExpression(
const std::string& parameterName, std::shared_ptr<common::Value> value)
: Expression{common::PARAMETER, common::ANY, createUniqueName(parameterName)},
: Expression{common::PARAMETER, common::DataType(common::ANY),
createUniqueName(parameterName)},
parameterName(parameterName), value{std::move(value)} {}

inline void setDataType(const common::DataType& targetType) {
Expand Down
12 changes: 2 additions & 10 deletions src/include/binder/expression/rel_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ namespace binder {

class RelExpression : public NodeOrRelExpression {
public:
RelExpression(std::string uniqueName, std::string variableName,
RelExpression(common::DataType dataType, std::string uniqueName, std::string variableName,
std::vector<common::table_id_t> tableIDs, std::shared_ptr<NodeExpression> srcNode,
std::shared_ptr<NodeExpression> dstNode, bool directed, common::QueryRelType relType,
uint64_t lowerBound, uint64_t upperBound)
: NodeOrRelExpression{common::REL, std::move(uniqueName), std::move(variableName),
: NodeOrRelExpression{dataType, std::move(uniqueName), std::move(variableName),
std::move(tableIDs)},
srcNode{std::move(srcNode)}, dstNode{std::move(dstNode)}, directed{directed},
relType{relType}, lowerBound{lowerBound}, upperBound{upperBound} {}
Expand All @@ -40,21 +40,13 @@ class RelExpression : public NodeOrRelExpression {
return getPropertyExpression(common::INTERNAL_ID_SUFFIX);
}

inline void setInternalLengthProperty(std::shared_ptr<Expression> expression) {
internalLengthExpression = std::move(expression);
}
inline std::shared_ptr<Expression> getInternalLengthProperty() {
return internalLengthExpression;
}

private:
std::shared_ptr<NodeExpression> srcNode;
std::shared_ptr<NodeExpression> dstNode;
bool directed;
common::QueryRelType relType;
uint64_t lowerBound;
uint64_t upperBound;
std::shared_ptr<Expression> internalLengthExpression;
};

} // namespace binder
Expand Down
1 change: 0 additions & 1 deletion src/include/binder/expression_binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ class ExpressionBinder {
const parser::ParsedExpression& parsedExpression);
std::shared_ptr<Expression> bindInternalIDExpression(const Expression& expression);
std::unique_ptr<Expression> createInternalNodeIDExpression(const Expression& node);
std::shared_ptr<Expression> createInternalLengthExpression(const Expression& rel);
std::shared_ptr<Expression> bindLabelFunction(const parser::ParsedExpression& parsedExpression);
std::shared_ptr<Expression> bindNodeLabelFunction(const Expression& expression);
std::shared_ptr<Expression> bindRelLabelFunction(const Expression& expression);
Expand Down
120 changes: 38 additions & 82 deletions src/include/processor/operator/recursive_extend/bfs_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,64 +14,50 @@ enum VisitedState : uint8_t {

struct Frontier {
std::vector<common::offset_t> nodeOffsets;
std::unordered_map<common::offset_t, std::vector<common::offset_t>> bwdEdges;

Frontier() = default;
virtual ~Frontier() = default;
inline virtual void resetState() { nodeOffsets.clear(); }
inline virtual uint64_t getMultiplicity(common::offset_t offset) { return 1; }
};

struct FrontierWithMultiplicity : public Frontier {
// Multiplicity stands for number of paths that can reach an offset.
std::unordered_map<common::offset_t, uint64_t> offsetToMultiplicity;

FrontierWithMultiplicity() : Frontier() {}
inline void resetState() override {
Frontier::resetState();
offsetToMultiplicity.clear();
}
inline uint64_t getMultiplicity(common::offset_t offset) override {
assert(offsetToMultiplicity.contains(offset));
return offsetToMultiplicity.at(offset);
}
inline void addOffset(common::offset_t offset, uint64_t multiplicity) {
if (offsetToMultiplicity.contains(offset)) {
offsetToMultiplicity.at(offset) += multiplicity;
} else {
offsetToMultiplicity.insert({offset, multiplicity});
nodeOffsets.push_back(offset);
inline virtual void resetState() {
nodeOffsets.clear();
bwdEdges.clear();
}
inline void addEdge(common::offset_t boundOffset, common::offset_t nbrOffset) {
if (!bwdEdges.contains(nbrOffset)) {
nodeOffsets.push_back(nbrOffset);
bwdEdges.insert({nbrOffset, std::vector<common::offset_t>{}});
}
}
inline bool contains(common::offset_t offset) const {
return offsetToMultiplicity.contains(offset);
bwdEdges.at(nbrOffset).push_back(boundOffset);
}
};

struct BaseBFSMorsel {
friend struct FixedLengthPathScanner;
// Static information
common::offset_t maxOffset;
uint8_t lowerBound;
uint8_t upperBound;
// Level state
uint8_t currentLevel;
uint64_t nextNodeIdxToExtend; // next node to extend from current frontier.
std::unique_ptr<Frontier> currentFrontier;
std::unique_ptr<Frontier> nextFrontier;
Frontier* currentFrontier;
Frontier* nextFrontier;
std::vector<std::unique_ptr<Frontier>> frontiers;

// Target information.
// Target dst nodes are populated from semi mask and is expected to have small size.
// TargetDstNodeOffsets is empty if no semi mask available. Note that at the end of BFS, we may
// not visit all target dst nodes because they may simply not connect to src.
uint64_t numTargetDstNodes;
std::vector<common::offset_t> targetDstNodeOffsets;
std::unordered_set<common::offset_t> targetDstNodeOffsets;

explicit BaseBFSMorsel(common::offset_t maxOffset, uint8_t lowerBound, uint8_t upperBound,
NodeOffsetSemiMask* semiMask)
: maxOffset{maxOffset}, lowerBound{lowerBound}, upperBound{upperBound}, currentLevel{0},
nextNodeIdxToExtend{0}, numTargetDstNodes{0} {
if (semiMask->isEnabled()) {
for (auto offset = 0u; offset < maxOffset + 1; ++offset) {
for (auto offset = 0u; offset < getNumNodes(); ++offset) {
if (semiMask->isNodeMasked(offset)) {
targetDstNodeOffsets.push_back(offset);
targetDstNodeOffsets.insert(offset);
}
}
}
Expand All @@ -84,31 +70,35 @@ struct BaseBFSMorsel {
virtual void resetState();
virtual bool isComplete() = 0;
virtual void markSrc(common::offset_t offset) = 0;
virtual void markVisited(common::offset_t offset, uint64_t multiplicity) = 0;
virtual void finalizeCurrentLevel() = 0;
virtual void markVisited(common::offset_t boundOffset, common::offset_t nbrOffset) = 0;
inline void finalizeCurrentLevel() { moveNextLevelAsCurrentLevel(); }

protected:
inline uint64_t getNumNodes() const { return maxOffset + 1; }
inline bool isAllDstTarget() const { return targetDstNodeOffsets.empty(); }
inline bool isCurrentFrontierEmpty() const { return currentFrontier->nodeOffsets.empty(); }
inline bool isUpperBoundReached() const { return currentLevel == upperBound; }
inline bool isAllDstTarget() const { return targetDstNodeOffsets.empty(); }
inline void initStartFrontier() {
assert(frontiers.empty());
frontiers.push_back(std::make_unique<Frontier>());
currentFrontier = frontiers[frontiers.size() - 1].get();
}
inline void addNextFrontier() {
frontiers.push_back(std::make_unique<Frontier>());
nextFrontier = frontiers[frontiers.size() - 1].get();
}
void moveNextLevelAsCurrentLevel();
virtual std::unique_ptr<Frontier> createFrontier() = 0;
};

struct ShortestPathBFSMorsel : public BaseBFSMorsel {
// Visited state
uint64_t numVisitedDstNodes;
uint8_t* visitedNodes;
// Results
std::vector<common::offset_t> dstNodeOffsets;
std::unordered_map<common::offset_t, uint64_t> dstNodeOffset2PathLength;

ShortestPathBFSMorsel(common::offset_t maxOffset, uint8_t lowerBound, uint8_t upperBound,
NodeOffsetSemiMask* semiMask)
: BaseBFSMorsel{maxOffset, lowerBound, upperBound, semiMask}, numVisitedDstNodes{0} {
currentFrontier = std::make_unique<Frontier>();
nextFrontier = std::make_unique<Frontier>();
visitedNodesBuffer = std::make_unique<uint8_t[]>(maxOffset + 1 * sizeof(uint8_t));
visitedNodesBuffer = std::make_unique<uint8_t[]>(getNumNodes() * sizeof(uint8_t));
visitedNodes = visitedNodesBuffer.get();
}

Expand All @@ -120,65 +110,31 @@ struct ShortestPathBFSMorsel : public BaseBFSMorsel {
resetVisitedState();
}
void markSrc(common::offset_t offset) override;
void markVisited(common::offset_t offset, uint64_t multiplicity) override;
inline void finalizeCurrentLevel() override { moveNextLevelAsCurrentLevel(); }
void markVisited(common::offset_t boundOffset, common::offset_t nbrOffset) override;

private:
inline bool isAllDstReached() const { return numVisitedDstNodes == numTargetDstNodes; }
void resetVisitedState();
inline std::unique_ptr<Frontier> createFrontier() override {
return std::make_unique<Frontier>();
}

private:
std::unique_ptr<uint8_t[]> visitedNodesBuffer;
};

struct VariableLengthBFSMorsel : public BaseBFSMorsel {
// Results
std::vector<common::offset_t> dstNodeOffsets;
std::unordered_map<common::offset_t, uint64_t> dstNodeOffset2NumPath;

explicit VariableLengthBFSMorsel(common::offset_t maxOffset, uint8_t lowerBound,
uint8_t upperBound, NodeOffsetSemiMask* semiMask)
: BaseBFSMorsel{maxOffset, lowerBound, upperBound, semiMask} {
currentFrontier = std::make_unique<FrontierWithMultiplicity>();
nextFrontier = std::make_unique<FrontierWithMultiplicity>();
}
: BaseBFSMorsel{maxOffset, lowerBound, upperBound, semiMask} {}

inline void resetState() override {
BaseBFSMorsel::resetState();
resetNumPath();
numTargetDstNodes = isAllDstTarget() ? getNumNodes() : targetDstNodeOffsets.size();
}
inline bool isComplete() override { return isCurrentFrontierEmpty() || isUpperBoundReached(); }
inline void markSrc(common::offset_t offset) override {
((FrontierWithMultiplicity&)*currentFrontier).addOffset(offset, 1);
}
inline void markVisited(common::offset_t offset, uint64_t multiplicity) override {
((FrontierWithMultiplicity&)*nextFrontier).addOffset(offset, multiplicity);
}
inline void finalizeCurrentLevel() override {
moveNextLevelAsCurrentLevel();
updateNumPathFromCurrentFrontier();
}

private:
inline void resetNumPath() {
dstNodeOffsets.clear();
dstNodeOffset2NumPath.clear();
numTargetDstNodes = isAllDstTarget() ? maxOffset + 1 : targetDstNodeOffsets.size();
}
inline void updateNumPath(common::offset_t offset, uint64_t numPath) {
if (!dstNodeOffset2NumPath.contains(offset)) {
dstNodeOffsets.push_back(offset);
dstNodeOffset2NumPath.insert({offset, numPath});
} else {
dstNodeOffset2NumPath.at(offset) += numPath;
}
currentFrontier->nodeOffsets.push_back(offset);
}
void updateNumPathFromCurrentFrontier();
inline std::unique_ptr<Frontier> createFrontier() override {
return std::make_unique<FrontierWithMultiplicity>();
inline void markVisited(common::offset_t boundOffset, common::offset_t nbrOffset) override {
nextFrontier->addEdge(boundOffset, nbrOffset);
}
};

Expand Down
Loading

0 comments on commit ea6499a

Please sign in to comment.