Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue 3262 #3384

Merged
merged 1 commit into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/include/planner/operator/extend/extend_direction.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace planner {
enum class ExtendDirection : uint8_t { FWD = 0, BWD = 1, BOTH = 2 };

struct ExtendDirectionUtils {
static inline ExtendDirection getExtendDirection(const binder::RelExpression& relExpression,
static ExtendDirection getExtendDirection(const binder::RelExpression& relExpression,
const binder::NodeExpression& boundNode) {
if (relExpression.getDirectionType() == binder::RelDirectionType::BOTH) {
return ExtendDirection::BOTH;
Expand All @@ -23,7 +23,7 @@ struct ExtendDirectionUtils {
}
}

static inline common::RelDataDirection getRelDataDirection(ExtendDirection extendDirection) {
static common::RelDataDirection getRelDataDirection(ExtendDirection extendDirection) {
KU_ASSERT(extendDirection != ExtendDirection::BOTH);
return extendDirection == ExtendDirection::FWD ? common::RelDataDirection::FWD :
common::RelDataDirection::BWD;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ class PathScanner : public BaseFrontierScanner {
public:
PathScanner(TargetDstNodes* targetDstNodes, size_t k,
std::unordered_map<common::table_id_t, std::string> tableIDToName,
path_semantic_check_t semanticCheckFunc)
path_semantic_check_t semanticCheckFunc, bool extendInBwd)
: BaseFrontierScanner{targetDstNodes, k}, tableIDToName{std::move(tableIDToName)},
semanticCheckFunc{semanticCheckFunc} {
semanticCheckFunc{semanticCheckFunc}, extendInBwd{extendInBwd} {
nodeIDs.resize(k + 1);
relIDs.resize(k + 1);
}
Expand Down Expand Up @@ -97,8 +97,10 @@ class PathScanner : public BaseFrontierScanner {
std::stack<nbrs_t> nbrsStack;
std::stack<int64_t> cursorStack;
std::unordered_map<common::table_id_t, std::string> tableIDToName;

// Path semantic
path_semantic_check_t semanticCheckFunc;
// Extend direction
bool extendInBwd;
};

/*
Expand Down
82 changes: 46 additions & 36 deletions src/include/processor/operator/recursive_extend/recursive_join.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "bfs_state.h"
#include "common/enums/query_rel_type.h"
#include "frontier_scanner.h"
#include "planner/operator/extend/extend_direction.h"
#include "planner/operator/extend/recursive_join_type.h"
#include "processor/operator/mask.h"
#include "processor/operator/physical_operator.h"
Expand Down Expand Up @@ -35,25 +36,21 @@ struct RecursiveJoinDataInfo {
DataPos pathPos;
std::unordered_map<common::table_id_t, std::string> tableIDToName;

RecursiveJoinDataInfo(const DataPos& srcNodePos, const DataPos& dstNodePos,
std::unordered_set<common::table_id_t> dstNodeTableIDs, const DataPos& pathLengthPos,
std::unique_ptr<ResultSetDescriptor> localResultSetDescriptor,
const DataPos& recursiveDstNodeIDPos,
std::unordered_set<common::table_id_t> recursiveDstNodeTableIDs,
const DataPos& recursiveEdgeIDPos, const DataPos& pathPos,
std::unordered_map<common::table_id_t, std::string> tableIDToName)
: srcNodePos{srcNodePos}, dstNodePos{dstNodePos},
dstNodeTableIDs{std::move(dstNodeTableIDs)}, pathLengthPos{pathLengthPos},
localResultSetDescriptor{std::move(localResultSetDescriptor)},
recursiveDstNodeIDPos{recursiveDstNodeIDPos},
recursiveDstNodeTableIDs{std::move(recursiveDstNodeTableIDs)},
recursiveEdgeIDPos{recursiveEdgeIDPos}, pathPos{pathPos},
tableIDToName{std::move(tableIDToName)} {}

inline std::unique_ptr<RecursiveJoinDataInfo> copy() {
return std::make_unique<RecursiveJoinDataInfo>(srcNodePos, dstNodePos, dstNodeTableIDs,
pathLengthPos, localResultSetDescriptor->copy(), recursiveDstNodeIDPos,
recursiveDstNodeTableIDs, recursiveEdgeIDPos, pathPos, tableIDToName);
RecursiveJoinDataInfo() = default;
EXPLICIT_COPY_DEFAULT_MOVE(RecursiveJoinDataInfo);

private:
RecursiveJoinDataInfo(const RecursiveJoinDataInfo& other) {
srcNodePos = other.srcNodePos;
dstNodePos = other.dstNodePos;
dstNodeTableIDs = other.dstNodeTableIDs;
pathLengthPos = other.pathLengthPos;
localResultSetDescriptor = other.localResultSetDescriptor->copy();
recursiveDstNodeIDPos = other.recursiveDstNodeIDPos;
recursiveDstNodeTableIDs = other.recursiveDstNodeTableIDs;
recursiveEdgeIDPos = other.recursiveEdgeIDPos;
pathPos = other.pathPos;
tableIDToName = other.tableIDToName;
}
};

Expand All @@ -75,29 +72,47 @@ struct RecursiveJoinVectors {
common::ValueVector* recursiveDstNodeIDVector = nullptr;
};

struct RecursiveJoinInfo {
RecursiveJoinDataInfo dataInfo;
uint8_t lowerBound;
uint8_t upperBound;
common::QueryRelType queryRelType;
planner::RecursiveJoinType joinType;
planner::ExtendDirection direction;

RecursiveJoinInfo() = default;
EXPLICIT_COPY_DEFAULT_MOVE(RecursiveJoinInfo);

private:
RecursiveJoinInfo(const RecursiveJoinInfo& other) {
dataInfo = other.dataInfo.copy();
lowerBound = other.lowerBound;
upperBound = other.upperBound;
queryRelType = other.queryRelType;
joinType = other.joinType;
direction = other.direction;
}
};

class RecursiveJoin : public PhysicalOperator {
public:
RecursiveJoin(uint8_t lowerBound, uint8_t upperBound, common::QueryRelType queryRelType,
planner::RecursiveJoinType joinType, std::shared_ptr<RecursiveJoinSharedState> sharedState,
std::unique_ptr<RecursiveJoinDataInfo> dataInfo, std::unique_ptr<PhysicalOperator> child,
uint32_t id, const std::string& paramsString,
RecursiveJoin(RecursiveJoinInfo info, std::shared_ptr<RecursiveJoinSharedState> sharedState,
std::unique_ptr<PhysicalOperator> child, uint32_t id, const std::string& paramsString,
std::unique_ptr<PhysicalOperator> recursiveRoot)
: PhysicalOperator{PhysicalOperatorType::RECURSIVE_JOIN, std::move(child), id,
paramsString},
lowerBound{lowerBound}, upperBound{upperBound}, queryRelType{queryRelType},
joinType{joinType}, sharedState{std::move(sharedState)}, dataInfo{std::move(dataInfo)},
info{std::move(info)}, sharedState{std::move(sharedState)},
recursiveRoot{std::move(recursiveRoot)} {}

inline RecursiveJoinSharedState* getSharedState() const { return sharedState.get(); }
RecursiveJoinSharedState* getSharedState() const { return sharedState.get(); }

void initLocalStateInternal(ResultSet* resultSet_, ExecutionContext* context) final;

bool getNextTuplesInternal(ExecutionContext* context) final;

inline std::unique_ptr<PhysicalOperator> clone() final {
return std::make_unique<RecursiveJoin>(lowerBound, upperBound, queryRelType, joinType,
sharedState, dataInfo->copy(), children[0]->clone(), id, paramsString,
recursiveRoot->clone());
std::unique_ptr<PhysicalOperator> clone() final {
return std::make_unique<RecursiveJoin>(info.copy(), sharedState, children[0]->clone(), id,
paramsString, recursiveRoot->clone());
}

private:
Expand All @@ -113,13 +128,8 @@ class RecursiveJoin : public PhysicalOperator {
void updateVisitedNodes(common::nodeID_t boundNodeID);

private:
uint8_t lowerBound;
uint8_t upperBound;
common::QueryRelType queryRelType;
planner::RecursiveJoinType joinType;

RecursiveJoinInfo info;
std::shared_ptr<RecursiveJoinSharedState> sharedState;
std::unique_ptr<RecursiveJoinDataInfo> dataInfo;

// Local recursive plan
std::unique_ptr<ResultSet> localResultSet;
Expand Down
55 changes: 29 additions & 26 deletions src/processor/map/map_recursive_extend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,48 +20,51 @@ static std::shared_ptr<RecursiveJoinSharedState> createSharedState(
return std::make_shared<RecursiveJoinSharedState>(std::move(semiMasks));
}

std::unique_ptr<PhysicalOperator> PlanMapper::mapRecursiveExtend(
planner::LogicalOperator* logicalOperator) {
auto extend = (LogicalRecursiveExtend*)logicalOperator;
std::unique_ptr<PhysicalOperator> PlanMapper::mapRecursiveExtend(LogicalOperator* logicalOperator) {
auto extend = logicalOperator->constPtrCast<LogicalRecursiveExtend>();
auto boundNode = extend->getBoundNode();
auto nbrNode = extend->getNbrNode();
auto rel = extend->getRel();
auto recursiveInfo = rel->getRecursiveInfo();
auto lengthExpression = rel->getLengthExpression();
// Map recursive plan
auto logicalRecursiveRoot = extend->getRecursiveChild();
auto recursiveRoot = mapOperator(logicalRecursiveRoot.get());
auto recursivePlanSchema = logicalRecursiveRoot->getSchema();
auto recursivePlanResultSetDescriptor =
std::make_unique<ResultSetDescriptor>(recursivePlanSchema);
auto recursiveDstNodeIDPos =
DataPos(recursivePlanSchema->getExpressionPos(*recursiveInfo->nodeCopy->getInternalID()));
auto recursiveEdgeIDPos = DataPos(
recursivePlanSchema->getExpressionPos(*recursiveInfo->rel->getInternalIDProperty()));
// Generate RecursiveJoin
auto outSchema = extend->getSchema();
auto inSchema = extend->getChild(0)->getSchema();
auto boundNodeIDPos = DataPos(inSchema->getExpressionPos(*boundNode->getInternalID()));
auto nbrNodeIDPos = DataPos(outSchema->getExpressionPos(*nbrNode->getInternalID()));
auto lengthPos = DataPos(outSchema->getExpressionPos(*lengthExpression));
auto sharedState = createSharedState(*nbrNode, *clientContext->getStorageManager());
auto pathPos = DataPos();
if (extend->getJoinType() == planner::RecursiveJoinType::TRACK_PATH) {
pathPos = DataPos(outSchema->getExpressionPos(*rel));
// Data info
auto dataInfo = RecursiveJoinDataInfo();
dataInfo.srcNodePos = getDataPos(*boundNode->getInternalID(), *inSchema);
dataInfo.dstNodePos = getDataPos(*nbrNode->getInternalID(), *outSchema);
dataInfo.dstNodeTableIDs = nbrNode->getTableIDsSet();
dataInfo.pathLengthPos = getDataPos(*rel->getLengthExpression(), *outSchema);
dataInfo.localResultSetDescriptor = std::make_unique<ResultSetDescriptor>(recursivePlanSchema);
dataInfo.recursiveDstNodeIDPos =
getDataPos(*recursiveInfo->nodeCopy->getInternalID(), *recursivePlanSchema);
dataInfo.recursiveDstNodeTableIDs = recursiveInfo->node->getTableIDsSet();
dataInfo.recursiveEdgeIDPos =
getDataPos(*recursiveInfo->rel->getInternalIDProperty(), *recursivePlanSchema);
if (extend->getJoinType() == RecursiveJoinType::TRACK_PATH) {
dataInfo.pathPos = getDataPos(*rel, *outSchema);
} else {
dataInfo.pathPos = DataPos::getInvalidPos();
}
std::unordered_map<common::table_id_t, std::string> tableIDToName;
for (auto& entry : clientContext->getCatalog()->getTableEntries(clientContext->getTx())) {
tableIDToName.insert({entry->getTableID(), entry->getName()});
dataInfo.tableIDToName.insert({entry->getTableID(), entry->getName()});
}
auto dataInfo = std::make_unique<RecursiveJoinDataInfo>(boundNodeIDPos, nbrNodeIDPos,
nbrNode->getTableIDsSet(), lengthPos, std::move(recursivePlanResultSetDescriptor),
recursiveDstNodeIDPos, recursiveInfo->node->getTableIDsSet(), recursiveEdgeIDPos, pathPos,
std::move(tableIDToName));
// Info
auto info = RecursiveJoinInfo();
info.dataInfo = std::move(dataInfo);
info.lowerBound = rel->getLowerBound();
info.upperBound = rel->getUpperBound();
info.queryRelType = rel->getRelType();
info.joinType = extend->getJoinType();
info.direction = extend->getDirection();
auto prevOperator = mapOperator(logicalOperator->getChild(0).get());
return std::make_unique<RecursiveJoin>(rel->getLowerBound(), rel->getUpperBound(),
rel->getRelType(), extend->getJoinType(), sharedState, std::move(dataInfo),
std::move(prevOperator), getOperatorID(), extend->getExpressionsForPrinting(),
std::move(recursiveRoot));
return std::make_unique<RecursiveJoin>(std::move(info), sharedState, std::move(prevOperator),
getOperatorID(), extend->getExpressionsForPrinting(), std::move(recursiveRoot));
}

} // namespace processor
Expand Down
41 changes: 30 additions & 11 deletions src/processor/operator/recursive_extend/frontier_scanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,18 +122,29 @@ void PathScanner::initDfs(const frontier::node_rel_id_t& nodeAndRelID, size_t cu
initDfs(nbrs->at(0), currentDepth - 1);
}

static void writePathRels(RecursiveJoinVectors* vectors, sel_t pos, nodeID_t srcNodeID,
nodeID_t dstNodeID, relID_t relID, const std::string& labelName) {
vectors->pathRelsSrcIDDataVector->setValue<nodeID_t>(pos, srcNodeID);
vectors->pathRelsDstIDDataVector->setValue<nodeID_t>(pos, dstNodeID);
vectors->pathRelsIDDataVector->setValue<relID_t>(pos, relID);
StringVector::addString(vectors->pathRelsLabelDataVector, pos, labelName);
}

void PathScanner::writePathToVector(RecursiveJoinVectors* vectors, sel_t& vectorPos,
sel_t& nodeIDDataVectorPos, sel_t& relIDDataVectorPos) {
if (semanticCheckFunc && !semanticCheckFunc(nodeIDs, relIDs)) {
return;
}
KU_ASSERT(vectorPos < DEFAULT_VECTOR_CAPACITY);
// Allocate list entries.
auto nodeTableEntry = ListVector::addList(vectors->pathNodesVector, k > 0 ? k - 1 : 0);
auto relTableEntry = ListVector::addList(vectors->pathRelsVector, k);
vectors->pathNodesVector->setValue(vectorPos, nodeTableEntry);
vectors->pathRelsVector->setValue(vectorPos, relTableEntry);
// Write dst
writeDstNodeOffsetAndLength(vectors->dstNodeIDVector, vectors->pathLengthVector, vectorPos);
vectorPos++;
// Write path nodes.
for (auto i = 1u; i < k; ++i) {
auto nodeID = nodeIDs[i];
vectors->pathNodesIDDataVector->setValue<nodeID_t>(nodeIDDataVectorPos, nodeID);
Expand All @@ -142,17 +153,25 @@ void PathScanner::writePathToVector(RecursiveJoinVectors* vectors, sel_t& vector
labelName.data(), labelName.length());
nodeIDDataVectorPos++;
}
for (auto i = 0u; i < k; ++i) {
auto srcNodeID = nodeIDs[i];
auto dstNodeID = nodeIDs[i + 1];
vectors->pathRelsSrcIDDataVector->setValue<nodeID_t>(relIDDataVectorPos, srcNodeID);
vectors->pathRelsDstIDDataVector->setValue<nodeID_t>(relIDDataVectorPos, dstNodeID);
auto relID = relIDs[i];
vectors->pathRelsIDDataVector->setValue<relID_t>(relIDDataVectorPos, relID);
auto labelName = tableIDToName.at(relID.tableID);
StringVector::addString(vectors->pathRelsLabelDataVector, relIDDataVectorPos,
labelName.data(), labelName.length());
relIDDataVectorPos++;
// Write path rels.
if (extendInBwd) {
for (auto i = 0u; i < k; ++i) {
auto srcNodeID = nodeIDs[i + 1];
auto dstNodeID = nodeIDs[i];
auto relID = relIDs[i];
writePathRels(vectors, relIDDataVectorPos, srcNodeID, dstNodeID, relID,
tableIDToName.at(relID.tableID));
relIDDataVectorPos++;
}
} else {
for (auto i = 0u; i < k; ++i) {
auto srcNodeID = nodeIDs[i];
auto dstNodeID = nodeIDs[i + 1];
auto relID = relIDs[i];
writePathRels(vectors, relIDDataVectorPos, srcNodeID, dstNodeID, relID,
tableIDToName.at(relID.tableID));
relIDDataVectorPos++;
}
}
}

Expand Down
Loading
Loading