Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Factorization rewriter #1307

Merged
merged 1 commit into from
Feb 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions src/include/optimizer/factorization_rewriter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#include "planner/logical_plan/logical_plan.h"

namespace kuzu {
namespace optimizer {

class FactorizationRewriter {
andyfengHKU marked this conversation as resolved.
Show resolved Hide resolved
public:
void rewrite(planner::LogicalPlan* plan);

private:
void visitOperator(planner::LogicalOperator* op);
void visitExtend(planner::LogicalOperator* op);
void visitHashJoin(planner::LogicalOperator* op);
void visitIntersect(planner::LogicalOperator* op);
void visitProjection(planner::LogicalOperator* op);
void visitAggregate(planner::LogicalOperator* op);
void visitOrderBy(planner::LogicalOperator* op);
void visitSkip(planner::LogicalOperator* op);
void visitLimit(planner::LogicalOperator* op);
void visitDistinct(planner::LogicalOperator* op);
void visitUnwind(planner::LogicalOperator* op);
void visitUnion(planner::LogicalOperator* op);
void visitFilter(planner::LogicalOperator* op);
void visitSetNodeProperty(planner::LogicalOperator* op);
void visitSetRelProperty(planner::LogicalOperator* op);
void visitDeleteRel(planner::LogicalOperator* op);
void visitCreateNode(planner::LogicalOperator* op);
void visitCreateRel(planner::LogicalOperator* op);

std::shared_ptr<planner::LogicalOperator> appendFlattens(
std::shared_ptr<planner::LogicalOperator> op,
const std::unordered_set<planner::f_group_pos>& groupsPos);
std::shared_ptr<planner::LogicalOperator> appendFlattenIfNecessary(
std::shared_ptr<planner::LogicalOperator> op, planner::f_group_pos groupPos);
};

} // namespace optimizer
} // namespace kuzu
18 changes: 18 additions & 0 deletions src/include/optimizer/remove_factorization_rewriter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#include "planner/logical_plan/logical_plan.h"

namespace kuzu {
namespace optimizer {

class RemoveFactorizationRewriter {
public:
void rewrite(planner::LogicalPlan* plan);

private:
std::shared_ptr<planner::LogicalOperator> rewriteOperator(
std::shared_ptr<planner::LogicalOperator> op);

bool subPlanHasFlatten(planner::LogicalOperator* op);
};

} // namespace optimizer
} // namespace kuzu
6 changes: 2 additions & 4 deletions src/include/planner/join_order_enumerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,6 @@ class JoinOrderEnumerator {
void appendIndexScanNode(std::shared_ptr<NodeExpression>& node,
std::shared_ptr<Expression> indexExpression, LogicalPlan& plan);

bool needFlatInput(
RelExpression& rel, NodeExpression& boundNode, common::RelDirection direction);
bool needExtendToNewGroup(
RelExpression& rel, NodeExpression& boundNode, common::RelDirection direction);
void appendExtend(std::shared_ptr<NodeExpression> boundNode,
Expand All @@ -129,8 +127,8 @@ class JoinOrderEnumerator {
static void appendMarkJoin(const binder::expression_vector& joinNodeIDs,
const std::shared_ptr<Expression>& mark, bool isProbeAcc, LogicalPlan& probePlan,
LogicalPlan& buildPlan);
static void appendIntersect(const std::shared_ptr<NodeExpression>& intersectNode,
std::vector<std::shared_ptr<NodeExpression>>& boundNodes, LogicalPlan& probePlan,
static void appendIntersect(const std::shared_ptr<Expression>& intersectNodeID,
binder::expression_vector& boundNodeIDs, LogicalPlan& probePlan,
std::vector<std::unique_ptr<LogicalPlan>>& buildPlans);
static void appendCrossProduct(LogicalPlan& probePlan, LogicalPlan& buildPlan);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class LogicalOperator {
// Used for operators with more than two children e.g. Union
inline void addChild(std::shared_ptr<LogicalOperator> op) { children.push_back(std::move(op)); }
inline std::shared_ptr<LogicalOperator> getChild(uint64_t idx) const { return children[idx]; }
inline std::vector<std::shared_ptr<LogicalOperator>> getChildren() const { return children; }
inline void setChild(uint64_t idx, std::shared_ptr<LogicalOperator> child) {
children[idx] = std::move(child);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#pragma once

#include "planner/logical_plan/logical_operator/base_logical_operator.h"

namespace kuzu {
namespace planner {
namespace factorization {

struct FlattenAllButOne {
static f_group_pos_set getGroupsPosToFlatten(const f_group_pos_set& groupsPos, Schema* schema);
};

struct FlattenAll {
static f_group_pos_set getGroupsPosToFlatten(const f_group_pos_set& groupsPos, Schema* schema);
};

} // namespace factorization
} // namespace planner
} // namespace kuzu
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ class LogicalAggregate : public LogicalOperator {
expressionsToGroupBy{std::move(expressionsToGroupBy)}, expressionsToAggregate{std::move(
expressionsToAggregate)} {}

f_group_pos_set getGroupsPosToFlattenForGroupBy();
f_group_pos_set getGroupsPosToFlattenForAggregate();

void computeSchema() override;

std::string getExpressionsForPrinting() const override;
Expand All @@ -31,6 +34,9 @@ class LogicalAggregate : public LogicalOperator {
expressionsToGroupBy, expressionsToAggregate, children[0]->copy());
}

private:
bool hasDistinctAggregate();

private:
binder::expression_vector expressionsToGroupBy;
binder::expression_vector expressionsToAggregate;
Expand Down
15 changes: 15 additions & 0 deletions src/include/planner/logical_plan/logical_operator/logical_create.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include "flatten_resolver.h"
#include "logical_update.h"

namespace kuzu {
Expand All @@ -15,6 +16,14 @@ class LogicalCreateNode : public LogicalUpdateNode {

void computeSchema() override;

inline f_group_pos_set getGroupsPosToFlatten() {
// Flatten all inputs. E.g. MATCH (a) CREATE (b). We need to create b for each tuple in the
// match clause. This is to simplify operator implementation.
auto childSchema = children[0]->getSchema();
return factorization::FlattenAll::getGroupsPosToFlatten(
childSchema->getGroupsPosInScope(), childSchema);
}

inline std::shared_ptr<binder::Expression> getPrimaryKey(size_t idx) const {
return primaryKeys[idx];
}
Expand All @@ -36,6 +45,12 @@ class LogicalCreateRel : public LogicalUpdateRel {
setItemsPerRel{std::move(setItemsPerRel)} {}
~LogicalCreateRel() override = default;

inline f_group_pos_set getGroupsPosToFlatten() {
auto childSchema = children[0]->getSchema();
return factorization::FlattenAll::getGroupsPosToFlatten(
childSchema->getGroupsPosInScope(), childSchema);
}

inline std::vector<binder::expression_pair> getSetItems(uint32_t idx) const {
return setItemsPerRel[idx];
}
Expand Down
10 changes: 10 additions & 0 deletions src/include/planner/logical_plan/logical_operator/logical_delete.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include "logical_update.h"
#include "planner/logical_plan/logical_operator/flatten_resolver.h"

namespace kuzu {
namespace planner {
Expand Down Expand Up @@ -32,6 +33,15 @@ class LogicalDeleteRel : public LogicalUpdateRel {
: LogicalUpdateRel{LogicalOperatorType::DELETE_REL, std::move(rels), std::move(child)} {}
~LogicalDeleteRel() override = default;

inline f_group_pos_set getGroupsPosToFlatten(uint32_t relIdx) {
f_group_pos_set result;
auto rel = rels[relIdx];
auto childSchema = children[0]->getSchema();
result.insert(childSchema->getGroupPos(*rel->getSrcNode()->getInternalIDProperty()));
result.insert(childSchema->getGroupPos(*rel->getDstNode()->getInternalIDProperty()));
return factorization::FlattenAll::getGroupsPosToFlatten(result, childSchema);
}

inline std::unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalDeleteRel>(rels, children[0]->copy());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ class LogicalDistinct : public LogicalOperator {
: LogicalOperator{LogicalOperatorType::DISTINCT, std::move(child)},
expressionsToDistinct{std::move(expressionsToDistinct)} {}

f_group_pos_set getGroupsPosToFlatten();

void computeSchema() override;

std::string getExpressionsForPrinting() const override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ class LogicalExtend : public LogicalOperator {
nbrNode{std::move(nbrNode)}, rel{std::move(rel)}, direction{direction},
properties{std::move(properties)}, extendToNewGroup{extendToNewGroup} {}

f_group_pos_set getGroupsPosToFlatten();

void computeSchema() override;

inline std::string getExpressionsForPrinting() const override {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ class LogicalFilter : public LogicalOperator {
: LogicalOperator{LogicalOperatorType::FILTER, std::move(child)}, expression{std::move(
expression)} {}

f_group_pos_set getGroupsPosToFlatten();

inline void computeSchema() override { copyChildSchema(0); }

inline std::string getExpressionsForPrinting() const override {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,21 @@ namespace planner {

class LogicalFlatten : public LogicalOperator {
public:
LogicalFlatten(
std::shared_ptr<binder::Expression> expression, std::shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::FLATTEN, std::move(child)}, expression{std::move(
expression)} {}
LogicalFlatten(f_group_pos groupPos, std::shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::FLATTEN, std::move(child)}, groupPos{groupPos} {}

void computeSchema() override;

inline std::string getExpressionsForPrinting() const override {
return expression->getUniqueName();
}
inline std::string getExpressionsForPrinting() const override { return std::string{}; }

inline std::shared_ptr<binder::Expression> getExpression() const { return expression; }
inline f_group_pos getGroupPos() const { return groupPos; }

inline std::unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalFlatten>(expression, children[0]->copy());
return make_unique<LogicalFlatten>(groupPos, children[0]->copy());
}

private:
std::shared_ptr<binder::Expression> expression;
f_group_pos groupPos;
};

} // namespace planner
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ class LogicalHashJoin : public LogicalOperator {
joinNodeIDs(std::move(joinNodeIDs)), joinType{joinType}, mark{std::move(mark)},
isProbeAcc{isProbeAcc}, expressionsToMaterialize{std::move(expressionsToMaterialize)} {}

f_group_pos_set getGroupsPosToFlattenOnProbeSide();
f_group_pos_set getGroupsPosToFlattenOnBuildSide();

void computeSchema() override;

inline std::string getExpressionsForPrinting() const override {
Expand All @@ -63,6 +66,18 @@ class LogicalHashJoin : public LogicalOperator {
expressionsToMaterialize, children[0]->copy(), children[1]->copy());
}

private:
// Flat probe side key group in either of the following two cases:
// 1. there are multiple join nodes;
// 2. if the build side contains more than one group or the build side has projected out data
// chunks, which may increase the multiplicity of data chunks in the build side. The key is to
// keep probe side key unflat only when we know that there is only 0 or 1 match for each key.
// TODO(Guodong): when the build side has only flat payloads, we should consider getting rid of
// flattening probe key, instead duplicating keys as in vectorized processing if necessary.
bool requireFlatProbeKeys();

bool isJoinKeyUniqueOnBuildSide(const binder::Expression& joinNodeID);

private:
binder::expression_vector joinNodeIDs;
common::JoinType joinType;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ class LogicalIntersect : public LogicalOperator {
}
}

f_group_pos_set getGroupsPosToFlattenOnProbeSide();
f_group_pos_set getGroupsPosToFlattenOnBuildSide(uint32_t buildIdx);

void computeSchema() override;

std::string getExpressionsForPrinting() const override { return intersectNodeID->getRawName(); }
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include "base_logical_operator.h"
#include "planner/logical_plan/logical_operator/flatten_resolver.h"

namespace kuzu {
namespace planner {
Expand All @@ -10,6 +11,8 @@ class LogicalLimit : public LogicalOperator {
LogicalLimit(uint64_t limitNumber, std::shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::LIMIT, std::move(child)}, limitNumber{limitNumber} {}

f_group_pos_set getGroupsPosToFlatten();

inline void computeSchema() override { copyChildSchema(0); }

inline std::string getExpressionsForPrinting() const override {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ class LogicalOrderBy : public LogicalOperator {
expressionsToOrderBy{std::move(expressionsToOrderBy)}, isAscOrders{std::move(sortOrders)},
expressionsToMaterialize{std::move(expressionsToMaterialize)} {}

f_group_pos_set getGroupsPosToFlatten();

void computeSchema() override;

inline std::string getExpressionsForPrinting() const override {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include "logical_update.h"
#include "planner/logical_plan/logical_operator/flatten_resolver.h"

namespace kuzu {
namespace planner {
Expand Down Expand Up @@ -39,6 +40,8 @@ class LogicalSetRelProperty : public LogicalUpdateRel {
std::move(child)},
setItems{std::move(setItems)} {}

f_group_pos_set getGroupsPosToFlatten(uint32_t setItemIdx);

inline std::string getExpressionsForPrinting() const override {
std::string result;
for (auto& [lhs, rhs] : setItems) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include "base_logical_operator.h"
#include "planner/logical_plan/logical_operator/flatten_resolver.h"

namespace kuzu {
namespace planner {
Expand All @@ -10,6 +11,8 @@ class LogicalSkip : public LogicalOperator {
LogicalSkip(uint64_t skipNumber, std::shared_ptr<LogicalOperator> child)
: LogicalOperator(LogicalOperatorType::SKIP, std::move(child)), skipNumber{skipNumber} {}

f_group_pos_set getGroupsPosToFlatten();

inline void computeSchema() override { copyChildSchema(0); }

inline std::string getExpressionsForPrinting() const override {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,23 @@ class LogicalUnion : public LogicalOperator {
: LogicalOperator{LogicalOperatorType::UNION_ALL, std::move(children)},
expressionsToUnion{std::move(expressions)} {}

f_group_pos_set getGroupsPosToFlatten(uint32_t childIdx);

void computeSchema() override;

inline std::string getExpressionsForPrinting() const override { return std::string(); }
inline std::string getExpressionsForPrinting() const override { return std::string{}; }

inline binder::expression_vector getExpressionsToUnion() { return expressionsToUnion; }

inline Schema* getSchemaBeforeUnion(uint32_t idx) { return children[idx]->getSchema(); }

std::unique_ptr<LogicalOperator> copy() override;

private:
// If an expression to union has different flat/unflat state in different child, we
// need to flatten that expression in all the single queries.
bool requireFlatExpression(uint32_t expressionIdx);

private:
binder::expression_vector expressionsToUnion;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ class LogicalUnwind : public LogicalOperator {
: LogicalOperator{LogicalOperatorType::UNWIND, std::move(childOperator)},
expression{std::move(expression)}, aliasExpression{std::move(aliasExpression)} {}

f_group_pos_set getGroupsPosToFlatten();

void computeSchema() override;

inline std::shared_ptr<binder::Expression> getExpression() { return expression; }
Expand Down
1 change: 1 addition & 0 deletions src/include/planner/logical_plan/logical_operator/schema.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ namespace kuzu {
namespace planner {

typedef uint32_t f_group_pos;
typedef std::unordered_set<f_group_pos> f_group_pos_set;
constexpr f_group_pos INVALID_F_GROUP_POS = UINT32_MAX;

class FactorizationGroup {
Expand Down
Loading