Skip to content

Commit

Permalink
Merge pull request #1652 from kuzudb/all-shortest-path
Browse files Browse the repository at this point in the history
Add all shortest path
  • Loading branch information
andyfengHKU committed Jun 9, 2023
2 parents 322dab6 + e9a28db commit d5d13c8
Show file tree
Hide file tree
Showing 29 changed files with 1,787 additions and 1,686 deletions.
2 changes: 1 addition & 1 deletion src/antlr4/Cypher.g4
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ oC_NodeLabel
: ':' SP? oC_LabelName ;

oC_RangeLiteral
: '*' SP? SHORTEST? SP? oC_IntegerLiteral SP? '..' SP? oC_IntegerLiteral ;
: '*' SP? ( SHORTEST | ALL SP SHORTEST )? SP? oC_IntegerLiteral SP? '..' SP? oC_IntegerLiteral ;

SHORTEST : ( 'S' | 's' ) ( 'H' | 'h' ) ( 'O' | 'o' ) ( 'R' | 'r' ) ( 'T' | 't' ) ( 'E' | 'e' ) ( 'S' | 's' ) ( 'T' | 't' ) ;

Expand Down
7 changes: 7 additions & 0 deletions src/include/common/query_rel_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@ enum class QueryRelType : uint8_t {
NON_RECURSIVE = 0,
VARIABLE_LENGTH = 1,
SHORTEST = 2,
ALL_SHORTEST = 3,
};

struct QueryRelTypeUtils {
static inline bool isRecursive(QueryRelType queryRelType) {
return queryRelType != QueryRelType::NON_RECURSIVE;
}
};

} // namespace common
Expand Down
2 changes: 1 addition & 1 deletion src/include/optimizer/acc_hash_join_optimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class HashJoinSIPOptimizer : public LogicalOperatorVisitor {
const binder::Expression& nodeID, planner::LogicalOperator* root);

std::shared_ptr<planner::LogicalOperator> appendSemiMask(
std::shared_ptr<binder::Expression> nodeID, std::vector<planner::LogicalOperator*> ops,
std::vector<planner::LogicalOperator*> ops,
std::shared_ptr<planner::LogicalOperator> child);
std::shared_ptr<planner::LogicalOperator> appendAccumulate(
std::shared_ptr<planner::LogicalOperator> child);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#pragma once

#include "bfs_state.h"

namespace kuzu {
namespace processor {

template<bool TRACK_PATH>
class AllShortestPathState : public BaseBFSState {
public:
AllShortestPathState(uint8_t upperBound, TargetDstNodes* targetDstNodes)
: BaseBFSState{upperBound, targetDstNodes}, minDistance{0}, numVisitedDstNodes{0} {}

inline bool isComplete() final {
return isCurrentFrontierEmpty() || isUpperBoundReached() ||
isAllDstReachedWithMinDistance();
}

inline void resetState() final {
BaseBFSState::resetState();
minDistance = 0;
numVisitedDstNodes = 0;
visitedNodeToDistance.clear();
}

inline void markSrc(common::nodeID_t nodeID) override {
visitedNodeToDistance.insert({nodeID, 0});
if (targetDstNodes->contains(nodeID)) {
numVisitedDstNodes++;
}
currentFrontier->addNodeWithMultiplicity(nodeID, 1);
}

void markVisited(common::nodeID_t boundNodeID, common::nodeID_t nbrNodeID,
common::relID_t relID, uint64_t multiplicity) final {
if (!visitedNodeToDistance.contains(nbrNodeID)) {
visitedNodeToDistance.insert({nbrNodeID, currentLevel});
if (targetDstNodes->contains(nbrNodeID)) {
minDistance = currentLevel;
numVisitedDstNodes++;
}
if constexpr (TRACK_PATH) {
nextFrontier->addEdge(boundNodeID, nbrNodeID, relID);
} else {
nextFrontier->addNodeWithMultiplicity(nbrNodeID, multiplicity);
}
} else if (currentLevel <= visitedNodeToDistance.at(nbrNodeID)) {
if constexpr (TRACK_PATH) {
nextFrontier->addEdge(boundNodeID, nbrNodeID, relID);
} else {
nextFrontier->addNodeWithMultiplicity(nbrNodeID, multiplicity);
}
}
}

private:
inline bool isAllDstReachedWithMinDistance() const {
return numVisitedDstNodes == targetDstNodes->getNumNodes() && currentLevel > minDistance;
}

private:
uint32_t minDistance; // Min distance to add dst nodes that have been reached.
uint64_t numVisitedDstNodes;
frontier::node_id_map_t<uint32_t> visitedNodeToDistance;
};

} // namespace processor
} // namespace kuzu
78 changes: 27 additions & 51 deletions src/include/processor/operator/recursive_extend/recursive_join.h
Original file line number Diff line number Diff line change
@@ -1,43 +1,16 @@
#pragma once

#include "bfs_state.h"
#include "common/query_rel_type.h"
#include "frontier_scanner.h"
#include "planner/logical_plan/logical_operator/recursive_join_type.h"
#include "processor/operator/physical_operator.h"
#include "processor/operator/result_collector.h"
#include "storage/store/node_table.h"

namespace kuzu {
namespace processor {

class ScanFrontier : public PhysicalOperator {
public:
ScanFrontier(DataPos nodeIDVectorPos, uint32_t id, const std::string& paramsString)
: PhysicalOperator{PhysicalOperatorType::SCAN_NODE_ID, id, paramsString},
nodeIDVectorPos{nodeIDVectorPos} {}

inline bool isSource() const override { return true; }

inline void initLocalStateInternal(ResultSet* resultSet_, ExecutionContext* context) override {
nodeIDVector = resultSet->getValueVector(nodeIDVectorPos);
}

bool getNextTuplesInternal(kuzu::processor::ExecutionContext* context) override;

void setNodeID(common::nodeID_t nodeID) {
nodeIDVector->setValue<common::nodeID_t>(0, nodeID);
hasExecuted = false;
}

std::unique_ptr<PhysicalOperator> clone() override {
return std::make_unique<ScanFrontier>(nodeIDVectorPos, id, paramsString);
}

private:
DataPos nodeIDVectorPos;
std::shared_ptr<common::ValueVector> nodeIDVector;
bool hasExecuted;
};
class ScanFrontier;

struct RecursiveJoinSharedState {
std::shared_ptr<FTableSharedState> inputFTableSharedState;
Expand Down Expand Up @@ -65,10 +38,9 @@ struct RecursiveJoinDataInfo {
// Recursive join info.
std::unique_ptr<ResultSetDescriptor> localResultSetDescriptor;
DataPos recursiveDstNodeIDPos;
std::unordered_set<common::table_id_t> recursiveDstNodeTableIDs;
std::unordered_set<common::table_id_t> recursiveDstNodeTableIDs; // TODO: move this out?
DataPos recursiveEdgeIDPos;
// Path info
planner::RecursiveJoinType joinType;
DataPos pathPos;

RecursiveJoinDataInfo(std::vector<DataPos> vectorsToScanPos,
Expand All @@ -77,65 +49,67 @@ struct RecursiveJoinDataInfo {
const DataPos& pathLengthPos, std::unique_ptr<ResultSetDescriptor> localResultSetDescriptor,
const DataPos& recursiveDstNodeIDPos,
std::unordered_set<common::table_id_t> recursiveDstNodeTableIDs,
const DataPos& recursiveEdgeIDPos, planner::RecursiveJoinType joinType)
const DataPos& recursiveEdgeIDPos)
: RecursiveJoinDataInfo{std::move(vectorsToScanPos), std::move(colIndicesToScan),
srcNodePos, dstNodePos, std::move(dstNodeTableIDs), pathLengthPos,
std::move(localResultSetDescriptor), recursiveDstNodeIDPos,
std::move(recursiveDstNodeTableIDs), recursiveEdgeIDPos, joinType, DataPos()} {}
std::move(recursiveDstNodeTableIDs), recursiveEdgeIDPos, DataPos()} {}
RecursiveJoinDataInfo(std::vector<DataPos> vectorsToScanPos,
std::vector<ft_col_idx_t> colIndicesToScan, const DataPos& srcNodePos,
const DataPos& dstNodePos, std::unordered_set<common::table_id_t> dstNodeTableIDs,
const DataPos& pathLengthPos, std::unique_ptr<ResultSetDescriptor> localResultSetDescriptor,
const DataPos& recursiveDstNodeIDPos,
std::unordered_set<common::table_id_t> recursiveDstNodeTableIDs,
const DataPos& recursiveEdgeIDPos, planner::RecursiveJoinType joinType,
const DataPos& pathPos)
const DataPos& recursiveEdgeIDPos, const DataPos& pathPos)
: vectorsToScanPos{std::move(vectorsToScanPos)},
colIndicesToScan{std::move(colIndicesToScan)}, srcNodePos{srcNodePos},
dstNodePos{dstNodePos}, dstNodeTableIDs{std::move(dstNodeTableIDs)},
pathLengthPos{pathLengthPos}, localResultSetDescriptor{std::move(
localResultSetDescriptor)},
recursiveDstNodeIDPos{recursiveDstNodeIDPos}, recursiveDstNodeTableIDs{std::move(
recursiveDstNodeTableIDs)},
recursiveEdgeIDPos{recursiveEdgeIDPos}, joinType{joinType}, pathPos{pathPos} {}
recursiveEdgeIDPos{recursiveEdgeIDPos}, pathPos{pathPos} {}

inline std::unique_ptr<RecursiveJoinDataInfo> copy() {
return std::make_unique<RecursiveJoinDataInfo>(vectorsToScanPos, colIndicesToScan,
srcNodePos, dstNodePos, dstNodeTableIDs, pathLengthPos,
localResultSetDescriptor->copy(), recursiveDstNodeIDPos, recursiveDstNodeTableIDs,
recursiveEdgeIDPos, joinType, pathPos);
recursiveEdgeIDPos, pathPos);
}
};

class BaseRecursiveJoin : public PhysicalOperator {
class RecursiveJoin : public PhysicalOperator {
public:
BaseRecursiveJoin(uint8_t lowerBound, uint8_t upperBound,
std::shared_ptr<RecursiveJoinSharedState> sharedState,
RecursiveJoin(uint8_t lowerBound, uint8_t upperBound, common::QueryRelType queryRelType,
planner::RecursiveJoinType joinType, std::shared_ptr<RecursiveJoinSharedState> sharedState,
std::unique_ptr<RecursiveJoinDataInfo> dataInfo, std::unique_ptr<PhysicalOperator> child,
uint32_t id, const std::string& paramsString,
std::unique_ptr<PhysicalOperator> recursiveRoot)
: PhysicalOperator{PhysicalOperatorType::RECURSIVE_JOIN, std::move(child), id,
paramsString},
lowerBound{lowerBound}, upperBound{upperBound}, sharedState{std::move(sharedState)},
dataInfo{std::move(dataInfo)}, recursiveRoot{std::move(recursiveRoot)} {}
lowerBound{lowerBound}, upperBound{upperBound}, queryRelType{queryRelType},
joinType{joinType}, sharedState{std::move(sharedState)}, dataInfo{std::move(dataInfo)},
recursiveRoot{std::move(recursiveRoot)} {}

BaseRecursiveJoin(uint8_t lowerBound, uint8_t upperBound,
std::shared_ptr<RecursiveJoinSharedState> sharedState,
RecursiveJoin(uint8_t lowerBound, uint8_t upperBound, common::QueryRelType queryRelType,
planner::RecursiveJoinType joinType, std::shared_ptr<RecursiveJoinSharedState> sharedState,
std::unique_ptr<RecursiveJoinDataInfo> dataInfo, uint32_t id,
const std::string& paramsString, std::unique_ptr<PhysicalOperator> recursiveRoot)
: PhysicalOperator{PhysicalOperatorType::RECURSIVE_JOIN, id, paramsString},
lowerBound{lowerBound}, upperBound{upperBound}, sharedState{std::move(sharedState)},
dataInfo{std::move(dataInfo)}, recursiveRoot{std::move(recursiveRoot)} {}

virtual ~BaseRecursiveJoin() = default;
lowerBound{lowerBound}, upperBound{upperBound}, queryRelType{queryRelType},
joinType{joinType}, sharedState{std::move(sharedState)}, dataInfo{std::move(dataInfo)},
recursiveRoot{std::move(recursiveRoot)} {}

inline RecursiveJoinSharedState* getSharedState() const { return sharedState.get(); }

void initLocalStateInternal(ResultSet* resultSet_, ExecutionContext* context) override;
void initLocalStateInternal(ResultSet* resultSet_, ExecutionContext* context) final;

bool getNextTuplesInternal(ExecutionContext* context) override;
bool getNextTuplesInternal(ExecutionContext* context) final;

std::unique_ptr<PhysicalOperator> clone() override = 0;
inline std::unique_ptr<PhysicalOperator> clone() final {
return std::make_unique<RecursiveJoin>(lowerBound, upperBound, queryRelType, joinType,
sharedState, dataInfo->copy(), id, paramsString, recursiveRoot->clone());
}

private:
void initLocalRecursivePlan(ExecutionContext* context);
Expand All @@ -152,6 +126,8 @@ class BaseRecursiveJoin : public PhysicalOperator {
protected:
uint8_t lowerBound;
uint8_t upperBound;
common::QueryRelType queryRelType;
planner::RecursiveJoinType joinType;

std::shared_ptr<RecursiveJoinSharedState> sharedState;
std::unique_ptr<RecursiveJoinDataInfo> dataInfo;
Expand Down
38 changes: 38 additions & 0 deletions src/include/processor/operator/recursive_extend/scan_frontier.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#pragma once

#include "processor/operator/physical_operator.h"

namespace kuzu {
namespace processor {

class ScanFrontier : public PhysicalOperator {
public:
ScanFrontier(DataPos nodeIDVectorPos, uint32_t id, const std::string& paramsString)
: PhysicalOperator{PhysicalOperatorType::SCAN_NODE_ID, id, paramsString},
nodeIDVectorPos{nodeIDVectorPos} {}

inline bool isSource() const override { return true; }

inline void initLocalStateInternal(ResultSet* resultSet_, ExecutionContext* context) override {
nodeIDVector = resultSet->getValueVector(nodeIDVectorPos);
}

bool getNextTuplesInternal(kuzu::processor::ExecutionContext* context) override;

inline void setNodeID(common::nodeID_t nodeID) {
nodeIDVector->setValue<common::nodeID_t>(0, nodeID);
hasExecuted = false;
}

std::unique_ptr<PhysicalOperator> clone() override {
return std::make_unique<ScanFrontier>(nodeIDVectorPos, id, paramsString);
}

private:
DataPos nodeIDVectorPos;
std::shared_ptr<common::ValueVector> nodeIDVector;
bool hasExecuted;
};

} // namespace processor
} // namespace kuzu

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,8 @@ namespace kuzu {
namespace processor {

template<bool TRACK_PATH>
struct ShortestPathState : public BaseBFSState {
// Visited state
uint64_t numVisitedDstNodes;
frontier::node_id_set_t visited;

class ShortestPathState : public BaseBFSState {
public:
ShortestPathState(uint8_t upperBound, TargetDstNodes* targetDstNodes)
: BaseBFSState{upperBound, targetDstNodes}, numVisitedDstNodes{0} {}
~ShortestPathState() override = default;
Expand All @@ -20,7 +17,8 @@ struct ShortestPathState : public BaseBFSState {
}
inline void resetState() final {
BaseBFSState::resetState();
resetVisitedState();
numVisitedDstNodes = 0;
visited.clear();
}

inline void markSrc(common::nodeID_t nodeID) final {
Expand All @@ -47,13 +45,14 @@ struct ShortestPathState : public BaseBFSState {
}
}

private:
inline bool isAllDstReached() const {
return numVisitedDstNodes == targetDstNodes->getNumNodes();
}
inline void resetVisitedState() {
numVisitedDstNodes = 0;
visited.clear();
}

private:
uint64_t numVisitedDstNodes;
frontier::node_id_set_t visited;
};

} // namespace processor
Expand Down
Loading

0 comments on commit d5d13c8

Please sign in to comment.