Skip to content

Commit

Permalink
Add logical operator visitor
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed Mar 2, 2023
1 parent f498110 commit cdbb48d
Show file tree
Hide file tree
Showing 11 changed files with 69 additions and 145 deletions.
17 changes: 8 additions & 9 deletions src/include/optimizer/index_nested_loop_join_optimizer.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#pragma once

#include <vector>

#include "logical_operator_visitor.h"
#include "planner/logical_plan/logical_plan.h"

namespace kuzu {
Expand All @@ -14,22 +13,22 @@ namespace optimizer {
// In the absense of a generic hash join operator.
// We should merge this operator to filter push down + ASP when the generic hash join is
// implemented.
class IndexNestedLoopJoinOptimizer {
class IndexNestedLoopJoinOptimizer : public LogicalOperatorVisitor {
public:
static void rewrite(planner::LogicalPlan* plan);
void rewrite(planner::LogicalPlan* plan);

private:
static std::shared_ptr<planner::LogicalOperator> rewrite(
std::shared_ptr<planner::LogicalOperator> visitOperator(
std::shared_ptr<planner::LogicalOperator> op);

static std::shared_ptr<planner::LogicalOperator> rewriteFilter(
std::shared_ptr<planner::LogicalOperator> op);
std::shared_ptr<planner::LogicalOperator> visitFilterReplace(
std::shared_ptr<planner::LogicalOperator> op) override;

static std::shared_ptr<planner::LogicalOperator> rewriteCrossProduct(
std::shared_ptr<planner::LogicalOperator> rewriteCrossProduct(
std::shared_ptr<planner::LogicalOperator> op,
std::shared_ptr<binder::Expression> predicate);

static planner::LogicalOperator* searchScanNodeOnPipeline(planner::LogicalOperator* op);
planner::LogicalOperator* searchScanNodeOnPipeline(planner::LogicalOperator* op);
};

} // namespace optimizer
Expand Down
14 changes: 13 additions & 1 deletion src/include/optimizer/logical_operator_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ class LogicalOperatorVisitor {
return op;
}

virtual void visitAccumulate(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitAccumulateReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}

virtual void visitDistinct(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitDistinctReplace(
std::shared_ptr<planner::LogicalOperator> op) {
Expand Down Expand Up @@ -105,6 +111,12 @@ class LogicalOperatorVisitor {
return op;
}

virtual void visitDeleteNode(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitDeleteNodeReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}

virtual void visitDeleteRel(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitDeleteRelReplace(
std::shared_ptr<planner::LogicalOperator> op) {
Expand All @@ -125,4 +137,4 @@ class LogicalOperatorVisitor {
};

} // namespace optimizer
} // namespace kuzu
} // namespace kuzu
29 changes: 15 additions & 14 deletions src/include/optimizer/projection_push_down_optimizer.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include "logical_operator_visitor.h"
#include "planner/logical_plan/logical_plan.h"

namespace kuzu {
Expand All @@ -11,26 +12,26 @@ namespace optimizer {
// it's hard to figure out what expression is in-use, e.g. COUNT(a.age) + 1, it could be either the
// whole expression was evaluated in a WITH clause or only COUNT(a.age) was evaluated or only a.age
// is evaluate. For simplicity, we only consider the push down for property.
class ProjectionPushDownOptimizer {
class ProjectionPushDownOptimizer : public LogicalOperatorVisitor {
public:
void rewrite(planner::LogicalPlan* plan);

private:
void visitOperator(planner::LogicalOperator* op);

void visitAccumulate(planner::LogicalOperator* op);
void visitFilter(planner::LogicalOperator* op);
void visitHashJoin(planner::LogicalOperator* op);
void visitIntersect(planner::LogicalOperator* op);
void visitProjection(planner::LogicalOperator* op);
void visitOrderBy(planner::LogicalOperator* op);
void visitUnwind(planner::LogicalOperator* op);
void visitSetNodeProperty(planner::LogicalOperator* op);
void visitSetRelProperty(planner::LogicalOperator* op);
void visitCreateNode(planner::LogicalOperator* op);
void visitCreateRel(planner::LogicalOperator* op);
void visitDeleteNode(planner::LogicalOperator* op);
void visitDeleteRel(planner::LogicalOperator* op);
void visitAccumulate(planner::LogicalOperator* op) override;
void visitFilter(planner::LogicalOperator* op) override;
void visitHashJoin(planner::LogicalOperator* op) override;
void visitIntersect(planner::LogicalOperator* op) override;
void visitProjection(planner::LogicalOperator* op) override;
void visitOrderBy(planner::LogicalOperator* op) override;
void visitUnwind(planner::LogicalOperator* op) override;
void visitSetNodeProperty(planner::LogicalOperator* op) override;
void visitSetRelProperty(planner::LogicalOperator* op) override;
void visitCreateNode(planner::LogicalOperator* op) override;
void visitCreateRel(planner::LogicalOperator* op) override;
void visitDeleteNode(planner::LogicalOperator* op) override;
void visitDeleteRel(planner::LogicalOperator* op) override;

void collectPropertiesInUse(std::shared_ptr<binder::Expression> expression);

Expand Down
56 changes: 0 additions & 56 deletions src/optimizer/factorization_rewriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,62 +34,6 @@ void FactorizationRewriter::visitOperator(planner::LogicalOperator* op) {
}
visitOperatorSwitch(op);
op->computeSchema();
// switch (op->getOperatorType()) {
// case LogicalOperatorType::EXTEND: {
// visitExtend(op);
// } break;
// case LogicalOperatorType::HASH_JOIN: {
// visitHashJoin(op);
// } break;
// case LogicalOperatorType::INTERSECT: {
// visitIntersect(op);
// } break;
// case LogicalOperatorType::PROJECTION: {
// visitProjection(op);
// } break;
// case LogicalOperatorType::AGGREGATE: {
// visitAggregate(op);
// } break;
// case LogicalOperatorType::ORDER_BY: {
// visitOrderBy(op);
// } break;
// case LogicalOperatorType::SKIP: {
// visitSkip(op);
// } break;
// case LogicalOperatorType::LIMIT: {
// visitLimit(op);
// } break;
// case LogicalOperatorType::DISTINCT: {
// visitDistinct(op);
// } break;
// case LogicalOperatorType::UNWIND: {
// visitUnwind(op);
// } break;
// case LogicalOperatorType::UNION_ALL: {
// visitUnion(op);
// } break;
// case LogicalOperatorType::FILTER: {
// visitFilter(op);
// } break;
// case LogicalOperatorType::SET_NODE_PROPERTY: {
// visitSetNodeProperty(op);
// } break;
// case LogicalOperatorType::SET_REL_PROPERTY: {
// visitSetRelProperty(op);
// } break;
// case LogicalOperatorType::DELETE_REL: {
// visitDeleteRel(op);
// } break;
// case LogicalOperatorType::CREATE_NODE: {
// visitCreateNode(op);
// } break;
// case LogicalOperatorType::CREATE_REL: {
// visitCreateRel(op);
// } break;
// default:
// break;
// }

}

void FactorizationRewriter::visitExtend(planner::LogicalOperator* op) {
Expand Down
13 changes: 5 additions & 8 deletions src/optimizer/index_nested_loop_join_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,15 @@ namespace kuzu {
namespace optimizer {

void IndexNestedLoopJoinOptimizer::rewrite(planner::LogicalPlan* plan) {
rewrite(plan->getLastOperator());
visitOperator(plan->getLastOperator());
}

std::shared_ptr<planner::LogicalOperator> IndexNestedLoopJoinOptimizer::rewrite(
std::shared_ptr<planner::LogicalOperator> IndexNestedLoopJoinOptimizer::visitOperator(
std::shared_ptr<planner::LogicalOperator> op) {
if (op->getOperatorType() == LogicalOperatorType::FILTER) {
return rewriteFilter(op);
}
for (auto i = 0u; i < op->getNumChildren(); ++i) {
op->setChild(i, rewrite(op->getChild(i)));
op->setChild(i, visitOperator(op->getChild(i)));
}
return op;
return visitOperatorReplaceSwitch(op);
}

static bool isPrimaryKey(const binder::Expression& expression) {
Expand All @@ -44,7 +41,7 @@ static bool isPrimaryKeyEqualityComparison(const Expression& expression) {
return false;
}

std::shared_ptr<LogicalOperator> IndexNestedLoopJoinOptimizer::rewriteFilter(
std::shared_ptr<LogicalOperator> IndexNestedLoopJoinOptimizer::visitFilterReplace(
std::shared_ptr<LogicalOperator> op) {
// Match filter on primary key
auto filter = (LogicalFilter*)op.get();
Expand Down
12 changes: 12 additions & 0 deletions src/optimizer/logical_operator_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ void LogicalOperatorVisitor::visitOperatorSwitch(planner::LogicalOperator* op) {
case LogicalOperatorType::LIMIT: {
visitLimit(op);
} break;
case LogicalOperatorType::ACCUMULATE: {
visitAccumulate(op);
} break;
case LogicalOperatorType::DISTINCT: {
visitDistinct(op);
} break;
Expand All @@ -52,6 +55,9 @@ void LogicalOperatorVisitor::visitOperatorSwitch(planner::LogicalOperator* op) {
case LogicalOperatorType::SET_REL_PROPERTY: {
visitSetRelProperty(op);
} break;
case LogicalOperatorType::DELETE_NODE: {
visitDeleteNode(op);
} break;
case LogicalOperatorType::DELETE_REL: {
visitDeleteRel(op);
} break;
Expand Down Expand Up @@ -96,6 +102,9 @@ std::shared_ptr<planner::LogicalOperator> LogicalOperatorVisitor::visitOperatorR
case LogicalOperatorType::LIMIT: {
return visitLimitReplace(op);
}
case LogicalOperatorType::ACCUMULATE: {
return visitAccumulateReplace(op);
}
case LogicalOperatorType::DISTINCT: {
return visitDistinctReplace(op);
}
Expand All @@ -114,6 +123,9 @@ std::shared_ptr<planner::LogicalOperator> LogicalOperatorVisitor::visitOperatorR
case LogicalOperatorType::SET_REL_PROPERTY: {
return visitSetRelPropertyReplace(op);
}
case LogicalOperatorType::DELETE_NODE: {
return visitDeleteNodeReplace(op);
}
case LogicalOperatorType::DELETE_REL: {
return visitDeleteRelReplace(op);
}
Expand Down
3 changes: 2 additions & 1 deletion src/optimizer/optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ void Optimizer::optimize(planner::LogicalPlan* plan) {
auto removeUnnecessaryJoinOptimizer = RemoveUnnecessaryJoinOptimizer();
removeUnnecessaryJoinOptimizer.rewrite(plan);

IndexNestedLoopJoinOptimizer::rewrite(plan);
auto indexNestedLoopJoinOptimizer = IndexNestedLoopJoinOptimizer();
indexNestedLoopJoinOptimizer.rewrite(plan);

auto projectionPushDownOptimizer = ProjectionPushDownOptimizer();
projectionPushDownOptimizer.rewrite(plan);
Expand Down
46 changes: 4 additions & 42 deletions src/optimizer/projection_push_down_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,50 +23,12 @@ void ProjectionPushDownOptimizer::rewrite(planner::LogicalPlan* plan) {
}

void ProjectionPushDownOptimizer::visitOperator(LogicalOperator* op) {
switch (op->getOperatorType()) {
case LogicalOperatorType::ACCUMULATE: {
visitAccumulate(op);
} break;
case LogicalOperatorType::FILTER: {
visitFilter(op);
} break;
case LogicalOperatorType::HASH_JOIN: {
visitHashJoin(op);
} break;
case LogicalOperatorType::PROJECTION: {
visitProjection(op);
visitOperatorSwitch(op);
if (op->getOperatorType() == LogicalOperatorType::PROJECTION) {
// We will start a new optimizer once a projection is encountered.
return;
}
case LogicalOperatorType::INTERSECT: {
visitIntersect(op);
} break;
case LogicalOperatorType::ORDER_BY: {
visitOrderBy(op);
} break;
case LogicalOperatorType::UNWIND: {
visitUnwind(op);
} break;
case LogicalOperatorType::CREATE_NODE: {
visitCreateNode(op);
} break;
case LogicalOperatorType::CREATE_REL: {
visitCreateRel(op);
} break;
case LogicalOperatorType::DELETE_NODE: {
visitDeleteNode(op);
} break;
case LogicalOperatorType::DELETE_REL: {
visitDeleteRel(op);
} break;
case LogicalOperatorType::SET_NODE_PROPERTY: {
visitSetNodeProperty(op);
} break;
case LogicalOperatorType::SET_REL_PROPERTY: {
visitSetRelProperty(op);
} break;
default:
break;
}
// top-down traversal
for (auto i = 0; i < op->getNumChildren(); ++i) {
visitOperator(op->getChild(i).get());
}
Expand Down
11 changes: 0 additions & 11 deletions test/test_files/tinysnb/match/node.test
Original file line number Diff line number Diff line change
@@ -1,14 +1,3 @@
-NAME ReturnIntRelProp
-QUERY MATCH (a:person)-[e:marries]->(b:person) RETURN e.times
---- 7
5
2
3
7
9
11
13

-NAME node1
-QUERY MATCH (a:person) RETURN COUNT(*)
---- 1
Expand Down
2 changes: 1 addition & 1 deletion test/test_files/tinysnb/projection/single_label.test
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ Dan|Carol
[3,9]

-NAME ReturnIntRelProp
-QUERY MATCH (a:person)-[e:marries]->(b:person) RETURN e.times
-QUERY MATCH (a:person)-[e:meets]->(b:person) RETURN e.times
---- 7
5
2
Expand Down
11 changes: 9 additions & 2 deletions test/test_helper/test_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,19 @@ void TestHelper::initializeConnection(TestQueryConfig* config, Connection& conn)
bool TestHelper::testQuery(TestQueryConfig* config, Connection& conn) {
initializeConnection(config, conn);
auto preparedStatement = conn.prepareNoLock(config->query, config->enumerate);
assert(!preparedStatement->logicalPlans.empty());
if (!preparedStatement->isSuccess()) {
spdlog::error(preparedStatement->getErrorMessage());
return false;
}
auto numPlans = preparedStatement->logicalPlans.size();
if (numPlans == 0) {
spdlog::error("Query {} has no plans" + config->name);
return false;
}
auto numPassedPlans = 0u;
for (auto i = 0u; i < numPlans; ++i) {
auto plan = preparedStatement->logicalPlans[i].get();
auto planStr = preparedStatement->logicalPlans[i]->toString();
auto planStr = plan->toString();
auto result = conn.executeAndAutoCommitIfNecessaryNoLock(preparedStatement.get(), i);
assert(result->isSuccess());
std::vector<std::string> resultTuples =
Expand Down

0 comments on commit cdbb48d

Please sign in to comment.