Skip to content

Commit

Permalink
clean
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed Mar 25, 2023
1 parent 8729d7e commit 4b47530
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 150 deletions.
11 changes: 2 additions & 9 deletions src/include/planner/join_order_enumerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,11 @@ class JoinOrderEnumerator {
std::vector<std::unique_ptr<LogicalPlan>> enumerate(
QueryGraph* queryGraph, binder::expression_vector& predicates);

// Level 1 contains base table scans.
void planBaseTableScan();
void planNodeScan(uint32_t nodePos);
void appendScanNodeAndFilter(std::shared_ptr<NodeExpression> node, LogicalPlan& plan);
void planRelScan(uint32_t relPos);
void appendExtendAndFilter(
std::shared_ptr<RelExpression> rel, common::RelDirection direction, LogicalPlan& plan);
void appendExtendAndFilter(std::shared_ptr<RelExpression> rel, common::RelDirection direction,
const expression_vector& predicates, LogicalPlan& plan);

void planLevel(uint32_t level);
void planLevelExactly(uint32_t level);
Expand All @@ -88,11 +86,6 @@ class JoinOrderEnumerator {

bool tryPlanINLJoin(const SubqueryGraph& subgraph, const SubqueryGraph& otherSubgraph,
const std::vector<std::shared_ptr<NodeExpression>>& joinNodes);

bool canApplyINLJoin(const SubqueryGraph& subgraph, const SubqueryGraph& otherSubgraph,
const std::vector<std::shared_ptr<NodeExpression>>& joinNodes);
void planInnerINLJoin(const SubqueryGraph& subgraph, const SubqueryGraph& otherSubgraph,
const std::vector<std::shared_ptr<NodeExpression>>& joinNodes);
void planInnerHashJoin(const SubqueryGraph& subgraph, const SubqueryGraph& otherSubgraph,
std::vector<std::shared_ptr<NodeExpression>> joinNodes, bool flipPlan);
// Filter push down for hash join.
Expand Down
21 changes: 2 additions & 19 deletions src/include/planner/join_order_enumerator_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,8 @@ class JoinOrderEnumeratorContext {
queryGraph{nullptr} {}

void init(QueryGraph* queryGraph, const expression_vector& predicates);
void initPredicateCollection(const expression_vector& predicates);

inline expression_vector getSingleVarPredicates(const Expression& expression) const {
if (!predicateSet->varName2Predicates.contains(expression.getUniqueName())) {
return expression_vector{};
}
return predicateSet->varName2Predicates.at(expression.getUniqueName());
}
inline expression_vector getMultiVarPredicates() const {
return predicateSet->multiVarPredicates;
}
inline expression_vector getWhereExpressions() { return whereExpressionsSplitOnAND; }

inline bool containPlans(const SubqueryGraph& subqueryGraph) const {
return subPlansTable->containSubgraphPlans(subqueryGraph);
Expand Down Expand Up @@ -56,15 +47,7 @@ class JoinOrderEnumeratorContext {
void resetState();

private:
struct PredicateSet {
// Single-variable predicates
std::unordered_map<std::string, expression_vector> varName2Predicates;
// Multi-variable predicates
expression_vector multiVarPredicates;
};

private:
std::unique_ptr<PredicateSet> predicateSet;
expression_vector whereExpressionsSplitOnAND;

uint32_t currentLevel;
uint32_t maxLevel;
Expand Down
131 changes: 28 additions & 103 deletions src/planner/join_order_enumerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ void JoinOrderEnumerator::exitSubquery(std::unique_ptr<JoinOrderEnumeratorContex
}

void JoinOrderEnumerator::planLevel(uint32_t level) {
assert(level > 2);
assert(level > 1);
if (level > MAX_LEVEL_TO_PLAN_EXACTLY) {
planLevelApproximately(level);
} else {
Expand Down Expand Up @@ -137,27 +137,20 @@ void JoinOrderEnumerator::planNodeScan(uint32_t nodePos) {
// scanning storage twice, we keep track of node table "a" and make sure when planning inner
// query, we only scan internal ID of "a".
if (!context->nodeToScanFromInnerAndOuter(node.get())) {
appendScanNodeAndFilter(node, *plan);
appendScanNodeID(node, *plan);
auto properties = queryPlanner->getPropertiesForNode(*node);
queryPlanner->appendScanNodePropIfNecessary(properties, node, *plan);
auto predicates = getNewlyMatchedExpressions(
context->getEmptySubqueryGraph(), newSubgraph, context->getWhereExpressions());
queryPlanner->appendFilters(predicates, *plan);
} else {
appendScanNodeID(node, *plan);
}
context->addPlan(newSubgraph, std::move(plan));
}

void JoinOrderEnumerator::appendScanNodeAndFilter(
std::shared_ptr<NodeExpression> node, LogicalPlan& plan) {
appendScanNodeID(node, plan);
auto properties = queryPlanner->getPropertiesForNode(*node);
queryPlanner->appendScanNodePropIfNecessary(properties, node, plan);
auto predicates = context->getSingleVarPredicates(*node);
queryPlanner->appendFilters(predicates, plan);
}

using bound_nbr_nodes_pair_t =
std::pair<std::shared_ptr<NodeExpression>, std::shared_ptr<NodeExpression>>;

static bound_nbr_nodes_pair_t getBoundAndNbrNodes(
const RelExpression& rel, RelDirection direction) {
static std::pair<std::shared_ptr<NodeExpression>, std::shared_ptr<NodeExpression>>
getBoundAndNbrNodes(const RelExpression& rel, RelDirection direction) {
auto boundNode = direction == FWD ? rel.getSrcNode() : rel.getDstNode();
auto dstNode = direction == FWD ? rel.getDstNode() : rel.getSrcNode();
return make_pair(boundNode, dstNode);
Expand All @@ -167,22 +160,22 @@ void JoinOrderEnumerator::planRelScan(uint32_t relPos) {
auto rel = context->queryGraph->getQueryRel(relPos);
auto newSubgraph = context->getEmptySubqueryGraph();
newSubgraph.addQueryRel(relPos);
auto predicates = context->getSingleVarPredicates(*rel);
auto predicates = getNewlyMatchedExpressions(
context->getEmptySubqueryGraph(), newSubgraph, context->getWhereExpressions());
for (auto direction : REL_DIRECTIONS) {
auto plan = std::make_unique<LogicalPlan>();
auto [boundNode, _] = getBoundAndNbrNodes(*rel, direction);
appendScanNodeID(boundNode, *plan);
appendExtendAndFilter(rel, direction, *plan);
appendExtendAndFilter(rel, direction, predicates, *plan);
context->addPlan(newSubgraph, std::move(plan));
}
}

void JoinOrderEnumerator::appendExtendAndFilter(
std::shared_ptr<RelExpression> rel, common::RelDirection direction, LogicalPlan& plan) {
void JoinOrderEnumerator::appendExtendAndFilter(std::shared_ptr<RelExpression> rel,
common::RelDirection direction, const expression_vector& predicates, LogicalPlan& plan) {
auto [boundNode, dstNode] = getBoundAndNbrNodes(*rel, direction);
auto properties = queryPlanner->getPropertiesForRel(*rel);
appendExtend(boundNode, dstNode, rel, direction, properties, plan);
auto predicates = context->getSingleVarPredicates(*rel);
queryPlanner->appendFilters(predicates, plan);
}

Expand Down Expand Up @@ -228,28 +221,9 @@ void JoinOrderEnumerator::planWCOJoin(uint32_t leftLevel, uint32_t rightLevel) {
}
}

//std::shared_ptr<NodeExpression> LogicalPlanUtil::getSequentialNode(LogicalPlan& plan) {
// auto pipelineSource = getCurrentPipelineSourceOperator(plan);
// if (pipelineSource->getOperatorType() != LogicalOperatorType::SCAN_NODE) {
// // Pipeline source is not ScanNodeID, meaning at least one sink has happened (e.g. HashJoin)
// // and we loose any sequential guarantees.
// return nullptr;
// }
// return ((LogicalScanNode*)pipelineSource)->getNode();
//}
//
//LogicalOperator* LogicalPlanUtil::getCurrentPipelineSourceOperator(LogicalPlan& plan) {
// auto op = plan.getLastOperator().get();
// // Operator with more than one child will be broken into different pipelines.
// while (op->getNumChildren() == 1) {
// op = op->getChild(0).get();
// }
// assert(op != nullptr);
// return op;
//}

static LogicalScanNode* getSequentialScanNodeOperator(LogicalOperator* op) {
switch (op->getOperatorType()) {
case LogicalOperatorType::FLATTEN:
case LogicalOperatorType::FILTER:
case LogicalOperatorType::SCAN_NODE_PROPERTY:
case LogicalOperatorType::EXTEND:
Expand All @@ -266,8 +240,11 @@ static LogicalScanNode* getSequentialScanNodeOperator(LogicalOperator* op) {

// Check whether given node ID has sequential guarantee on the plan.
static bool isNodeSequentialOnPlan(LogicalPlan& plan, const NodeExpression& node) {
auto sequentialNode = getSequentialScanNodeOperator(plan.getLastOperator().get())->getNode();
return sequentialNode != nullptr && sequentialNode->getUniqueName() == node.getUniqueName();
auto sequentialScanNode = getSequentialScanNodeOperator(plan.getLastOperator().get());
if (sequentialScanNode == nullptr) {
return false;
}
return sequentialScanNode->getNode()->getUniqueName() == node.getUniqueName();
}

// As a heuristic for wcoj, we always pick rel scan that starts from the bound node.
Expand Down Expand Up @@ -313,7 +290,7 @@ void JoinOrderEnumerator::planWCOJoin(const SubqueryGraph& subgraph,
relPlans.push_back(std::move(relPlan));
}
auto predicates =
getNewlyMatchedExpressions(prevSubgraphs, newSubgraph, context->getMultiVarPredicates());
getNewlyMatchedExpressions(prevSubgraphs, newSubgraph, context->getWhereExpressions());
for (auto& leftPlan : context->getPlans(subgraph)) {
auto leftPlanCopy = leftPlan->shallowCopy();
std::vector<std::unique_ptr<LogicalPlan>> rightPlansCopy;
Expand Down Expand Up @@ -365,13 +342,10 @@ void JoinOrderEnumerator::planInnerJoin(uint32_t leftLevel, uint32_t rightLevel)
continue;
}
// If index nested loop (INL) join is possible, we prune hash join plans
if (canApplyINLJoin(rightSubgraph, nbrSubgraph, joinNodes)) {
planInnerINLJoin(rightSubgraph, nbrSubgraph, joinNodes);
} else if (canApplyINLJoin(nbrSubgraph, rightSubgraph, joinNodes)) {
planInnerINLJoin(nbrSubgraph, rightSubgraph, joinNodes);
} else {
planInnerHashJoin(rightSubgraph, nbrSubgraph, joinNodes, leftLevel != rightLevel);
if (tryPlanINLJoin(rightSubgraph, nbrSubgraph, joinNodes)) {
continue;
}
planInnerHashJoin(rightSubgraph, nbrSubgraph, joinNodes, leftLevel != rightLevel);
}
}
}
Expand All @@ -390,7 +364,7 @@ bool JoinOrderEnumerator::tryPlanINLJoin(const SubqueryGraph& subgraph,
}
auto relPos = UINT32_MAX;
for (auto i = 0u; i < context->queryGraph->getNumQueryRels(); ++i) {
if (subgraph.queryRelsSelector[i]) {
if (otherSubgraph.queryRelsSelector[i]) {
relPos = i;
}
}
Expand All @@ -401,76 +375,27 @@ bool JoinOrderEnumerator::tryPlanINLJoin(const SubqueryGraph& subgraph,
auto newSubgraph = subgraph;
newSubgraph.addQueryRel(relPos);
auto predicates =
getNewlyMatchedExpressions(subgraph, newSubgraph, context->getMultiVarPredicates());
getNewlyMatchedExpressions(subgraph, newSubgraph, context->getWhereExpressions());
bool hasAppliedINLJoin = false;
for (auto& prevPlan : context->getPlans(subgraph)) {
if (isNodeSequentialOnPlan(*prevPlan, *boundNode)) {
auto plan = prevPlan->shallowCopy();
appendExtendAndFilter(rel, direction, *plan);
appendExtendAndFilter(rel, direction, predicates, *plan);
context->addPlan(newSubgraph, std::move(plan));
hasAppliedINLJoin = true;
}
}
return hasAppliedINLJoin;
}

// We apply index nested loop join if the following to conditions are satisfied
// - otherSubgraph is an edge; and
// - join node is sequential on at least one plan corresponding to subgraph. (Otherwise INLJ will
// trigger non-sequential read).
//bool JoinOrderEnumerator::canApplyINLJoin(const SubqueryGraph& subgraph,
// const SubqueryGraph& otherSubgraph,
// const std::vector<std::shared_ptr<NodeExpression>>& joinNodes) {
// if (!otherSubgraph.isSingleRel() || joinNodes.size() > 1) {
// return false;
// }
// for (auto& plan : context->getPlans(subgraph)) {
// if (isNodeSequential(*plan, joinNodes[0].get())) {
// return true;
// }
// }
// return false;
//}
//
//static uint32_t extractJoinRelPos(const SubqueryGraph& subgraph, const QueryGraph& queryGraph) {
// for (auto relPos = 0u; relPos < queryGraph.getNumQueryRels(); ++relPos) {
// if (subgraph.queryRelsSelector[relPos]) {
// return relPos;
// }
// }
// throw InternalException("Cannot extract relPos.");
//}

void JoinOrderEnumerator::planInnerINLJoin(const SubqueryGraph& subgraph,
const SubqueryGraph& otherSubgraph,
const std::vector<std::shared_ptr<NodeExpression>>& joinNodes) {
assert(otherSubgraph.getNumQueryRels() == 1 && joinNodes.size() == 1);
auto boundNode = joinNodes[0].get();
auto queryGraph = context->getQueryGraph();
auto relPos = extractJoinRelPos(otherSubgraph, *queryGraph);
auto rel = queryGraph->getQueryRel(relPos);
auto newSubgraph = subgraph;
newSubgraph.addQueryRel(relPos);
auto predicates =
getNewlyMatchedExpressions(subgraph, newSubgraph, context->getMultiVarPredicates());
for (auto& prevPlan : context->getPlans(subgraph)) {
if (isNodeSequential(*prevPlan, boundNode)) {
auto plan = prevPlan->shallowCopy();
auto direction = boundNode->getUniqueName() == rel->getSrcNodeName() ? FWD : BWD;
planExtendAndFilters(rel, direction, predicates, *plan);
context->addPlan(newSubgraph, move(plan));
}
}
}

void JoinOrderEnumerator::planInnerHashJoin(const SubqueryGraph& subgraph,
const SubqueryGraph& otherSubgraph, std::vector<std::shared_ptr<NodeExpression>> joinNodes,
bool flipPlan) {
auto newSubgraph = subgraph;
newSubgraph.addSubqueryGraph(otherSubgraph);
auto predicates =
getNewlyMatchedExpressions(std::vector<SubqueryGraph>{subgraph, otherSubgraph}, newSubgraph,
context->getMultiVarPredicates());
context->getWhereExpressions());
for (auto& leftPlan : context->getPlans(subgraph)) {
for (auto& rightPlan : context->getPlans(otherSubgraph)) {
auto leftPlanProbeCopy = leftPlan->shallowCopy();
Expand Down
18 changes: 1 addition & 17 deletions src/planner/join_order_enumerator_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ namespace planner {

void JoinOrderEnumeratorContext::init(
QueryGraph* queryGraph_, const expression_vector& predicates) {
initPredicateCollection(predicates);
whereExpressionsSplitOnAND = predicates;
this->queryGraph = queryGraph_;
// clear and resize subPlansTable
subPlansTable->clear();
Expand All @@ -16,22 +16,6 @@ void JoinOrderEnumeratorContext::init(
currentLevel = 1;
}

void JoinOrderEnumeratorContext::initPredicateCollection(
const kuzu::binder::expression_vector& predicates) {
for (auto& predicate : predicates) {
auto dependentVarNames = predicate->getDependentVariableNames();
if (dependentVarNames.size() == 1) {
auto varName = *dependentVarNames.begin();
if (!predicateSet->varName2Predicates.contains(varName)) {
predicateSet->varName2Predicates.insert({varName, expression_vector{}});
}
predicateSet->varName2Predicates.at(varName).push_back(predicate);
} else {
predicateSet->multiVarPredicates.push_back(predicate);
}
}
}

SubqueryGraph JoinOrderEnumeratorContext::getFullyMatchedSubqueryGraph() const {
auto subqueryGraph = SubqueryGraph(*queryGraph);
for (auto i = 0u; i < queryGraph->getNumQueryNodes(); ++i) {
Expand Down
2 changes: 0 additions & 2 deletions src/planner/operator/logical_plan_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ using namespace kuzu::binder;
namespace kuzu {
namespace planner {



void LogicalPlanUtil::encodeJoinRecursive(
LogicalOperator* logicalOperator, std::string& encodeString) {
switch (logicalOperator->getOperatorType()) {
Expand Down

0 comments on commit 4b47530

Please sign in to comment.