Skip to content

Commit

Permalink
Add multi-scan asp
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed Mar 15, 2023
1 parent e3961e6 commit 01aa911
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 70 deletions.
5 changes: 3 additions & 2 deletions src/include/optimizer/asp_optimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@ class ASPOptimizer : public LogicalOperatorVisitor {

bool isProbeSideQualified(planner::LogicalOperator* probeRoot);

std::vector<planner::LogicalOperator*> resolveScanNodesToApplySemiMask(
binder::expression_map<std::vector<planner::LogicalOperator*>> resolveScanNodesToApplySemiMask(
const binder::expression_vector& nodeIDCandidates,
const std::vector<planner::LogicalOperator*>& buildRoots);

void applyASP(
const std::vector<planner::LogicalOperator*>& scanNodes, planner::LogicalOperator* op);
const binder::expression_map<std::vector<planner::LogicalOperator*>>& nodeIDToScanNodes,
planner::LogicalOperator* op);
};

} // namespace optimizer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,29 @@ namespace planner {

class LogicalSemiMasker : public LogicalOperator {
public:
LogicalSemiMasker(LogicalScanNode* scanNode, std::shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::SEMI_MASKER, std::move(child)}, scanNode{scanNode} {}
LogicalSemiMasker(std::shared_ptr<binder::Expression> nodeID,
std::vector<LogicalOperator*> scanNodes, std::shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::SEMI_MASKER, std::move(child)},
nodeID{std::move(nodeID)}, scanNodes{std::move(scanNodes)} {}

inline void computeFactorizedSchema() override { copyChildSchema(0); }
inline void computeFlatSchema() override { copyChildSchema(0); }

inline std::string getExpressionsForPrinting() const override {
return scanNode->getNode()->toString();
}
inline std::string getExpressionsForPrinting() const override { return nodeID->toString(); }

inline LogicalScanNode* getScanNode() const { return scanNode; }
inline std::shared_ptr<binder::Expression> getNodeID() const { return nodeID; }
inline bool isMultiLabel() const {
return ((LogicalScanNode*)scanNodes[0])->getNode()->isMultiLabeled();
}
inline std::vector<LogicalOperator*> getScanNodes() const { return scanNodes; }

inline std::unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalSemiMasker>(scanNode, children[0]->copy());
return make_unique<LogicalSemiMasker>(nodeID, scanNodes, children[0]->copy());
}

private:
LogicalScanNode* scanNode;
std::shared_ptr<binder::Expression> nodeID;
std::vector<LogicalOperator*> scanNodes;
};

} // namespace planner
Expand Down
37 changes: 19 additions & 18 deletions src/include/processor/operator/semi_masker.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,66 +6,67 @@
namespace kuzu {
namespace processor {

// Multiple maskers can point to the same SemiMask, thus we associate each masker with an idx
// to indicate the execution sequence of its pipeline. Also, the maskerIdx is used as a flag to
// indicate if a value in the mask is masked or not, as each masker will increment the selected
// value in the mask by 1. More details are described in NodeTableSemiMask.
using mask_and_idx_pair = std::pair<NodeTableSemiMask*, uint8_t>;

class BaseSemiMasker : public PhysicalOperator {
protected:
BaseSemiMasker(const DataPos& keyDataPos, ScanNodeIDSharedState* scanNodeIDSharedState,
BaseSemiMasker(const DataPos& keyDataPos, std::vector<ScanNodeIDSharedState*> scanStates,
std::unique_ptr<PhysicalOperator> child, uint32_t id, const std::string& paramsString)
: PhysicalOperator{PhysicalOperatorType::SEMI_MASKER, std::move(child), id, paramsString},
keyDataPos{keyDataPos}, scanNodeIDSharedState{scanNodeIDSharedState} {}
keyDataPos{keyDataPos}, scanStates{std::move(scanStates)} {}

void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override;

protected:
DataPos keyDataPos;
ScanNodeIDSharedState* scanNodeIDSharedState;
std::vector<ScanNodeIDSharedState*> scanStates;
std::shared_ptr<common::ValueVector> keyValueVector;
};

class SingleTableSemiMasker : public BaseSemiMasker {
public:
SingleTableSemiMasker(const DataPos& keyDataPos, ScanNodeIDSharedState* scanNodeIDSharedState,
SingleTableSemiMasker(const DataPos& keyDataPos, std::vector<ScanNodeIDSharedState*> scanStates,
std::unique_ptr<PhysicalOperator> child, uint32_t id, const std::string& paramsString)
: BaseSemiMasker{keyDataPos, scanNodeIDSharedState, std::move(child), id, paramsString} {}
: BaseSemiMasker{keyDataPos, std::move(scanStates), std::move(child), id, paramsString} {}

void initGlobalStateInternal(kuzu::processor::ExecutionContext* context) override;

bool getNextTuplesInternal() override;

inline std::unique_ptr<PhysicalOperator> clone() override {
auto result = std::make_unique<SingleTableSemiMasker>(
keyDataPos, scanNodeIDSharedState, children[0]->clone(), id, paramsString);
result->maskerIdxAndMask = maskerIdxAndMask;
keyDataPos, scanStates, children[0]->clone(), id, paramsString);
result->maskPerScan = maskPerScan;
return result;
}

private:
// Multiple maskers can point to the same SemiMask, thus we associate each masker with an idx
// to indicate the execution sequence of its pipeline. Also, the maskerIdx is used as a flag to
// indicate if a value in the mask is masked or not, as each masker will increment the selected
// value in the mask by 1. More details are described in NodeTableSemiMask.
std::pair<uint8_t, NodeTableSemiMask*> maskerIdxAndMask;
std::vector<mask_and_idx_pair> maskPerScan;
};

class MultiTableSemiMasker : public BaseSemiMasker {
public:
MultiTableSemiMasker(const DataPos& keyDataPos, ScanNodeIDSharedState* scanNodeIDSharedState,
MultiTableSemiMasker(const DataPos& keyDataPos, std::vector<ScanNodeIDSharedState*> scanStates,
std::unique_ptr<PhysicalOperator> child, uint32_t id, const std::string& paramsString)
: BaseSemiMasker{keyDataPos, scanNodeIDSharedState, std::move(child), id, paramsString} {}
: BaseSemiMasker{keyDataPos, std::move(scanStates), std::move(child), id, paramsString} {}

void initGlobalStateInternal(kuzu::processor::ExecutionContext* context) override;

bool getNextTuplesInternal() override;

inline std::unique_ptr<PhysicalOperator> clone() override {
auto result = std::make_unique<MultiTableSemiMasker>(
keyDataPos, scanNodeIDSharedState, children[0]->clone(), id, paramsString);
result->maskerIdxAndMasks = maskerIdxAndMasks;
keyDataPos, scanStates, children[0]->clone(), id, paramsString);
result->maskerPerLabelPerScan = maskerPerLabelPerScan;
return result;
}

private:
std::unordered_map<common::table_id_t, std::pair<uint8_t, NodeTableSemiMask*>>
maskerIdxAndMasks;
std::vector<std::unordered_map<common::table_id_t, mask_and_idx_pair>> maskerPerLabelPerScan;
};

} // namespace processor
Expand Down
20 changes: 8 additions & 12 deletions src/optimizer/asp_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ bool ASPOptimizer::isProbeSideQualified(planner::LogicalOperator* probeRoot) {
return true;
}

std::vector<planner::LogicalOperator*> ASPOptimizer::resolveScanNodesToApplySemiMask(
const binder::expression_vector& nodeIDCandidates,
binder::expression_map<std::vector<planner::LogicalOperator*>>
ASPOptimizer::resolveScanNodesToApplySemiMask(const binder::expression_vector& nodeIDCandidates,
const std::vector<planner::LogicalOperator*>& buildRoots) {
binder::expression_map<std::vector<LogicalOperator*>> nodeIDToScanOperatorsMap;
for (auto& buildRoot : buildRoots) {
Expand All @@ -89,27 +89,23 @@ std::vector<planner::LogicalOperator*> ASPOptimizer::resolveScanNodesToApplySemi
}
}
// Match node ID candidate with scanNode operators.
std::vector<LogicalOperator*> result;
binder::expression_map<std::vector<planner::LogicalOperator*>> result;
for (auto& nodeID : nodeIDCandidates) {
if (!nodeIDToScanOperatorsMap.contains(nodeID)) {
// No scan on the build side to push semi mask to.
continue;
}
if (nodeIDToScanOperatorsMap.at(nodeID).size() > 1) {
// We don't push semi mask to multiple scans. This can be solved.
continue;
}
result.push_back(nodeIDToScanOperatorsMap.at(nodeID)[0]);
result.insert({nodeID, nodeIDToScanOperatorsMap.at(nodeID)});
}
return result;
}

void ASPOptimizer::applyASP(
const std::vector<planner::LogicalOperator*>& scanNodes, planner::LogicalOperator* op) {
const binder::expression_map<std::vector<planner::LogicalOperator*>>& nodeIDToScanNodes,
planner::LogicalOperator* op) {
auto currentChild = op->getChild(0);
for (auto& op_ : scanNodes) {
auto scanNode = (LogicalScanNode*)op_;
auto semiMasker = std::make_shared<LogicalSemiMasker>(scanNode, currentChild);
for (auto& [nodeID, scanNodes] : nodeIDToScanNodes) {
auto semiMasker = std::make_shared<LogicalSemiMasker>(nodeID, scanNodes, currentChild);
semiMasker->computeFlatSchema();
currentChild = semiMasker;
}
Expand Down
20 changes: 11 additions & 9 deletions src/processor/mapper/map_asp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,19 @@ std::unique_ptr<PhysicalOperator> PlanMapper::mapLogicalSemiMaskerToPhysical(
auto logicalSemiMasker = (LogicalSemiMasker*)logicalOperator;
auto inSchema = logicalSemiMasker->getChild(0)->getSchema();
auto prevOperator = mapLogicalOperatorToPhysical(logicalOperator->getChild(0));
auto logicalScanNode = logicalSemiMasker->getScanNode();
auto physicalScanNode = (ScanNodeID*)logicalOpToPhysicalOpMap.at(logicalScanNode);
auto keyDataPos =
DataPos(inSchema->getExpressionPos(*logicalScanNode->getNode()->getInternalIDProperty()));
if (physicalScanNode->getSharedState()->getNumTableStates() > 1) {
return std::make_unique<MultiTableSemiMasker>(keyDataPos,
physicalScanNode->getSharedState(), std::move(prevOperator), getOperatorID(),
std::vector<ScanNodeIDSharedState*> scanStates;
for (auto& op : logicalSemiMasker->getScanNodes()) {
auto physicalScanNode = (ScanNodeID*)logicalOpToPhysicalOpMap.at(op);
scanStates.push_back(physicalScanNode->getSharedState());
}
auto keyDataPos = DataPos(inSchema->getExpressionPos(*logicalSemiMasker->getNodeID()));
if (logicalSemiMasker->isMultiLabel()) {
return std::make_unique<MultiTableSemiMasker>(keyDataPos, std::move(scanStates),
std::move(prevOperator), getOperatorID(),
logicalSemiMasker->getExpressionsForPrinting());
} else {
return std::make_unique<SingleTableSemiMasker>(keyDataPos,
physicalScanNode->getSharedState(), std::move(prevOperator), getOperatorID(),
return std::make_unique<SingleTableSemiMasker>(keyDataPos, std::move(scanStates),
std::move(prevOperator), getOperatorID(),
logicalSemiMasker->getExpressionsForPrinting());
}
}
Expand Down
50 changes: 29 additions & 21 deletions src/processor/operator/semi_masker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,44 +10,50 @@ void BaseSemiMasker::initLocalStateInternal(ResultSet* resultSet, ExecutionConte
assert(keyValueVector->dataType.typeID == INTERNAL_ID);
}

static std::pair<uint8_t, NodeTableSemiMask*> initSemiMaskForTableState(
static mask_and_idx_pair initSemiMaskForTableState(
NodeTableState* tableState, transaction::Transaction* trx) {
tableState->initSemiMask(trx);
auto maskerIdx = tableState->getNumMaskers();
assert(maskerIdx < UINT8_MAX);
tableState->incrementNumMaskers();
return std::make_pair(maskerIdx, tableState->getSemiMask());
return std::make_pair(tableState->getSemiMask(), maskerIdx);
}

void SingleTableSemiMasker::initGlobalStateInternal(kuzu::processor::ExecutionContext* context) {
assert(scanNodeIDSharedState->getNumTableStates() == 1);
auto tableState = scanNodeIDSharedState->getTableState(0);
maskerIdxAndMask = initSemiMaskForTableState(tableState, context->transaction);
for (auto& scanState : scanStates) {
assert(scanState->getNumTableStates() == 1);
auto tableState = scanState->getTableState(0);
maskPerScan.push_back(initSemiMaskForTableState(tableState, context->transaction));
}
}

bool SingleTableSemiMasker::getNextTuplesInternal() {
if (!children[0]->getNextTuple()) {
return false;
}
auto [maskerIdx, mask] = maskerIdxAndMask;
auto numValues =
keyValueVector->state->isFlat() ? 1 : keyValueVector->state->selVector->selectedSize;
for (auto i = 0u; i < numValues; i++) {
auto pos = keyValueVector->state->selVector->selectedPositions[i];
auto nodeID = keyValueVector->getValue<nodeID_t>(pos);
mask->incrementMaskValue(nodeID.offset, maskerIdx);
for (auto [mask, maskerIdx] : maskPerScan) {
for (auto i = 0u; i < numValues; i++) {
auto pos = keyValueVector->state->selVector->selectedPositions[i];
auto nodeID = keyValueVector->getValue<nodeID_t>(pos);
mask->incrementMaskValue(nodeID.offset, maskerIdx);
}
}
metrics->numOutputTuple.increase(numValues);
return true;
}

void MultiTableSemiMasker::initGlobalStateInternal(kuzu::processor::ExecutionContext* context) {
assert(scanNodeIDSharedState->getNumTableStates() > 1);
for (auto i = 0u; i < scanNodeIDSharedState->getNumTableStates(); ++i) {
auto tableState = scanNodeIDSharedState->getTableState(i);
auto maskerIdxAndMask = initSemiMaskForTableState(tableState, context->transaction);
maskerIdxAndMasks.insert(
{tableState->getTable()->getTableID(), std::move(maskerIdxAndMask)});
for (auto& scanState : scanStates) {
assert(scanState->getNumTableStates() > 1);
std::unordered_map<common::table_id_t, mask_and_idx_pair> maskerPerLabel;
for (auto i = 0u; i < scanState->getNumTableStates(); ++i) {
auto tableState = scanState->getTableState(i);
maskerPerLabel.insert({tableState->getTable()->getTableID(),
initSemiMaskForTableState(tableState, context->transaction)});
}
maskerPerLabelPerScan.push_back(std::move(maskerPerLabel));
}
}

Expand All @@ -57,11 +63,13 @@ bool MultiTableSemiMasker::getNextTuplesInternal() {
}
auto numValues =
keyValueVector->state->isFlat() ? 1 : keyValueVector->state->selVector->selectedSize;
for (auto i = 0u; i < numValues; i++) {
auto pos = keyValueVector->state->selVector->selectedPositions[i];
auto nodeID = keyValueVector->getValue<nodeID_t>(pos);
auto [maskerIdx, mask] = maskerIdxAndMasks.at(nodeID.tableID);
mask->incrementMaskValue(nodeID.offset, maskerIdx);
for (auto& maskerPerLabel : maskerPerLabelPerScan) {
for (auto i = 0u; i < numValues; i++) {
auto pos = keyValueVector->state->selVector->selectedPositions[i];
auto nodeID = keyValueVector->getValue<nodeID_t>(pos);
auto [mask, maskerIdx] = maskerPerLabel.at(nodeID.tableID);
mask->incrementMaskValue(nodeID.offset, maskerIdx);
}
}
metrics->numOutputTuple.increase(numValues);
return true;
Expand Down
7 changes: 7 additions & 0 deletions test/test_files/tinysnb/asp/asp.test
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ Dan|Carol
Carol
Dan

-NAME AspOneMaskToMultiScan
-QUERY MATCH (a:person)-[e1:knows]->(b:person), (a)-[e2:knows]->(b), (a)-[e3:knows]->(c:person)-[e4:knows]->(d:person) WHERE e1.date=date('1950-05-14') AND a.ID>0 AND c.fName='Bob' AND d.fName='Carol' RETURN a.fName, b.fName
-ENUMERATE
---- 2
Carol|Bob
Dan|Bob

-NAME AspIntersect
-QUERY MATCH (a:person)<-[e1:knows]-(b:person)-[e2:knows]->(c:person), (a)-[e3:knows]->(c) WHERE b.fName='Bob' AND a.fName='Alice' RETURN COUNT(*)
-ENUMERATE
Expand Down

0 comments on commit 01aa911

Please sign in to comment.