Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework asp optimizer #1417

Merged
merged 1 commit into from
Mar 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/common/copier_config/copier_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ std::string CopyDescription::getFileTypeName(FileType fileType) {
case FileType::NPY: {
return "npy";
}
default:
throw InternalException("Unimplemented getFileTypeName().");
}
}

Expand Down
5 changes: 4 additions & 1 deletion src/include/common/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,13 @@ struct LoggerConstants {
};
};

struct EnumeratorKnobs {
struct PlannerKnobs {
static constexpr double NON_EQUALITY_PREDICATE_SELECTIVITY = 0.1;
static constexpr double EQUALITY_PREDICATE_SELECTIVITY = 0.01;
static constexpr uint64_t BUILD_PENALTY = 2;
// Avoid doing probe to build SIP if we have to accumulate a probe side that is much bigger than
// build side. Also avoid doing build to probe SIP if probe side is not much bigger than build.
static constexpr uint64_t ACC_HJ_PROBE_BUILD_RATIO = 5;
};

struct ClientContextConstants {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,32 @@
namespace kuzu {
namespace optimizer {

// This optimizer enables the ASP join algorithm as introduced in paper "Kuzu Graph Database
// Management System".
class ASPOptimizer : public LogicalOperatorVisitor {
// This optimizer enables the Accumulated hash join algorithm as introduced in paper "Kuzu Graph
// Database Management System".
class AccHashJoinOptimizer : public LogicalOperatorVisitor {
public:
void rewrite(planner::LogicalPlan* plan);

private:
void visitOperator(planner::LogicalOperator* op);

void visitHashJoin(planner::LogicalOperator* op) override;

bool tryProbeToBuildHJSIP(planner::LogicalOperator* op);
bool tryBuildToProbeHJSIP(planner::LogicalOperator* op);

void visitIntersect(planner::LogicalOperator* op) override;

bool isProbeSideQualified(planner::LogicalOperator* probeRoot);

binder::expression_map<std::vector<planner::LogicalOperator*>> resolveScanNodesToApplySemiMask(
const binder::expression_vector& nodeIDCandidates,
const std::vector<planner::LogicalOperator*>& buildRoots);
const std::vector<planner::LogicalOperator*>& roots);

void applyASP(
std::shared_ptr<planner::LogicalOperator> applySemiMasks(
const binder::expression_map<std::vector<planner::LogicalOperator*>>& nodeIDToScanNodes,
std::shared_ptr<planner::LogicalOperator> root);
void applyAccHashJoin(
const binder::expression_map<std::vector<planner::LogicalOperator*>>& nodeIDToScanNodes,
planner::LogicalOperator* op);
};
Expand Down
5 changes: 5 additions & 0 deletions src/include/optimizer/logical_operator_collector.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,10 @@ class LogicalScanNodeCollector : public LogicalOperatorCollector {
void visitScanNode(planner::LogicalOperator* op) override { ops.push_back(op); }
};

class LogicalIndexScanNodeCollector : public LogicalOperatorCollector {
protected:
void visitIndexScanNode(planner::LogicalOperator* op) override { ops.push_back(op); }
};

} // namespace optimizer
} // namespace kuzu
6 changes: 6 additions & 0 deletions src/include/optimizer/logical_operator_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ class LogicalOperatorVisitor {
return op;
}

virtual void visitIndexScanNode(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitIndexScanNodeReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}

virtual void visitExtend(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitExtendReplace(
std::shared_ptr<planner::LogicalOperator> op) {
Expand Down
8 changes: 0 additions & 8 deletions src/include/planner/join_order_enumerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,6 @@ class JoinOrderEnumerator {
std::shared_ptr<Expression> mark, LogicalPlan& probePlan, LogicalPlan& buildPlan) {
planJoin(joinNodeIDs, common::JoinType::MARK, mark, probePlan, buildPlan);
}
inline void planInnerHashJoin(const std::vector<std::shared_ptr<NodeExpression>>& joinNodes,
LogicalPlan& probePlan, LogicalPlan& buildPlan) {
binder::expression_vector joinNodeIDs;
for (auto& joinNode : joinNodes) {
joinNodeIDs.push_back(joinNode->getInternalIDProperty());
}
planJoin(joinNodeIDs, common::JoinType::INNER, nullptr /* mark */, probePlan, buildPlan);
}
inline void planInnerHashJoin(const binder::expression_vector& joinNodeIDs,
LogicalPlan& probePlan, LogicalPlan& buildPlan) {
planJoin(joinNodeIDs, common::JoinType::INNER, nullptr /* mark */, probePlan, buildPlan);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ namespace planner {

enum class SidewaysInfoPassing : uint8_t {
NONE = 0,
LEFT_TO_RIGHT = 1,
PROBE_TO_BUILD = 1,
PROHIBIT_PROBE_TO_BUILD = 2,
BUILD_TO_PROBE = 3,
PROHIBIT_BUILD_TO_PROBE = 4,
};

} // namespace planner
Expand Down
5 changes: 5 additions & 0 deletions src/include/planner/subplans_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class SubgraphPlans {
public:
SubgraphPlans(const SubqueryGraph& subqueryGraph);

inline uint64_t getMaxCost() const { return maxCost; }

void addPlan(std::unique_ptr<LogicalPlan> plan);

std::vector<std::unique_ptr<LogicalPlan>>& getPlans() { return plans; }
Expand All @@ -42,6 +44,7 @@ class SubgraphPlans {
constexpr static uint32_t MAX_NUM_PLANS = 10;

private:
uint64_t maxCost = UINT64_MAX;
binder::expression_vector nodeIDsToEncode;
std::vector<std::unique_ptr<LogicalPlan>> plans;
std::unordered_map<std::bitset<MAX_NUM_QUERY_VARIABLES>, common::vector_idx_t>
Expand Down Expand Up @@ -77,6 +80,8 @@ class SubPlansTable {
public:
void resize(uint32_t newSize);

uint64_t getMaxCost(const SubqueryGraph& subqueryGraph) const;

bool containSubgraphPlans(const SubqueryGraph& subqueryGraph) const;

std::vector<std::unique_ptr<LogicalPlan>>& getSubgraphPlans(const SubqueryGraph& subqueryGraph);
Expand Down
2 changes: 1 addition & 1 deletion src/include/processor/mapper/plan_mapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ class PlanMapper {
BuildDataInfo generateBuildDataInfo(const planner::Schema& buildSideSchema,
const binder::expression_vector& keys, const binder::expression_vector& payloads);

void mapASP(PhysicalOperator* probe);
void mapAccHashJoin(PhysicalOperator* probe);

public:
storage::StorageManager& storageManager;
Expand Down
2 changes: 1 addition & 1 deletion src/optimizer/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
add_library(kuzu_optimizer
OBJECT
asp_optimizer.cpp
acc_hash_join_optimizer.cpp
factorization_rewriter.cpp
filter_push_down_optimizer.cpp
logical_operator_collector.cpp
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include "optimizer/asp_optimizer.h"
#include "optimizer/acc_hash_join_optimizer.h"

#include "optimizer/logical_operator_collector.h"
#include "planner/logical_plan/logical_operator/logical_accumulate.h"
Expand All @@ -12,36 +12,83 @@ using namespace kuzu::planner;
namespace kuzu {
namespace optimizer {

void ASPOptimizer::rewrite(planner::LogicalPlan* plan) {
void AccHashJoinOptimizer::rewrite(planner::LogicalPlan* plan) {
visitOperator(plan->getLastOperator().get());
}

void ASPOptimizer::visitOperator(planner::LogicalOperator* op) {
void AccHashJoinOptimizer::visitOperator(planner::LogicalOperator* op) {
// bottom up traversal
for (auto i = 0u; i < op->getNumChildren(); ++i) {
visitOperator(op->getChild(i).get());
}
visitOperatorSwitch(op);
}

void ASPOptimizer::visitHashJoin(planner::LogicalOperator* op) {
void AccHashJoinOptimizer::visitHashJoin(planner::LogicalOperator* op) {
andyfengHKU marked this conversation as resolved.
Show resolved Hide resolved
if (tryProbeToBuildHJSIP(op)) { // Try probe to build SIP first.
return;
}
tryBuildToProbeHJSIP(op);
}

bool AccHashJoinOptimizer::tryProbeToBuildHJSIP(planner::LogicalOperator* op) {
auto hashJoin = (LogicalHashJoin*)op;
if (hashJoin->getSIP() == planner::SidewaysInfoPassing::PROHIBIT_PROBE_TO_BUILD) {
return false;
}
if (!isProbeSideQualified(op->getChild(0).get())) {
return;
return false;
}
std::vector<LogicalOperator*> buildRoots;
buildRoots.push_back(hashJoin->getChild(1).get());
auto scanNodes = resolveScanNodesToApplySemiMask(hashJoin->getJoinNodeIDs(), buildRoots);
if (scanNodes.empty()) {
return;
return false;
}
// apply ASP
hashJoin->setSIP(SidewaysInfoPassing::LEFT_TO_RIGHT);
applyASP(scanNodes, op);
// apply accumulated hash join
hashJoin->setSIP(SidewaysInfoPassing::PROBE_TO_BUILD);
applyAccHashJoin(scanNodes, op);
return true;
}

void ASPOptimizer::visitIntersect(planner::LogicalOperator* op) {
static bool subPlanContainsFilter(LogicalOperator* root) {
auto filterCollector = LogicalFilterCollector();
filterCollector.collect(root);
auto indexScanNodeCollector = LogicalIndexScanNodeCollector();
indexScanNodeCollector.collect(root);
if (!filterCollector.hasOperators() && !indexScanNodeCollector.hasOperators()) {
return false;
}
return true;
}

bool AccHashJoinOptimizer::tryBuildToProbeHJSIP(planner::LogicalOperator* op) {
auto hashJoin = (LogicalHashJoin*)op;
if (hashJoin->getSIP() == planner::SidewaysInfoPassing::PROHIBIT_BUILD_TO_PROBE) {
return false;
}
if (hashJoin->getJoinType() != common::JoinType::INNER) {
return false;
}
if (!subPlanContainsFilter(hashJoin->getChild(1).get())) {
return false;
}
std::vector<LogicalOperator*> roots;
roots.push_back(hashJoin->getChild(0).get());
auto scanNodes = resolveScanNodesToApplySemiMask(hashJoin->getJoinNodeIDs(), roots);
if (scanNodes.empty()) {
return false;
}
hashJoin->setSIP(planner::SidewaysInfoPassing::BUILD_TO_PROBE);
hashJoin->setChild(1, applySemiMasks(scanNodes, op->getChild(1)));
return true;
}

void AccHashJoinOptimizer::visitIntersect(planner::LogicalOperator* op) {
auto intersect = (LogicalIntersect*)op;
if (intersect->getSIP() == planner::SidewaysInfoPassing::PROHIBIT_PROBE_TO_BUILD) {
return;
}
if (!isProbeSideQualified(op->getChild(0).get())) {
return;
}
Expand All @@ -53,32 +100,28 @@ void ASPOptimizer::visitIntersect(planner::LogicalOperator* op) {
if (scanNodes.empty()) {
return;
}
intersect->setSIP(SidewaysInfoPassing::LEFT_TO_RIGHT);
applyASP(scanNodes, op);
intersect->setSIP(SidewaysInfoPassing::PROBE_TO_BUILD);
applyAccHashJoin(scanNodes, op);
}

// Probe side is qualified if it is selective.
bool ASPOptimizer::isProbeSideQualified(planner::LogicalOperator* probeRoot) {
bool AccHashJoinOptimizer::isProbeSideQualified(planner::LogicalOperator* probeRoot) {
if (probeRoot->getOperatorType() == LogicalOperatorType::ACCUMULATE) {
// No ASP if probe side has already been accumulated. This can be solved.
// No Acc hash join if probe side has already been accumulated. This can be solved.
return false;
}
auto filterCollector = LogicalFilterCollector();
filterCollector.collect(probeRoot);
if (!filterCollector.hasOperators()) {
// Probe side is not selective. So we don't apply ASP.
return false;
}
return true;
// Probe side is not selective. So we don't apply acc hash join.
return subPlanContainsFilter(probeRoot);
}

binder::expression_map<std::vector<planner::LogicalOperator*>>
ASPOptimizer::resolveScanNodesToApplySemiMask(const binder::expression_vector& nodeIDCandidates,
const std::vector<planner::LogicalOperator*>& buildRoots) {
AccHashJoinOptimizer::resolveScanNodesToApplySemiMask(
andyfengHKU marked this conversation as resolved.
Show resolved Hide resolved
const binder::expression_vector& nodeIDCandidates,
const std::vector<planner::LogicalOperator*>& roots) {
binder::expression_map<std::vector<LogicalOperator*>> nodeIDToScanOperatorsMap;
for (auto& buildRoot : buildRoots) {
for (auto& root : roots) {
auto scanNodesCollector = LogicalScanNodeCollector();
scanNodesCollector.collect(buildRoot);
scanNodesCollector.collect(root);
for (auto& op : scanNodesCollector.getOperators()) {
auto scanNode = (LogicalScanNode*)op;
auto nodeID = scanNode->getNode()->getInternalIDProperty();
Expand All @@ -100,16 +143,23 @@ ASPOptimizer::resolveScanNodesToApplySemiMask(const binder::expression_vector& n
return result;
}

void ASPOptimizer::applyASP(
std::shared_ptr<planner::LogicalOperator> AccHashJoinOptimizer::applySemiMasks(
const binder::expression_map<std::vector<planner::LogicalOperator*>>& nodeIDToScanNodes,
planner::LogicalOperator* op) {
auto currentChild = op->getChild(0);
std::shared_ptr<planner::LogicalOperator> root) {
auto currentRoot = root;
for (auto& [nodeID, scanNodes] : nodeIDToScanNodes) {
auto semiMasker = std::make_shared<LogicalSemiMasker>(nodeID, scanNodes, currentChild);
auto semiMasker = std::make_shared<LogicalSemiMasker>(nodeID, scanNodes, currentRoot);
semiMasker->computeFlatSchema();
currentChild = semiMasker;
currentRoot = semiMasker;
}
auto accumulate = std::make_shared<LogicalAccumulate>(std::move(currentChild));
return currentRoot;
}

void AccHashJoinOptimizer::applyAccHashJoin(
const binder::expression_map<std::vector<planner::LogicalOperator*>>& nodeIDToScanNodes,
planner::LogicalOperator* op) {
auto currentRoot = applySemiMasks(nodeIDToScanNodes, op->getChild(0));
auto accumulate = std::make_shared<LogicalAccumulate>(std::move(currentRoot));
accumulate->computeFlatSchema();
op->setChild(0, std::move(accumulate));
}
Expand Down
6 changes: 6 additions & 0 deletions src/optimizer/logical_operator_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ void LogicalOperatorVisitor::visitOperatorSwitch(planner::LogicalOperator* op) {
case LogicalOperatorType::SCAN_NODE: {
visitScanNode(op);
} break;
case LogicalOperatorType::INDEX_SCAN_NODE: {
visitIndexScanNode(op);
} break;
case LogicalOperatorType::EXTEND: {
visitExtend(op);
} break;
Expand Down Expand Up @@ -84,6 +87,9 @@ std::shared_ptr<planner::LogicalOperator> LogicalOperatorVisitor::visitOperatorR
case LogicalOperatorType::SCAN_NODE: {
return visitScanNodeReplace(op);
}
case LogicalOperatorType::INDEX_SCAN_NODE: {
return visitIndexScanNodeReplace(op);
}
case LogicalOperatorType::EXTEND: {
return visitExtendReplace(op);
}
Expand Down
6 changes: 3 additions & 3 deletions src/optimizer/optimizer.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include "optimizer/optimizer.h"

#include "optimizer/asp_optimizer.h"
#include "optimizer/acc_hash_join_optimizer.h"
#include "optimizer/factorization_rewriter.h"
#include "optimizer/filter_push_down_optimizer.h"
#include "optimizer/projection_push_down_optimizer.h"
Expand All @@ -22,8 +22,8 @@ void Optimizer::optimize(planner::LogicalPlan* plan) {
filterPushDownOptimizer.rewrite(plan);

// ASP optimizer should be applied after optimizers that manipulate hash join.
auto aspOptimizer = ASPOptimizer();
aspOptimizer.rewrite(plan);
auto accHashJoinOptimizer = AccHashJoinOptimizer();
accHashJoinOptimizer.rewrite(plan);

auto projectionPushDownOptimizer = ProjectionPushDownOptimizer();
projectionPushDownOptimizer.rewrite(plan);
Expand Down
6 changes: 3 additions & 3 deletions src/planner/join_order/cardinality_estimator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ uint64_t CardinalityEstimator::estimateIntersect(const binder::expression_vector
const LogicalPlan& probePlan, const std::vector<std::unique_ptr<LogicalPlan>>& buildPlans) {
// Formula 1: treat intersect as a Filter on probe side.
uint64_t estCardinality1 =
probePlan.estCardinality * common::EnumeratorKnobs::NON_EQUALITY_PREDICATE_SELECTIVITY;
probePlan.estCardinality * common::PlannerKnobs::NON_EQUALITY_PREDICATE_SELECTIVITY;
// Formula 2: assume independence on join conditions.
auto denominator = 1u;
for (auto& joinNodeID : joinNodeIDs) {
Expand Down Expand Up @@ -93,11 +93,11 @@ uint64_t CardinalityEstimator::estimateFilter(
return 1;
} else {
return atLeastOne(
childPlan.estCardinality * common::EnumeratorKnobs::EQUALITY_PREDICATE_SELECTIVITY);
childPlan.estCardinality * common::PlannerKnobs::EQUALITY_PREDICATE_SELECTIVITY);
}
} else {
return atLeastOne(
childPlan.estCardinality * common::EnumeratorKnobs::NON_EQUALITY_PREDICATE_SELECTIVITY);
childPlan.estCardinality * common::PlannerKnobs::NON_EQUALITY_PREDICATE_SELECTIVITY);
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/planner/join_order/cost_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ uint64_t CostModel::computeHashJoinCost(const binder::expression_vector& joinNod
cost += probe.getCost();
cost += build.getCost();
cost += probe.getCardinality();
cost += common::EnumeratorKnobs::BUILD_PENALTY * build.getCardinality();
cost += common::PlannerKnobs::BUILD_PENALTY * build.getCardinality();
return cost;
}

Expand Down
Loading