Skip to content

Commit

Permalink
add factorization rewriter
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed Feb 21, 2023
1 parent 93140d2 commit 6a191e6
Show file tree
Hide file tree
Showing 49 changed files with 878 additions and 296 deletions.
38 changes: 38 additions & 0 deletions src/include/optimizer/factorization_rewriter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#include "planner/logical_plan/logical_plan.h"

namespace kuzu {
namespace optimizer {

class FactorizationRewriter {
public:
void rewrite(planner::LogicalPlan* plan);

private:
void visitOperator(planner::LogicalOperator* op);
void visitExtend(planner::LogicalOperator* op);
void visitHashJoin(planner::LogicalOperator* op);
void visitIntersect(planner::LogicalOperator* op);
void visitProjection(planner::LogicalOperator* op);
void visitAggregate(planner::LogicalOperator* op);
void visitOrderBy(planner::LogicalOperator* op);
void visitSkip(planner::LogicalOperator* op);
void visitLimit(planner::LogicalOperator* op);
void visitDistinct(planner::LogicalOperator* op);
void visitUnwind(planner::LogicalOperator* op);
void visitUnion(planner::LogicalOperator* op);
void visitFilter(planner::LogicalOperator* op);
void visitSetNodeProperty(planner::LogicalOperator* op);
void visitSetRelProperty(planner::LogicalOperator* op);
void visitDeleteRel(planner::LogicalOperator* op);
void visitCreateNode(planner::LogicalOperator* op);
void visitCreateRel(planner::LogicalOperator* op);

std::shared_ptr<planner::LogicalOperator> appendFlattens(
std::shared_ptr<planner::LogicalOperator> op,
const std::unordered_set<planner::f_group_pos>& groupsPos);
std::shared_ptr<planner::LogicalOperator> appendFlattenIfNecessary(
std::shared_ptr<planner::LogicalOperator> op, planner::f_group_pos groupPos);
};

} // namespace optimizer
} // namespace kuzu
18 changes: 18 additions & 0 deletions src/include/optimizer/remove_factorization_rewriter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#include "planner/logical_plan/logical_plan.h"

namespace kuzu {
namespace optimizer {

class RemoveFactorizationRewriter {
public:
void rewrite(planner::LogicalPlan* plan);

private:
std::shared_ptr<planner::LogicalOperator> rewriteOperator(
std::shared_ptr<planner::LogicalOperator> op);

bool subPlanHasFlatten(planner::LogicalOperator* op);
};

} // namespace optimizer
} // namespace kuzu
6 changes: 2 additions & 4 deletions src/include/planner/join_order_enumerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,6 @@ class JoinOrderEnumerator {
void appendIndexScanNode(std::shared_ptr<NodeExpression>& node,
std::shared_ptr<Expression> indexExpression, LogicalPlan& plan);

bool needFlatInput(
RelExpression& rel, NodeExpression& boundNode, common::RelDirection direction);
bool needExtendToNewGroup(
RelExpression& rel, NodeExpression& boundNode, common::RelDirection direction);
void appendExtend(std::shared_ptr<NodeExpression> boundNode,
Expand All @@ -129,8 +127,8 @@ class JoinOrderEnumerator {
static void appendMarkJoin(const binder::expression_vector& joinNodeIDs,
const std::shared_ptr<Expression>& mark, bool isProbeAcc, LogicalPlan& probePlan,
LogicalPlan& buildPlan);
static void appendIntersect(const std::shared_ptr<NodeExpression>& intersectNode,
std::vector<std::shared_ptr<NodeExpression>>& boundNodes, LogicalPlan& probePlan,
static void appendIntersect(const std::shared_ptr<Expression>& intersectNodeID,
binder::expression_vector& boundNodeIDs, LogicalPlan& probePlan,
std::vector<std::unique_ptr<LogicalPlan>>& buildPlans);
static void appendCrossProduct(LogicalPlan& probePlan, LogicalPlan& buildPlan);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class LogicalOperator {
// Used for operators with more than two children e.g. Union
inline void addChild(std::shared_ptr<LogicalOperator> op) { children.push_back(std::move(op)); }
inline std::shared_ptr<LogicalOperator> getChild(uint64_t idx) const { return children[idx]; }
inline std::vector<std::shared_ptr<LogicalOperator>> getChildren() const { return children; }
inline void setChild(uint64_t idx, std::shared_ptr<LogicalOperator> child) {
children[idx] = std::move(child);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#pragma once

#include "planner/logical_plan/logical_operator/base_logical_operator.h"

namespace kuzu {
namespace planner {
namespace factorization {

struct FlattenAllButOne {
static f_group_pos_set getGroupsPosToFlatten(const f_group_pos_set& groupsPos, Schema* schema);
};

struct FlattenAll {
static f_group_pos_set getGroupsPosToFlatten(const f_group_pos_set& groupsPos, Schema* schema);
};

} // namespace factorization
} // namespace planner
} // namespace kuzu
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ class LogicalAggregate : public LogicalOperator {
expressionsToGroupBy{std::move(expressionsToGroupBy)}, expressionsToAggregate{std::move(
expressionsToAggregate)} {}

f_group_pos_set getGroupsPosToFlattenForGroupBy();
f_group_pos_set getGroupsPosToFlattenForAggregate();

void computeSchema() override;

std::string getExpressionsForPrinting() const override;
Expand All @@ -31,6 +34,9 @@ class LogicalAggregate : public LogicalOperator {
expressionsToGroupBy, expressionsToAggregate, children[0]->copy());
}

private:
bool hasDistinctAggregate();

private:
binder::expression_vector expressionsToGroupBy;
binder::expression_vector expressionsToAggregate;
Expand Down
15 changes: 15 additions & 0 deletions src/include/planner/logical_plan/logical_operator/logical_create.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include "flatten_resolver.h"
#include "logical_update.h"

namespace kuzu {
Expand All @@ -15,6 +16,14 @@ class LogicalCreateNode : public LogicalUpdateNode {

void computeSchema() override;

inline f_group_pos_set getGroupsPosToFlatten() {
// Flatten all inputs. E.g. MATCH (a) CREATE (b). We need to create b for each tuple in the
// match clause. This is to simplify operator implementation.
auto childSchema = children[0]->getSchema();
return factorization::FlattenAll::getGroupsPosToFlatten(
childSchema->getGroupsPosInScope(), childSchema);
}

inline std::shared_ptr<binder::Expression> getPrimaryKey(size_t idx) const {
return primaryKeys[idx];
}
Expand All @@ -36,6 +45,12 @@ class LogicalCreateRel : public LogicalUpdateRel {
setItemsPerRel{std::move(setItemsPerRel)} {}
~LogicalCreateRel() override = default;

inline f_group_pos_set getGroupsPosToFlatten() {
auto childSchema = children[0]->getSchema();
return factorization::FlattenAll::getGroupsPosToFlatten(
childSchema->getGroupsPosInScope(), childSchema);
}

inline std::vector<binder::expression_pair> getSetItems(uint32_t idx) const {
return setItemsPerRel[idx];
}
Expand Down
10 changes: 10 additions & 0 deletions src/include/planner/logical_plan/logical_operator/logical_delete.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include "logical_update.h"
#include "planner/logical_plan/logical_operator/flatten_resolver.h"

namespace kuzu {
namespace planner {
Expand Down Expand Up @@ -32,6 +33,15 @@ class LogicalDeleteRel : public LogicalUpdateRel {
: LogicalUpdateRel{LogicalOperatorType::DELETE_REL, std::move(rels), std::move(child)} {}
~LogicalDeleteRel() override = default;

inline f_group_pos_set getGroupsPosToFlatten(uint32_t relIdx) {
f_group_pos_set result;
auto rel = rels[relIdx];
auto childSchema = children[0]->getSchema();
result.insert(childSchema->getGroupPos(*rel->getSrcNode()->getInternalIDProperty()));
result.insert(childSchema->getGroupPos(*rel->getDstNode()->getInternalIDProperty()));
return factorization::FlattenAll::getGroupsPosToFlatten(result, childSchema);
}

inline std::unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalDeleteRel>(rels, children[0]->copy());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ class LogicalDistinct : public LogicalOperator {
: LogicalOperator{LogicalOperatorType::DISTINCT, std::move(child)},
expressionsToDistinct{std::move(expressionsToDistinct)} {}

f_group_pos_set getGroupsPosToFlatten();

void computeSchema() override;

std::string getExpressionsForPrinting() const override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ class LogicalExtend : public LogicalOperator {
nbrNode{std::move(nbrNode)}, rel{std::move(rel)}, direction{direction},
properties{std::move(properties)}, extendToNewGroup{extendToNewGroup} {}

f_group_pos_set getGroupsPosToFlatten();

void computeSchema() override;

inline std::string getExpressionsForPrinting() const override {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ class LogicalFilter : public LogicalOperator {
: LogicalOperator{LogicalOperatorType::FILTER, std::move(child)}, expression{std::move(
expression)} {}

f_group_pos_set getGroupsPosToFlatten();

inline void computeSchema() override { copyChildSchema(0); }

inline std::string getExpressionsForPrinting() const override {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,21 @@ namespace planner {

class LogicalFlatten : public LogicalOperator {
public:
LogicalFlatten(
std::shared_ptr<binder::Expression> expression, std::shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::FLATTEN, std::move(child)}, expression{std::move(
expression)} {}
LogicalFlatten(f_group_pos groupPos, std::shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::FLATTEN, std::move(child)}, groupPos{groupPos} {}

void computeSchema() override;

inline std::string getExpressionsForPrinting() const override {
return expression->getUniqueName();
}
inline std::string getExpressionsForPrinting() const override { return std::string{}; }

inline std::shared_ptr<binder::Expression> getExpression() const { return expression; }
inline f_group_pos getGroupPos() const { return groupPos; }

inline std::unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalFlatten>(expression, children[0]->copy());
return make_unique<LogicalFlatten>(groupPos, children[0]->copy());
}

private:
std::shared_ptr<binder::Expression> expression;
f_group_pos groupPos;
};

} // namespace planner
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ class LogicalHashJoin : public LogicalOperator {
joinNodeIDs(std::move(joinNodeIDs)), joinType{joinType}, mark{std::move(mark)},
isProbeAcc{isProbeAcc}, expressionsToMaterialize{std::move(expressionsToMaterialize)} {}

f_group_pos_set getGroupsPosToFlattenOnProbeSide();
f_group_pos_set getGroupsPosToFlattenOnBuildSide();

void computeSchema() override;

inline std::string getExpressionsForPrinting() const override {
Expand All @@ -63,6 +66,18 @@ class LogicalHashJoin : public LogicalOperator {
expressionsToMaterialize, children[0]->copy(), children[1]->copy());
}

private:
// Flat probe side key group in either of the following two cases:
// 1. there are multiple join nodes;
// 2. if the build side contains more than one group or the build side has projected out data
// chunks, which may increase the multiplicity of data chunks in the build side. The key is to
// keep probe side key unflat only when we know that there is only 0 or 1 match for each key.
// TODO(Guodong): when the build side has only flat payloads, we should consider getting rid of
// flattening probe key, instead duplicating keys as in vectorized processing if necessary.
bool requireFlatProbeKeys();

bool isJoinKeyUniqueOnBuildSide(const binder::Expression& joinNodeID);

private:
binder::expression_vector joinNodeIDs;
common::JoinType joinType;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ class LogicalIntersect : public LogicalOperator {
}
}

f_group_pos_set getGroupsPosToFlattenOnProbeSide();
f_group_pos_set getGroupsPosToFlattenOnBuildSide(uint32_t buildIdx);

void computeSchema() override;

std::string getExpressionsForPrinting() const override { return intersectNodeID->getRawName(); }
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include "base_logical_operator.h"
#include "planner/logical_plan/logical_operator/flatten_resolver.h"

namespace kuzu {
namespace planner {
Expand All @@ -10,6 +11,8 @@ class LogicalLimit : public LogicalOperator {
LogicalLimit(uint64_t limitNumber, std::shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::LIMIT, std::move(child)}, limitNumber{limitNumber} {}

f_group_pos_set getGroupsPosToFlatten();

inline void computeSchema() override { copyChildSchema(0); }

inline std::string getExpressionsForPrinting() const override {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ class LogicalOrderBy : public LogicalOperator {
expressionsToOrderBy{std::move(expressionsToOrderBy)}, isAscOrders{std::move(sortOrders)},
expressionsToMaterialize{std::move(expressionsToMaterialize)} {}

f_group_pos_set getGroupsPosToFlatten();

void computeSchema() override;

inline std::string getExpressionsForPrinting() const override {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include "logical_update.h"
#include "planner/logical_plan/logical_operator/flatten_resolver.h"

namespace kuzu {
namespace planner {
Expand Down Expand Up @@ -39,6 +40,8 @@ class LogicalSetRelProperty : public LogicalUpdateRel {
std::move(child)},
setItems{std::move(setItems)} {}

f_group_pos_set getGroupsPosToFlatten(uint32_t setItemIdx);

inline std::string getExpressionsForPrinting() const override {
std::string result;
for (auto& [lhs, rhs] : setItems) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include "base_logical_operator.h"
#include "planner/logical_plan/logical_operator/flatten_resolver.h"

namespace kuzu {
namespace planner {
Expand All @@ -10,6 +11,8 @@ class LogicalSkip : public LogicalOperator {
LogicalSkip(uint64_t skipNumber, std::shared_ptr<LogicalOperator> child)
: LogicalOperator(LogicalOperatorType::SKIP, std::move(child)), skipNumber{skipNumber} {}

f_group_pos_set getGroupsPosToFlatten();

inline void computeSchema() override { copyChildSchema(0); }

inline std::string getExpressionsForPrinting() const override {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,23 @@ class LogicalUnion : public LogicalOperator {
: LogicalOperator{LogicalOperatorType::UNION_ALL, std::move(children)},
expressionsToUnion{std::move(expressions)} {}

f_group_pos_set getGroupsPosToFlatten(uint32_t childIdx);

void computeSchema() override;

inline std::string getExpressionsForPrinting() const override { return std::string(); }
inline std::string getExpressionsForPrinting() const override { return std::string{}; }

inline binder::expression_vector getExpressionsToUnion() { return expressionsToUnion; }

inline Schema* getSchemaBeforeUnion(uint32_t idx) { return children[idx]->getSchema(); }

std::unique_ptr<LogicalOperator> copy() override;

private:
// If an expression to union has different flat/unflat state in different child, we
// need to flatten that expression in all the single queries.
bool requireFlatExpression(uint32_t expressionIdx);

private:
binder::expression_vector expressionsToUnion;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ class LogicalUnwind : public LogicalOperator {
: LogicalOperator{LogicalOperatorType::UNWIND, std::move(childOperator)},
expression{std::move(expression)}, aliasExpression{std::move(aliasExpression)} {}

f_group_pos_set getGroupsPosToFlatten();

void computeSchema() override;

inline std::shared_ptr<binder::Expression> getExpression() { return expression; }
Expand Down
1 change: 1 addition & 0 deletions src/include/planner/logical_plan/logical_operator/schema.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ namespace kuzu {
namespace planner {

typedef uint32_t f_group_pos;
typedef std::unordered_set<f_group_pos> f_group_pos_set;
constexpr f_group_pos INVALID_F_GROUP_POS = UINT32_MAX;

class FactorizationGroup {
Expand Down
Loading

0 comments on commit 6a191e6

Please sign in to comment.