Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add all shortest path #1652

Merged
merged 1 commit into from
Jun 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/antlr4/Cypher.g4
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ oC_NodeLabel
: ':' SP? oC_LabelName ;

oC_RangeLiteral
: '*' SP? SHORTEST? SP? oC_IntegerLiteral SP? '..' SP? oC_IntegerLiteral ;
: '*' SP? ( SHORTEST | ALL SP SHORTEST )? SP? oC_IntegerLiteral SP? '..' SP? oC_IntegerLiteral ;

SHORTEST : ( 'S' | 's' ) ( 'H' | 'h' ) ( 'O' | 'o' ) ( 'R' | 'r' ) ( 'T' | 't' ) ( 'E' | 'e' ) ( 'S' | 's' ) ( 'T' | 't' ) ;

Expand Down
7 changes: 7 additions & 0 deletions src/include/common/query_rel_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@ enum class QueryRelType : uint8_t {
NON_RECURSIVE = 0,
VARIABLE_LENGTH = 1,
SHORTEST = 2,
ALL_SHORTEST = 3,
};

struct QueryRelTypeUtils {
static inline bool isRecursive(QueryRelType queryRelType) {
return queryRelType != QueryRelType::NON_RECURSIVE;
}
};

} // namespace common
Expand Down
2 changes: 1 addition & 1 deletion src/include/optimizer/acc_hash_join_optimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class HashJoinSIPOptimizer : public LogicalOperatorVisitor {
const binder::Expression& nodeID, planner::LogicalOperator* root);

std::shared_ptr<planner::LogicalOperator> appendSemiMask(
std::shared_ptr<binder::Expression> nodeID, std::vector<planner::LogicalOperator*> ops,
std::vector<planner::LogicalOperator*> ops,
std::shared_ptr<planner::LogicalOperator> child);
std::shared_ptr<planner::LogicalOperator> appendAccumulate(
std::shared_ptr<planner::LogicalOperator> child);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#pragma once

#include "bfs_state.h"

namespace kuzu {
namespace processor {

template<bool TRACK_PATH>
class AllShortestPathState : public BaseBFSState {
public:
AllShortestPathState(uint8_t upperBound, TargetDstNodes* targetDstNodes)
: BaseBFSState{upperBound, targetDstNodes}, minDistance{0}, numVisitedDstNodes{0} {}

inline bool isComplete() final {
return isCurrentFrontierEmpty() || isUpperBoundReached() ||
isAllDstReachedWithMinDistance();
}

inline void resetState() final {
BaseBFSState::resetState();
minDistance = 0;
numVisitedDstNodes = 0;
visitedNodeToDistance.clear();
}

inline void markSrc(common::nodeID_t nodeID) override {
visitedNodeToDistance.insert({nodeID, 0});
if (targetDstNodes->contains(nodeID)) {
numVisitedDstNodes++;
}
currentFrontier->addNodeWithMultiplicity(nodeID, 1);
}

void markVisited(common::nodeID_t boundNodeID, common::nodeID_t nbrNodeID,
common::relID_t relID, uint64_t multiplicity) final {
if (!visitedNodeToDistance.contains(nbrNodeID)) {
visitedNodeToDistance.insert({nbrNodeID, currentLevel});
if (targetDstNodes->contains(nbrNodeID)) {
minDistance = currentLevel;
numVisitedDstNodes++;
}
if constexpr (TRACK_PATH) {
nextFrontier->addEdge(boundNodeID, nbrNodeID, relID);
} else {
nextFrontier->addNodeWithMultiplicity(nbrNodeID, multiplicity);
}
} else if (currentLevel <= visitedNodeToDistance.at(nbrNodeID)) {
if constexpr (TRACK_PATH) {
nextFrontier->addEdge(boundNodeID, nbrNodeID, relID);
} else {
nextFrontier->addNodeWithMultiplicity(nbrNodeID, multiplicity);
}
}
}

private:
inline bool isAllDstReachedWithMinDistance() const {
return numVisitedDstNodes == targetDstNodes->getNumNodes() && currentLevel > minDistance;
}

private:
uint32_t minDistance; // Min distance to add dst nodes that have been reached.
uint64_t numVisitedDstNodes;
frontier::node_id_map_t<uint32_t> visitedNodeToDistance;
};

} // namespace processor
} // namespace kuzu
78 changes: 27 additions & 51 deletions src/include/processor/operator/recursive_extend/recursive_join.h
Original file line number Diff line number Diff line change
@@ -1,43 +1,16 @@
#pragma once

#include "bfs_state.h"
#include "common/query_rel_type.h"
#include "frontier_scanner.h"
#include "planner/logical_plan/logical_operator/recursive_join_type.h"
#include "processor/operator/physical_operator.h"
#include "processor/operator/result_collector.h"
#include "storage/store/node_table.h"

namespace kuzu {
namespace processor {

class ScanFrontier : public PhysicalOperator {
public:
ScanFrontier(DataPos nodeIDVectorPos, uint32_t id, const std::string& paramsString)
: PhysicalOperator{PhysicalOperatorType::SCAN_NODE_ID, id, paramsString},
nodeIDVectorPos{nodeIDVectorPos} {}

inline bool isSource() const override { return true; }

inline void initLocalStateInternal(ResultSet* resultSet_, ExecutionContext* context) override {
nodeIDVector = resultSet->getValueVector(nodeIDVectorPos);
}

bool getNextTuplesInternal(kuzu::processor::ExecutionContext* context) override;

void setNodeID(common::nodeID_t nodeID) {
nodeIDVector->setValue<common::nodeID_t>(0, nodeID);
hasExecuted = false;
}

std::unique_ptr<PhysicalOperator> clone() override {
return std::make_unique<ScanFrontier>(nodeIDVectorPos, id, paramsString);
}

private:
DataPos nodeIDVectorPos;
std::shared_ptr<common::ValueVector> nodeIDVector;
bool hasExecuted;
};
class ScanFrontier;

struct RecursiveJoinSharedState {
std::shared_ptr<FTableSharedState> inputFTableSharedState;
Expand Down Expand Up @@ -65,10 +38,9 @@ struct RecursiveJoinDataInfo {
// Recursive join info.
std::unique_ptr<ResultSetDescriptor> localResultSetDescriptor;
DataPos recursiveDstNodeIDPos;
std::unordered_set<common::table_id_t> recursiveDstNodeTableIDs;
std::unordered_set<common::table_id_t> recursiveDstNodeTableIDs; // TODO: move this out?
DataPos recursiveEdgeIDPos;
// Path info
planner::RecursiveJoinType joinType;
DataPos pathPos;

RecursiveJoinDataInfo(std::vector<DataPos> vectorsToScanPos,
Expand All @@ -77,65 +49,67 @@ struct RecursiveJoinDataInfo {
const DataPos& pathLengthPos, std::unique_ptr<ResultSetDescriptor> localResultSetDescriptor,
const DataPos& recursiveDstNodeIDPos,
std::unordered_set<common::table_id_t> recursiveDstNodeTableIDs,
const DataPos& recursiveEdgeIDPos, planner::RecursiveJoinType joinType)
const DataPos& recursiveEdgeIDPos)
: RecursiveJoinDataInfo{std::move(vectorsToScanPos), std::move(colIndicesToScan),
srcNodePos, dstNodePos, std::move(dstNodeTableIDs), pathLengthPos,
std::move(localResultSetDescriptor), recursiveDstNodeIDPos,
std::move(recursiveDstNodeTableIDs), recursiveEdgeIDPos, joinType, DataPos()} {}
std::move(recursiveDstNodeTableIDs), recursiveEdgeIDPos, DataPos()} {}
RecursiveJoinDataInfo(std::vector<DataPos> vectorsToScanPos,
std::vector<ft_col_idx_t> colIndicesToScan, const DataPos& srcNodePos,
const DataPos& dstNodePos, std::unordered_set<common::table_id_t> dstNodeTableIDs,
const DataPos& pathLengthPos, std::unique_ptr<ResultSetDescriptor> localResultSetDescriptor,
const DataPos& recursiveDstNodeIDPos,
std::unordered_set<common::table_id_t> recursiveDstNodeTableIDs,
const DataPos& recursiveEdgeIDPos, planner::RecursiveJoinType joinType,
const DataPos& pathPos)
const DataPos& recursiveEdgeIDPos, const DataPos& pathPos)
: vectorsToScanPos{std::move(vectorsToScanPos)},
colIndicesToScan{std::move(colIndicesToScan)}, srcNodePos{srcNodePos},
dstNodePos{dstNodePos}, dstNodeTableIDs{std::move(dstNodeTableIDs)},
pathLengthPos{pathLengthPos}, localResultSetDescriptor{std::move(
localResultSetDescriptor)},
recursiveDstNodeIDPos{recursiveDstNodeIDPos}, recursiveDstNodeTableIDs{std::move(
recursiveDstNodeTableIDs)},
recursiveEdgeIDPos{recursiveEdgeIDPos}, joinType{joinType}, pathPos{pathPos} {}
recursiveEdgeIDPos{recursiveEdgeIDPos}, pathPos{pathPos} {}

inline std::unique_ptr<RecursiveJoinDataInfo> copy() {
return std::make_unique<RecursiveJoinDataInfo>(vectorsToScanPos, colIndicesToScan,
srcNodePos, dstNodePos, dstNodeTableIDs, pathLengthPos,
localResultSetDescriptor->copy(), recursiveDstNodeIDPos, recursiveDstNodeTableIDs,
recursiveEdgeIDPos, joinType, pathPos);
recursiveEdgeIDPos, pathPos);
}
};

class BaseRecursiveJoin : public PhysicalOperator {
class RecursiveJoin : public PhysicalOperator {
public:
BaseRecursiveJoin(uint8_t lowerBound, uint8_t upperBound,
std::shared_ptr<RecursiveJoinSharedState> sharedState,
RecursiveJoin(uint8_t lowerBound, uint8_t upperBound, common::QueryRelType queryRelType,
planner::RecursiveJoinType joinType, std::shared_ptr<RecursiveJoinSharedState> sharedState,
std::unique_ptr<RecursiveJoinDataInfo> dataInfo, std::unique_ptr<PhysicalOperator> child,
uint32_t id, const std::string& paramsString,
std::unique_ptr<PhysicalOperator> recursiveRoot)
: PhysicalOperator{PhysicalOperatorType::RECURSIVE_JOIN, std::move(child), id,
paramsString},
lowerBound{lowerBound}, upperBound{upperBound}, sharedState{std::move(sharedState)},
dataInfo{std::move(dataInfo)}, recursiveRoot{std::move(recursiveRoot)} {}
lowerBound{lowerBound}, upperBound{upperBound}, queryRelType{queryRelType},
joinType{joinType}, sharedState{std::move(sharedState)}, dataInfo{std::move(dataInfo)},
recursiveRoot{std::move(recursiveRoot)} {}

BaseRecursiveJoin(uint8_t lowerBound, uint8_t upperBound,
std::shared_ptr<RecursiveJoinSharedState> sharedState,
RecursiveJoin(uint8_t lowerBound, uint8_t upperBound, common::QueryRelType queryRelType,
planner::RecursiveJoinType joinType, std::shared_ptr<RecursiveJoinSharedState> sharedState,
std::unique_ptr<RecursiveJoinDataInfo> dataInfo, uint32_t id,
const std::string& paramsString, std::unique_ptr<PhysicalOperator> recursiveRoot)
: PhysicalOperator{PhysicalOperatorType::RECURSIVE_JOIN, id, paramsString},
lowerBound{lowerBound}, upperBound{upperBound}, sharedState{std::move(sharedState)},
dataInfo{std::move(dataInfo)}, recursiveRoot{std::move(recursiveRoot)} {}

virtual ~BaseRecursiveJoin() = default;
lowerBound{lowerBound}, upperBound{upperBound}, queryRelType{queryRelType},
joinType{joinType}, sharedState{std::move(sharedState)}, dataInfo{std::move(dataInfo)},
recursiveRoot{std::move(recursiveRoot)} {}

inline RecursiveJoinSharedState* getSharedState() const { return sharedState.get(); }

void initLocalStateInternal(ResultSet* resultSet_, ExecutionContext* context) override;
void initLocalStateInternal(ResultSet* resultSet_, ExecutionContext* context) final;

bool getNextTuplesInternal(ExecutionContext* context) override;
bool getNextTuplesInternal(ExecutionContext* context) final;

std::unique_ptr<PhysicalOperator> clone() override = 0;
inline std::unique_ptr<PhysicalOperator> clone() final {
return std::make_unique<RecursiveJoin>(lowerBound, upperBound, queryRelType, joinType,
sharedState, dataInfo->copy(), id, paramsString, recursiveRoot->clone());
}

private:
void initLocalRecursivePlan(ExecutionContext* context);
Expand All @@ -152,6 +126,8 @@ class BaseRecursiveJoin : public PhysicalOperator {
protected:
uint8_t lowerBound;
uint8_t upperBound;
common::QueryRelType queryRelType;
planner::RecursiveJoinType joinType;

std::shared_ptr<RecursiveJoinSharedState> sharedState;
std::unique_ptr<RecursiveJoinDataInfo> dataInfo;
Expand Down
38 changes: 38 additions & 0 deletions src/include/processor/operator/recursive_extend/scan_frontier.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#pragma once

#include "processor/operator/physical_operator.h"

namespace kuzu {
namespace processor {

class ScanFrontier : public PhysicalOperator {
public:
ScanFrontier(DataPos nodeIDVectorPos, uint32_t id, const std::string& paramsString)
: PhysicalOperator{PhysicalOperatorType::SCAN_NODE_ID, id, paramsString},
nodeIDVectorPos{nodeIDVectorPos} {}

inline bool isSource() const override { return true; }

inline void initLocalStateInternal(ResultSet* resultSet_, ExecutionContext* context) override {
nodeIDVector = resultSet->getValueVector(nodeIDVectorPos);
}

bool getNextTuplesInternal(kuzu::processor::ExecutionContext* context) override;

inline void setNodeID(common::nodeID_t nodeID) {
nodeIDVector->setValue<common::nodeID_t>(0, nodeID);
hasExecuted = false;
}

std::unique_ptr<PhysicalOperator> clone() override {
return std::make_unique<ScanFrontier>(nodeIDVectorPos, id, paramsString);
}

private:
DataPos nodeIDVectorPos;
std::shared_ptr<common::ValueVector> nodeIDVector;
bool hasExecuted;
};

} // namespace processor
} // namespace kuzu

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,8 @@ namespace kuzu {
namespace processor {

template<bool TRACK_PATH>
struct ShortestPathState : public BaseBFSState {
// Visited state
uint64_t numVisitedDstNodes;
frontier::node_id_set_t visited;

class ShortestPathState : public BaseBFSState {
public:
ShortestPathState(uint8_t upperBound, TargetDstNodes* targetDstNodes)
: BaseBFSState{upperBound, targetDstNodes}, numVisitedDstNodes{0} {}
~ShortestPathState() override = default;
Expand All @@ -20,7 +17,8 @@ struct ShortestPathState : public BaseBFSState {
}
inline void resetState() final {
BaseBFSState::resetState();
resetVisitedState();
numVisitedDstNodes = 0;
visited.clear();
}

inline void markSrc(common::nodeID_t nodeID) final {
Expand All @@ -47,13 +45,14 @@ struct ShortestPathState : public BaseBFSState {
}
}

private:
inline bool isAllDstReached() const {
return numVisitedDstNodes == targetDstNodes->getNumNodes();
}
inline void resetVisitedState() {
numVisitedDstNodes = 0;
visited.clear();
}

private:
uint64_t numVisitedDstNodes;
frontier::node_id_set_t visited;
};

} // namespace processor
Expand Down
Loading
Loading