Skip to content

Commit

Permalink
Merge pull request #1500 from kuzudb/ssp-sink
Browse files Browse the repository at this point in the history
Add sink before recursive join
  • Loading branch information
andyfengHKU committed Apr 27, 2023
2 parents d9695fd + ddfdcde commit 4e5f17d
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 18 deletions.
19 changes: 14 additions & 5 deletions src/include/processor/operator/var_length_extend/recursive_join.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,18 @@ struct BFSScanState {
class RecursiveJoin : public PhysicalOperator {
public:
RecursiveJoin(uint8_t upperBound, storage::NodeTable* nodeTable,
std::shared_ptr<FTableSharedState> inputFTableSharedState,
std::vector<DataPos> vectorsToScanPos, std::vector<ft_col_idx_t> colIndicesToScan,
const DataPos& srcNodeIDVectorPos, const DataPos& dstNodeIDVectorPos,
const DataPos& distanceVectorPos, std::unique_ptr<PhysicalOperator> child, uint32_t id,
const std::string& paramsString, std::unique_ptr<PhysicalOperator> root)
: PhysicalOperator{PhysicalOperatorType::SCAN_BFS_LEVEL, std::move(child), id,
paramsString},
upperBound{upperBound}, nodeTable{nodeTable}, srcNodeIDVectorPos{srcNodeIDVectorPos},
dstNodeIDVectorPos{dstNodeIDVectorPos},
upperBound{upperBound}, nodeTable{nodeTable}, inputFTableSharedState{std::move(
inputFTableSharedState)},
vectorsToScanPos{std::move(vectorsToScanPos)}, colIndicesToScan{std::move(
colIndicesToScan)},
srcNodeIDVectorPos{srcNodeIDVectorPos}, dstNodeIDVectorPos{dstNodeIDVectorPos},
distanceVectorPos{distanceVectorPos}, root{std::move(root)}, bfsScanState{} {}

static inline DataPos getTmpSrcNodeVectorPos() { return DataPos{0, 0}; }
Expand All @@ -66,9 +71,9 @@ class RecursiveJoin : public PhysicalOperator {
bool getNextTuplesInternal(ExecutionContext* context) override;

std::unique_ptr<PhysicalOperator> clone() override {
return std::make_unique<RecursiveJoin>(upperBound, nodeTable, srcNodeIDVectorPos,
dstNodeIDVectorPos, distanceVectorPos, children[0]->clone(), id, paramsString,
root->clone());
return std::make_unique<RecursiveJoin>(upperBound, nodeTable, inputFTableSharedState,
vectorsToScanPos, colIndicesToScan, srcNodeIDVectorPos, dstNodeIDVectorPos,
distanceVectorPos, children[0]->clone(), id, paramsString, root->clone());
}

private:
Expand All @@ -85,6 +90,9 @@ class RecursiveJoin : public PhysicalOperator {
private:
uint8_t upperBound;
storage::NodeTable* nodeTable;
std::shared_ptr<FTableSharedState> inputFTableSharedState;
std::vector<DataPos> vectorsToScanPos;
std::vector<ft_col_idx_t> colIndicesToScan;
DataPos srcNodeIDVectorPos;
DataPos dstNodeIDVectorPos;
DataPos distanceVectorPos;
Expand All @@ -97,6 +105,7 @@ class RecursiveJoin : public PhysicalOperator {
std::unique_ptr<BFSMorsel> bfsMorsel;

common::offset_t maxNodeOffset;
std::vector<common::ValueVector*> vectorsToScan;
std::shared_ptr<common::ValueVector> srcNodeIDVector;
std::shared_ptr<common::ValueVector> dstNodeIDVector;
std::shared_ptr<common::ValueVector> distanceVector;
Expand Down
9 changes: 1 addition & 8 deletions src/planner/join_order/cardinality_estimator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,14 +128,7 @@ double CardinalityEstimator::getExtensionRate(
}
case common::QueryRelType::VARIABLE_LENGTH:
case common::QueryRelType::SHORTEST: {
auto extensionRate = oneHopExtensionRate;
for (auto i = 0u; i < rel.getUpperBound(); ++i) {
extensionRate *= oneHopExtensionRate;
if (extensionRate > numRels) { // extension rate in bounded by numRels under BFS.
return numRels;
}
}
return extensionRate;
return oneHopExtensionRate * 2 /*magic number*/;
}
default:
throw common::NotImplementedException("getExtensionRate()");
Expand Down
4 changes: 3 additions & 1 deletion src/planner/operator/logical_recursive_extend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ void LogicalVariableLengthExtend::computeFactorizedSchema() {
}

void LogicalShortestPathExtend::computeFactorizedSchema() {
copyChildSchema(0);
createEmptySchema();
auto childSchema = children[0]->getSchema();
SinkOperatorUtil::recomputeSchema(*childSchema, childSchema->getExpressionsInScope(), *schema);
auto nbrGroupPos = schema->createGroup();
schema->insertToGroupAndScope(nbrNode->getInternalIDProperty(), nbrGroupPos);
schema->insertToGroupAndScope(rel->getInternalLengthProperty(), nbrGroupPos);
Expand Down
15 changes: 13 additions & 2 deletions src/processor/mapper/map_extend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,16 @@ std::unique_ptr<PhysicalOperator> PlanMapper::mapLogicalRecursiveExtendToPhysica
}
} else {
assert(rel->getRelType() == common::QueryRelType::SHORTEST);
auto expressions = inSchema->getExpressionsInScope();
auto resultCollector =
appendResultCollector(expressions, *inSchema, std::move(prevOperator));
auto sharedInputFTable = resultCollector->getSharedState();
std::vector<DataPos> outDataPoses;
std::vector<uint32_t> colIndicesToScan;
for (auto i = 0u; i < expressions.size(); ++i) {
outDataPoses.emplace_back(outSchema->getExpressionPos(*expressions[i]));
colIndicesToScan.push_back(i);
}
auto upperBound = rel->getUpperBound();
auto& nodeStore = storageManager.getNodesStore();
auto nodeTable = nodeStore.getNodeTable(boundNode->getSingleTableID());
Expand All @@ -161,8 +171,9 @@ std::unique_ptr<PhysicalOperator> PlanMapper::mapLogicalRecursiveExtendToPhysica
emptyPropertyIDs, tmpSrcNodePos, std::vector<DataPos>{tmpDstNodePos},
std::move(scanFrontier), getOperatorID(), emptyParamString);
}
return std::make_unique<RecursiveJoin>(upperBound, nodeTable, inNodeIDVectorPos,
outNodeIDVectorPos, distanceVectorPos, std::move(prevOperator), getOperatorID(),
return std::make_unique<RecursiveJoin>(upperBound, nodeTable, sharedInputFTable,
outDataPoses, colIndicesToScan, inNodeIDVectorPos, outNodeIDVectorPos,
distanceVectorPos, std::move(resultCollector), getOperatorID(),
extend->getExpressionsForPrinting(), std::move(scanRelTable));
}
}
Expand Down
9 changes: 7 additions & 2 deletions src/processor/operator/var_length_extend/recursive_join.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ bool ScanFrontier::getNextTuplesInternal(ExecutionContext* context) {

void RecursiveJoin::initLocalStateInternal(ResultSet* resultSet_, ExecutionContext* context) {
maxNodeOffset = nodeTable->getMaxNodeOffset(transaction);
for (auto& dataPos : vectorsToScanPos) {
vectorsToScan.push_back(resultSet->getValueVector(dataPos).get());
}
srcNodeIDVector = resultSet->getValueVector(srcNodeIDVectorPos);
dstNodeIDVector = resultSet->getValueVector(dstNodeIDVectorPos);
distanceVector = resultSet->getValueVector(distanceVectorPos);
Expand Down Expand Up @@ -48,11 +51,13 @@ bool RecursiveJoin::getNextTuplesInternal(ExecutionContext* context) {
}

bool RecursiveJoin::computeBFS(ExecutionContext* context) {
if (!children[0]->getNextTuple(context)) {
auto inputFTableMorsel = inputFTableSharedState->getMorsel(1);
if (inputFTableMorsel->numTuples == 0) {
return false;
}
inputFTableSharedState->getTable()->scan(vectorsToScan, inputFTableMorsel->startTupleIdx,
inputFTableMorsel->numTuples, colIndicesToScan);
bfsMorsel->resetState();
assert(srcNodeIDVector->state->isFlat());
auto nodeID = srcNodeIDVector->getValue<common::nodeID_t>(
srcNodeIDVector->state->selVector->selectedPositions[0]);
bfsMorsel->markSrc(nodeID.offset);
Expand Down

0 comments on commit 4e5f17d

Please sign in to comment.