Skip to content

Commit

Permalink
Fix issue 3262 (#3384)
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed Apr 25, 2024
1 parent 2a8a68a commit 9ad863a
Show file tree
Hide file tree
Showing 7 changed files with 193 additions and 97 deletions.
4 changes: 2 additions & 2 deletions src/include/planner/operator/extend/extend_direction.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace planner {
enum class ExtendDirection : uint8_t { FWD = 0, BWD = 1, BOTH = 2 };

struct ExtendDirectionUtils {
static inline ExtendDirection getExtendDirection(const binder::RelExpression& relExpression,
static ExtendDirection getExtendDirection(const binder::RelExpression& relExpression,
const binder::NodeExpression& boundNode) {
if (relExpression.getDirectionType() == binder::RelDirectionType::BOTH) {
return ExtendDirection::BOTH;
Expand All @@ -23,7 +23,7 @@ struct ExtendDirectionUtils {
}
}

static inline common::RelDataDirection getRelDataDirection(ExtendDirection extendDirection) {
static common::RelDataDirection getRelDataDirection(ExtendDirection extendDirection) {
KU_ASSERT(extendDirection != ExtendDirection::BOTH);
return extendDirection == ExtendDirection::FWD ? common::RelDataDirection::FWD :
common::RelDataDirection::BWD;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ class PathScanner : public BaseFrontierScanner {
public:
PathScanner(TargetDstNodes* targetDstNodes, size_t k,
std::unordered_map<common::table_id_t, std::string> tableIDToName,
path_semantic_check_t semanticCheckFunc)
path_semantic_check_t semanticCheckFunc, bool extendInBwd)
: BaseFrontierScanner{targetDstNodes, k}, tableIDToName{std::move(tableIDToName)},
semanticCheckFunc{semanticCheckFunc} {
semanticCheckFunc{semanticCheckFunc}, extendInBwd{extendInBwd} {
nodeIDs.resize(k + 1);
relIDs.resize(k + 1);
}
Expand Down Expand Up @@ -97,8 +97,10 @@ class PathScanner : public BaseFrontierScanner {
std::stack<nbrs_t> nbrsStack;
std::stack<int64_t> cursorStack;
std::unordered_map<common::table_id_t, std::string> tableIDToName;

// Path semantic
path_semantic_check_t semanticCheckFunc;
// Extend direction
bool extendInBwd;
};

/*
Expand Down
82 changes: 46 additions & 36 deletions src/include/processor/operator/recursive_extend/recursive_join.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "bfs_state.h"
#include "common/enums/query_rel_type.h"
#include "frontier_scanner.h"
#include "planner/operator/extend/extend_direction.h"
#include "planner/operator/extend/recursive_join_type.h"
#include "processor/operator/mask.h"
#include "processor/operator/physical_operator.h"
Expand Down Expand Up @@ -35,25 +36,21 @@ struct RecursiveJoinDataInfo {
DataPos pathPos;
std::unordered_map<common::table_id_t, std::string> tableIDToName;

RecursiveJoinDataInfo(const DataPos& srcNodePos, const DataPos& dstNodePos,
std::unordered_set<common::table_id_t> dstNodeTableIDs, const DataPos& pathLengthPos,
std::unique_ptr<ResultSetDescriptor> localResultSetDescriptor,
const DataPos& recursiveDstNodeIDPos,
std::unordered_set<common::table_id_t> recursiveDstNodeTableIDs,
const DataPos& recursiveEdgeIDPos, const DataPos& pathPos,
std::unordered_map<common::table_id_t, std::string> tableIDToName)
: srcNodePos{srcNodePos}, dstNodePos{dstNodePos},
dstNodeTableIDs{std::move(dstNodeTableIDs)}, pathLengthPos{pathLengthPos},
localResultSetDescriptor{std::move(localResultSetDescriptor)},
recursiveDstNodeIDPos{recursiveDstNodeIDPos},
recursiveDstNodeTableIDs{std::move(recursiveDstNodeTableIDs)},
recursiveEdgeIDPos{recursiveEdgeIDPos}, pathPos{pathPos},
tableIDToName{std::move(tableIDToName)} {}

inline std::unique_ptr<RecursiveJoinDataInfo> copy() {
return std::make_unique<RecursiveJoinDataInfo>(srcNodePos, dstNodePos, dstNodeTableIDs,
pathLengthPos, localResultSetDescriptor->copy(), recursiveDstNodeIDPos,
recursiveDstNodeTableIDs, recursiveEdgeIDPos, pathPos, tableIDToName);
RecursiveJoinDataInfo() = default;
EXPLICIT_COPY_DEFAULT_MOVE(RecursiveJoinDataInfo);

private:
RecursiveJoinDataInfo(const RecursiveJoinDataInfo& other) {
srcNodePos = other.srcNodePos;
dstNodePos = other.dstNodePos;
dstNodeTableIDs = other.dstNodeTableIDs;
pathLengthPos = other.pathLengthPos;
localResultSetDescriptor = other.localResultSetDescriptor->copy();
recursiveDstNodeIDPos = other.recursiveDstNodeIDPos;
recursiveDstNodeTableIDs = other.recursiveDstNodeTableIDs;
recursiveEdgeIDPos = other.recursiveEdgeIDPos;
pathPos = other.pathPos;
tableIDToName = other.tableIDToName;
}
};

Expand All @@ -75,29 +72,47 @@ struct RecursiveJoinVectors {
common::ValueVector* recursiveDstNodeIDVector = nullptr;
};

struct RecursiveJoinInfo {
RecursiveJoinDataInfo dataInfo;
uint8_t lowerBound;
uint8_t upperBound;
common::QueryRelType queryRelType;
planner::RecursiveJoinType joinType;
planner::ExtendDirection direction;

RecursiveJoinInfo() = default;
EXPLICIT_COPY_DEFAULT_MOVE(RecursiveJoinInfo);

private:
RecursiveJoinInfo(const RecursiveJoinInfo& other) {
dataInfo = other.dataInfo.copy();
lowerBound = other.lowerBound;
upperBound = other.upperBound;
queryRelType = other.queryRelType;
joinType = other.joinType;
direction = other.direction;
}
};

class RecursiveJoin : public PhysicalOperator {
public:
RecursiveJoin(uint8_t lowerBound, uint8_t upperBound, common::QueryRelType queryRelType,
planner::RecursiveJoinType joinType, std::shared_ptr<RecursiveJoinSharedState> sharedState,
std::unique_ptr<RecursiveJoinDataInfo> dataInfo, std::unique_ptr<PhysicalOperator> child,
uint32_t id, const std::string& paramsString,
RecursiveJoin(RecursiveJoinInfo info, std::shared_ptr<RecursiveJoinSharedState> sharedState,
std::unique_ptr<PhysicalOperator> child, uint32_t id, const std::string& paramsString,
std::unique_ptr<PhysicalOperator> recursiveRoot)
: PhysicalOperator{PhysicalOperatorType::RECURSIVE_JOIN, std::move(child), id,
paramsString},
lowerBound{lowerBound}, upperBound{upperBound}, queryRelType{queryRelType},
joinType{joinType}, sharedState{std::move(sharedState)}, dataInfo{std::move(dataInfo)},
info{std::move(info)}, sharedState{std::move(sharedState)},
recursiveRoot{std::move(recursiveRoot)} {}

inline RecursiveJoinSharedState* getSharedState() const { return sharedState.get(); }
RecursiveJoinSharedState* getSharedState() const { return sharedState.get(); }

void initLocalStateInternal(ResultSet* resultSet_, ExecutionContext* context) final;

bool getNextTuplesInternal(ExecutionContext* context) final;

inline std::unique_ptr<PhysicalOperator> clone() final {
return std::make_unique<RecursiveJoin>(lowerBound, upperBound, queryRelType, joinType,
sharedState, dataInfo->copy(), children[0]->clone(), id, paramsString,
recursiveRoot->clone());
std::unique_ptr<PhysicalOperator> clone() final {
return std::make_unique<RecursiveJoin>(info.copy(), sharedState, children[0]->clone(), id,
paramsString, recursiveRoot->clone());
}

private:
Expand All @@ -113,13 +128,8 @@ class RecursiveJoin : public PhysicalOperator {
void updateVisitedNodes(common::nodeID_t boundNodeID);

private:
uint8_t lowerBound;
uint8_t upperBound;
common::QueryRelType queryRelType;
planner::RecursiveJoinType joinType;

RecursiveJoinInfo info;
std::shared_ptr<RecursiveJoinSharedState> sharedState;
std::unique_ptr<RecursiveJoinDataInfo> dataInfo;

// Local recursive plan
std::unique_ptr<ResultSet> localResultSet;
Expand Down
55 changes: 29 additions & 26 deletions src/processor/map/map_recursive_extend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,48 +20,51 @@ static std::shared_ptr<RecursiveJoinSharedState> createSharedState(
return std::make_shared<RecursiveJoinSharedState>(std::move(semiMasks));
}

std::unique_ptr<PhysicalOperator> PlanMapper::mapRecursiveExtend(
planner::LogicalOperator* logicalOperator) {
auto extend = (LogicalRecursiveExtend*)logicalOperator;
std::unique_ptr<PhysicalOperator> PlanMapper::mapRecursiveExtend(LogicalOperator* logicalOperator) {
auto extend = logicalOperator->constPtrCast<LogicalRecursiveExtend>();
auto boundNode = extend->getBoundNode();
auto nbrNode = extend->getNbrNode();
auto rel = extend->getRel();
auto recursiveInfo = rel->getRecursiveInfo();
auto lengthExpression = rel->getLengthExpression();
// Map recursive plan
auto logicalRecursiveRoot = extend->getRecursiveChild();
auto recursiveRoot = mapOperator(logicalRecursiveRoot.get());
auto recursivePlanSchema = logicalRecursiveRoot->getSchema();
auto recursivePlanResultSetDescriptor =
std::make_unique<ResultSetDescriptor>(recursivePlanSchema);
auto recursiveDstNodeIDPos =
DataPos(recursivePlanSchema->getExpressionPos(*recursiveInfo->nodeCopy->getInternalID()));
auto recursiveEdgeIDPos = DataPos(
recursivePlanSchema->getExpressionPos(*recursiveInfo->rel->getInternalIDProperty()));
// Generate RecursiveJoin
auto outSchema = extend->getSchema();
auto inSchema = extend->getChild(0)->getSchema();
auto boundNodeIDPos = DataPos(inSchema->getExpressionPos(*boundNode->getInternalID()));
auto nbrNodeIDPos = DataPos(outSchema->getExpressionPos(*nbrNode->getInternalID()));
auto lengthPos = DataPos(outSchema->getExpressionPos(*lengthExpression));
auto sharedState = createSharedState(*nbrNode, *clientContext->getStorageManager());
auto pathPos = DataPos();
if (extend->getJoinType() == planner::RecursiveJoinType::TRACK_PATH) {
pathPos = DataPos(outSchema->getExpressionPos(*rel));
// Data info
auto dataInfo = RecursiveJoinDataInfo();
dataInfo.srcNodePos = getDataPos(*boundNode->getInternalID(), *inSchema);
dataInfo.dstNodePos = getDataPos(*nbrNode->getInternalID(), *outSchema);
dataInfo.dstNodeTableIDs = nbrNode->getTableIDsSet();
dataInfo.pathLengthPos = getDataPos(*rel->getLengthExpression(), *outSchema);
dataInfo.localResultSetDescriptor = std::make_unique<ResultSetDescriptor>(recursivePlanSchema);
dataInfo.recursiveDstNodeIDPos =
getDataPos(*recursiveInfo->nodeCopy->getInternalID(), *recursivePlanSchema);
dataInfo.recursiveDstNodeTableIDs = recursiveInfo->node->getTableIDsSet();
dataInfo.recursiveEdgeIDPos =
getDataPos(*recursiveInfo->rel->getInternalIDProperty(), *recursivePlanSchema);
if (extend->getJoinType() == RecursiveJoinType::TRACK_PATH) {
dataInfo.pathPos = getDataPos(*rel, *outSchema);
} else {
dataInfo.pathPos = DataPos::getInvalidPos();
}
std::unordered_map<common::table_id_t, std::string> tableIDToName;
for (auto& entry : clientContext->getCatalog()->getTableEntries(clientContext->getTx())) {
tableIDToName.insert({entry->getTableID(), entry->getName()});
dataInfo.tableIDToName.insert({entry->getTableID(), entry->getName()});
}
auto dataInfo = std::make_unique<RecursiveJoinDataInfo>(boundNodeIDPos, nbrNodeIDPos,
nbrNode->getTableIDsSet(), lengthPos, std::move(recursivePlanResultSetDescriptor),
recursiveDstNodeIDPos, recursiveInfo->node->getTableIDsSet(), recursiveEdgeIDPos, pathPos,
std::move(tableIDToName));
// Info
auto info = RecursiveJoinInfo();
info.dataInfo = std::move(dataInfo);
info.lowerBound = rel->getLowerBound();
info.upperBound = rel->getUpperBound();
info.queryRelType = rel->getRelType();
info.joinType = extend->getJoinType();
info.direction = extend->getDirection();
auto prevOperator = mapOperator(logicalOperator->getChild(0).get());
return std::make_unique<RecursiveJoin>(rel->getLowerBound(), rel->getUpperBound(),
rel->getRelType(), extend->getJoinType(), sharedState, std::move(dataInfo),
std::move(prevOperator), getOperatorID(), extend->getExpressionsForPrinting(),
std::move(recursiveRoot));
return std::make_unique<RecursiveJoin>(std::move(info), sharedState, std::move(prevOperator),
getOperatorID(), extend->getExpressionsForPrinting(), std::move(recursiveRoot));
}

} // namespace processor
Expand Down
41 changes: 30 additions & 11 deletions src/processor/operator/recursive_extend/frontier_scanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,18 +122,29 @@ void PathScanner::initDfs(const frontier::node_rel_id_t& nodeAndRelID, size_t cu
initDfs(nbrs->at(0), currentDepth - 1);
}

static void writePathRels(RecursiveJoinVectors* vectors, sel_t pos, nodeID_t srcNodeID,
nodeID_t dstNodeID, relID_t relID, const std::string& labelName) {
vectors->pathRelsSrcIDDataVector->setValue<nodeID_t>(pos, srcNodeID);
vectors->pathRelsDstIDDataVector->setValue<nodeID_t>(pos, dstNodeID);
vectors->pathRelsIDDataVector->setValue<relID_t>(pos, relID);
StringVector::addString(vectors->pathRelsLabelDataVector, pos, labelName);
}

void PathScanner::writePathToVector(RecursiveJoinVectors* vectors, sel_t& vectorPos,
sel_t& nodeIDDataVectorPos, sel_t& relIDDataVectorPos) {
if (semanticCheckFunc && !semanticCheckFunc(nodeIDs, relIDs)) {
return;
}
KU_ASSERT(vectorPos < DEFAULT_VECTOR_CAPACITY);
// Allocate list entries.
auto nodeTableEntry = ListVector::addList(vectors->pathNodesVector, k > 0 ? k - 1 : 0);
auto relTableEntry = ListVector::addList(vectors->pathRelsVector, k);
vectors->pathNodesVector->setValue(vectorPos, nodeTableEntry);
vectors->pathRelsVector->setValue(vectorPos, relTableEntry);
// Write dst
writeDstNodeOffsetAndLength(vectors->dstNodeIDVector, vectors->pathLengthVector, vectorPos);
vectorPos++;
// Write path nodes.
for (auto i = 1u; i < k; ++i) {
auto nodeID = nodeIDs[i];
vectors->pathNodesIDDataVector->setValue<nodeID_t>(nodeIDDataVectorPos, nodeID);
Expand All @@ -142,17 +153,25 @@ void PathScanner::writePathToVector(RecursiveJoinVectors* vectors, sel_t& vector
labelName.data(), labelName.length());
nodeIDDataVectorPos++;
}
for (auto i = 0u; i < k; ++i) {
auto srcNodeID = nodeIDs[i];
auto dstNodeID = nodeIDs[i + 1];
vectors->pathRelsSrcIDDataVector->setValue<nodeID_t>(relIDDataVectorPos, srcNodeID);
vectors->pathRelsDstIDDataVector->setValue<nodeID_t>(relIDDataVectorPos, dstNodeID);
auto relID = relIDs[i];
vectors->pathRelsIDDataVector->setValue<relID_t>(relIDDataVectorPos, relID);
auto labelName = tableIDToName.at(relID.tableID);
StringVector::addString(vectors->pathRelsLabelDataVector, relIDDataVectorPos,
labelName.data(), labelName.length());
relIDDataVectorPos++;
// Write path rels.
if (extendInBwd) {
for (auto i = 0u; i < k; ++i) {
auto srcNodeID = nodeIDs[i + 1];
auto dstNodeID = nodeIDs[i];
auto relID = relIDs[i];
writePathRels(vectors, relIDDataVectorPos, srcNodeID, dstNodeID, relID,
tableIDToName.at(relID.tableID));
relIDDataVectorPos++;
}
} else {
for (auto i = 0u; i < k; ++i) {
auto srcNodeID = nodeIDs[i];
auto dstNodeID = nodeIDs[i + 1];
auto relID = relIDs[i];
writePathRels(vectors, relIDDataVectorPos, srcNodeID, dstNodeID, relID,
tableIDToName.at(relID.tableID));
relIDDataVectorPos++;
}
}
}

Expand Down
Loading

0 comments on commit 9ad863a

Please sign in to comment.