Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recursive join fix #1607

Merged
merged 1 commit into from
Jun 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/include/processor/operator/recursive_extend/bfs_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,16 @@ class TargetDstNodes {
TargetDstNodes(uint64_t numNodes, frontier::node_id_set_t nodeIDs)
: numNodes{numNodes}, nodeIDs{std::move(nodeIDs)} {}

inline void setTableIDFilter(std::unordered_set<common::table_id_t> filter) {
tableIDFilter = std::move(filter);
}

inline bool contains(common::nodeID_t nodeID) {
if (nodeIDs.empty()) { // All nodeIDs are targets
return true;
if (nodeIDs.empty()) { // no semi mask available
if (tableIDFilter.empty()) { // no dst table ID filter available
return true;
}
return tableIDFilter.contains(nodeID.tableID);
}
return nodeIDs.contains(nodeID);
}
Expand All @@ -26,6 +33,7 @@ class TargetDstNodes {
private:
uint64_t numNodes;
frontier::node_id_set_t nodeIDs;
std::unordered_set<common::table_id_t> tableIDFilter;
};

class BaseBFSState {
Expand Down
49 changes: 29 additions & 20 deletions src/include/processor/operator/recursive_extend/recursive_join.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,41 +60,50 @@ struct RecursiveJoinDataInfo {
DataPos srcNodePos;
// Join output info.
DataPos dstNodePos;
std::unordered_set<common::table_id_t> dstNodeTableIDs;
DataPos pathLengthPos;
// Recursive join info.
std::unique_ptr<ResultSetDescriptor> localResultSetDescriptor;
DataPos tmpDstNodeIDPos;
DataPos tmpEdgeIDPos;
DataPos recursiveDstNodeIDPos;
std::unordered_set<common::table_id_t> recursiveDstNodeTableIDs;
DataPos recursiveEdgeIDPos;
// Path info
planner::RecursiveJoinType joinType;
DataPos pathPos;

RecursiveJoinDataInfo(std::vector<DataPos> vectorsToScanPos,
std::vector<ft_col_idx_t> colIndicesToScan, const DataPos& srcNodePos,
const DataPos& dstNodePos, const DataPos& pathLengthPos,
std::unique_ptr<ResultSetDescriptor> localResultSetDescriptor,
const DataPos& tmpDstNodeIDPos, const DataPos& tmpEdgeIDPos,
planner::RecursiveJoinType joinType)
const DataPos& dstNodePos, std::unordered_set<common::table_id_t> dstNodeTableIDs,
const DataPos& pathLengthPos, std::unique_ptr<ResultSetDescriptor> localResultSetDescriptor,
const DataPos& recursiveDstNodeIDPos,
std::unordered_set<common::table_id_t> recursiveDstNodeTableIDs,
const DataPos& recursiveEdgeIDPos, planner::RecursiveJoinType joinType)
: RecursiveJoinDataInfo{std::move(vectorsToScanPos), std::move(colIndicesToScan),
srcNodePos, dstNodePos, pathLengthPos, std::move(localResultSetDescriptor),
tmpDstNodeIDPos, tmpEdgeIDPos, joinType, DataPos()} {}
srcNodePos, dstNodePos, std::move(dstNodeTableIDs), pathLengthPos,
std::move(localResultSetDescriptor), recursiveDstNodeIDPos,
std::move(recursiveDstNodeTableIDs), recursiveEdgeIDPos, joinType, DataPos()} {}
RecursiveJoinDataInfo(std::vector<DataPos> vectorsToScanPos,
std::vector<ft_col_idx_t> colIndicesToScan, const DataPos& srcNodePos,
const DataPos& dstNodePos, const DataPos& pathLengthPos,
std::unique_ptr<ResultSetDescriptor> localResultSetDescriptor,
const DataPos& tmpDstNodeIDPos, const DataPos& tmpEdgeIDPos,
planner::RecursiveJoinType joinType, const DataPos& pathPos)
const DataPos& dstNodePos, std::unordered_set<common::table_id_t> dstNodeTableIDs,
const DataPos& pathLengthPos, std::unique_ptr<ResultSetDescriptor> localResultSetDescriptor,
const DataPos& recursiveDstNodeIDPos,
std::unordered_set<common::table_id_t> recursiveDstNodeTableIDs,
const DataPos& recursiveEdgeIDPos, planner::RecursiveJoinType joinType,
const DataPos& pathPos)
: vectorsToScanPos{std::move(vectorsToScanPos)},
colIndicesToScan{std::move(colIndicesToScan)}, srcNodePos{srcNodePos},
dstNodePos{dstNodePos}, pathLengthPos{pathLengthPos}, localResultSetDescriptor{std::move(
localResultSetDescriptor)},
tmpDstNodeIDPos{tmpDstNodeIDPos},
tmpEdgeIDPos{tmpEdgeIDPos}, joinType{joinType}, pathPos{pathPos} {}
dstNodePos{dstNodePos}, dstNodeTableIDs{std::move(dstNodeTableIDs)},
pathLengthPos{pathLengthPos}, localResultSetDescriptor{std::move(
localResultSetDescriptor)},
recursiveDstNodeIDPos{recursiveDstNodeIDPos}, recursiveDstNodeTableIDs{std::move(
recursiveDstNodeTableIDs)},
recursiveEdgeIDPos{recursiveEdgeIDPos}, joinType{joinType}, pathPos{pathPos} {}

inline std::unique_ptr<RecursiveJoinDataInfo> copy() {
return std::make_unique<RecursiveJoinDataInfo>(vectorsToScanPos, colIndicesToScan,
srcNodePos, dstNodePos, pathLengthPos, localResultSetDescriptor->copy(),
tmpDstNodeIDPos, tmpEdgeIDPos, joinType, pathPos);
srcNodePos, dstNodePos, dstNodeTableIDs, pathLengthPos,
localResultSetDescriptor->copy(), recursiveDstNodeIDPos, recursiveDstNodeTableIDs,
recursiveEdgeIDPos, joinType, pathPos);
}
};

Expand Down Expand Up @@ -160,8 +169,8 @@ class BaseRecursiveJoin : public PhysicalOperator {
common::ValueVector* pathVector;

// temporary recursive join result.
common::ValueVector* tmpEdgeIDVector;
common::ValueVector* tmpDstNodeIDVector;
common::ValueVector* recursiveEdgeIDVector;
common::ValueVector* recursiveDstNodeIDVector;

std::unique_ptr<BaseBFSState> bfsState;
std::unique_ptr<FrontiersScanner> frontiersScanner;
Expand Down
22 changes: 11 additions & 11 deletions src/processor/mapper/map_extend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,17 +166,17 @@ std::unique_ptr<PhysicalOperator> PlanMapper::mapLogicalRecursiveExtendToPhysica
auto recursivePlanSchema = logicalRecursiveRoot->getSchema();
auto recursivePlanResultSetDescriptor =
std::make_unique<ResultSetDescriptor>(recursivePlanSchema);
auto tmpDstNodeIDPos =
auto recursiveDstNodeIDPos =
DataPos(recursivePlanSchema->getExpressionPos(*recursiveNode->getInternalIDProperty()));
auto tmpEdgeIDPos =
auto recursiveEdgeIDPos =
DataPos(recursivePlanSchema->getExpressionPos(*rel->getInternalIDProperty()));
// map child plan
auto outSchema = extend->getSchema();
auto inSchema = extend->getChild(0)->getSchema();
auto prevOperator = mapLogicalOperatorToPhysical(logicalOperator->getChild(0));
auto inNodeIDVectorPos =
auto boundNodeIDVectorPos =
DataPos(inSchema->getExpressionPos(*boundNode->getInternalIDProperty()));
auto outNodeIDVectorPos =
auto nbrNodeIDVectorPos =
DataPos(outSchema->getExpressionPos(*nbrNode->getInternalIDProperty()));
auto lengthVectorPos = DataPos(outSchema->getExpressionPos(*lengthExpression));
auto expressions = inSchema->getExpressionsInScope();
Expand All @@ -188,22 +188,22 @@ std::unique_ptr<PhysicalOperator> PlanMapper::mapLogicalRecursiveExtendToPhysica
outDataPoses.emplace_back(outSchema->getExpressionPos(*expressions[i]));
colIndicesToScan.push_back(i);
}

// Generate RecursiveJoinDataInfo
std::unique_ptr<RecursiveJoinDataInfo> dataInfo;
switch (extend->getJoinType()) {
case planner::RecursiveJoinType::TRACK_PATH: {
auto pathVectorPos = DataPos(outSchema->getExpressionPos(*rel));
dataInfo = std::make_unique<RecursiveJoinDataInfo>(outDataPoses, colIndicesToScan,
inNodeIDVectorPos, outNodeIDVectorPos, lengthVectorPos,
std::move(recursivePlanResultSetDescriptor), tmpDstNodeIDPos, tmpEdgeIDPos,
extend->getJoinType(), pathVectorPos);
boundNodeIDVectorPos, nbrNodeIDVectorPos, nbrNode->getTableIDsSet(), lengthVectorPos,
std::move(recursivePlanResultSetDescriptor), recursiveDstNodeIDPos,
recursiveNode->getTableIDsSet(), recursiveEdgeIDPos, extend->getJoinType(),
pathVectorPos);
} break;
case planner::RecursiveJoinType::TRACK_NONE: {
dataInfo = std::make_unique<RecursiveJoinDataInfo>(outDataPoses, colIndicesToScan,
inNodeIDVectorPos, outNodeIDVectorPos, lengthVectorPos,
std::move(recursivePlanResultSetDescriptor), tmpDstNodeIDPos, tmpEdgeIDPos,
extend->getJoinType());
boundNodeIDVectorPos, nbrNodeIDVectorPos, nbrNode->getTableIDsSet(), lengthVectorPos,
std::move(recursivePlanResultSetDescriptor), recursiveDstNodeIDPos,
recursiveNode->getTableIDsSet(), recursiveEdgeIDPos, extend->getJoinType());
} break;
default:
throw common::NotImplementedException("PlanMapper::mapLogicalRecursiveExtendToPhysical");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ size_t BaseFrontierScanner::scan(common::ValueVector* pathVector,
currentDstNodeID = lastFrontier->nodeIDs[lastFrontierCursor++];
// Skip nodes that is not in semi mask.
if (!targetDstNodes->contains(currentDstNodeID)) {
currentDstNodeID.offset = common::INVALID_OFFSET;
continue;
}
initScanFromDstOffset();
Expand Down
19 changes: 13 additions & 6 deletions src/processor/operator/recursive_extend/recursive_join.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,10 @@ void BaseRecursiveJoin::computeBFS(ExecutionContext* context) {

void BaseRecursiveJoin::updateVisitedNodes(common::nodeID_t boundNodeID) {
auto boundNodeMultiplicity = bfsState->getMultiplicity(boundNodeID);
for (auto i = 0u; i < tmpDstNodeIDVector->state->selVector->selectedSize; ++i) {
auto pos = tmpDstNodeIDVector->state->selVector->selectedPositions[i];
auto nbrNodeID = tmpDstNodeIDVector->getValue<common::nodeID_t>(pos);
auto edgeID = tmpEdgeIDVector->getValue<common::relID_t>(pos);
for (auto i = 0u; i < recursiveDstNodeIDVector->state->selVector->selectedSize; ++i) {
auto pos = recursiveDstNodeIDVector->state->selVector->selectedPositions[i];
auto nbrNodeID = recursiveDstNodeIDVector->getValue<common::nodeID_t>(pos);
auto edgeID = recursiveEdgeIDVector->getValue<common::relID_t>(pos);
bfsState->markVisited(boundNodeID, nbrNodeID, edgeID, boundNodeMultiplicity);
}
}
Expand All @@ -115,8 +115,9 @@ void BaseRecursiveJoin::initLocalRecursivePlan(ExecutionContext* context) {
scanFrontier = (ScanFrontier*)op;
localResultSet = std::make_unique<ResultSet>(
dataInfo->localResultSetDescriptor.get(), context->memoryManager);
tmpDstNodeIDVector = localResultSet->getValueVector(dataInfo->tmpDstNodeIDPos).get();
tmpEdgeIDVector = localResultSet->getValueVector(dataInfo->tmpEdgeIDPos).get();
recursiveDstNodeIDVector =
localResultSet->getValueVector(dataInfo->recursiveDstNodeIDPos).get();
recursiveEdgeIDVector = localResultSet->getValueVector(dataInfo->recursiveEdgeIDPos).get();
recursiveRoot->initLocalState(localResultSet.get(), context);
}

Expand All @@ -139,6 +140,12 @@ void BaseRecursiveJoin::populateTargetDstNodes() {
}
}
targetDstNodes = std::make_unique<TargetDstNodes>(numTargetNodes, std::move(targetNodeIDs));
for (auto tableID : dataInfo->recursiveDstNodeTableIDs) {
if (!dataInfo->dstNodeTableIDs.contains(tableID)) {
targetDstNodes->setTableIDFilter(dataInfo->dstNodeTableIDs);
return;
}
}
}

} // namespace processor
Expand Down
3 changes: 2 additions & 1 deletion src/storage/storage_structure/in_mem_file.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,8 @@ common::ku_list_t InMemOverflowFile::appendList(common::LogicalType& type,
for (auto offset = startOffset; offset < endOffset; offset++) {
auto stringView = stringArray.GetView(offset);
childStrings[offset - startOffset] = copyString(stringView.data(),
std::min(BufferPoolConstants::PAGE_4KB_SIZE, (uint64_t)stringView.length()),
std::min<uint64_t>(
BufferPoolConstants::PAGE_4KB_SIZE, (uint64_t)stringView.length()),
overflowCursor);
}
lock.lock();
Expand Down
9 changes: 7 additions & 2 deletions test/test_files/tinysnb/var_length_extend/multi_label.test
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,21 @@
4|[0:0,3:1,0:2,5:0,1:1]
6|[0:0,3:2,0:3,5:1,1:2]

-NAME MixMultiLabelTest
-NAME MixMultiLabelTest2
-QUERY MATCH (a:person)-[e:meets|:marries|:studyAt*2..2]->(b) WHERE a.fName = 'Alice' RETURN b.ID, e
---- 4
1|[0:0,6:0,0:1,4:1,1:0]
1|[0:0,7:0,0:1,4:1,1:0]
5|[0:0,6:0,0:1,6:1,0:3]
5|[0:0,7:0,0:1,6:1,0:3]

-NAME MixMultiLabelTest
-NAME MixMultiLabelTest3
-QUERY MATCH (a:person)-[e:meets|:marries|:studyAt*2..2]->(b) WHERE a.fName = 'Alice' AND b.ID < 5 RETURN COUNT(*)
-ENUMERATE
---- 1
2

-NAME MixMultiLabelTest4
-QUERY MATCH (a:person)-[e*2..2]->(b:organisation) WHERE a.fName = 'Alice' RETURN COUNT(*)
---- 1
5