Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flat schema #1357

Merged
merged 1 commit into from
Mar 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/include/optimizer/projection_push_down_optimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ class ProjectionPushDownOptimizer : public LogicalOperatorVisitor {

binder::expression_vector pruneExpressions(const binder::expression_vector& expressions);

void preAppendProjection(
planner::LogicalOperator* op, uint32_t childIdx, binder::expression_vector expressions);

private:
binder::expression_set propertiesInUse;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ class LogicalOperator {
inline LogicalOperatorType getOperatorType() const { return operatorType; }

inline Schema* getSchema() const { return schema.get(); }
virtual void computeSchema() = 0;
virtual void computeFactorizedSchema() = 0;
virtual void computeFlatSchema() = 0;

virtual std::string getExpressionsForPrinting() const = 0;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,21 @@ namespace planner {

class LogicalAccumulate : public LogicalOperator {
public:
LogicalAccumulate(binder::expression_vector expressions, std::shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::ACCUMULATE, std::move(child)}, expressions{std::move(
expressions)} {}
LogicalAccumulate(std::shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::ACCUMULATE, std::move(child)} {}

void computeSchema() override;
void computeFactorizedSchema() override;
void computeFlatSchema() override;

inline std::string getExpressionsForPrinting() const override {
return binder::ExpressionUtil::toString(expressions);
}
inline std::string getExpressionsForPrinting() const override { return std::string{}; }

inline void setExpressions(binder::expression_vector expressions_) {
expressions = std::move(expressions_);
inline binder::expression_vector getExpressions() const {
return children[0]->getSchema()->getExpressionsInScope();
}
inline binder::expression_vector getExpressions() const { return expressions; }
inline Schema* getSchemaBeforeSink() const { return children[0]->getSchema(); }

inline std::unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalAccumulate>(expressions, children[0]->copy());
return make_unique<LogicalAccumulate>(children[0]->copy());
}

private:
binder::expression_vector expressions;
};

} // namespace planner
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@ class LogicalAggregate : public LogicalOperator {
expressionsToGroupBy{std::move(expressionsToGroupBy)}, expressionsToAggregate{std::move(
expressionsToAggregate)} {}

void computeFactorizedSchema() override;
void computeFlatSchema() override;

f_group_pos_set getGroupsPosToFlattenForGroupBy();
f_group_pos_set getGroupsPosToFlattenForAggregate();

void computeSchema() override;

std::string getExpressionsForPrinting() const override;

inline bool hasExpressionsToGroupBy() const { return !expressionsToGroupBy.empty(); }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ class LogicalCopy : public LogicalOperator {
: LogicalOperator{LogicalOperatorType::COPY_CSV},
copyDescription{copyDescription}, tableID{tableID}, tableName{std::move(tableName)} {}

inline void computeSchema() override { createEmptySchema(); }
inline void computeFactorizedSchema() override { createEmptySchema(); }
inline void computeFlatSchema() override { createEmptySchema(); }

inline std::string getExpressionsForPrinting() const override { return tableName; }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ class LogicalCreateNode : public LogicalUpdateNode {
primaryKeys{std::move(primaryKeys)} {}
~LogicalCreateNode() override = default;

void computeSchema() override;
void computeFactorizedSchema() override;
void computeFlatSchema() override;

inline f_group_pos_set getGroupsPosToFlatten() {
// Flatten all inputs. E.g. MATCH (a) CREATE (b). We need to create b for each tuple in the
Expand Down Expand Up @@ -45,6 +46,9 @@ class LogicalCreateRel : public LogicalUpdateRel {
setItemsPerRel{std::move(setItemsPerRel)} {}
~LogicalCreateRel() override = default;

inline void computeFactorizedSchema() override { copyChildSchema(0); }
inline void computeFlatSchema() override { copyChildSchema(0); }

inline f_group_pos_set getGroupsPosToFlatten() {
auto childSchema = children[0]->getSchema();
return factorization::FlattenAll::getGroupsPosToFlatten(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ class LogicalCrossProduct : public LogicalOperator {
: LogicalOperator{LogicalOperatorType::CROSS_PRODUCT, std::move(probeSideChild),
std::move(buildSideChild)} {}

void computeSchema() override;
void computeFactorizedSchema() override;
void computeFlatSchema() override;

inline std::string getExpressionsForPrinting() const override { return std::string(); }

Expand Down
10 changes: 3 additions & 7 deletions src/include/planner/logical_plan/logical_operator/logical_ddl.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,16 @@ class LogicalDDL : public LogicalOperator {
: LogicalOperator{operatorType}, tableName{std::move(tableName)},
outputExpression{std::move(outputExpression)} {}

void computeFactorizedSchema() override;
void computeFlatSchema() override;

inline std::string getTableName() const { return tableName; }
inline std::shared_ptr<binder::Expression> getOutputExpression() const {
return outputExpression;
}

inline std::string getExpressionsForPrinting() const override { return tableName; }

inline void computeSchema() override {
schema = std::make_unique<Schema>();
auto groupPos = schema->createGroup();
schema->insertToGroupAndScope(outputExpression, groupPos);
schema->setGroupAsSingleState(groupPos);
}

protected:
std::string tableName;
std::shared_ptr<binder::Expression> outputExpression;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ class LogicalDeleteNode : public LogicalUpdateNode {
primaryKeys{std::move(primaryKeys)} {}
~LogicalDeleteNode() override = default;

inline void computeFactorizedSchema() override { copyChildSchema(0); }
inline void computeFlatSchema() override { copyChildSchema(0); }

inline std::shared_ptr<binder::Expression> getPrimaryKey(size_t idx) const {
return primaryKeys[idx];
}
Expand All @@ -33,6 +36,9 @@ class LogicalDeleteRel : public LogicalUpdateRel {
: LogicalUpdateRel{LogicalOperatorType::DELETE_REL, std::move(rels), std::move(child)} {}
~LogicalDeleteRel() override = default;

inline void computeFactorizedSchema() override { copyChildSchema(0); }
inline void computeFlatSchema() override { copyChildSchema(0); }

inline f_group_pos_set getGroupsPosToFlatten(uint32_t relIdx) {
f_group_pos_set result;
auto rel = rels[relIdx];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@ class LogicalDistinct : public LogicalOperator {
: LogicalOperator{LogicalOperatorType::DISTINCT, std::move(child)},
expressionsToDistinct{std::move(expressionsToDistinct)} {}

f_group_pos_set getGroupsPosToFlatten();
void computeFactorizedSchema() override;
void computeFlatSchema() override;

void computeSchema() override;
f_group_pos_set getGroupsPosToFlatten();

std::string getExpressionsForPrinting() const override;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ class LogicalExpressionsScan : public LogicalOperator {
: LogicalOperator{LogicalOperatorType::EXPRESSIONS_SCAN}, expressions{
std::move(expressions)} {}

void computeSchema() override;
void computeFactorizedSchema() override;
void computeFlatSchema() override;

inline std::string getExpressionsForPrinting() const override {
return binder::ExpressionUtil::toString(expressions);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ class LogicalExtend : public LogicalOperator {

f_group_pos_set getGroupsPosToFlatten();

void computeSchema() override;
void computeFactorizedSchema() override;
void computeFlatSchema() override;

inline std::string getExpressionsForPrinting() const override {
return boundNode->toString() + (direction == common::RelDirection::FWD ? "->" : "<-") +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@ class LogicalFilter : public LogicalOperator {
: LogicalOperator{LogicalOperatorType::FILTER, std::move(child)}, expression{std::move(
expression)} {}

f_group_pos_set getGroupsPosToFlatten();
inline void computeFactorizedSchema() override { copyChildSchema(0); }
inline void computeFlatSchema() override { copyChildSchema(0); }

inline void computeSchema() override { copyChildSchema(0); }
f_group_pos_set getGroupsPosToFlatten();

inline std::string getExpressionsForPrinting() const override { return expression->toString(); }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ class LogicalFlatten : public LogicalOperator {
LogicalFlatten(f_group_pos groupPos, std::shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::FLATTEN, std::move(child)}, groupPos{groupPos} {}

void computeSchema() override;
void computeFactorizedSchema() override;
void computeFlatSchema() override;

inline std::string getExpressionsForPrinting() const override { return std::string{}; }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ class LogicalFTableScan : public LogicalOperator {
expressionsToScan{std::move(expressionsToScan)}, schemaToScanFrom{
std::move(schemaToScanFrom)} {}

void computeSchema() override;
void computeFactorizedSchema() override;
void computeFlatSchema() override;

inline std::string getExpressionsForPrinting() const override {
return binder::ExpressionUtil::toString(expressionsToScan);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,46 +14,38 @@ class LogicalHashJoin : public LogicalOperator {
public:
// Inner and left join.
LogicalHashJoin(binder::expression_vector joinNodeIDs, common::JoinType joinType,
bool isProbeAcc, binder::expression_vector expressionsToMaterialize,
std::shared_ptr<LogicalOperator> probeSideChild,
bool isProbeAcc, std::shared_ptr<LogicalOperator> probeSideChild,
std::shared_ptr<LogicalOperator> buildSideChild)
: LogicalHashJoin{std::move(joinNodeIDs), joinType, nullptr, isProbeAcc,
std::move(expressionsToMaterialize), std::move(probeSideChild),
std::move(buildSideChild)} {}
std::move(probeSideChild), std::move(buildSideChild)} {}

// Mark join.
LogicalHashJoin(binder::expression_vector joinNodeIDs, std::shared_ptr<binder::Expression> mark,
bool isProbeAcc, std::shared_ptr<LogicalOperator> probeSideChild,
std::shared_ptr<LogicalOperator> buildSideChild)
: LogicalHashJoin{std::move(joinNodeIDs), common::JoinType::MARK, std::move(mark),
isProbeAcc, binder::expression_vector{} /* expressionsToMaterialize */,
std::move(probeSideChild), std::move(buildSideChild)} {}
isProbeAcc, std::move(probeSideChild), std::move(buildSideChild)} {}

LogicalHashJoin(binder::expression_vector joinNodeIDs, common::JoinType joinType,
std::shared_ptr<binder::Expression> mark, bool isProbeAcc,
binder::expression_vector expressionsToMaterialize,
std::shared_ptr<LogicalOperator> probeSideChild,
std::shared_ptr<LogicalOperator> buildSideChild)
: LogicalOperator{LogicalOperatorType::HASH_JOIN, std::move(probeSideChild),
std::move(buildSideChild)},
joinNodeIDs(std::move(joinNodeIDs)), joinType{joinType}, mark{std::move(mark)},
isProbeAcc{isProbeAcc}, expressionsToMaterialize{std::move(expressionsToMaterialize)} {}
isProbeAcc{isProbeAcc} {}

f_group_pos_set getGroupsPosToFlattenOnProbeSide();
f_group_pos_set getGroupsPosToFlattenOnBuildSide();

void computeSchema() override;
void computeFactorizedSchema() override;
void computeFlatSchema() override;

inline std::string getExpressionsForPrinting() const override {
return binder::ExpressionUtil::toString(joinNodeIDs);
}

inline void setExpressionsToMaterialize(binder::expression_vector expressions) {
expressionsToMaterialize = std::move(expressions);
}
inline binder::expression_vector getExpressionsToMaterialize() const {
return expressionsToMaterialize;
}
binder::expression_vector getExpressionsToMaterialize() const;
inline binder::expression_vector getJoinNodeIDs() const { return joinNodeIDs; }
inline common::JoinType getJoinType() const { return joinType; }

Expand All @@ -62,11 +54,10 @@ class LogicalHashJoin : public LogicalOperator {
return mark;
}
inline bool getIsProbeAcc() const { return isProbeAcc; }
inline Schema* getBuildSideSchema() const { return children[1]->getSchema(); }

inline std::unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalHashJoin>(joinNodeIDs, joinType, mark, isProbeAcc,
expressionsToMaterialize, children[0]->copy(), children[1]->copy());
return make_unique<LogicalHashJoin>(
joinNodeIDs, joinType, mark, isProbeAcc, children[0]->copy(), children[1]->copy());
}

private:
Expand All @@ -86,7 +77,6 @@ class LogicalHashJoin : public LogicalOperator {
common::JoinType joinType;
std::shared_ptr<binder::Expression> mark; // when joinType is Mark
bool isProbeAcc;
binder::expression_vector expressionsToMaterialize;
};

} // namespace planner
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,34 +7,13 @@
namespace kuzu {
namespace planner {

struct LogicalIntersectBuildInfo {
LogicalIntersectBuildInfo(
std::shared_ptr<binder::Expression> keyNodeID, binder::expression_vector expressions)
: keyNodeID{std::move(keyNodeID)}, expressionsToMaterialize{std::move(expressions)} {}

inline void setExpressionsToMaterialize(binder::expression_vector expressions) {
expressionsToMaterialize = std::move(expressions);
}
inline binder::expression_vector getExpressionsToMaterialize() const {
return expressionsToMaterialize;
}

inline std::unique_ptr<LogicalIntersectBuildInfo> copy() {
return make_unique<LogicalIntersectBuildInfo>(keyNodeID, expressionsToMaterialize);
}

std::shared_ptr<binder::Expression> keyNodeID;
binder::expression_vector expressionsToMaterialize;
};

class LogicalIntersect : public LogicalOperator {
public:
LogicalIntersect(std::shared_ptr<binder::Expression> intersectNodeID,
std::shared_ptr<LogicalOperator> probeChild,
std::vector<std::shared_ptr<LogicalOperator>> buildChildren,
std::vector<std::unique_ptr<LogicalIntersectBuildInfo>> buildInfos)
binder::expression_vector keyNodeIDs, std::shared_ptr<LogicalOperator> probeChild,
std::vector<std::shared_ptr<LogicalOperator>> buildChildren)
: LogicalOperator{LogicalOperatorType::INTERSECT, std::move(probeChild)},
intersectNodeID{std::move(intersectNodeID)}, buildInfos{std::move(buildInfos)} {
intersectNodeID{std::move(intersectNodeID)}, keyNodeIDs{std::move(keyNodeIDs)} {
for (auto& child : buildChildren) {
children.push_back(std::move(child));
}
Expand All @@ -43,23 +22,25 @@ class LogicalIntersect : public LogicalOperator {
f_group_pos_set getGroupsPosToFlattenOnProbeSide();
f_group_pos_set getGroupsPosToFlattenOnBuildSide(uint32_t buildIdx);

void computeSchema() override;
void computeFactorizedSchema() override;
void computeFlatSchema() override;

std::string getExpressionsForPrinting() const override { return intersectNodeID->toString(); }

inline std::shared_ptr<binder::Expression> getIntersectNodeID() const {
return intersectNodeID;
}
inline LogicalIntersectBuildInfo* getBuildInfo(uint32_t idx) const {
return buildInfos[idx].get();

inline uint32_t getNumBuilds() const { return keyNodeIDs.size(); }
inline std::shared_ptr<binder::Expression> getKeyNodeID(uint32_t idx) const {
return keyNodeIDs[idx];
}
inline uint32_t getNumBuilds() const { return buildInfos.size(); }

std::unique_ptr<LogicalOperator> copy() override;

private:
std::shared_ptr<binder::Expression> intersectNodeID;
std::vector<std::unique_ptr<LogicalIntersectBuildInfo>> buildInfos;
binder::expression_vector keyNodeIDs;
};

} // namespace planner
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ class LogicalLimit : public LogicalOperator {

f_group_pos_set getGroupsPosToFlatten();

inline void computeSchema() override { copyChildSchema(0); }
inline void computeFactorizedSchema() override { copyChildSchema(0); }
inline void computeFlatSchema() override { copyChildSchema(0); }

inline std::string getExpressionsForPrinting() const override {
return std::to_string(limitNumber);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ class LogicalMultiplicityReducer : public LogicalOperator {
explicit LogicalMultiplicityReducer(std::shared_ptr<LogicalOperator> child)
: LogicalOperator(LogicalOperatorType::MULTIPLICITY_REDUCER, std::move(child)) {}

inline void computeSchema() override { copyChildSchema(0); }
inline void computeFactorizedSchema() override { copyChildSchema(0); }
inline void computeFlatSchema() override { copyChildSchema(0); }

inline std::string getExpressionsForPrinting() const override { return std::string(); }

Expand Down
Loading