Skip to content

Commit

Permalink
Merge pull request #1364 from kuzudb/multi-key-asp
Browse files Browse the repository at this point in the history
Multi key asp
  • Loading branch information
andyfengHKU committed Mar 10, 2023
2 parents 247748c + 86cbad2 commit 5978d3b
Show file tree
Hide file tree
Showing 15 changed files with 147 additions and 123 deletions.
4 changes: 4 additions & 0 deletions src/include/binder/expression/expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <cassert>
#include <functional>
#include <memory>
#include <unordered_map>
#include <unordered_set>

#include "common/exception.h"
Expand All @@ -20,6 +21,9 @@ struct ExpressionHasher;
struct ExpressionEquality;
using expression_set =
std::unordered_set<std::shared_ptr<Expression>, ExpressionHasher, ExpressionEquality>;
template<typename T>
using expression_map =
std::unordered_map<std::shared_ptr<Expression>, T, ExpressionHasher, ExpressionEquality>;

class Expression : public std::enable_shared_from_this<Expression> {
public:
Expand Down
5 changes: 3 additions & 2 deletions src/include/optimizer/asp_optimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
namespace kuzu {
namespace optimizer {

// This optimizer implements the ASP join algorithm as introduced in paper "Kùzu Graph Database
// This optimizer enables the ASP join algorithm as introduced in paper "Kuzu Graph Database
// Management System".
class ASPOptimizer : public LogicalOperatorVisitor {
public:
Expand All @@ -15,7 +15,8 @@ class ASPOptimizer : public LogicalOperatorVisitor {

void visitHashJoin(planner::LogicalOperator* op) override;

bool canApplyASP(planner::LogicalOperator* op);
std::vector<planner::LogicalOperator*> resolveScanNodesToApplySemiMask(
planner::LogicalOperator* op);
};

} // namespace optimizer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,31 @@

#include "base_logical_operator.h"
#include "binder/expression/node_expression.h"
#include "logical_scan_node.h"

namespace kuzu {
namespace planner {

class LogicalSemiMasker : public LogicalOperator {
public:
LogicalSemiMasker(
std::shared_ptr<binder::Expression> nodeID, std::shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::SEMI_MASKER, std::move(child)}, nodeID{std::move(
nodeID)} {}
LogicalSemiMasker(LogicalScanNode* scanNode, std::shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::SEMI_MASKER, std::move(child)}, scanNode{scanNode} {}

inline void computeFactorizedSchema() override { copyChildSchema(0); }
inline void computeFlatSchema() override { copyChildSchema(0); }

inline std::string getExpressionsForPrinting() const override { return nodeID->toString(); }
inline std::string getExpressionsForPrinting() const override {
return scanNode->getNode()->toString();
}

inline std::shared_ptr<binder::Expression> getNodeID() const { return nodeID; }
inline LogicalScanNode* getScanNode() const { return scanNode; }

inline std::unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalSemiMasker>(nodeID, children[0]->copy());
return make_unique<LogicalSemiMasker>(scanNode, children[0]->copy());
}

private:
std::shared_ptr<binder::Expression> nodeID;
LogicalScanNode* scanNode;
};

} // namespace planner
Expand Down
1 change: 1 addition & 0 deletions src/include/processor/mapper/plan_mapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ class PlanMapper {
catalog::Catalog* catalog;

private:
std::unordered_map<planner::LogicalOperator*, PhysicalOperator*> logicalOpToPhysicalOpMap;
uint32_t physicalOperatorID;
};

Expand Down
28 changes: 16 additions & 12 deletions src/include/processor/operator/scan_node_id.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,13 @@ struct ScanNodeIDSemiMask {
explicit ScanNodeIDSemiMask() : numMaskers{0} {}

inline void initializeMaskData(common::offset_t maxNodeOffset, common::offset_t maxMorselIdx) {
if (nodeMask == nullptr) {
assert(morselMask == nullptr);
nodeMask = std::make_unique<Mask>(maxNodeOffset + 1);
morselMask = std::make_unique<Mask>(maxMorselIdx + 1);
if (nodeMask != nullptr) {
// Multiple semi mask might be applied to the same sacn and thus initialize repeatedly.
return;
}
assert(morselMask == nullptr && maxNodeOffset != common::INVALID_NODE_OFFSET);
nodeMask = std::make_unique<Mask>(maxNodeOffset + 1);
morselMask = std::make_unique<Mask>(maxMorselIdx + 1);
}

inline bool isMorselMasked(uint64_t morselIdx) {
Expand Down Expand Up @@ -67,12 +69,17 @@ class ScanTableNodeIDSharedState {
inline storage::NodeTable* getTable() { return table; }

inline void initializeMaxOffset(transaction::Transaction* transaction) {
assert(maxNodeOffset == UINT64_MAX && maxMorselIdx == UINT64_MAX);
if (maxNodeOffset != common::INVALID_NODE_OFFSET) {
// We might initialize twice because semi mask (which is on a different pipeline that
// execute beforehand) will also try to initialize.
return;
}
maxNodeOffset = table->getMaxNodeOffset(transaction);
maxMorselIdx = maxNodeOffset >> common::DEFAULT_VECTOR_CAPACITY_LOG_2;
}

inline void initSemiMask(transaction::Transaction* transaction) {
initializeMaxOffset(transaction);
semiMask->initializeMaskData(maxNodeOffset, maxMorselIdx);
}
inline bool isSemiMaskEnabled() { return semiMask->getNumMaskers() > 0; }
Expand Down Expand Up @@ -119,23 +126,21 @@ class ScanNodeIDSharedState {

class ScanNodeID : public PhysicalOperator {
public:
ScanNodeID(std::string nodeID, const DataPos& outDataPos,
std::shared_ptr<ScanNodeIDSharedState> sharedState, uint32_t id,
const std::string& paramsString)
ScanNodeID(const DataPos& outDataPos, std::shared_ptr<ScanNodeIDSharedState> sharedState,
uint32_t id, const std::string& paramsString)
: PhysicalOperator{PhysicalOperatorType::SCAN_NODE_ID, id, paramsString},
nodeID{std::move(nodeID)}, outDataPos{outDataPos}, sharedState{std::move(sharedState)} {}
outDataPos{outDataPos}, sharedState{std::move(sharedState)} {}

bool isSource() const override { return true; }

inline std::string getNodeID() const { return nodeID; }
inline ScanNodeIDSharedState* getSharedState() const { return sharedState.get(); }

void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override;

bool getNextTuplesInternal() override;

inline std::unique_ptr<PhysicalOperator> clone() override {
return std::make_unique<ScanNodeID>(nodeID, outDataPos, sharedState, id, paramsString);
return std::make_unique<ScanNodeID>(outDataPos, sharedState, id, paramsString);
}

private:
Expand All @@ -147,7 +152,6 @@ class ScanNodeID : public PhysicalOperator {
common::offset_t endOffset);

private:
std::string nodeID;
DataPos outDataPos;
std::shared_ptr<ScanNodeIDSharedState> sharedState;
std::shared_ptr<common::ValueVector> outValueVector;
Expand Down
15 changes: 6 additions & 9 deletions src/include/processor/operator/semi_masker.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ namespace kuzu {
namespace processor {

class SemiMasker : public PhysicalOperator {

public:
SemiMasker(const DataPos& keyDataPos, std::unique_ptr<PhysicalOperator> child, uint32_t id,
const std::string& paramsString)
Expand All @@ -20,11 +19,9 @@ class SemiMasker : public PhysicalOperator {
keyDataPos{other.keyDataPos}, maskerIdx{other.maskerIdx},
scanTableNodeIDSharedState{other.scanTableNodeIDSharedState} {}

// This function is used in the plan mapper to configure the shared state between the SemiMasker
// and ScanNodeID.
void setSharedState(ScanTableNodeIDSharedState* sharedState);

void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override;
inline void setSharedState(ScanTableNodeIDSharedState* sharedState) {
scanTableNodeIDSharedState = sharedState;
}

bool getNextTuplesInternal() override;

Expand All @@ -33,9 +30,9 @@ class SemiMasker : public PhysicalOperator {
}

private:
inline void initGlobalStateInternal(ExecutionContext* context) override {
scanTableNodeIDSharedState->initSemiMask(context->transaction);
}
void initGlobalStateInternal(ExecutionContext* context) override;

void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override;

private:
DataPos keyDataPos;
Expand Down
83 changes: 48 additions & 35 deletions src/optimizer/asp_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,54 +25,67 @@ void ASPOptimizer::visitOperator(planner::LogicalOperator* op) {

void ASPOptimizer::visitHashJoin(planner::LogicalOperator* op) {
auto hashJoin = (LogicalHashJoin*)op;
if (!canApplyASP(op)) {
if (hashJoin->getChild(0)->getOperatorType() == LogicalOperatorType::ACCUMULATE) {
// No ASP if probe side has already been accumulated. This can be solved.
return;
}
auto probeSideFilterCollector = LogicalFilterCollector();
probeSideFilterCollector.collect(op->getChild(0).get());
if (!probeSideFilterCollector.hasOperators()) {
// Probe side is not selective so we don't apply ASP.
return;
}
auto scanNodes = resolveScanNodesToApplySemiMask(op);
if (scanNodes.empty()) {
return;
}
auto joinNodeID = hashJoin->getJoinNodeIDs()[0];
// apply ASP
hashJoin->setInfoPassing(planner::HashJoinSideWayInfoPassing::LEFT_TO_RIGHT);
auto semiMasker = std::make_shared<LogicalSemiMasker>(joinNodeID, op->getChild(0));
semiMasker->computeFlatSchema();
auto accumulate = std::make_shared<LogicalAccumulate>(std::move(semiMasker));
auto currentChild = hashJoin->getChild(0);
for (auto& op_ : scanNodes) {
auto scanNode = (LogicalScanNode*)op_;
auto semiMasker = std::make_shared<LogicalSemiMasker>(scanNode, currentChild);
semiMasker->computeFlatSchema();
currentChild = semiMasker;
}
auto accumulate = std::make_shared<LogicalAccumulate>(std::move(currentChild));
accumulate->computeFlatSchema();
op->setChild(0, std::move(accumulate));
}

bool ASPOptimizer::canApplyASP(planner::LogicalOperator* op) {
std::vector<planner::LogicalOperator*> ASPOptimizer::resolveScanNodesToApplySemiMask(
planner::LogicalOperator* op) {
auto hashJoin = (LogicalHashJoin*)op;
// TODO(Xiyang): solve the cases where we cannot apply ASP.
if (hashJoin->getJoinNodeIDs().size() > 1) {
// No ASP for multiple join keys. This can be solved.
return false;
}
auto joinNodeID = hashJoin->getJoinNodeIDs()[0];
if (hashJoin->getChild(0)->getOperatorType() == LogicalOperatorType::ACCUMULATE) {
// No ASP if probe side has already been accumulated. This can be solved.
return false;
}
auto probeSideFilterCollector = LogicalFilterCollector();
probeSideFilterCollector.collect(op->getChild(0).get());
if (!probeSideFilterCollector.hasOperators()) {
// Probe side is not selective.
return false;
}
binder::expression_map<std::vector<LogicalOperator*>> nodeIDToScanOperatorsMap;
auto buildSideScanNodesCollector = LogicalScanNodeCollector();
buildSideScanNodesCollector.collect(op->getChild(1).get());
auto buildSideScanNodes = buildSideScanNodesCollector.getOperators();
if (buildSideScanNodes.size() != 1) {
// No ASP if we try to apply semi mask to multiple scan nodes. This can be solved.
return false;
}
auto buildSideNode = ((LogicalScanNode*)buildSideScanNodes[0])->getNode();
if (buildSideNode->isMultiLabeled()) {
// No ASP for multi-labeled scan. This can be solved.
return false;
// populate node ID to scan operator map
for (auto& op_ : buildSideScanNodesCollector.getOperators()) {
auto scanNode = (LogicalScanNode*)op_;
if (scanNode->getNode()->isMultiLabeled()) {
// We don't push semi mask to multi-labeled scan. This can be solved.
continue;
}
auto nodeID = scanNode->getNode()->getInternalIDProperty();
if (!nodeIDToScanOperatorsMap.contains(nodeID)) {
nodeIDToScanOperatorsMap.insert({nodeID, std::vector<LogicalOperator*>{}});
}
nodeIDToScanOperatorsMap.at(nodeID).push_back(op_);
}
if (joinNodeID->getUniqueName() != buildSideNode->getInternalIDPropertyName()) {
// We only push semi mask to scan nodes.
return false;
// generate semi mask info
std::vector<LogicalOperator*> result;
for (auto& joinNodeID : hashJoin->getJoinNodeIDs()) {
if (!nodeIDToScanOperatorsMap.contains(joinNodeID)) {
// No scan on the build side to push semi mask to.
continue;
}
if (nodeIDToScanOperatorsMap.at(joinNodeID).size() > 1) {
// We don't push semi mask to multiple scans. This can be solved.
continue;
}
result.push_back(nodeIDToScanOperatorsMap.at(joinNodeID)[0]);
}
return true;
return result;
}

} // namespace optimizer
Expand Down
60 changes: 11 additions & 49 deletions src/processor/mapper/map_hash_join.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,6 @@ using namespace kuzu::planner;
namespace kuzu {
namespace processor {

static bool containASPOnPipeline(LogicalHashJoin* logicalHashJoin) {
auto op = logicalHashJoin->getChild(0); // check probe side
while (op->getNumChildren() == 1) { // check pipeline
if (op->getOperatorType() == LogicalOperatorType::SEMI_MASKER) {
return true;
}
op = op->getChild(0);
}
return false;
}

static FactorizedTableScan* getTableScanForAccHashJoin(HashJoinProbe* hashJoinProbe) {
auto op = hashJoinProbe->getChild(0);
while (op->getOperatorType() == PhysicalOperatorType::FLATTEN) {
Expand All @@ -33,44 +22,12 @@ static FactorizedTableScan* getTableScanForAccHashJoin(HashJoinProbe* hashJoinPr
return (FactorizedTableScan*)op;
}

static SemiMasker* getSemiMasker(FactorizedTableScan* tableScan) {
auto op = (PhysicalOperator*)tableScan;
// Search on current pipeline.
while (
op->getNumChildren() == 1 && op->getOperatorType() != PhysicalOperatorType::SEMI_MASKER) {
op = op->getChild(0);
}
assert(op->getOperatorType() == PhysicalOperatorType::SEMI_MASKER);
return (SemiMasker*)op;
}

static void constructAccPipeline(FactorizedTableScan* tableScan, HashJoinProbe* hashJoinProbe) {
static void mapASPJoin(HashJoinProbe* hashJoinProbe) {
auto tableScan = getTableScanForAccHashJoin(hashJoinProbe);
auto resultCollector = tableScan->moveUnaryChild();
hashJoinProbe->addChild(std::move(resultCollector));
}

static void mapASPJoin(Expression* joinNodeID, HashJoinProbe* hashJoinProbe) {
// fetch scan node ID on build side
auto hashJoinBuild = hashJoinProbe->getChild(1);
assert(hashJoinBuild->getOperatorType() == PhysicalOperatorType::HASH_JOIN_BUILD);
std::vector<ScanNodeID*> scanNodeIDCandidates;
for (auto& op :
PhysicalPlanUtil::collectOperators(hashJoinBuild, PhysicalOperatorType::SCAN_NODE_ID)) {
auto scanNodeID = (ScanNodeID*)op;
if (scanNodeID->getNodeID() == joinNodeID->getUniqueName()) {
scanNodeIDCandidates.push_back(scanNodeID);
}
}
assert(scanNodeIDCandidates.size() == 1);
// set semi masker
auto tableScan = getTableScanForAccHashJoin(hashJoinProbe);
auto semiMasker = getSemiMasker(tableScan);
auto sharedState = scanNodeIDCandidates[0]->getSharedState();
assert(sharedState->getNumTableStates() == 1);
semiMasker->setSharedState(sharedState->getTableState(0));
constructAccPipeline(tableScan, hashJoinProbe);
}

BuildDataInfo PlanMapper::generateBuildDataInfo(const Schema& buildSideSchema,
const expression_vector& keys, const expression_vector& payloads) {
std::vector<std::pair<DataPos, common::DataType>> buildKeysPosAndType, buildPayloadsPosAndTypes;
Expand Down Expand Up @@ -134,8 +91,7 @@ std::unique_ptr<PhysicalOperator> PlanMapper::mapLogicalHashJoinToPhysical(
probeDataInfo, std::move(probeSidePrevOperator), std::move(hashJoinBuild), getOperatorID(),
paramsString);
if (hashJoin->getInfoPassing() == planner::HashJoinSideWayInfoPassing::LEFT_TO_RIGHT) {
assert(containASPOnPipeline(hashJoin));
mapASPJoin(hashJoin->getJoinNodeIDs()[0].get(), hashJoinProbe.get());
mapASPJoin(hashJoinProbe.get());
}
return hashJoinProbe;
}
Expand All @@ -145,9 +101,15 @@ std::unique_ptr<PhysicalOperator> PlanMapper::mapLogicalSemiMaskerToPhysical(
auto logicalSemiMasker = (LogicalSemiMasker*)logicalOperator;
auto inSchema = logicalSemiMasker->getChild(0)->getSchema();
auto prevOperator = mapLogicalOperatorToPhysical(logicalOperator->getChild(0));
auto keyDataPos = DataPos(inSchema->getExpressionPos(*logicalSemiMasker->getNodeID()));
return make_unique<SemiMasker>(keyDataPos, std::move(prevOperator), getOperatorID(),
auto logicalScanNode = logicalSemiMasker->getScanNode();
auto physicalScanNode = (ScanNodeID*)logicalOpToPhysicalOpMap.at(logicalScanNode);
auto keyDataPos =
DataPos(inSchema->getExpressionPos(*logicalScanNode->getNode()->getInternalIDProperty()));
auto semiMasker = make_unique<SemiMasker>(keyDataPos, std::move(prevOperator), getOperatorID(),
logicalSemiMasker->getExpressionsForPrinting());
assert(physicalScanNode->getSharedState()->getNumTableStates() == 1);
semiMasker->setSharedState(physicalScanNode->getSharedState()->getTableState(0));
return semiMasker;
}

} // namespace processor
Expand Down
Loading

0 comments on commit 5978d3b

Please sign in to comment.