Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add logical operator visitor #1333

Merged
merged 1 commit into from
Mar 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 19 additions & 18 deletions src/include/optimizer/factorization_rewriter.h
Original file line number Diff line number Diff line change
@@ -1,31 +1,32 @@
#include "logical_operator_visitor.h"
#include "planner/logical_plan/logical_plan.h"

namespace kuzu {
namespace optimizer {

class FactorizationRewriter {
class FactorizationRewriter : public LogicalOperatorVisitor {
public:
void rewrite(planner::LogicalPlan* plan);

private:
void visitOperator(planner::LogicalOperator* op);
void visitExtend(planner::LogicalOperator* op);
void visitHashJoin(planner::LogicalOperator* op);
void visitIntersect(planner::LogicalOperator* op);
void visitProjection(planner::LogicalOperator* op);
void visitAggregate(planner::LogicalOperator* op);
void visitOrderBy(planner::LogicalOperator* op);
void visitSkip(planner::LogicalOperator* op);
void visitLimit(planner::LogicalOperator* op);
void visitDistinct(planner::LogicalOperator* op);
void visitUnwind(planner::LogicalOperator* op);
void visitUnion(planner::LogicalOperator* op);
void visitFilter(planner::LogicalOperator* op);
void visitSetNodeProperty(planner::LogicalOperator* op);
void visitSetRelProperty(planner::LogicalOperator* op);
void visitDeleteRel(planner::LogicalOperator* op);
void visitCreateNode(planner::LogicalOperator* op);
void visitCreateRel(planner::LogicalOperator* op);
void visitExtend(planner::LogicalOperator* op) override;
void visitHashJoin(planner::LogicalOperator* op) override;
void visitIntersect(planner::LogicalOperator* op) override;
void visitProjection(planner::LogicalOperator* op) override;
void visitAggregate(planner::LogicalOperator* op) override;
void visitOrderBy(planner::LogicalOperator* op) override;
void visitSkip(planner::LogicalOperator* op) override;
void visitLimit(planner::LogicalOperator* op) override;
void visitDistinct(planner::LogicalOperator* op) override;
void visitUnwind(planner::LogicalOperator* op) override;
void visitUnion(planner::LogicalOperator* op) override;
void visitFilter(planner::LogicalOperator* op) override;
void visitSetNodeProperty(planner::LogicalOperator* op) override;
void visitSetRelProperty(planner::LogicalOperator* op) override;
void visitDeleteRel(planner::LogicalOperator* op) override;
void visitCreateNode(planner::LogicalOperator* op) override;
void visitCreateRel(planner::LogicalOperator* op) override;

std::shared_ptr<planner::LogicalOperator> appendFlattens(
std::shared_ptr<planner::LogicalOperator> op,
Expand Down
17 changes: 8 additions & 9 deletions src/include/optimizer/index_nested_loop_join_optimizer.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#pragma once

#include <vector>

#include "logical_operator_visitor.h"
#include "planner/logical_plan/logical_plan.h"

namespace kuzu {
Expand All @@ -14,22 +13,22 @@ namespace optimizer {
// In the absense of a generic hash join operator.
// We should merge this operator to filter push down + ASP when the generic hash join is
// implemented.
class IndexNestedLoopJoinOptimizer {
class IndexNestedLoopJoinOptimizer : public LogicalOperatorVisitor {
public:
static void rewrite(planner::LogicalPlan* plan);
void rewrite(planner::LogicalPlan* plan);

private:
static std::shared_ptr<planner::LogicalOperator> rewrite(
std::shared_ptr<planner::LogicalOperator> visitOperator(
std::shared_ptr<planner::LogicalOperator> op);

static std::shared_ptr<planner::LogicalOperator> rewriteFilter(
std::shared_ptr<planner::LogicalOperator> op);
std::shared_ptr<planner::LogicalOperator> visitFilterReplace(
std::shared_ptr<planner::LogicalOperator> op) override;

static std::shared_ptr<planner::LogicalOperator> rewriteCrossProduct(
std::shared_ptr<planner::LogicalOperator> rewriteCrossProduct(
std::shared_ptr<planner::LogicalOperator> op,
std::shared_ptr<binder::Expression> predicate);

static planner::LogicalOperator* searchScanNodeOnPipeline(planner::LogicalOperator* op);
planner::LogicalOperator* searchScanNodeOnPipeline(planner::LogicalOperator* op);
};

} // namespace optimizer
Expand Down
140 changes: 140 additions & 0 deletions src/include/optimizer/logical_operator_visitor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
#pragma once

#include "planner/logical_plan/logical_operator/base_logical_operator.h"

namespace kuzu {
namespace optimizer {

class LogicalOperatorVisitor {
public:
LogicalOperatorVisitor() = default;
virtual ~LogicalOperatorVisitor() = default;

protected:
void visitOperatorSwitch(planner::LogicalOperator* op);
std::shared_ptr<planner::LogicalOperator> visitOperatorReplaceSwitch(
std::shared_ptr<planner::LogicalOperator> op);

virtual void visitFlatten(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitFlattenReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}

virtual void visitExtend(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitExtendReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}

virtual void visitHashJoin(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitHashJoinReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}

virtual void visitIntersect(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitIntersectReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}

virtual void visitProjection(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitProjectionReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}

virtual void visitAggregate(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitAggregateReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}

virtual void visitOrderBy(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitOrderByReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}

virtual void visitSkip(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitSkipReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}

virtual void visitLimit(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitLimitReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}

virtual void visitAccumulate(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitAccumulateReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}

virtual void visitDistinct(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitDistinctReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}

virtual void visitUnwind(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitUnwindReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}

virtual void visitUnion(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitUnionReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}

virtual void visitFilter(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitFilterReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}

virtual void visitSetNodeProperty(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitSetNodePropertyReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}

virtual void visitSetRelProperty(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitSetRelPropertyReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}

virtual void visitDeleteNode(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitDeleteNodeReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}

virtual void visitDeleteRel(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitDeleteRelReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}

virtual void visitCreateNode(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitCreateNodeReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}

virtual void visitCreateRel(planner::LogicalOperator* op) {}
virtual std::shared_ptr<planner::LogicalOperator> visitCreateRelReplace(
std::shared_ptr<planner::LogicalOperator> op) {
return op;
}
};

} // namespace optimizer
} // namespace kuzu
29 changes: 15 additions & 14 deletions src/include/optimizer/projection_push_down_optimizer.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include "logical_operator_visitor.h"
#include "planner/logical_plan/logical_plan.h"

namespace kuzu {
Expand All @@ -11,26 +12,26 @@ namespace optimizer {
// it's hard to figure out what expression is in-use, e.g. COUNT(a.age) + 1, it could be either the
// whole expression was evaluated in a WITH clause or only COUNT(a.age) was evaluated or only a.age
// is evaluate. For simplicity, we only consider the push down for property.
class ProjectionPushDownOptimizer {
class ProjectionPushDownOptimizer : public LogicalOperatorVisitor {
public:
void rewrite(planner::LogicalPlan* plan);

private:
void visitOperator(planner::LogicalOperator* op);

void visitAccumulate(planner::LogicalOperator* op);
void visitFilter(planner::LogicalOperator* op);
void visitHashJoin(planner::LogicalOperator* op);
void visitIntersect(planner::LogicalOperator* op);
void visitProjection(planner::LogicalOperator* op);
void visitOrderBy(planner::LogicalOperator* op);
void visitUnwind(planner::LogicalOperator* op);
void visitSetNodeProperty(planner::LogicalOperator* op);
void visitSetRelProperty(planner::LogicalOperator* op);
void visitCreateNode(planner::LogicalOperator* op);
void visitCreateRel(planner::LogicalOperator* op);
void visitDeleteNode(planner::LogicalOperator* op);
void visitDeleteRel(planner::LogicalOperator* op);
void visitAccumulate(planner::LogicalOperator* op) override;
void visitFilter(planner::LogicalOperator* op) override;
void visitHashJoin(planner::LogicalOperator* op) override;
void visitIntersect(planner::LogicalOperator* op) override;
void visitProjection(planner::LogicalOperator* op) override;
void visitOrderBy(planner::LogicalOperator* op) override;
void visitUnwind(planner::LogicalOperator* op) override;
void visitSetNodeProperty(planner::LogicalOperator* op) override;
void visitSetRelProperty(planner::LogicalOperator* op) override;
void visitCreateNode(planner::LogicalOperator* op) override;
void visitCreateRel(planner::LogicalOperator* op) override;
void visitDeleteNode(planner::LogicalOperator* op) override;
void visitDeleteRel(planner::LogicalOperator* op) override;

void collectPropertiesInUse(std::shared_ptr<binder::Expression> expression);

Expand Down
25 changes: 22 additions & 3 deletions src/include/optimizer/remove_factorization_rewriter.h
Original file line number Diff line number Diff line change
@@ -1,17 +1,36 @@
#pragma once

#include "logical_operator_visitor.h"
#include "planner/logical_plan/logical_plan.h"

namespace kuzu {
namespace optimizer {

class RemoveFactorizationRewriter {
class RemoveFactorizationRewriter : public LogicalOperatorVisitor {
public:
void rewrite(planner::LogicalPlan* plan);

private:
std::shared_ptr<planner::LogicalOperator> rewriteOperator(
std::shared_ptr<planner::LogicalOperator> visitOperator(
std::shared_ptr<planner::LogicalOperator> op);

bool subPlanHasFlatten(planner::LogicalOperator* op);
std::shared_ptr<planner::LogicalOperator> visitFlattenReplace(
std::shared_ptr<planner::LogicalOperator> op) override;

class Verifier : public LogicalOperatorVisitor {
public:
Verifier() : containsFlatten_{false} {}

inline bool containsFlatten() const { return containsFlatten_; }

void visit(planner::LogicalOperator* op);

private:
void visitFlatten(planner::LogicalOperator* op) override;

private:
bool containsFlatten_;
};
};

} // namespace optimizer
Expand Down
12 changes: 7 additions & 5 deletions src/include/optimizer/remove_unnecessary_join_optimizer.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include "logical_operator_visitor.h"
#include "planner/logical_plan/logical_plan.h"

namespace kuzu {
Expand All @@ -15,15 +16,16 @@ namespace optimizer {
// |
// S(a)
// This optimizer prunes such redundant joins.
class RemoveUnnecessaryJoinOptimizer {
class RemoveUnnecessaryJoinOptimizer : public LogicalOperatorVisitor {
public:
static void rewrite(planner::LogicalPlan* plan);
void rewrite(planner::LogicalPlan* plan);

private:
static std::shared_ptr<planner::LogicalOperator> visitOperator(
std::shared_ptr<planner::LogicalOperator> op);
static std::shared_ptr<planner::LogicalOperator> visitHashJoin(
std::shared_ptr<planner::LogicalOperator> visitOperator(
std::shared_ptr<planner::LogicalOperator> op);

std::shared_ptr<planner::LogicalOperator> visitHashJoinReplace(
std::shared_ptr<planner::LogicalOperator> op) override;
};

} // namespace optimizer
Expand Down
1 change: 1 addition & 0 deletions src/optimizer/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ add_library(kuzu_optimizer
OBJECT
factorization_rewriter.cpp
index_nested_loop_join_optimizer.cpp
logical_operator_visitor.cpp
optimizer.cpp
projection_push_down_optimizer.cpp
remove_factorization_rewriter.cpp
Expand Down
Loading