Skip to content

Commit

Permalink
Merge pull request #1333 from kuzudb/logical-operator-visitor
Browse files Browse the repository at this point in the history
Add logical operator visitor
  • Loading branch information
andyfengHKU committed Mar 2, 2023
2 parents d025c83 + 409379f commit 308fe34
Show file tree
Hide file tree
Showing 16 changed files with 404 additions and 185 deletions.
37 changes: 19 additions & 18 deletions src/include/optimizer/factorization_rewriter.h
Original file line number Diff line number Diff line change
@@ -1,31 +1,32 @@
#include "logical_operator_visitor.h"
#include "planner/logical_plan/logical_plan.h"

namespace kuzu {
namespace optimizer {

class FactorizationRewriter {
class FactorizationRewriter : public LogicalOperatorVisitor {
public:
void rewrite(planner::LogicalPlan* plan);

private:
void visitOperator(planner::LogicalOperator* op);
void visitExtend(planner::LogicalOperator* op);
void visitHashJoin(planner::LogicalOperator* op);
void visitIntersect(planner::LogicalOperator* op);
void visitProjection(planner::LogicalOperator* op);
void visitAggregate(planner::LogicalOperator* op);
void visitOrderBy(planner::LogicalOperator* op);
void visitSkip(planner::LogicalOperator* op);
void visitLimit(planner::LogicalOperator* op);
void visitDistinct(planner::LogicalOperator* op);
void visitUnwind(planner::LogicalOperator* op);
void visitUnion(planner::LogicalOperator* op);
void visitFilter(planner::LogicalOperator* op);
void visitSetNodeProperty(planner::LogicalOperator* op);
void visitSetRelProperty(planner::LogicalOperator* op);
void visitDeleteRel(planner::LogicalOperator* op);
void visitCreateNode(planner::LogicalOperator* op);
void visitCreateRel(planner::LogicalOperator* op);
void visitExtend(planner::LogicalOperator* op) override;
void visitHashJoin(planner::LogicalOperator* op) override;
void visitIntersect(planner::LogicalOperator* op) override;
void visitProjection(planner::LogicalOperator* op) override;
void visitAggregate(planner::LogicalOperator* op) override;
void visitOrderBy(planner::LogicalOperator* op) override;
void visitSkip(planner::LogicalOperator* op) override;
void visitLimit(planner::LogicalOperator* op) override;
void visitDistinct(planner::LogicalOperator* op) override;
void visitUnwind(planner::LogicalOperator* op) override;
void visitUnion(planner::LogicalOperator* op) override;
void visitFilter(planner::LogicalOperator* op) override;
void visitSetNodeProperty(planner::LogicalOperator* op) override;
void visitSetRelProperty(planner::LogicalOperator* op) override;
void visitDeleteRel(planner::LogicalOperator* op) override;
void visitCreateNode(planner::LogicalOperator* op) override;
void visitCreateRel(planner::LogicalOperator* op) override;

std::shared_ptr<planner::LogicalOperator> appendFlattens(
std::shared_ptr<planner::LogicalOperator> op,
Expand Down
17 changes: 8 additions & 9 deletions src/include/optimizer/index_nested_loop_join_optimizer.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#pragma once

#include <vector>

#include "logical_operator_visitor.h"
#include "planner/logical_plan/logical_plan.h"

namespace kuzu {
Expand All @@ -14,22 +13,22 @@ namespace optimizer {
// In the absense of a generic hash join operator.
// We should merge this operator to filter push down + ASP when the generic hash join is
// implemented.
class IndexNestedLoopJoinOptimizer {
class IndexNestedLoopJoinOptimizer : public LogicalOperatorVisitor {
public:
static void rewrite(planner::LogicalPlan* plan);
void rewrite(planner::LogicalPlan* plan);

private:
static std::shared_ptr<planner::LogicalOperator> rewrite(
std::shared_ptr<planner::LogicalOperator> visitOperator(
std::shared_ptr<planner::LogicalOperator> op);

static std::shared_ptr<planner::LogicalOperator> rewriteFilter(
std::shared_ptr<planner::LogicalOperator> op);
std::shared_ptr<planner::LogicalOperator> visitFilterReplace(
std::shared_ptr<planner::LogicalOperator> op) override;

static std::shared_ptr<planner::LogicalOperator> rewriteCrossProduct(
std::shared_ptr<planner::LogicalOperator> rewriteCrossProduct(
std::shared_ptr<planner::LogicalOperator> op,
std::shared_ptr<binder::Expression> predicate);

static planner::LogicalOperator* searchScanNodeOnPipeline(planner::LogicalOperator* op);
planner::LogicalOperator* searchScanNodeOnPipeline(planner::LogicalOperator* op);
};

} // namespace optimizer
Expand Down
140 changes: 140 additions & 0 deletions src/include/optimizer/logical_operator_visitor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
#pragma once

#include "planner/logical_plan/logical_operator/base_logical_operator.h"

namespace kuzu {
namespace optimizer {

class LogicalOperatorVisitor {
public:
LogicalOperatorVisitor() = default;
virtual ~LogicalOperatorVisitor() = default;

protected:
void visitOperatorSwitch(planner::LogicalOperator* op);
std::shared_ptr<planner::LogicalOperator> visitOperatorReplaceSwitch(
std::shared_ptr<planner::LogicalOperator> op);

virtual void visitFlatten(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitFlattenReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}

virtual void visitExtend(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitExtendReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}

virtual void visitHashJoin(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitHashJoinReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}

virtual void visitIntersect(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitIntersectReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}

virtual void visitProjection(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitProjectionReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}

virtual void visitAggregate(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitAggregateReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}

virtual void visitOrderBy(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitOrderByReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}

virtual void visitSkip(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitSkipReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}

virtual void visitLimit(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitLimitReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}

virtual void visitAccumulate(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitAccumulateReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}

virtual void visitDistinct(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitDistinctReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}

virtual void visitUnwind(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitUnwindReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}

virtual void visitUnion(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitUnionReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}

virtual void visitFilter(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitFilterReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}

virtual void visitSetNodeProperty(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitSetNodePropertyReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}

virtual void visitSetRelProperty(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitSetRelPropertyReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}

virtual void visitDeleteNode(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitDeleteNodeReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}

virtual void visitDeleteRel(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitDeleteRelReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}

virtual void visitCreateNode(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitCreateNodeReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}

virtual void visitCreateRel(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitCreateRelReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}
};

} // namespace optimizer
} // namespace kuzu
29 changes: 15 additions & 14 deletions src/include/optimizer/projection_push_down_optimizer.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include "logical_operator_visitor.h"
#include "planner/logical_plan/logical_plan.h"

namespace kuzu {
Expand All @@ -11,26 +12,26 @@ namespace optimizer {
// it's hard to figure out what expression is in-use, e.g. COUNT(a.age) + 1, it could be either the
// whole expression was evaluated in a WITH clause or only COUNT(a.age) was evaluated or only a.age
// is evaluate. For simplicity, we only consider the push down for property.
class ProjectionPushDownOptimizer {
class ProjectionPushDownOptimizer : public LogicalOperatorVisitor {
public:
void rewrite(planner::LogicalPlan* plan);

private:
void visitOperator(planner::LogicalOperator* op);

void visitAccumulate(planner::LogicalOperator* op);
void visitFilter(planner::LogicalOperator* op);
void visitHashJoin(planner::LogicalOperator* op);
void visitIntersect(planner::LogicalOperator* op);
void visitProjection(planner::LogicalOperator* op);
void visitOrderBy(planner::LogicalOperator* op);
void visitUnwind(planner::LogicalOperator* op);
void visitSetNodeProperty(planner::LogicalOperator* op);
void visitSetRelProperty(planner::LogicalOperator* op);
void visitCreateNode(planner::LogicalOperator* op);
void visitCreateRel(planner::LogicalOperator* op);
void visitDeleteNode(planner::LogicalOperator* op);
void visitDeleteRel(planner::LogicalOperator* op);
void visitAccumulate(planner::LogicalOperator* op) override;
void visitFilter(planner::LogicalOperator* op) override;
void visitHashJoin(planner::LogicalOperator* op) override;
void visitIntersect(planner::LogicalOperator* op) override;
void visitProjection(planner::LogicalOperator* op) override;
void visitOrderBy(planner::LogicalOperator* op) override;
void visitUnwind(planner::LogicalOperator* op) override;
void visitSetNodeProperty(planner::LogicalOperator* op) override;
void visitSetRelProperty(planner::LogicalOperator* op) override;
void visitCreateNode(planner::LogicalOperator* op) override;
void visitCreateRel(planner::LogicalOperator* op) override;
void visitDeleteNode(planner::LogicalOperator* op) override;
void visitDeleteRel(planner::LogicalOperator* op) override;

void collectPropertiesInUse(std::shared_ptr<binder::Expression> expression);

Expand Down
25 changes: 22 additions & 3 deletions src/include/optimizer/remove_factorization_rewriter.h
Original file line number Diff line number Diff line change
@@ -1,17 +1,36 @@
#pragma once

#include "logical_operator_visitor.h"
#include "planner/logical_plan/logical_plan.h"

namespace kuzu {
namespace optimizer {

class RemoveFactorizationRewriter {
class RemoveFactorizationRewriter : public LogicalOperatorVisitor {
public:
void rewrite(planner::LogicalPlan* plan);

private:
std::shared_ptr<planner::LogicalOperator> rewriteOperator(
std::shared_ptr<planner::LogicalOperator> visitOperator(
std::shared_ptr<planner::LogicalOperator> op);

bool subPlanHasFlatten(planner::LogicalOperator* op);
std::shared_ptr<planner::LogicalOperator> visitFlattenReplace(
std::shared_ptr<planner::LogicalOperator> op) override;

class Verifier : public LogicalOperatorVisitor {
public:
Verifier() : containsFlatten_{false} {}

inline bool containsFlatten() const { return containsFlatten_; }

void visit(planner::LogicalOperator* op);

private:
void visitFlatten(planner::LogicalOperator* op) override;

private:
bool containsFlatten_;
};
};

} // namespace optimizer
Expand Down
12 changes: 7 additions & 5 deletions src/include/optimizer/remove_unnecessary_join_optimizer.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include "logical_operator_visitor.h"
#include "planner/logical_plan/logical_plan.h"

namespace kuzu {
Expand All @@ -15,15 +16,16 @@ namespace optimizer {
// |
// S(a)
// This optimizer prunes such redundant joins.
class RemoveUnnecessaryJoinOptimizer {
class RemoveUnnecessaryJoinOptimizer : public LogicalOperatorVisitor {
public:
static void rewrite(planner::LogicalPlan* plan);
void rewrite(planner::LogicalPlan* plan);

private:
static std::shared_ptr<planner::LogicalOperator> visitOperator(
std::shared_ptr<planner::LogicalOperator> op);
static std::shared_ptr<planner::LogicalOperator> visitHashJoin(
std::shared_ptr<planner::LogicalOperator> visitOperator(
std::shared_ptr<planner::LogicalOperator> op);

std::shared_ptr<planner::LogicalOperator> visitHashJoinReplace(
std::shared_ptr<planner::LogicalOperator> op) override;
};

} // namespace optimizer
Expand Down
1 change: 1 addition & 0 deletions src/optimizer/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ add_library(kuzu_optimizer
OBJECT
factorization_rewriter.cpp
index_nested_loop_join_optimizer.cpp
logical_operator_visitor.cpp
optimizer.cpp
projection_push_down_optimizer.cpp
remove_factorization_rewriter.cpp
Expand Down
Loading

0 comments on commit 308fe34

Please sign in to comment.