Skip to content

Commit

Permalink
Add asp for wcoj
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed Mar 12, 2023
1 parent 1a4ae3d commit 6cd980f
Show file tree
Hide file tree
Showing 11 changed files with 172 additions and 98 deletions.
9 changes: 8 additions & 1 deletion src/include/optimizer/asp_optimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,16 @@ class ASPOptimizer : public LogicalOperatorVisitor {
void visitOperator(planner::LogicalOperator* op);

void visitHashJoin(planner::LogicalOperator* op) override;
void visitIntersect(planner::LogicalOperator* op) override;

bool isProbeSideQualified(planner::LogicalOperator* probeRoot);

std::vector<planner::LogicalOperator*> resolveScanNodesToApplySemiMask(
planner::LogicalOperator* op);
const binder::expression_vector& nodeIDCandidates,
const std::vector<planner::LogicalOperator*>& buildRoots);

void applyASP(
const std::vector<planner::LogicalOperator*>& scanNodes, planner::LogicalOperator* op);
};

} // namespace optimizer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,12 @@
#include <utility>

#include "base_logical_operator.h"
#include "binder/expression/node_expression.h"
#include "common/join_type.h"
#include "side_way_info_passing.h"

namespace kuzu {
namespace planner {

enum class HashJoinSideWayInfoPassing : uint8_t {
NONE = 0,
LEFT_TO_RIGHT = 1,
};

// Probe side on left, i.e. children[0]. Build side on right, i.e. children[1].
class LogicalHashJoin : public LogicalOperator {
public:
Expand All @@ -37,7 +32,7 @@ class LogicalHashJoin : public LogicalOperator {
: LogicalOperator{LogicalOperatorType::HASH_JOIN, std::move(probeSideChild),
std::move(buildSideChild)},
joinNodeIDs(std::move(joinNodeIDs)), joinType{joinType}, mark{std::move(mark)},
infoPassing{HashJoinSideWayInfoPassing::NONE} {}
sip{SidewaysInfoPassing::NONE} {}

f_group_pos_set getGroupsPosToFlattenOnProbeSide();
f_group_pos_set getGroupsPosToFlattenOnBuildSide();
Expand All @@ -56,10 +51,8 @@ class LogicalHashJoin : public LogicalOperator {
assert(joinType == common::JoinType::MARK && mark);
return mark;
}
inline void setInfoPassing(HashJoinSideWayInfoPassing infoPassing_) {
infoPassing = infoPassing_;
}
inline HashJoinSideWayInfoPassing getInfoPassing() const { return infoPassing; }
inline void setSIP(SidewaysInfoPassing sip_) { sip = sip_; }
inline SidewaysInfoPassing getSIP() const { return sip; }

inline std::unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalHashJoin>(
Expand All @@ -82,7 +75,7 @@ class LogicalHashJoin : public LogicalOperator {
binder::expression_vector joinNodeIDs;
common::JoinType joinType;
std::shared_ptr<binder::Expression> mark; // when joinType is Mark
HashJoinSideWayInfoPassing infoPassing;
SidewaysInfoPassing sip;
};

} // namespace planner
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
#pragma once

#include "base_logical_operator.h"
#include "binder/expression/node_expression.h"
#include "schema.h"
#include "side_way_info_passing.h"

namespace kuzu {
namespace planner {
Expand All @@ -13,7 +12,8 @@ class LogicalIntersect : public LogicalOperator {
binder::expression_vector keyNodeIDs, std::shared_ptr<LogicalOperator> probeChild,
std::vector<std::shared_ptr<LogicalOperator>> buildChildren)
: LogicalOperator{LogicalOperatorType::INTERSECT, std::move(probeChild)},
intersectNodeID{std::move(intersectNodeID)}, keyNodeIDs{std::move(keyNodeIDs)} {
intersectNodeID{std::move(intersectNodeID)},
keyNodeIDs{std::move(keyNodeIDs)}, sip{SidewaysInfoPassing::NONE} {
for (auto& child : buildChildren) {
children.push_back(std::move(child));
}
Expand All @@ -30,17 +30,20 @@ class LogicalIntersect : public LogicalOperator {
inline std::shared_ptr<binder::Expression> getIntersectNodeID() const {
return intersectNodeID;
}

inline uint32_t getNumBuilds() const { return keyNodeIDs.size(); }
inline binder::expression_vector getKeyNodeIDs() const { return keyNodeIDs; }
inline std::shared_ptr<binder::Expression> getKeyNodeID(uint32_t idx) const {
return keyNodeIDs[idx];
}
inline void setSIP(SidewaysInfoPassing sip_) { sip = sip_; }
inline SidewaysInfoPassing getSIP() const { return sip; }

std::unique_ptr<LogicalOperator> copy() override;

private:
std::shared_ptr<binder::Expression> intersectNodeID;
binder::expression_vector keyNodeIDs;
SidewaysInfoPassing sip;
};

} // namespace planner
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#pragma once

#include <cstddef>

namespace kuzu {
namespace planner {

enum class SidewaysInfoPassing : uint8_t {
NONE = 0,
LEFT_TO_RIGHT = 1,
};

} // namespace planner
} // namespace kuzu
4 changes: 3 additions & 1 deletion src/include/processor/mapper/plan_mapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,11 @@ class PlanMapper {
const planner::Schema& inSchema, const planner::Schema& outSchema,
std::vector<bool>& isInputGroupByHashKeyVectorFlat);

static BuildDataInfo generateBuildDataInfo(const planner::Schema& buildSideSchema,
BuildDataInfo generateBuildDataInfo(const planner::Schema& buildSideSchema,
const binder::expression_vector& keys, const binder::expression_vector& payloads);

void mapASP(PhysicalOperator* probe);

public:
storage::StorageManager& storageManager;
storage::MemoryManager* memoryManager;
Expand Down
110 changes: 72 additions & 38 deletions src/optimizer/asp_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "optimizer/logical_operator_collector.h"
#include "planner/logical_plan/logical_operator/logical_accumulate.h"
#include "planner/logical_plan/logical_operator/logical_hash_join.h"
#include "planner/logical_plan/logical_operator/logical_intersect.h"
#include "planner/logical_plan/logical_operator/logical_scan_node.h"
#include "planner/logical_plan/logical_operator/logical_semi_masker.h"

Expand All @@ -25,68 +26,101 @@ void ASPOptimizer::visitOperator(planner::LogicalOperator* op) {

void ASPOptimizer::visitHashJoin(planner::LogicalOperator* op) {
auto hashJoin = (LogicalHashJoin*)op;
if (hashJoin->getChild(0)->getOperatorType() == LogicalOperatorType::ACCUMULATE) {
// No ASP if probe side has already been accumulated. This can be solved.
if (!isProbeSideQualified(op->getChild(0).get())) {
return;
}
auto probeSideFilterCollector = LogicalFilterCollector();
probeSideFilterCollector.collect(op->getChild(0).get());
if (!probeSideFilterCollector.hasOperators()) {
// Probe side is not selective so we don't apply ASP.
std::vector<LogicalOperator*> buildRoots;
buildRoots.push_back(hashJoin->getChild(1).get());
auto scanNodes = resolveScanNodesToApplySemiMask(hashJoin->getJoinNodeIDs(), buildRoots);
if (scanNodes.empty()) {
return;
}
// apply ASP
hashJoin->setSIP(SidewaysInfoPassing::LEFT_TO_RIGHT);
applyASP(scanNodes, op);
}

void ASPOptimizer::visitIntersect(planner::LogicalOperator* op) {
auto intersect = (LogicalIntersect*)op;
if (!isProbeSideQualified(op->getChild(0).get())) {
return;
}
auto scanNodes = resolveScanNodesToApplySemiMask(op);
std::vector<LogicalOperator*> buildRoots;
for (auto i = 1; i < intersect->getNumChildren(); ++i) {
buildRoots.push_back(intersect->getChild(i).get());
}
auto scanNodes = resolveScanNodesToApplySemiMask(intersect->getKeyNodeIDs(), buildRoots);
if (scanNodes.empty()) {
return;
}
// apply ASP
hashJoin->setInfoPassing(planner::HashJoinSideWayInfoPassing::LEFT_TO_RIGHT);
auto currentChild = hashJoin->getChild(0);
for (auto& op_ : scanNodes) {
auto scanNode = (LogicalScanNode*)op_;
auto semiMasker = std::make_shared<LogicalSemiMasker>(scanNode, currentChild);
semiMasker->computeFlatSchema();
currentChild = semiMasker;
intersect->setSIP(SidewaysInfoPassing::LEFT_TO_RIGHT);
applyASP(scanNodes, op);
}

// Probe side is qualified if it is selective.
bool ASPOptimizer::isProbeSideQualified(planner::LogicalOperator* probeRoot) {
if (probeRoot->getOperatorType() == LogicalOperatorType::ACCUMULATE) {
// No ASP if probe side has already been accumulated. This can be solved.
return false;
}
auto accumulate = std::make_shared<LogicalAccumulate>(std::move(currentChild));
accumulate->computeFlatSchema();
op->setChild(0, std::move(accumulate));
auto filterCollector = LogicalFilterCollector();
filterCollector.collect(probeRoot);
if (!filterCollector.hasOperators()) {
// Probe side is not selective. So we don't apply ASP.
return false;
}
return true;
}

std::vector<planner::LogicalOperator*> ASPOptimizer::resolveScanNodesToApplySemiMask(
planner::LogicalOperator* op) {
auto hashJoin = (LogicalHashJoin*)op;
const binder::expression_vector& nodeIDCandidates,
const std::vector<planner::LogicalOperator*>& buildRoots) {
binder::expression_map<std::vector<LogicalOperator*>> nodeIDToScanOperatorsMap;
auto buildSideScanNodesCollector = LogicalScanNodeCollector();
buildSideScanNodesCollector.collect(op->getChild(1).get());
// populate node ID to scan operator map
for (auto& op_ : buildSideScanNodesCollector.getOperators()) {
auto scanNode = (LogicalScanNode*)op_;
if (scanNode->getNode()->isMultiLabeled()) {
// We don't push semi mask to multi-labeled scan. This can be solved.
continue;
for (auto& buildRoot : buildRoots) {
auto scanNodesCollector = LogicalScanNodeCollector();
scanNodesCollector.collect(buildRoot);
for (auto& op : scanNodesCollector.getOperators()) {
auto scanNode = (LogicalScanNode*)op;
if (scanNode->getNode()->isMultiLabeled()) {
// We don't push semi mask to multi-labeled scan. This can be solved.
continue;
}
auto nodeID = scanNode->getNode()->getInternalIDProperty();
if (!nodeIDToScanOperatorsMap.contains(nodeID)) {
nodeIDToScanOperatorsMap.insert({nodeID, std::vector<LogicalOperator*>{}});
}
nodeIDToScanOperatorsMap.at(nodeID).push_back(op);
}
auto nodeID = scanNode->getNode()->getInternalIDProperty();
if (!nodeIDToScanOperatorsMap.contains(nodeID)) {
nodeIDToScanOperatorsMap.insert({nodeID, std::vector<LogicalOperator*>{}});
}
nodeIDToScanOperatorsMap.at(nodeID).push_back(op_);
}
// generate semi mask info
// Match node ID candidate with scanNode operators.
std::vector<LogicalOperator*> result;
for (auto& joinNodeID : hashJoin->getJoinNodeIDs()) {
if (!nodeIDToScanOperatorsMap.contains(joinNodeID)) {
for (auto& nodeID : nodeIDCandidates) {
if (!nodeIDToScanOperatorsMap.contains(nodeID)) {
// No scan on the build side to push semi mask to.
continue;
}
if (nodeIDToScanOperatorsMap.at(joinNodeID).size() > 1) {
if (nodeIDToScanOperatorsMap.at(nodeID).size() > 1) {
// We don't push semi mask to multiple scans. This can be solved.
continue;
}
result.push_back(nodeIDToScanOperatorsMap.at(joinNodeID)[0]);
result.push_back(nodeIDToScanOperatorsMap.at(nodeID)[0]);
}
return result;
}

void ASPOptimizer::applyASP(
const std::vector<planner::LogicalOperator*>& scanNodes, planner::LogicalOperator* op) {
auto currentChild = op->getChild(0);
for (auto& op_ : scanNodes) {
auto scanNode = (LogicalScanNode*)op_;
auto semiMasker = std::make_shared<LogicalSemiMasker>(scanNode, currentChild);
semiMasker->computeFlatSchema();
currentChild = semiMasker;
}
auto accumulate = std::make_shared<LogicalAccumulate>(std::move(currentChild));
accumulate->computeFlatSchema();
op->setChild(0, std::move(accumulate));
}

} // namespace optimizer
} // namespace kuzu
1 change: 1 addition & 0 deletions src/processor/mapper/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ add_library(kuzu_processor_mapper
expression_mapper.cpp
map_accumulate.cpp
map_aggregate.cpp
map_asp.cpp
map_create.cpp
map_cross_product.cpp
map_ddl.cpp
Expand Down
44 changes: 44 additions & 0 deletions src/processor/mapper/map_asp.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#include "planner/logical_plan/logical_operator/logical_semi_masker.h"
#include "processor/mapper/plan_mapper.h"
#include "processor/operator/scan_node_id.h"
#include "processor/operator/semi_masker.h"
#include "processor/operator/table_scan/factorized_table_scan.h"

using namespace kuzu::planner;

namespace kuzu {
namespace processor {

static FactorizedTableScan* getTableScanForAccHashJoin(PhysicalOperator* probe) {
auto op = probe->getChild(0);
while (op->getOperatorType() == PhysicalOperatorType::FLATTEN) {
op = op->getChild(0);
}
assert(op->getOperatorType() == PhysicalOperatorType::FACTORIZED_TABLE_SCAN);
return (FactorizedTableScan*)op;
}

void PlanMapper::mapASP(kuzu::processor::PhysicalOperator* probe) {
auto tableScan = getTableScanForAccHashJoin(probe);
auto resultCollector = tableScan->moveUnaryChild();
probe->addChild(std::move(resultCollector));
}

std::unique_ptr<PhysicalOperator> PlanMapper::mapLogicalSemiMaskerToPhysical(
LogicalOperator* logicalOperator) {
auto logicalSemiMasker = (LogicalSemiMasker*)logicalOperator;
auto inSchema = logicalSemiMasker->getChild(0)->getSchema();
auto prevOperator = mapLogicalOperatorToPhysical(logicalOperator->getChild(0));
auto logicalScanNode = logicalSemiMasker->getScanNode();
auto physicalScanNode = (ScanNodeID*)logicalOpToPhysicalOpMap.at(logicalScanNode);
auto keyDataPos =
DataPos(inSchema->getExpressionPos(*logicalScanNode->getNode()->getInternalIDProperty()));
auto semiMasker = make_unique<SemiMasker>(keyDataPos, std::move(prevOperator), getOperatorID(),
logicalSemiMasker->getExpressionsForPrinting());
assert(physicalScanNode->getSharedState()->getNumTableStates() == 1);
semiMasker->setSharedState(physicalScanNode->getSharedState()->getTableState(0));
return semiMasker;
}

} // namespace processor
} // namespace kuzu
Loading

0 comments on commit 6cd980f

Please sign in to comment.