Skip to content

Commit

Permalink
Rework varible-length join to reuse BFS framework
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed May 3, 2023
1 parent ecb5d46 commit 7280162
Show file tree
Hide file tree
Showing 28 changed files with 678 additions and 880 deletions.
5 changes: 1 addition & 4 deletions src/include/planner/join_order_enumerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,7 @@ class JoinOrderEnumerator {
std::shared_ptr<NodeExpression> nbrNode, std::shared_ptr<RelExpression> rel,
common::RelDirection direction, const binder::expression_vector& properties,
LogicalPlan& plan);
void appendVariableLengthExtend(std::shared_ptr<NodeExpression> boundNode,
std::shared_ptr<NodeExpression> nbrNode, std::shared_ptr<RelExpression> rel,
common::RelDirection direction, LogicalPlan& plan);
void appendShortestPathExtend(std::shared_ptr<NodeExpression> boundNode,
void appendRecursiveExtend(std::shared_ptr<NodeExpression> boundNode,
std::shared_ptr<NodeExpression> nbrNode, std::shared_ptr<RelExpression> rel,
common::RelDirection direction, LogicalPlan& plan);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
namespace kuzu {
namespace planner {

// TODO(Xiyang): we should have a single LogicalRecursiveExtend once we migrate VariableLengthExtend
// to use the same infrastructure as shortest path.
class LogicalRecursiveExtend : public BaseLogicalExtend {
public:
LogicalRecursiveExtend(std::shared_ptr<binder::NodeExpression> boundNode,
Expand All @@ -17,42 +15,11 @@ class LogicalRecursiveExtend : public BaseLogicalExtend {

f_group_pos_set getGroupsPosToFlatten() override;

void computeFlatSchema() override;
};

class LogicalVariableLengthExtend : public LogicalRecursiveExtend {
public:
LogicalVariableLengthExtend(std::shared_ptr<binder::NodeExpression> boundNode,
std::shared_ptr<binder::NodeExpression> nbrNode, std::shared_ptr<binder::RelExpression> rel,
common::RelDirection direction, bool hasAtMostOneNbr,
std::shared_ptr<LogicalOperator> child)
: LogicalRecursiveExtend{std::move(boundNode), std::move(nbrNode), std::move(rel),
direction, std::move(child)},
hasAtMostOneNbr{hasAtMostOneNbr} {}

void computeFactorizedSchema() override;

inline std::unique_ptr<LogicalOperator> copy() override {
return std::make_unique<LogicalVariableLengthExtend>(
boundNode, nbrNode, rel, direction, hasAtMostOneNbr, children[0]->copy());
}

private:
bool hasAtMostOneNbr;
};

class LogicalShortestPathExtend : public LogicalRecursiveExtend {
public:
LogicalShortestPathExtend(std::shared_ptr<binder::NodeExpression> boundNode,
std::shared_ptr<binder::NodeExpression> nbrNode, std::shared_ptr<binder::RelExpression> rel,
common::RelDirection direction, std::shared_ptr<LogicalOperator> child)
: LogicalRecursiveExtend{std::move(boundNode), std::move(nbrNode), std::move(rel),
direction, std::move(child)} {}

void computeFactorizedSchema() override;
void computeFlatSchema() override;

inline std::unique_ptr<LogicalOperator> copy() override {
return std::make_unique<LogicalShortestPathExtend>(
return std::make_unique<LogicalRecursiveExtend>(
boundNode, nbrNode, rel, direction, children[0]->copy());
}
};
Expand Down
2 changes: 2 additions & 0 deletions src/include/processor/operator/mask.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class MaskCollection {
MaskCollection() : numMasks{0} {}

inline void init(common::offset_t maxOffset) {
std::unique_lock lck{mtx};
if (maskData != nullptr) { // MaskCollection might be initialized repeatedly.
return;
}
Expand All @@ -55,6 +56,7 @@ class MaskCollection {
inline void incrementNumMasks() { numMasks++; }

private:
std::mutex mtx;
std::unique_ptr<MaskData> maskData;
uint8_t numMasks;
};
Expand Down
186 changes: 186 additions & 0 deletions src/include/processor/operator/recursive_extend/bfs_state.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
#pragma once

#include "processor/operator/mask.h"

namespace kuzu {
namespace processor {

enum VisitedState : uint8_t {
NOT_VISITED_DST = 0,
VISITED_DST = 1,
NOT_VISITED = 2,
VISITED = 3,
};

struct Frontier {
std::vector<common::offset_t> nodeOffsets;

Frontier() = default;
virtual ~Frontier() = default;
inline virtual void resetState() { nodeOffsets.clear(); }
inline virtual uint64_t getMultiplicity(common::offset_t offset) { return 1; }
};

struct FrontierWithMultiplicity : public Frontier {
// Multiplicity stands for number of paths that can reach an offset.
std::unordered_map<common::offset_t, uint64_t> offsetToMultiplicity;

FrontierWithMultiplicity() : Frontier() {}
inline void resetState() override {
Frontier::resetState();
offsetToMultiplicity.clear();
}
inline uint64_t getMultiplicity(common::offset_t offset) override {
assert(offsetToMultiplicity.contains(offset));
return offsetToMultiplicity.at(offset);
}
inline void addOffset(common::offset_t offset, uint64_t multiplicity) {
if (offsetToMultiplicity.contains(offset)) {
offsetToMultiplicity.at(offset) += multiplicity;
} else {
offsetToMultiplicity.insert({offset, multiplicity});
nodeOffsets.push_back(offset);
}
}
inline bool contains(common::offset_t offset) const {
return offsetToMultiplicity.contains(offset);
}
};

struct BaseBFSMorsel {
// Static information
common::offset_t maxOffset;
uint8_t lowerBound;
uint8_t upperBound;
// Level state
uint8_t currentLevel;
uint64_t nextNodeIdxToExtend; // next node to extend from current frontier.
std::unique_ptr<Frontier> currentFrontier;
std::unique_ptr<Frontier> nextFrontier;
// Target information.
// Target dst nodes are populated from semi mask and is expected to have small size.
// TargetDstNodeOffsets is empty if no semi mask available. Note that at the end of BFS, we may
// not visit all target dst nodes because they may simply not connect to src.
uint64_t numTargetDstNodes;
std::vector<common::offset_t> targetDstNodeOffsets;

explicit BaseBFSMorsel(common::offset_t maxOffset, uint8_t lowerBound, uint8_t upperBound,
NodeOffsetSemiMask* semiMask)
: maxOffset{maxOffset}, lowerBound{lowerBound}, upperBound{upperBound}, currentLevel{0},
nextNodeIdxToExtend{0}, numTargetDstNodes{0} {
if (semiMask->isEnabled()) {
for (auto offset = 0u; offset < maxOffset + 1; ++offset) {
if (semiMask->isNodeMasked(offset)) {
targetDstNodeOffsets.push_back(offset);
}
}
}
}
virtual ~BaseBFSMorsel() = default;

// Get next node offset to extend from current level.
common::offset_t getNextNodeOffset();

virtual void resetState();
virtual bool isComplete() = 0;
virtual void markSrc(common::offset_t offset) = 0;
virtual void markVisited(common::offset_t offset, uint64_t multiplicity) = 0;
virtual void finalizeCurrentLevel() = 0;

protected:
inline bool isCurrentFrontierEmpty() const { return currentFrontier->nodeOffsets.empty(); }
inline bool isUpperBoundReached() const { return currentLevel == upperBound; }
inline bool isAllDstTarget() const { return targetDstNodeOffsets.empty(); }
void moveNextLevelAsCurrentLevel();
virtual std::unique_ptr<Frontier> createFrontier() = 0;
};

struct ShortestPathBFSMorsel : public BaseBFSMorsel {
// Visited state
uint64_t numVisitedDstNodes;
uint8_t* visitedNodes;
// Results
std::vector<common::offset_t> dstNodeOffsets;
std::unordered_map<common::offset_t, uint64_t> dstNodeOffset2PathLength;

ShortestPathBFSMorsel(common::offset_t maxOffset, uint8_t lowerBound, uint8_t upperBound,
NodeOffsetSemiMask* semiMask)
: BaseBFSMorsel{maxOffset, lowerBound, upperBound, semiMask}, numVisitedDstNodes{0} {
currentFrontier = std::make_unique<Frontier>();
nextFrontier = std::make_unique<Frontier>();
visitedNodesBuffer = std::make_unique<uint8_t[]>(maxOffset + 1 * sizeof(uint8_t));
visitedNodes = visitedNodesBuffer.get();
}

inline bool isComplete() override {
return isCurrentFrontierEmpty() || isUpperBoundReached() || isAllDstReached();
}
inline void resetState() override {
BaseBFSMorsel::resetState();
resetVisitedState();
}
void markSrc(common::offset_t offset) override;
void markVisited(common::offset_t offset, uint64_t multiplicity) override;
inline void finalizeCurrentLevel() override { moveNextLevelAsCurrentLevel(); }

private:
inline bool isAllDstReached() const { return numVisitedDstNodes == numTargetDstNodes; }
void resetVisitedState();
inline std::unique_ptr<Frontier> createFrontier() override {
return std::make_unique<Frontier>();
}

private:
std::unique_ptr<uint8_t[]> visitedNodesBuffer;
};

struct VariableLengthBFSMorsel : public BaseBFSMorsel {
// Results
std::vector<common::offset_t> dstNodeOffsets;
std::unordered_map<common::offset_t, uint64_t> dstNodeOffset2NumPath;

explicit VariableLengthBFSMorsel(common::offset_t maxOffset, uint8_t lowerBound,
uint8_t upperBound, NodeOffsetSemiMask* semiMask)
: BaseBFSMorsel{maxOffset, lowerBound, upperBound, semiMask} {
currentFrontier = std::make_unique<FrontierWithMultiplicity>();
nextFrontier = std::make_unique<FrontierWithMultiplicity>();
}

inline void resetState() override {
BaseBFSMorsel::resetState();
resetNumPath();
}
inline bool isComplete() override { return isCurrentFrontierEmpty() || isUpperBoundReached(); }
inline void markSrc(common::offset_t offset) override {
((FrontierWithMultiplicity&)*currentFrontier).addOffset(offset, 1);
}
inline void markVisited(common::offset_t offset, uint64_t multiplicity) override {
((FrontierWithMultiplicity&)*nextFrontier).addOffset(offset, multiplicity);
}
inline void finalizeCurrentLevel() override {
moveNextLevelAsCurrentLevel();
updateNumPathFromCurrentFrontier();
}

private:
inline void resetNumPath() {
dstNodeOffsets.clear();
dstNodeOffset2NumPath.clear();
numTargetDstNodes = isAllDstTarget() ? maxOffset + 1 : targetDstNodeOffsets.size();
}
inline void updateNumPath(common::offset_t offset, uint64_t numPath) {
if (!dstNodeOffset2NumPath.contains(offset)) {
dstNodeOffsets.push_back(offset);
dstNodeOffset2NumPath.insert({offset, numPath});
} else {
dstNodeOffset2NumPath.at(offset) += numPath;
}
}
void updateNumPathFromCurrentFrontier();
inline std::unique_ptr<Frontier> createFrontier() override {
return std::make_unique<FrontierWithMultiplicity>();
}
};

} // namespace processor
} // namespace kuzu
Loading

0 comments on commit 7280162

Please sign in to comment.