Skip to content

Commit

Permalink
Rework asp optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed Mar 29, 2023
1 parent 8f34ab1 commit 1710491
Show file tree
Hide file tree
Showing 23 changed files with 198 additions and 75 deletions.
2 changes: 2 additions & 0 deletions src/common/copier_config/copier_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ std::string CopyDescription::getFileTypeName(FileType fileType) {
case FileType::NPY: {
return "npy";
}
default:
throw InternalException("Unimplemented getFileTypeName().");
}
}

Expand Down
5 changes: 4 additions & 1 deletion src/include/common/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,13 @@ struct LoggerConstants {
};
};

struct EnumeratorKnobs {
struct PlannerKnobs {
static constexpr double NON_EQUALITY_PREDICATE_SELECTIVITY = 0.1;
static constexpr double EQUALITY_PREDICATE_SELECTIVITY = 0.01;
static constexpr uint64_t BUILD_PENALTY = 2;
// Avoid doing probe to build SIP if we have to accumulate a probe side that is much bigger than
// build side. Also avoid doing build to probe SIP if probe side is not much bigger than build.
static constexpr uint64_t ACC_HJ_PROBE_BUILD_RATIO = 5;
};

struct ClientContextConstants {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,32 @@
namespace kuzu {
namespace optimizer {

// This optimizer enables the ASP join algorithm as introduced in paper "Kuzu Graph Database
// Management System".
class ASPOptimizer : public LogicalOperatorVisitor {
// This optimizer enables the Accumulated hash join algorithm as introduced in paper "Kuzu Graph
// Database Management System".
class AccHashJoinOptimizer : public LogicalOperatorVisitor {
public:
void rewrite(planner::LogicalPlan* plan);

private:
void visitOperator(planner::LogicalOperator* op);

void visitHashJoin(planner::LogicalOperator* op) override;

bool tryProbeToBuildHJSIP(planner::LogicalOperator* op);
bool tryBuildToProbeHJSIP(planner::LogicalOperator* op);

void visitIntersect(planner::LogicalOperator* op) override;

bool isProbeSideQualified(planner::LogicalOperator* probeRoot);

binder::expression_map<std::vector<planner::LogicalOperator*>> resolveScanNodesToApplySemiMask(
const binder::expression_vector& nodeIDCandidates,
const std::vector<planner::LogicalOperator*>& buildRoots);
const std::vector<planner::LogicalOperator*>& roots);

void applyASP(
std::shared_ptr<planner::LogicalOperator> applySemiMasks(
const binder::expression_map<std::vector<planner::LogicalOperator*>>& nodeIDToScanNodes,
std::shared_ptr<planner::LogicalOperator> root);
void applyAccHashJoin(
const binder::expression_map<std::vector<planner::LogicalOperator*>>& nodeIDToScanNodes,
planner::LogicalOperator* op);
};
Expand Down
5 changes: 5 additions & 0 deletions src/include/optimizer/logical_operator_collector.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,10 @@ class LogicalScanNodeCollector : public LogicalOperatorCollector {
void visitScanNode(planner::LogicalOperator* op) override { ops.push_back(op); }
};

class LogicalIndexScanNodeCollector : public LogicalOperatorCollector {
protected:
void visitIndexScanNode(planner::LogicalOperator* op) override { ops.push_back(op); }
};

} // namespace optimizer
} // namespace kuzu
6 changes: 6 additions & 0 deletions src/include/optimizer/logical_operator_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ class LogicalOperatorVisitor {
return op;
}

virtual void visitIndexScanNode(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitIndexScanNodeReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}

virtual void visitExtend(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitExtendReplace(
std::shared_ptr<planner::LogicalOperator> op) {
Expand Down
8 changes: 0 additions & 8 deletions src/include/planner/join_order_enumerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,6 @@ class JoinOrderEnumerator {
std::shared_ptr<Expression> mark, LogicalPlan& probePlan, LogicalPlan& buildPlan) {
planJoin(joinNodeIDs, common::JoinType::MARK, mark, probePlan, buildPlan);
}
inline void planInnerHashJoin(const std::vector<std::shared_ptr<NodeExpression>>& joinNodes,
LogicalPlan& probePlan, LogicalPlan& buildPlan) {
binder::expression_vector joinNodeIDs;
for (auto& joinNode : joinNodes) {
joinNodeIDs.push_back(joinNode->getInternalIDProperty());
}
planJoin(joinNodeIDs, common::JoinType::INNER, nullptr /* mark */, probePlan, buildPlan);
}
inline void planInnerHashJoin(const binder::expression_vector& joinNodeIDs,
LogicalPlan& probePlan, LogicalPlan& buildPlan) {
planJoin(joinNodeIDs, common::JoinType::INNER, nullptr /* mark */, probePlan, buildPlan);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ namespace planner {

enum class SidewaysInfoPassing : uint8_t {
NONE = 0,
LEFT_TO_RIGHT = 1,
PROBE_TO_BUILD = 1,
PROHIBIT_PROBE_TO_BUILD = 2,
BUILD_TO_PROBE = 3,
PROHIBIT_BUILD_TO_PROBE = 4,
};

} // namespace planner
Expand Down
5 changes: 5 additions & 0 deletions src/include/planner/subplans_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class SubgraphPlans {
public:
SubgraphPlans(const SubqueryGraph& subqueryGraph);

inline uint64_t getMaxCost() const { return maxCost; }

void addPlan(std::unique_ptr<LogicalPlan> plan);

std::vector<std::unique_ptr<LogicalPlan>>& getPlans() { return plans; }
Expand All @@ -42,6 +44,7 @@ class SubgraphPlans {
constexpr static uint32_t MAX_NUM_PLANS = 10;

private:
uint64_t maxCost = UINT64_MAX;
binder::expression_vector nodeIDsToEncode;
std::vector<std::unique_ptr<LogicalPlan>> plans;
std::unordered_map<std::bitset<MAX_NUM_QUERY_VARIABLES>, common::vector_idx_t>
Expand Down Expand Up @@ -77,6 +80,8 @@ class SubPlansTable {
public:
void resize(uint32_t newSize);

uint64_t getMaxCost(const SubqueryGraph& subqueryGraph) const;

bool containSubgraphPlans(const SubqueryGraph& subqueryGraph) const;

std::vector<std::unique_ptr<LogicalPlan>>& getSubgraphPlans(const SubqueryGraph& subqueryGraph);
Expand Down
2 changes: 1 addition & 1 deletion src/include/processor/mapper/plan_mapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ class PlanMapper {
BuildDataInfo generateBuildDataInfo(const planner::Schema& buildSideSchema,
const binder::expression_vector& keys, const binder::expression_vector& payloads);

void mapASP(PhysicalOperator* probe);
void mapAccHashJoin(PhysicalOperator* probe);

public:
storage::StorageManager& storageManager;
Expand Down
2 changes: 1 addition & 1 deletion src/optimizer/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
add_library(kuzu_optimizer
OBJECT
asp_optimizer.cpp
acc_hash_join_optimizer.cpp
factorization_rewriter.cpp
filter_push_down_optimizer.cpp
logical_operator_collector.cpp
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include "optimizer/asp_optimizer.h"
#include "optimizer/acc_hash_join_optimizer.h"

#include "optimizer/logical_operator_collector.h"
#include "planner/logical_plan/logical_operator/logical_accumulate.h"
Expand All @@ -12,36 +12,83 @@ using namespace kuzu::planner;
namespace kuzu {
namespace optimizer {

void ASPOptimizer::rewrite(planner::LogicalPlan* plan) {
void AccHashJoinOptimizer::rewrite(planner::LogicalPlan* plan) {
visitOperator(plan->getLastOperator().get());
}

void ASPOptimizer::visitOperator(planner::LogicalOperator* op) {
void AccHashJoinOptimizer::visitOperator(planner::LogicalOperator* op) {
// bottom up traversal
for (auto i = 0u; i < op->getNumChildren(); ++i) {
visitOperator(op->getChild(i).get());
}
visitOperatorSwitch(op);
}

void ASPOptimizer::visitHashJoin(planner::LogicalOperator* op) {
void AccHashJoinOptimizer::visitHashJoin(planner::LogicalOperator* op) {
if (tryProbeToBuildHJSIP(op)) { // Try probe to build SIP first.
return;
}
tryBuildToProbeHJSIP(op);
}

bool AccHashJoinOptimizer::tryProbeToBuildHJSIP(planner::LogicalOperator* op) {
auto hashJoin = (LogicalHashJoin*)op;
if (hashJoin->getSIP() == planner::SidewaysInfoPassing::PROHIBIT_PROBE_TO_BUILD) {
return false;
}
if (!isProbeSideQualified(op->getChild(0).get())) {
return;
return false;
}
std::vector<LogicalOperator*> buildRoots;
buildRoots.push_back(hashJoin->getChild(1).get());
auto scanNodes = resolveScanNodesToApplySemiMask(hashJoin->getJoinNodeIDs(), buildRoots);
if (scanNodes.empty()) {
return;
return false;
}
// apply ASP
hashJoin->setSIP(SidewaysInfoPassing::LEFT_TO_RIGHT);
applyASP(scanNodes, op);
// apply accumulated hash join
hashJoin->setSIP(SidewaysInfoPassing::PROBE_TO_BUILD);
applyAccHashJoin(scanNodes, op);
return true;
}

void ASPOptimizer::visitIntersect(planner::LogicalOperator* op) {
static bool subPlanContainsFilter(LogicalOperator* root) {
auto filterCollector = LogicalFilterCollector();
filterCollector.collect(root);
auto indexScanNodeCollector = LogicalIndexScanNodeCollector();
indexScanNodeCollector.collect(root);
if (!filterCollector.hasOperators() && !indexScanNodeCollector.hasOperators()) {
return false;
}
return true;
}

bool AccHashJoinOptimizer::tryBuildToProbeHJSIP(planner::LogicalOperator* op) {
auto hashJoin = (LogicalHashJoin*)op;
if (hashJoin->getSIP() == planner::SidewaysInfoPassing::PROHIBIT_BUILD_TO_PROBE) {
return false;
}
if (hashJoin->getJoinType() != common::JoinType::INNER) {
return false;
}
if (!subPlanContainsFilter(hashJoin->getChild(1).get())) {
return false;
}
std::vector<LogicalOperator*> roots;
roots.push_back(hashJoin->getChild(0).get());
auto scanNodes = resolveScanNodesToApplySemiMask(hashJoin->getJoinNodeIDs(), roots);
if (scanNodes.empty()) {
return false;
}
hashJoin->setSIP(planner::SidewaysInfoPassing::BUILD_TO_PROBE);
hashJoin->setChild(1, applySemiMasks(scanNodes, op->getChild(1)));
return true;
}

void AccHashJoinOptimizer::visitIntersect(planner::LogicalOperator* op) {
auto intersect = (LogicalIntersect*)op;
if (intersect->getSIP() == planner::SidewaysInfoPassing::PROHIBIT_PROBE_TO_BUILD) {
return;
}
if (!isProbeSideQualified(op->getChild(0).get())) {
return;
}
Expand All @@ -53,32 +100,28 @@ void ASPOptimizer::visitIntersect(planner::LogicalOperator* op) {
if (scanNodes.empty()) {
return;
}
intersect->setSIP(SidewaysInfoPassing::LEFT_TO_RIGHT);
applyASP(scanNodes, op);
intersect->setSIP(SidewaysInfoPassing::PROBE_TO_BUILD);
applyAccHashJoin(scanNodes, op);
}

// Probe side is qualified if it is selective.
bool ASPOptimizer::isProbeSideQualified(planner::LogicalOperator* probeRoot) {
bool AccHashJoinOptimizer::isProbeSideQualified(planner::LogicalOperator* probeRoot) {
if (probeRoot->getOperatorType() == LogicalOperatorType::ACCUMULATE) {
// No ASP if probe side has already been accumulated. This can be solved.
// No Acc hash join if probe side has already been accumulated. This can be solved.
return false;
}
auto filterCollector = LogicalFilterCollector();
filterCollector.collect(probeRoot);
if (!filterCollector.hasOperators()) {
// Probe side is not selective. So we don't apply ASP.
return false;
}
return true;
// Probe side is not selective. So we don't apply acc hash join.
return subPlanContainsFilter(probeRoot);
}

binder::expression_map<std::vector<planner::LogicalOperator*>>
ASPOptimizer::resolveScanNodesToApplySemiMask(const binder::expression_vector& nodeIDCandidates,
const std::vector<planner::LogicalOperator*>& buildRoots) {
AccHashJoinOptimizer::resolveScanNodesToApplySemiMask(
const binder::expression_vector& nodeIDCandidates,
const std::vector<planner::LogicalOperator*>& roots) {
binder::expression_map<std::vector<LogicalOperator*>> nodeIDToScanOperatorsMap;
for (auto& buildRoot : buildRoots) {
for (auto& root : roots) {
auto scanNodesCollector = LogicalScanNodeCollector();
scanNodesCollector.collect(buildRoot);
scanNodesCollector.collect(root);
for (auto& op : scanNodesCollector.getOperators()) {
auto scanNode = (LogicalScanNode*)op;
auto nodeID = scanNode->getNode()->getInternalIDProperty();
Expand All @@ -100,16 +143,23 @@ ASPOptimizer::resolveScanNodesToApplySemiMask(const binder::expression_vector& n
return result;
}

void ASPOptimizer::applyASP(
std::shared_ptr<planner::LogicalOperator> AccHashJoinOptimizer::applySemiMasks(
const binder::expression_map<std::vector<planner::LogicalOperator*>>& nodeIDToScanNodes,
planner::LogicalOperator* op) {
auto currentChild = op->getChild(0);
std::shared_ptr<planner::LogicalOperator> root) {
auto currentRoot = root;
for (auto& [nodeID, scanNodes] : nodeIDToScanNodes) {
auto semiMasker = std::make_shared<LogicalSemiMasker>(nodeID, scanNodes, currentChild);
auto semiMasker = std::make_shared<LogicalSemiMasker>(nodeID, scanNodes, currentRoot);
semiMasker->computeFlatSchema();
currentChild = semiMasker;
currentRoot = semiMasker;
}
auto accumulate = std::make_shared<LogicalAccumulate>(std::move(currentChild));
return currentRoot;
}

void AccHashJoinOptimizer::applyAccHashJoin(
const binder::expression_map<std::vector<planner::LogicalOperator*>>& nodeIDToScanNodes,
planner::LogicalOperator* op) {
auto currentRoot = applySemiMasks(nodeIDToScanNodes, op->getChild(0));
auto accumulate = std::make_shared<LogicalAccumulate>(std::move(currentRoot));
accumulate->computeFlatSchema();
op->setChild(0, std::move(accumulate));
}
Expand Down
6 changes: 6 additions & 0 deletions src/optimizer/logical_operator_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ void LogicalOperatorVisitor::visitOperatorSwitch(planner::LogicalOperator* op) {
case LogicalOperatorType::SCAN_NODE: {
visitScanNode(op);
} break;
case LogicalOperatorType::INDEX_SCAN_NODE: {
visitIndexScanNode(op);
} break;
case LogicalOperatorType::EXTEND: {
visitExtend(op);
} break;
Expand Down Expand Up @@ -84,6 +87,9 @@ std::shared_ptr<planner::LogicalOperator> LogicalOperatorVisitor::visitOperatorR
case LogicalOperatorType::SCAN_NODE: {
return visitScanNodeReplace(op);
}
case LogicalOperatorType::INDEX_SCAN_NODE: {
return visitIndexScanNodeReplace(op);
}
case LogicalOperatorType::EXTEND: {
return visitExtendReplace(op);
}
Expand Down
6 changes: 3 additions & 3 deletions src/optimizer/optimizer.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include "optimizer/optimizer.h"

#include "optimizer/asp_optimizer.h"
#include "optimizer/acc_hash_join_optimizer.h"
#include "optimizer/factorization_rewriter.h"
#include "optimizer/filter_push_down_optimizer.h"
#include "optimizer/projection_push_down_optimizer.h"
Expand All @@ -22,8 +22,8 @@ void Optimizer::optimize(planner::LogicalPlan* plan) {
filterPushDownOptimizer.rewrite(plan);

// ASP optimizer should be applied after optimizers that manipulate hash join.
auto aspOptimizer = ASPOptimizer();
aspOptimizer.rewrite(plan);
auto accHashJoinOptimizer = AccHashJoinOptimizer();
accHashJoinOptimizer.rewrite(plan);

auto projectionPushDownOptimizer = ProjectionPushDownOptimizer();
projectionPushDownOptimizer.rewrite(plan);
Expand Down
6 changes: 3 additions & 3 deletions src/planner/join_order/cardinality_estimator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ uint64_t CardinalityEstimator::estimateIntersect(const binder::expression_vector
const LogicalPlan& probePlan, const std::vector<std::unique_ptr<LogicalPlan>>& buildPlans) {
// Formula 1: treat intersect as a Filter on probe side.
uint64_t estCardinality1 =
probePlan.estCardinality * common::EnumeratorKnobs::NON_EQUALITY_PREDICATE_SELECTIVITY;
probePlan.estCardinality * common::PlannerKnobs::NON_EQUALITY_PREDICATE_SELECTIVITY;
// Formula 2: assume independence on join conditions.
auto denominator = 1u;
for (auto& joinNodeID : joinNodeIDs) {
Expand Down Expand Up @@ -93,11 +93,11 @@ uint64_t CardinalityEstimator::estimateFilter(
return 1;
} else {
return atLeastOne(
childPlan.estCardinality * common::EnumeratorKnobs::EQUALITY_PREDICATE_SELECTIVITY);
childPlan.estCardinality * common::PlannerKnobs::EQUALITY_PREDICATE_SELECTIVITY);
}
} else {
return atLeastOne(
childPlan.estCardinality * common::EnumeratorKnobs::NON_EQUALITY_PREDICATE_SELECTIVITY);
childPlan.estCardinality * common::PlannerKnobs::NON_EQUALITY_PREDICATE_SELECTIVITY);
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/planner/join_order/cost_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ uint64_t CostModel::computeHashJoinCost(const binder::expression_vector& joinNod
cost += probe.getCost();
cost += build.getCost();
cost += probe.getCardinality();
cost += common::EnumeratorKnobs::BUILD_PENALTY * build.getCardinality();
cost += common::PlannerKnobs::BUILD_PENALTY * build.getCardinality();
return cost;
}

Expand Down
Loading

0 comments on commit 1710491

Please sign in to comment.