Skip to content

Commit

Permalink
Merge pull request #1411 from kuzudb/join-order-refactor
Browse files Browse the repository at this point in the history
Refactor index nested loop join planning
  • Loading branch information
andyfengHKU committed Mar 25, 2023
2 parents 73f3f91 + 4f5261e commit 3f574a4
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 113 deletions.
19 changes: 6 additions & 13 deletions src/include/planner/join_order_enumerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ class JoinOrderEnumeratorContext;
* filter push down
*/
class JoinOrderEnumerator {
friend class ASPOptimizer;

public:
JoinOrderEnumerator(const catalog::Catalog& catalog,
const storage::NodesStatisticsAndDeletedIDs& nodesStatistics,
Expand All @@ -32,8 +30,7 @@ class JoinOrderEnumerator {

inline void resetState() { context->resetState(); }

std::unique_ptr<JoinOrderEnumeratorContext> enterSubquery(LogicalPlan* outerPlan,
binder::expression_vector expressionsToScan,
std::unique_ptr<JoinOrderEnumeratorContext> enterSubquery(
binder::expression_vector nodeIDsToScanFromInnerAndOuter);
void exitSubquery(std::unique_ptr<JoinOrderEnumeratorContext> prevContext);

Expand Down Expand Up @@ -70,13 +67,11 @@ class JoinOrderEnumerator {
std::vector<std::unique_ptr<LogicalPlan>> enumerate(
QueryGraph* queryGraph, binder::expression_vector& predicates);

void planTableScan();

void planBaseTableScan();
void planNodeScan(uint32_t nodePos);
void planRelScan(uint32_t relPos);

void planExtendAndFilters(std::shared_ptr<RelExpression> rel, common::RelDirection direction,
binder::expression_vector& predicates, LogicalPlan& plan);
void appendExtendAndFilter(std::shared_ptr<RelExpression> rel, common::RelDirection direction,
const expression_vector& predicates, LogicalPlan& plan);

void planLevel(uint32_t level);
void planLevelExactly(uint32_t level);
Expand All @@ -89,16 +84,14 @@ class JoinOrderEnumerator {

void planInnerJoin(uint32_t leftLevel, uint32_t rightLevel);

bool canApplyINLJoin(const SubqueryGraph& subgraph, const SubqueryGraph& otherSubgraph,
const std::vector<std::shared_ptr<NodeExpression>>& joinNodes);
void planInnerINLJoin(const SubqueryGraph& subgraph, const SubqueryGraph& otherSubgraph,
bool tryPlanINLJoin(const SubqueryGraph& subgraph, const SubqueryGraph& otherSubgraph,
const std::vector<std::shared_ptr<NodeExpression>>& joinNodes);
void planInnerHashJoin(const SubqueryGraph& subgraph, const SubqueryGraph& otherSubgraph,
std::vector<std::shared_ptr<NodeExpression>> joinNodes, bool flipPlan);
// Filter push down for hash join.
void planFiltersForHashJoin(binder::expression_vector& predicates, LogicalPlan& plan);

void appendScanNode(std::shared_ptr<NodeExpression>& node, LogicalPlan& plan);
void appendScanNodeID(std::shared_ptr<NodeExpression>& node, LogicalPlan& plan);

bool needExtendToNewGroup(
RelExpression& rel, NodeExpression& boundNode, common::RelDirection direction);
Expand Down
6 changes: 2 additions & 4 deletions src/include/planner/join_order_enumerator_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ class JoinOrderEnumeratorContext {
public:
JoinOrderEnumeratorContext()
: currentLevel{0}, maxLevel{0}, subPlansTable{std::make_unique<SubPlansTable>()},
queryGraph{nullptr}, outerPlan{nullptr} {}
queryGraph{nullptr} {}

void init(QueryGraph* queryGraph, expression_vector& predicates);
void init(QueryGraph* queryGraph, const expression_vector& predicates);

inline expression_vector getWhereExpressions() { return whereExpressionsSplitOnAND; }

Expand Down Expand Up @@ -55,8 +55,6 @@ class JoinOrderEnumeratorContext {
std::unique_ptr<SubPlansTable> subPlansTable;
QueryGraph* queryGraph;

LogicalPlan* outerPlan;
expression_vector expressionsToScanFromOuter;
expression_vector nodeIDsToScanFromInnerAndOuter;
};

Expand Down
5 changes: 0 additions & 5 deletions src/include/planner/logical_plan/logical_plan_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,11 @@ namespace planner {

class LogicalPlanUtil {
public:
// Return the node whose ID has sequential guarantee on the plan.
static std::shared_ptr<binder::NodeExpression> getSequentialNode(LogicalPlan& plan);

static inline std::string encodeJoin(LogicalPlan& logicalPlan) {
return encodeJoin(logicalPlan.getLastOperator().get());
}

private:
static LogicalOperator* getCurrentPipelineSourceOperator(LogicalPlan& plan);

static std::string encodeJoin(LogicalOperator* logicalOperator) {
std::string result;
encodeJoinRecursive(logicalOperator, result);
Expand Down
128 changes: 64 additions & 64 deletions src/planner/join_order_enumerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,7 @@ std::vector<std::unique_ptr<LogicalPlan>> JoinOrderEnumerator::planCrossProduct(
std::vector<std::unique_ptr<LogicalPlan>> JoinOrderEnumerator::enumerate(
QueryGraph* queryGraph, expression_vector& predicates) {
context->init(queryGraph, predicates);
assert(context->expressionsToScanFromOuter.empty());
planTableScan();
planBaseTableScan();
context->currentLevel++;
while (context->currentLevel < context->maxLevel) {
planLevel(context->currentLevel++);
Expand All @@ -83,12 +82,9 @@ std::vector<std::unique_ptr<LogicalPlan>> JoinOrderEnumerator::enumerate(
}

std::unique_ptr<JoinOrderEnumeratorContext> JoinOrderEnumerator::enterSubquery(
LogicalPlan* outerPlan, expression_vector expressionsToScan,
expression_vector nodeIDsToScanFromInnerAndOuter) {
auto prevContext = std::move(context);
context = std::make_unique<JoinOrderEnumeratorContext>();
context->outerPlan = outerPlan;
context->expressionsToScanFromOuter = std::move(expressionsToScan);
context->nodeIDsToScanFromInnerAndOuter = std::move(nodeIDsToScanFromInnerAndOuter);
return prevContext;
}
Expand Down Expand Up @@ -121,7 +117,7 @@ void JoinOrderEnumerator::planLevelApproximately(uint32_t level) {
planInnerJoin(1, level - 1);
}

void JoinOrderEnumerator::planTableScan() {
void JoinOrderEnumerator::planBaseTableScan() {
auto queryGraph = context->getQueryGraph();
for (auto nodePos = 0u; nodePos < queryGraph->getNumQueryNodes(); ++nodePos) {
planNodeScan(nodePos);
Expand All @@ -136,19 +132,19 @@ void JoinOrderEnumerator::planNodeScan(uint32_t nodePos) {
auto newSubgraph = context->getEmptySubqueryGraph();
newSubgraph.addQueryNode(nodePos);
auto plan = std::make_unique<LogicalPlan>();
auto predicates = getNewlyMatchedExpressions(
context->getEmptySubqueryGraph(), newSubgraph, context->getWhereExpressions());
// In un-nested subquery, e.g. MATCH (a) OPTIONAL MATCH (a)-[e1]->(b), the inner query
// ("(a)-[e1]->(b)") needs to scan a, which is already scanned in the outer query (a). To avoid
// scanning storage twice, we keep track of node table "a" and make sure when planning inner
// query, we only scan internal ID of "a".
if (!context->nodeToScanFromInnerAndOuter(node.get())) {
appendScanNode(node, *plan);
appendScanNodeID(node, *plan);
auto properties = queryPlanner->getPropertiesForNode(*node);
queryPlanner->appendScanNodePropIfNecessary(properties, node, *plan);
auto predicates = getNewlyMatchedExpressions(
context->getEmptySubqueryGraph(), newSubgraph, context->getWhereExpressions());
queryPlanner->appendFilters(predicates, *plan);
} else {
appendScanNode(node, *plan);
appendScanNodeID(node, *plan);
}
context->addPlan(newSubgraph, std::move(plan));
}
Expand All @@ -169,14 +165,14 @@ void JoinOrderEnumerator::planRelScan(uint32_t relPos) {
for (auto direction : REL_DIRECTIONS) {
auto plan = std::make_unique<LogicalPlan>();
auto [boundNode, _] = getBoundAndNbrNodes(*rel, direction);
appendScanNode(boundNode, *plan);
planExtendAndFilters(rel, direction, predicates, *plan);
appendScanNodeID(boundNode, *plan);
appendExtendAndFilter(rel, direction, predicates, *plan);
context->addPlan(newSubgraph, std::move(plan));
}
}

void JoinOrderEnumerator::planExtendAndFilters(std::shared_ptr<RelExpression> rel,
RelDirection direction, expression_vector& predicates, LogicalPlan& plan) {
void JoinOrderEnumerator::appendExtendAndFilter(std::shared_ptr<RelExpression> rel,
common::RelDirection direction, const expression_vector& predicates, LogicalPlan& plan) {
auto [boundNode, dstNode] = getBoundAndNbrNodes(*rel, direction);
auto properties = queryPlanner->getPropertiesForRel(*rel);
appendExtend(boundNode, dstNode, rel, direction, properties, plan);
Expand Down Expand Up @@ -225,13 +221,38 @@ void JoinOrderEnumerator::planWCOJoin(uint32_t leftLevel, uint32_t rightLevel) {
}
}

static LogicalScanNode* getSequentialScanNodeOperator(LogicalOperator* op) {
switch (op->getOperatorType()) {
case LogicalOperatorType::FLATTEN:
case LogicalOperatorType::FILTER:
case LogicalOperatorType::SCAN_NODE_PROPERTY:
case LogicalOperatorType::EXTEND:
case LogicalOperatorType::PROJECTION: { // operators we directly search through
return getSequentialScanNodeOperator(op->getChild(0).get());
}
case LogicalOperatorType::SCAN_NODE: {
return (LogicalScanNode*)op;
}
default:
return nullptr;
}
}

// Check whether given node ID has sequential guarantee on the plan.
static bool isNodeSequentialOnPlan(LogicalPlan& plan, const NodeExpression& node) {
auto sequentialScanNode = getSequentialScanNodeOperator(plan.getLastOperator().get());
if (sequentialScanNode == nullptr) {
return false;
}
return sequentialScanNode->getNode()->getUniqueName() == node.getUniqueName();
}

// As a heuristic for wcoj, we always pick rel scan that starts from the bound node.
static std::unique_ptr<LogicalPlan> getWCOJBuildPlanForRel(
std::vector<std::unique_ptr<LogicalPlan>>& candidatePlans, const NodeExpression& boundNode) {
std::unique_ptr<LogicalPlan> result;
for (auto& candidatePlan : candidatePlans) {
if (LogicalPlanUtil::getSequentialNode(*candidatePlan)->getUniqueName() ==
boundNode.getUniqueName()) {
if (isNodeSequentialOnPlan(*candidatePlan, boundNode)) {
assert(result == nullptr);
result = candidatePlan->shallowCopy();
}
Expand Down Expand Up @@ -321,70 +342,50 @@ void JoinOrderEnumerator::planInnerJoin(uint32_t leftLevel, uint32_t rightLevel)
continue;
}
// If index nested loop (INL) join is possible, we prune hash join plans
if (canApplyINLJoin(rightSubgraph, nbrSubgraph, joinNodes)) {
planInnerINLJoin(rightSubgraph, nbrSubgraph, joinNodes);
} else if (canApplyINLJoin(nbrSubgraph, rightSubgraph, joinNodes)) {
planInnerINLJoin(nbrSubgraph, rightSubgraph, joinNodes);
} else {
planInnerHashJoin(rightSubgraph, nbrSubgraph, joinNodes, leftLevel != rightLevel);
if (tryPlanINLJoin(rightSubgraph, nbrSubgraph, joinNodes)) {
continue;
}
planInnerHashJoin(rightSubgraph, nbrSubgraph, joinNodes, leftLevel != rightLevel);
}
}
}

// Check whether given node ID has sequential guarantee on the plan.
static bool isNodeSequential(LogicalPlan& plan, NodeExpression* node) {
auto sequentialNode = LogicalPlanUtil::getSequentialNode(plan);
return sequentialNode != nullptr && sequentialNode->getUniqueName() == node->getUniqueName();
}

// We apply index nested loop join if the following to conditions are satisfied
// - otherSubgraph is an edge; and
// - join node is sequential on at least one plan corresponding to subgraph. (Otherwise INLJ will
// trigger non-sequential read).
bool JoinOrderEnumerator::canApplyINLJoin(const SubqueryGraph& subgraph,
bool JoinOrderEnumerator::tryPlanINLJoin(const SubqueryGraph& subgraph,
const SubqueryGraph& otherSubgraph,
const std::vector<std::shared_ptr<NodeExpression>>& joinNodes) {
if (!otherSubgraph.isSingleRel() || joinNodes.size() > 1) {
if (joinNodes.size() > 1) {
return false;
}
for (auto& plan : context->getPlans(subgraph)) {
if (isNodeSequential(*plan, joinNodes[0].get())) {
return true;
}
if (!subgraph.isSingleRel() && !otherSubgraph.isSingleRel()) {
return false;
}
return false;
}

static uint32_t extractJoinRelPos(const SubqueryGraph& subgraph, const QueryGraph& queryGraph) {
for (auto relPos = 0u; relPos < queryGraph.getNumQueryRels(); ++relPos) {
if (subgraph.queryRelsSelector[relPos]) {
return relPos;
if (subgraph.isSingleRel()) { // Always put single rel subgraph to right.
return tryPlanINLJoin(otherSubgraph, subgraph, joinNodes);
}
auto relPos = UINT32_MAX;
for (auto i = 0u; i < context->queryGraph->getNumQueryRels(); ++i) {
if (otherSubgraph.queryRelsSelector[i]) {
relPos = i;
}
}
throw InternalException("Cannot extract relPos.");
}

void JoinOrderEnumerator::planInnerINLJoin(const SubqueryGraph& subgraph,
const SubqueryGraph& otherSubgraph,
const std::vector<std::shared_ptr<NodeExpression>>& joinNodes) {
assert(otherSubgraph.getNumQueryRels() == 1 && joinNodes.size() == 1);
auto boundNode = joinNodes[0].get();
auto queryGraph = context->getQueryGraph();
auto relPos = extractJoinRelPos(otherSubgraph, *queryGraph);
auto rel = queryGraph->getQueryRel(relPos);
assert(relPos != UINT32_MAX);
auto rel = context->queryGraph->getQueryRel(relPos);
auto boundNode = joinNodes[0];
auto direction = boundNode->getUniqueName() == rel->getSrcNodeName() ? FWD : BWD;
auto newSubgraph = subgraph;
newSubgraph.addQueryRel(relPos);
auto predicates =
getNewlyMatchedExpressions(subgraph, newSubgraph, context->getWhereExpressions());
bool hasAppliedINLJoin = false;
for (auto& prevPlan : context->getPlans(subgraph)) {
if (isNodeSequential(*prevPlan, boundNode)) {
if (isNodeSequentialOnPlan(*prevPlan, *boundNode)) {
auto plan = prevPlan->shallowCopy();
auto direction = boundNode->getUniqueName() == rel->getSrcNodeName() ? FWD : BWD;
planExtendAndFilters(rel, direction, predicates, *plan);
context->addPlan(newSubgraph, move(plan));
appendExtendAndFilter(rel, direction, predicates, *plan);
context->addPlan(newSubgraph, std::move(plan));
hasAppliedINLJoin = true;
}
}
return hasAppliedINLJoin;
}

void JoinOrderEnumerator::planInnerHashJoin(const SubqueryGraph& subgraph,
Expand Down Expand Up @@ -415,12 +416,11 @@ void JoinOrderEnumerator::planInnerHashJoin(const SubqueryGraph& subgraph,
}

void JoinOrderEnumerator::planFiltersForHashJoin(expression_vector& predicates, LogicalPlan& plan) {
for (auto& predicate : predicates) {
queryPlanner->appendFilter(predicate, plan);
}
queryPlanner->appendFilters(predicates, plan);
}

void JoinOrderEnumerator::appendScanNode(std::shared_ptr<NodeExpression>& node, LogicalPlan& plan) {
void JoinOrderEnumerator::appendScanNodeID(
std::shared_ptr<NodeExpression>& node, LogicalPlan& plan) {
assert(plan.isEmpty());
auto scan = make_shared<LogicalScanNode>(node);
scan->computeFactorizedSchema();
Expand Down
3 changes: 2 additions & 1 deletion src/planner/join_order_enumerator_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
namespace kuzu {
namespace planner {

void JoinOrderEnumeratorContext::init(QueryGraph* queryGraph_, expression_vector& predicates) {
void JoinOrderEnumeratorContext::init(
QueryGraph* queryGraph_, const expression_vector& predicates) {
whereExpressionsSplitOnAND = predicates;
this->queryGraph = queryGraph_;
// clear and resize subPlansTable
Expand Down
20 changes: 0 additions & 20 deletions src/planner/operator/logical_plan_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,6 @@ using namespace kuzu::binder;
namespace kuzu {
namespace planner {

std::shared_ptr<NodeExpression> LogicalPlanUtil::getSequentialNode(LogicalPlan& plan) {
auto pipelineSource = getCurrentPipelineSourceOperator(plan);
if (pipelineSource->getOperatorType() != LogicalOperatorType::SCAN_NODE) {
// Pipeline source is not ScanNodeID, meaning at least one sink has happened (e.g. HashJoin)
// and we loose any sequential guarantees.
return nullptr;
}
return ((LogicalScanNode*)pipelineSource)->getNode();
}

LogicalOperator* LogicalPlanUtil::getCurrentPipelineSourceOperator(LogicalPlan& plan) {
auto op = plan.getLastOperator().get();
// Operator with more than one child will be broken into different pipelines.
while (op->getNumChildren() == 1) {
op = op->getChild(0).get();
}
assert(op != nullptr);
return op;
}

void LogicalPlanUtil::encodeJoinRecursive(
LogicalOperator* logicalOperator, std::string& encodeString) {
switch (logicalOperator->getOperatorType()) {
Expand Down
9 changes: 3 additions & 6 deletions src/planner/query_planner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,7 @@ void QueryPlanner::planOptionalMatch(const QueryGraphCollection& queryGraphColle
// When correlated variables are all NODE IDs, the subquery can be un-nested as left join.
// Join nodes are scanned twice in both outer and inner. However, we make sure inner table
// scan only scans node ID and does not scan from storage (i.e. no property scan).
auto prevContext = joinOrderEnumerator.enterSubquery(
&outerPlan, expression_vector{} /* nothing to scan from outer */, joinNodeIDs);
auto prevContext = joinOrderEnumerator.enterSubquery(joinNodeIDs);
auto innerPlans = joinOrderEnumerator.enumerate(queryGraphCollection, predicates);
auto bestInnerPlan = getBestPlan(std::move(innerPlans));
joinOrderEnumerator.exitSubquery(std::move(prevContext));
Expand Down Expand Up @@ -212,8 +211,7 @@ void QueryPlanner::planRegularMatch(const QueryGraphCollection& queryGraphCollec
// Multi-part query is actually CTE and CTE can be considered as a subquery but does not scan
// from outer (i.e. can always be un-nest). So we plan multi-part query in the same way as
// planning an un-nest subquery.
auto prevContext = joinOrderEnumerator.enterSubquery(
&prevPlan, expression_vector{} /* nothing to scan from outer */, joinNodeIDs);
auto prevContext = joinOrderEnumerator.enterSubquery(joinNodeIDs);
auto plans = joinOrderEnumerator.enumerate(queryGraphCollection, predicatesToPushDown);
joinOrderEnumerator.exitSubquery(std::move(prevContext));
auto bestPlan = getBestPlan(std::move(plans));
Expand All @@ -238,8 +236,7 @@ void QueryPlanner::planExistsSubquery(
if (ExpressionUtil::allExpressionsHaveDataType(correlatedExpressions, INTERNAL_ID)) {
auto joinNodeIDs = getJoinNodeIDs(correlatedExpressions);
// Unnest as mark join. See planOptionalMatch for unnesting logic.
auto prevContext = joinOrderEnumerator.enterSubquery(
&outerPlan, expression_vector{} /* nothing to scan from outer */, joinNodeIDs);
auto prevContext = joinOrderEnumerator.enterSubquery(joinNodeIDs);
auto predicates = subquery->hasWhereExpression() ?
subquery->getWhereExpression()->splitOnAND() :
expression_vector{};
Expand Down

0 comments on commit 3f574a4

Please sign in to comment.