Skip to content

Commit

Permalink
Add compute flat/factorzied schema
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed Mar 8, 2023
1 parent ddaa2fa commit 0205d87
Show file tree
Hide file tree
Showing 62 changed files with 443 additions and 250 deletions.
3 changes: 3 additions & 0 deletions src/include/optimizer/projection_push_down_optimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ class ProjectionPushDownOptimizer : public LogicalOperatorVisitor {

binder::expression_vector pruneExpressions(const binder::expression_vector& expressions);

void preAppendProjection(
planner::LogicalOperator* op, uint32_t childIdx, binder::expression_vector expressions);

private:
binder::expression_set propertiesInUse;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ class LogicalOperator {
inline LogicalOperatorType getOperatorType() const { return operatorType; }

inline Schema* getSchema() const { return schema.get(); }
virtual void computeSchema() = 0;
virtual void computeFactorizedSchema() = 0;
virtual void computeFlatSchema() = 0;

virtual std::string getExpressionsForPrinting() const = 0;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,21 @@ namespace planner {

class LogicalAccumulate : public LogicalOperator {
public:
LogicalAccumulate(binder::expression_vector expressions, std::shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::ACCUMULATE, std::move(child)}, expressions{std::move(
expressions)} {}
LogicalAccumulate(std::shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::ACCUMULATE, std::move(child)} {}

void computeSchema() override;
void computeFactorizedSchema() override;
void computeFlatSchema() override;

inline std::string getExpressionsForPrinting() const override {
return binder::ExpressionUtil::toString(expressions);
}
inline std::string getExpressionsForPrinting() const override { return std::string{}; }

inline void setExpressions(binder::expression_vector expressions_) {
expressions = std::move(expressions_);
inline binder::expression_vector getExpressions() const {
return children[0]->getSchema()->getExpressionsInScope();
}
inline binder::expression_vector getExpressions() const { return expressions; }
inline Schema* getSchemaBeforeSink() const { return children[0]->getSchema(); }

inline std::unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalAccumulate>(expressions, children[0]->copy());
return make_unique<LogicalAccumulate>(children[0]->copy());
}

private:
binder::expression_vector expressions;
};

} // namespace planner
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@ class LogicalAggregate : public LogicalOperator {
expressionsToGroupBy{std::move(expressionsToGroupBy)}, expressionsToAggregate{std::move(
expressionsToAggregate)} {}

void computeFactorizedSchema() override;
void computeFlatSchema() override;

f_group_pos_set getGroupsPosToFlattenForGroupBy();
f_group_pos_set getGroupsPosToFlattenForAggregate();

void computeSchema() override;

std::string getExpressionsForPrinting() const override;

inline bool hasExpressionsToGroupBy() const { return !expressionsToGroupBy.empty(); }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ class LogicalCopy : public LogicalOperator {
: LogicalOperator{LogicalOperatorType::COPY_CSV},
copyDescription{copyDescription}, tableID{tableID}, tableName{std::move(tableName)} {}

inline void computeSchema() override { createEmptySchema(); }
inline void computeFactorizedSchema() override { createEmptySchema(); }
inline void computeFlatSchema() override { createEmptySchema(); }

inline std::string getExpressionsForPrinting() const override { return tableName; }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ class LogicalCreateNode : public LogicalUpdateNode {
primaryKeys{std::move(primaryKeys)} {}
~LogicalCreateNode() override = default;

void computeSchema() override;
void computeFactorizedSchema() override;
void computeFlatSchema() override;

inline f_group_pos_set getGroupsPosToFlatten() {
// Flatten all inputs. E.g. MATCH (a) CREATE (b). We need to create b for each tuple in the
Expand Down Expand Up @@ -45,6 +46,9 @@ class LogicalCreateRel : public LogicalUpdateRel {
setItemsPerRel{std::move(setItemsPerRel)} {}
~LogicalCreateRel() override = default;

inline void computeFactorizedSchema() override { copyChildSchema(0); }
inline void computeFlatSchema() override { copyChildSchema(0); }

inline f_group_pos_set getGroupsPosToFlatten() {
auto childSchema = children[0]->getSchema();
return factorization::FlattenAll::getGroupsPosToFlatten(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ class LogicalCrossProduct : public LogicalOperator {
: LogicalOperator{LogicalOperatorType::CROSS_PRODUCT, std::move(probeSideChild),
std::move(buildSideChild)} {}

void computeSchema() override;
void computeFactorizedSchema() override;
void computeFlatSchema() override;

inline std::string getExpressionsForPrinting() const override { return std::string(); }

Expand Down
10 changes: 3 additions & 7 deletions src/include/planner/logical_plan/logical_operator/logical_ddl.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,16 @@ class LogicalDDL : public LogicalOperator {
: LogicalOperator{operatorType}, tableName{std::move(tableName)},
outputExpression{std::move(outputExpression)} {}

void computeFactorizedSchema() override;
void computeFlatSchema() override;

inline std::string getTableName() const { return tableName; }
inline std::shared_ptr<binder::Expression> getOutputExpression() const {
return outputExpression;
}

inline std::string getExpressionsForPrinting() const override { return tableName; }

inline void computeSchema() override {
schema = std::make_unique<Schema>();
auto groupPos = schema->createGroup();
schema->insertToGroupAndScope(outputExpression, groupPos);
schema->setGroupAsSingleState(groupPos);
}

protected:
std::string tableName;
std::shared_ptr<binder::Expression> outputExpression;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ class LogicalDeleteNode : public LogicalUpdateNode {
primaryKeys{std::move(primaryKeys)} {}
~LogicalDeleteNode() override = default;

inline void computeFactorizedSchema() override { copyChildSchema(0); }
inline void computeFlatSchema() override { copyChildSchema(0); }

inline std::shared_ptr<binder::Expression> getPrimaryKey(size_t idx) const {
return primaryKeys[idx];
}
Expand All @@ -33,6 +36,9 @@ class LogicalDeleteRel : public LogicalUpdateRel {
: LogicalUpdateRel{LogicalOperatorType::DELETE_REL, std::move(rels), std::move(child)} {}
~LogicalDeleteRel() override = default;

inline void computeFactorizedSchema() override { copyChildSchema(0); }
inline void computeFlatSchema() override { copyChildSchema(0); }

inline f_group_pos_set getGroupsPosToFlatten(uint32_t relIdx) {
f_group_pos_set result;
auto rel = rels[relIdx];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@ class LogicalDistinct : public LogicalOperator {
: LogicalOperator{LogicalOperatorType::DISTINCT, std::move(child)},
expressionsToDistinct{std::move(expressionsToDistinct)} {}

f_group_pos_set getGroupsPosToFlatten();
void computeFactorizedSchema() override;
void computeFlatSchema() override;

void computeSchema() override;
f_group_pos_set getGroupsPosToFlatten();

std::string getExpressionsForPrinting() const override;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ class LogicalExpressionsScan : public LogicalOperator {
: LogicalOperator{LogicalOperatorType::EXPRESSIONS_SCAN}, expressions{
std::move(expressions)} {}

void computeSchema() override;
void computeFactorizedSchema() override;
void computeFlatSchema() override;

inline std::string getExpressionsForPrinting() const override {
return binder::ExpressionUtil::toString(expressions);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ class LogicalExtend : public LogicalOperator {

f_group_pos_set getGroupsPosToFlatten();

void computeSchema() override;
void computeFactorizedSchema() override;
void computeFlatSchema() override;

inline std::string getExpressionsForPrinting() const override {
return boundNode->toString() + (direction == common::RelDirection::FWD ? "->" : "<-") +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@ class LogicalFilter : public LogicalOperator {
: LogicalOperator{LogicalOperatorType::FILTER, std::move(child)}, expression{std::move(
expression)} {}

f_group_pos_set getGroupsPosToFlatten();
inline void computeFactorizedSchema() override { copyChildSchema(0); }
inline void computeFlatSchema() override { copyChildSchema(0); }

inline void computeSchema() override { copyChildSchema(0); }
f_group_pos_set getGroupsPosToFlatten();

inline std::string getExpressionsForPrinting() const override { return expression->toString(); }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ class LogicalFlatten : public LogicalOperator {
LogicalFlatten(f_group_pos groupPos, std::shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::FLATTEN, std::move(child)}, groupPos{groupPos} {}

void computeSchema() override;
void computeFactorizedSchema() override;
void computeFlatSchema() override;

inline std::string getExpressionsForPrinting() const override { return std::string{}; }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ class LogicalFTableScan : public LogicalOperator {
expressionsToScan{std::move(expressionsToScan)}, schemaToScanFrom{
std::move(schemaToScanFrom)} {}

void computeSchema() override;
void computeFactorizedSchema() override;
void computeFlatSchema() override;

inline std::string getExpressionsForPrinting() const override {
return binder::ExpressionUtil::toString(expressionsToScan);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,46 +14,38 @@ class LogicalHashJoin : public LogicalOperator {
public:
// Inner and left join.
LogicalHashJoin(binder::expression_vector joinNodeIDs, common::JoinType joinType,
bool isProbeAcc, binder::expression_vector expressionsToMaterialize,
std::shared_ptr<LogicalOperator> probeSideChild,
bool isProbeAcc, std::shared_ptr<LogicalOperator> probeSideChild,
std::shared_ptr<LogicalOperator> buildSideChild)
: LogicalHashJoin{std::move(joinNodeIDs), joinType, nullptr, isProbeAcc,
std::move(expressionsToMaterialize), std::move(probeSideChild),
std::move(buildSideChild)} {}
std::move(probeSideChild), std::move(buildSideChild)} {}

// Mark join.
LogicalHashJoin(binder::expression_vector joinNodeIDs, std::shared_ptr<binder::Expression> mark,
bool isProbeAcc, std::shared_ptr<LogicalOperator> probeSideChild,
std::shared_ptr<LogicalOperator> buildSideChild)
: LogicalHashJoin{std::move(joinNodeIDs), common::JoinType::MARK, std::move(mark),
isProbeAcc, binder::expression_vector{} /* expressionsToMaterialize */,
std::move(probeSideChild), std::move(buildSideChild)} {}
isProbeAcc, std::move(probeSideChild), std::move(buildSideChild)} {}

LogicalHashJoin(binder::expression_vector joinNodeIDs, common::JoinType joinType,
std::shared_ptr<binder::Expression> mark, bool isProbeAcc,
binder::expression_vector expressionsToMaterialize,
std::shared_ptr<LogicalOperator> probeSideChild,
std::shared_ptr<LogicalOperator> buildSideChild)
: LogicalOperator{LogicalOperatorType::HASH_JOIN, std::move(probeSideChild),
std::move(buildSideChild)},
joinNodeIDs(std::move(joinNodeIDs)), joinType{joinType}, mark{std::move(mark)},
isProbeAcc{isProbeAcc}, expressionsToMaterialize{std::move(expressionsToMaterialize)} {}
isProbeAcc{isProbeAcc} {}

f_group_pos_set getGroupsPosToFlattenOnProbeSide();
f_group_pos_set getGroupsPosToFlattenOnBuildSide();

void computeSchema() override;
void computeFactorizedSchema() override;
void computeFlatSchema() override;

inline std::string getExpressionsForPrinting() const override {
return binder::ExpressionUtil::toString(joinNodeIDs);
}

inline void setExpressionsToMaterialize(binder::expression_vector expressions) {
expressionsToMaterialize = std::move(expressions);
}
inline binder::expression_vector getExpressionsToMaterialize() const {
return expressionsToMaterialize;
}
binder::expression_vector getExpressionsToMaterialize() const;
inline binder::expression_vector getJoinNodeIDs() const { return joinNodeIDs; }
inline common::JoinType getJoinType() const { return joinType; }

Expand All @@ -62,11 +54,10 @@ class LogicalHashJoin : public LogicalOperator {
return mark;
}
inline bool getIsProbeAcc() const { return isProbeAcc; }
inline Schema* getBuildSideSchema() const { return children[1]->getSchema(); }

inline std::unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalHashJoin>(joinNodeIDs, joinType, mark, isProbeAcc,
expressionsToMaterialize, children[0]->copy(), children[1]->copy());
return make_unique<LogicalHashJoin>(
joinNodeIDs, joinType, mark, isProbeAcc, children[0]->copy(), children[1]->copy());
}

private:
Expand All @@ -86,7 +77,6 @@ class LogicalHashJoin : public LogicalOperator {
common::JoinType joinType;
std::shared_ptr<binder::Expression> mark; // when joinType is Mark
bool isProbeAcc;
binder::expression_vector expressionsToMaterialize;
};

} // namespace planner
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,34 +7,13 @@
namespace kuzu {
namespace planner {

struct LogicalIntersectBuildInfo {
LogicalIntersectBuildInfo(
std::shared_ptr<binder::Expression> keyNodeID, binder::expression_vector expressions)
: keyNodeID{std::move(keyNodeID)}, expressionsToMaterialize{std::move(expressions)} {}

inline void setExpressionsToMaterialize(binder::expression_vector expressions) {
expressionsToMaterialize = std::move(expressions);
}
inline binder::expression_vector getExpressionsToMaterialize() const {
return expressionsToMaterialize;
}

inline std::unique_ptr<LogicalIntersectBuildInfo> copy() {
return make_unique<LogicalIntersectBuildInfo>(keyNodeID, expressionsToMaterialize);
}

std::shared_ptr<binder::Expression> keyNodeID;
binder::expression_vector expressionsToMaterialize;
};

class LogicalIntersect : public LogicalOperator {
public:
LogicalIntersect(std::shared_ptr<binder::Expression> intersectNodeID,
std::shared_ptr<LogicalOperator> probeChild,
std::vector<std::shared_ptr<LogicalOperator>> buildChildren,
std::vector<std::unique_ptr<LogicalIntersectBuildInfo>> buildInfos)
binder::expression_vector keyNodeIDs, std::shared_ptr<LogicalOperator> probeChild,
std::vector<std::shared_ptr<LogicalOperator>> buildChildren)
: LogicalOperator{LogicalOperatorType::INTERSECT, std::move(probeChild)},
intersectNodeID{std::move(intersectNodeID)}, buildInfos{std::move(buildInfos)} {
intersectNodeID{std::move(intersectNodeID)}, keyNodeIDs{std::move(keyNodeIDs)} {
for (auto& child : buildChildren) {
children.push_back(std::move(child));
}
Expand All @@ -43,23 +22,25 @@ class LogicalIntersect : public LogicalOperator {
f_group_pos_set getGroupsPosToFlattenOnProbeSide();
f_group_pos_set getGroupsPosToFlattenOnBuildSide(uint32_t buildIdx);

void computeSchema() override;
void computeFactorizedSchema() override;
void computeFlatSchema() override;

std::string getExpressionsForPrinting() const override { return intersectNodeID->toString(); }

inline std::shared_ptr<binder::Expression> getIntersectNodeID() const {
return intersectNodeID;
}
inline LogicalIntersectBuildInfo* getBuildInfo(uint32_t idx) const {
return buildInfos[idx].get();

inline uint32_t getNumBuilds() const { return keyNodeIDs.size(); }
inline std::shared_ptr<binder::Expression> getKeyNodeID(uint32_t idx) const {
return keyNodeIDs[idx];
}
inline uint32_t getNumBuilds() const { return buildInfos.size(); }

std::unique_ptr<LogicalOperator> copy() override;

private:
std::shared_ptr<binder::Expression> intersectNodeID;
std::vector<std::unique_ptr<LogicalIntersectBuildInfo>> buildInfos;
binder::expression_vector keyNodeIDs;
};

} // namespace planner
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ class LogicalLimit : public LogicalOperator {

f_group_pos_set getGroupsPosToFlatten();

inline void computeSchema() override { copyChildSchema(0); }
inline void computeFactorizedSchema() override { copyChildSchema(0); }
inline void computeFlatSchema() override { copyChildSchema(0); }

inline std::string getExpressionsForPrinting() const override {
return std::to_string(limitNumber);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ class LogicalMultiplicityReducer : public LogicalOperator {
explicit LogicalMultiplicityReducer(std::shared_ptr<LogicalOperator> child)
: LogicalOperator(LogicalOperatorType::MULTIPLICITY_REDUCER, std::move(child)) {}

inline void computeSchema() override { copyChildSchema(0); }
inline void computeFactorizedSchema() override { copyChildSchema(0); }
inline void computeFlatSchema() override { copyChildSchema(0); }

inline std::string getExpressionsForPrinting() const override { return std::string(); }

Expand Down
Loading

0 comments on commit 0205d87

Please sign in to comment.