Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor index nested loop join planning #1411

Merged
merged 1 commit into from
Mar 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 6 additions & 13 deletions src/include/planner/join_order_enumerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ class JoinOrderEnumeratorContext;
* filter push down
*/
class JoinOrderEnumerator {
friend class ASPOptimizer;

public:
JoinOrderEnumerator(const catalog::Catalog& catalog,
const storage::NodesStatisticsAndDeletedIDs& nodesStatistics,
Expand All @@ -32,8 +30,7 @@ class JoinOrderEnumerator {

inline void resetState() { context->resetState(); }

std::unique_ptr<JoinOrderEnumeratorContext> enterSubquery(LogicalPlan* outerPlan,
binder::expression_vector expressionsToScan,
std::unique_ptr<JoinOrderEnumeratorContext> enterSubquery(
binder::expression_vector nodeIDsToScanFromInnerAndOuter);
void exitSubquery(std::unique_ptr<JoinOrderEnumeratorContext> prevContext);

Expand Down Expand Up @@ -70,13 +67,11 @@ class JoinOrderEnumerator {
std::vector<std::unique_ptr<LogicalPlan>> enumerate(
QueryGraph* queryGraph, binder::expression_vector& predicates);

void planTableScan();

void planBaseTableScan();
void planNodeScan(uint32_t nodePos);
void planRelScan(uint32_t relPos);

void planExtendAndFilters(std::shared_ptr<RelExpression> rel, common::RelDirection direction,
binder::expression_vector& predicates, LogicalPlan& plan);
void appendExtendAndFilter(std::shared_ptr<RelExpression> rel, common::RelDirection direction,
const expression_vector& predicates, LogicalPlan& plan);

void planLevel(uint32_t level);
void planLevelExactly(uint32_t level);
Expand All @@ -89,16 +84,14 @@ class JoinOrderEnumerator {

void planInnerJoin(uint32_t leftLevel, uint32_t rightLevel);

bool canApplyINLJoin(const SubqueryGraph& subgraph, const SubqueryGraph& otherSubgraph,
const std::vector<std::shared_ptr<NodeExpression>>& joinNodes);
void planInnerINLJoin(const SubqueryGraph& subgraph, const SubqueryGraph& otherSubgraph,
bool tryPlanINLJoin(const SubqueryGraph& subgraph, const SubqueryGraph& otherSubgraph,
const std::vector<std::shared_ptr<NodeExpression>>& joinNodes);
void planInnerHashJoin(const SubqueryGraph& subgraph, const SubqueryGraph& otherSubgraph,
std::vector<std::shared_ptr<NodeExpression>> joinNodes, bool flipPlan);
// Filter push down for hash join.
void planFiltersForHashJoin(binder::expression_vector& predicates, LogicalPlan& plan);

void appendScanNode(std::shared_ptr<NodeExpression>& node, LogicalPlan& plan);
void appendScanNodeID(std::shared_ptr<NodeExpression>& node, LogicalPlan& plan);

bool needExtendToNewGroup(
RelExpression& rel, NodeExpression& boundNode, common::RelDirection direction);
Expand Down
6 changes: 2 additions & 4 deletions src/include/planner/join_order_enumerator_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ class JoinOrderEnumeratorContext {
public:
JoinOrderEnumeratorContext()
: currentLevel{0}, maxLevel{0}, subPlansTable{std::make_unique<SubPlansTable>()},
queryGraph{nullptr}, outerPlan{nullptr} {}
queryGraph{nullptr} {}

void init(QueryGraph* queryGraph, expression_vector& predicates);
void init(QueryGraph* queryGraph, const expression_vector& predicates);

inline expression_vector getWhereExpressions() { return whereExpressionsSplitOnAND; }

Expand Down Expand Up @@ -55,8 +55,6 @@ class JoinOrderEnumeratorContext {
std::unique_ptr<SubPlansTable> subPlansTable;
QueryGraph* queryGraph;

LogicalPlan* outerPlan;
expression_vector expressionsToScanFromOuter;
expression_vector nodeIDsToScanFromInnerAndOuter;
};

Expand Down
5 changes: 0 additions & 5 deletions src/include/planner/logical_plan/logical_plan_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,11 @@ namespace planner {

class LogicalPlanUtil {
public:
// Return the node whose ID has sequential guarantee on the plan.
static std::shared_ptr<binder::NodeExpression> getSequentialNode(LogicalPlan& plan);

static inline std::string encodeJoin(LogicalPlan& logicalPlan) {
return encodeJoin(logicalPlan.getLastOperator().get());
}

private:
static LogicalOperator* getCurrentPipelineSourceOperator(LogicalPlan& plan);

static std::string encodeJoin(LogicalOperator* logicalOperator) {
std::string result;
encodeJoinRecursive(logicalOperator, result);
Expand Down
128 changes: 64 additions & 64 deletions src/planner/join_order_enumerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,7 @@ std::vector<std::unique_ptr<LogicalPlan>> JoinOrderEnumerator::planCrossProduct(
std::vector<std::unique_ptr<LogicalPlan>> JoinOrderEnumerator::enumerate(
QueryGraph* queryGraph, expression_vector& predicates) {
context->init(queryGraph, predicates);
assert(context->expressionsToScanFromOuter.empty());
planTableScan();
planBaseTableScan();
context->currentLevel++;
while (context->currentLevel < context->maxLevel) {
planLevel(context->currentLevel++);
Expand All @@ -83,12 +82,9 @@ std::vector<std::unique_ptr<LogicalPlan>> JoinOrderEnumerator::enumerate(
}

std::unique_ptr<JoinOrderEnumeratorContext> JoinOrderEnumerator::enterSubquery(
LogicalPlan* outerPlan, expression_vector expressionsToScan,
expression_vector nodeIDsToScanFromInnerAndOuter) {
auto prevContext = std::move(context);
context = std::make_unique<JoinOrderEnumeratorContext>();
context->outerPlan = outerPlan;
context->expressionsToScanFromOuter = std::move(expressionsToScan);
context->nodeIDsToScanFromInnerAndOuter = std::move(nodeIDsToScanFromInnerAndOuter);
return prevContext;
}
Expand Down Expand Up @@ -121,7 +117,7 @@ void JoinOrderEnumerator::planLevelApproximately(uint32_t level) {
planInnerJoin(1, level - 1);
}

void JoinOrderEnumerator::planTableScan() {
void JoinOrderEnumerator::planBaseTableScan() {
auto queryGraph = context->getQueryGraph();
for (auto nodePos = 0u; nodePos < queryGraph->getNumQueryNodes(); ++nodePos) {
planNodeScan(nodePos);
Expand All @@ -136,19 +132,19 @@ void JoinOrderEnumerator::planNodeScan(uint32_t nodePos) {
auto newSubgraph = context->getEmptySubqueryGraph();
newSubgraph.addQueryNode(nodePos);
auto plan = std::make_unique<LogicalPlan>();
auto predicates = getNewlyMatchedExpressions(
context->getEmptySubqueryGraph(), newSubgraph, context->getWhereExpressions());
// In un-nested subquery, e.g. MATCH (a) OPTIONAL MATCH (a)-[e1]->(b), the inner query
// ("(a)-[e1]->(b)") needs to scan a, which is already scanned in the outer query (a). To avoid
// scanning storage twice, we keep track of node table "a" and make sure when planning inner
// query, we only scan internal ID of "a".
if (!context->nodeToScanFromInnerAndOuter(node.get())) {
appendScanNode(node, *plan);
appendScanNodeID(node, *plan);
auto properties = queryPlanner->getPropertiesForNode(*node);
queryPlanner->appendScanNodePropIfNecessary(properties, node, *plan);
auto predicates = getNewlyMatchedExpressions(
context->getEmptySubqueryGraph(), newSubgraph, context->getWhereExpressions());
queryPlanner->appendFilters(predicates, *plan);
} else {
appendScanNode(node, *plan);
appendScanNodeID(node, *plan);
}
context->addPlan(newSubgraph, std::move(plan));
}
Expand All @@ -169,14 +165,14 @@ void JoinOrderEnumerator::planRelScan(uint32_t relPos) {
for (auto direction : REL_DIRECTIONS) {
auto plan = std::make_unique<LogicalPlan>();
auto [boundNode, _] = getBoundAndNbrNodes(*rel, direction);
appendScanNode(boundNode, *plan);
planExtendAndFilters(rel, direction, predicates, *plan);
appendScanNodeID(boundNode, *plan);
appendExtendAndFilter(rel, direction, predicates, *plan);
context->addPlan(newSubgraph, std::move(plan));
}
}

void JoinOrderEnumerator::planExtendAndFilters(std::shared_ptr<RelExpression> rel,
RelDirection direction, expression_vector& predicates, LogicalPlan& plan) {
void JoinOrderEnumerator::appendExtendAndFilter(std::shared_ptr<RelExpression> rel,
common::RelDirection direction, const expression_vector& predicates, LogicalPlan& plan) {
auto [boundNode, dstNode] = getBoundAndNbrNodes(*rel, direction);
auto properties = queryPlanner->getPropertiesForRel(*rel);
appendExtend(boundNode, dstNode, rel, direction, properties, plan);
Expand Down Expand Up @@ -225,13 +221,38 @@ void JoinOrderEnumerator::planWCOJoin(uint32_t leftLevel, uint32_t rightLevel) {
}
}

static LogicalScanNode* getSequentialScanNodeOperator(LogicalOperator* op) {
switch (op->getOperatorType()) {
case LogicalOperatorType::FLATTEN:
case LogicalOperatorType::FILTER:
case LogicalOperatorType::SCAN_NODE_PROPERTY:
case LogicalOperatorType::EXTEND:
case LogicalOperatorType::PROJECTION: { // operators we directly search through
return getSequentialScanNodeOperator(op->getChild(0).get());
}
case LogicalOperatorType::SCAN_NODE: {
return (LogicalScanNode*)op;
}
default:
return nullptr;
}
}

// Check whether given node ID has sequential guarantee on the plan.
static bool isNodeSequentialOnPlan(LogicalPlan& plan, const NodeExpression& node) {
auto sequentialScanNode = getSequentialScanNodeOperator(plan.getLastOperator().get());
if (sequentialScanNode == nullptr) {
return false;
}
return sequentialScanNode->getNode()->getUniqueName() == node.getUniqueName();
}

// As a heuristic for wcoj, we always pick rel scan that starts from the bound node.
static std::unique_ptr<LogicalPlan> getWCOJBuildPlanForRel(
std::vector<std::unique_ptr<LogicalPlan>>& candidatePlans, const NodeExpression& boundNode) {
std::unique_ptr<LogicalPlan> result;
for (auto& candidatePlan : candidatePlans) {
if (LogicalPlanUtil::getSequentialNode(*candidatePlan)->getUniqueName() ==
boundNode.getUniqueName()) {
if (isNodeSequentialOnPlan(*candidatePlan, boundNode)) {
assert(result == nullptr);
result = candidatePlan->shallowCopy();
}
Expand Down Expand Up @@ -321,70 +342,50 @@ void JoinOrderEnumerator::planInnerJoin(uint32_t leftLevel, uint32_t rightLevel)
continue;
}
// If index nested loop (INL) join is possible, we prune hash join plans
if (canApplyINLJoin(rightSubgraph, nbrSubgraph, joinNodes)) {
planInnerINLJoin(rightSubgraph, nbrSubgraph, joinNodes);
} else if (canApplyINLJoin(nbrSubgraph, rightSubgraph, joinNodes)) {
planInnerINLJoin(nbrSubgraph, rightSubgraph, joinNodes);
} else {
planInnerHashJoin(rightSubgraph, nbrSubgraph, joinNodes, leftLevel != rightLevel);
if (tryPlanINLJoin(rightSubgraph, nbrSubgraph, joinNodes)) {
continue;
}
planInnerHashJoin(rightSubgraph, nbrSubgraph, joinNodes, leftLevel != rightLevel);
}
}
}

// Check whether given node ID has sequential guarantee on the plan.
static bool isNodeSequential(LogicalPlan& plan, NodeExpression* node) {
auto sequentialNode = LogicalPlanUtil::getSequentialNode(plan);
return sequentialNode != nullptr && sequentialNode->getUniqueName() == node->getUniqueName();
}

// We apply index nested loop join if the following to conditions are satisfied
// - otherSubgraph is an edge; and
// - join node is sequential on at least one plan corresponding to subgraph. (Otherwise INLJ will
// trigger non-sequential read).
bool JoinOrderEnumerator::canApplyINLJoin(const SubqueryGraph& subgraph,
bool JoinOrderEnumerator::tryPlanINLJoin(const SubqueryGraph& subgraph,
const SubqueryGraph& otherSubgraph,
const std::vector<std::shared_ptr<NodeExpression>>& joinNodes) {
if (!otherSubgraph.isSingleRel() || joinNodes.size() > 1) {
if (joinNodes.size() > 1) {
return false;
}
for (auto& plan : context->getPlans(subgraph)) {
if (isNodeSequential(*plan, joinNodes[0].get())) {
return true;
}
if (!subgraph.isSingleRel() && !otherSubgraph.isSingleRel()) {
return false;
}
return false;
}

static uint32_t extractJoinRelPos(const SubqueryGraph& subgraph, const QueryGraph& queryGraph) {
for (auto relPos = 0u; relPos < queryGraph.getNumQueryRels(); ++relPos) {
if (subgraph.queryRelsSelector[relPos]) {
return relPos;
if (subgraph.isSingleRel()) { // Always put single rel subgraph to right.
return tryPlanINLJoin(otherSubgraph, subgraph, joinNodes);
}
auto relPos = UINT32_MAX;
for (auto i = 0u; i < context->queryGraph->getNumQueryRels(); ++i) {
if (otherSubgraph.queryRelsSelector[i]) {
relPos = i;
}
}
throw InternalException("Cannot extract relPos.");
}

void JoinOrderEnumerator::planInnerINLJoin(const SubqueryGraph& subgraph,
const SubqueryGraph& otherSubgraph,
const std::vector<std::shared_ptr<NodeExpression>>& joinNodes) {
assert(otherSubgraph.getNumQueryRels() == 1 && joinNodes.size() == 1);
auto boundNode = joinNodes[0].get();
auto queryGraph = context->getQueryGraph();
auto relPos = extractJoinRelPos(otherSubgraph, *queryGraph);
auto rel = queryGraph->getQueryRel(relPos);
assert(relPos != UINT32_MAX);
auto rel = context->queryGraph->getQueryRel(relPos);
auto boundNode = joinNodes[0];
auto direction = boundNode->getUniqueName() == rel->getSrcNodeName() ? FWD : BWD;
auto newSubgraph = subgraph;
newSubgraph.addQueryRel(relPos);
auto predicates =
getNewlyMatchedExpressions(subgraph, newSubgraph, context->getWhereExpressions());
bool hasAppliedINLJoin = false;
for (auto& prevPlan : context->getPlans(subgraph)) {
if (isNodeSequential(*prevPlan, boundNode)) {
if (isNodeSequentialOnPlan(*prevPlan, *boundNode)) {
auto plan = prevPlan->shallowCopy();
auto direction = boundNode->getUniqueName() == rel->getSrcNodeName() ? FWD : BWD;
planExtendAndFilters(rel, direction, predicates, *plan);
context->addPlan(newSubgraph, move(plan));
appendExtendAndFilter(rel, direction, predicates, *plan);
context->addPlan(newSubgraph, std::move(plan));
hasAppliedINLJoin = true;
}
}
return hasAppliedINLJoin;
}

void JoinOrderEnumerator::planInnerHashJoin(const SubqueryGraph& subgraph,
Expand Down Expand Up @@ -415,12 +416,11 @@ void JoinOrderEnumerator::planInnerHashJoin(const SubqueryGraph& subgraph,
}

void JoinOrderEnumerator::planFiltersForHashJoin(expression_vector& predicates, LogicalPlan& plan) {
for (auto& predicate : predicates) {
queryPlanner->appendFilter(predicate, plan);
}
queryPlanner->appendFilters(predicates, plan);
}

void JoinOrderEnumerator::appendScanNode(std::shared_ptr<NodeExpression>& node, LogicalPlan& plan) {
void JoinOrderEnumerator::appendScanNodeID(
std::shared_ptr<NodeExpression>& node, LogicalPlan& plan) {
assert(plan.isEmpty());
auto scan = make_shared<LogicalScanNode>(node);
scan->computeFactorizedSchema();
Expand Down
3 changes: 2 additions & 1 deletion src/planner/join_order_enumerator_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
namespace kuzu {
namespace planner {

void JoinOrderEnumeratorContext::init(QueryGraph* queryGraph_, expression_vector& predicates) {
void JoinOrderEnumeratorContext::init(
QueryGraph* queryGraph_, const expression_vector& predicates) {
whereExpressionsSplitOnAND = predicates;
this->queryGraph = queryGraph_;
// clear and resize subPlansTable
Expand Down
20 changes: 0 additions & 20 deletions src/planner/operator/logical_plan_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,6 @@ using namespace kuzu::binder;
namespace kuzu {
namespace planner {

std::shared_ptr<NodeExpression> LogicalPlanUtil::getSequentialNode(LogicalPlan& plan) {
auto pipelineSource = getCurrentPipelineSourceOperator(plan);
if (pipelineSource->getOperatorType() != LogicalOperatorType::SCAN_NODE) {
// Pipeline source is not ScanNodeID, meaning at least one sink has happened (e.g. HashJoin)
// and we loose any sequential guarantees.
return nullptr;
}
return ((LogicalScanNode*)pipelineSource)->getNode();
}

LogicalOperator* LogicalPlanUtil::getCurrentPipelineSourceOperator(LogicalPlan& plan) {
auto op = plan.getLastOperator().get();
// Operator with more than one child will be broken into different pipelines.
while (op->getNumChildren() == 1) {
op = op->getChild(0).get();
}
assert(op != nullptr);
return op;
}

void LogicalPlanUtil::encodeJoinRecursive(
LogicalOperator* logicalOperator, std::string& encodeString) {
switch (logicalOperator->getOperatorType()) {
Expand Down
9 changes: 3 additions & 6 deletions src/planner/query_planner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,7 @@ void QueryPlanner::planOptionalMatch(const QueryGraphCollection& queryGraphColle
// When correlated variables are all NODE IDs, the subquery can be un-nested as left join.
// Join nodes are scanned twice in both outer and inner. However, we make sure inner table
// scan only scans node ID and does not scan from storage (i.e. no property scan).
auto prevContext = joinOrderEnumerator.enterSubquery(
&outerPlan, expression_vector{} /* nothing to scan from outer */, joinNodeIDs);
auto prevContext = joinOrderEnumerator.enterSubquery(joinNodeIDs);
auto innerPlans = joinOrderEnumerator.enumerate(queryGraphCollection, predicates);
auto bestInnerPlan = getBestPlan(std::move(innerPlans));
joinOrderEnumerator.exitSubquery(std::move(prevContext));
Expand Down Expand Up @@ -212,8 +211,7 @@ void QueryPlanner::planRegularMatch(const QueryGraphCollection& queryGraphCollec
// Multi-part query is actually CTE and CTE can be considered as a subquery but does not scan
// from outer (i.e. can always be un-nest). So we plan multi-part query in the same way as
// planning an un-nest subquery.
auto prevContext = joinOrderEnumerator.enterSubquery(
&prevPlan, expression_vector{} /* nothing to scan from outer */, joinNodeIDs);
auto prevContext = joinOrderEnumerator.enterSubquery(joinNodeIDs);
auto plans = joinOrderEnumerator.enumerate(queryGraphCollection, predicatesToPushDown);
joinOrderEnumerator.exitSubquery(std::move(prevContext));
auto bestPlan = getBestPlan(std::move(plans));
Expand All @@ -238,8 +236,7 @@ void QueryPlanner::planExistsSubquery(
if (ExpressionUtil::allExpressionsHaveDataType(correlatedExpressions, INTERNAL_ID)) {
auto joinNodeIDs = getJoinNodeIDs(correlatedExpressions);
// Unnest as mark join. See planOptionalMatch for unnesting logic.
auto prevContext = joinOrderEnumerator.enterSubquery(
&outerPlan, expression_vector{} /* nothing to scan from outer */, joinNodeIDs);
auto prevContext = joinOrderEnumerator.enterSubquery(joinNodeIDs);
auto predicates = subquery->hasWhereExpression() ?
subquery->getWhereExpression()->splitOnAND() :
expression_vector{};
Expand Down