Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move schema to operator #1119

Merged
merged 1 commit into from
Dec 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/binder/expression/expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@ string ExpressionUtil::toString(const expression_vector& expressions) {
if (expressions.empty()) {
return string{};
}
auto result = expressions[0]->getUniqueName();
auto result = expressions[0]->getRawName();
for (auto i = 1u; i < expressions.size(); ++i) {
result += "," + expressions[i]->getUniqueName();
result += "," + expressions[i]->getRawName();
}
return result;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,15 @@ class LogicalOperator {
inline uint32_t getNumChildren() const { return children.size(); }

// Used for operators with more than two children e.g. Union
inline void addChild(shared_ptr<LogicalOperator> op) { children.push_back(move(op)); }

inline void addChild(shared_ptr<LogicalOperator> op) { children.push_back(std::move(op)); }
inline shared_ptr<LogicalOperator> getChild(uint64_t idx) const { return children[idx]; }

inline LogicalOperatorType getOperatorType() const { return operatorType; }

inline Schema* getSchema() const { return schema.get(); }
void computeSchemaRecursive();
virtual void computeSchema() = 0;

virtual string getExpressionsForPrinting() const = 0;

bool descendantsContainType(const unordered_set<LogicalOperatorType>& types) const;
Expand All @@ -79,8 +82,13 @@ class LogicalOperator {
// TODO: remove this function once planner do not share operator across plans
virtual unique_ptr<LogicalOperator> copy() = 0;

protected:
inline void createEmptySchema() { schema = make_unique<Schema>(); }
inline void copyChildSchema(uint32_t idx) { schema = children[idx]->getSchema()->copy(); }

protected:
LogicalOperatorType operatorType;
unique_ptr<Schema> schema;
vector<shared_ptr<LogicalOperator>> children;
};

Expand Down
Original file line number Diff line number Diff line change
@@ -1,33 +1,31 @@
#pragma once

#include "base_logical_operator.h"
#include "schema.h"

namespace kuzu {
namespace planner {

class LogicalAccumulate : public LogicalOperator {
public:
LogicalAccumulate(expression_vector expressions, unique_ptr<Schema> schemaBeforeSink,
shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::ACCUMULATE, std::move(child)},
expressions{std::move(expressions)}, schemaBeforeSink{std::move(schemaBeforeSink)} {}
LogicalAccumulate(expression_vector expressions, shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::ACCUMULATE, std::move(child)}, expressions{std::move(
expressions)} {}

string getExpressionsForPrinting() const override {
void computeSchema() override;

inline string getExpressionsForPrinting() const override {
return ExpressionUtil::toString(expressions);
}

inline expression_vector getExpressions() const { return expressions; }
inline Schema* getSchemaBeforeSink() const { return schemaBeforeSink.get(); }
inline Schema* getSchemaBeforeSink() const { return children[0]->getSchema(); }

unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalAccumulate>(
expressions, schemaBeforeSink->copy(), children[0]->copy());
inline unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalAccumulate>(expressions, children[0]->copy());
}

private:
expression_vector expressions;
unique_ptr<Schema> schemaBeforeSink;
};

} // namespace planner
Expand Down
Original file line number Diff line number Diff line change
@@ -1,37 +1,35 @@
#pragma once

#include "base_logical_operator.h"
#include "schema.h"

namespace kuzu {
namespace planner {

class LogicalAggregate : public LogicalOperator {
public:
LogicalAggregate(expression_vector expressionsToGroupBy,
expression_vector expressionsToAggregate, unique_ptr<Schema> schemaBeforeAggregate,
shared_ptr<LogicalOperator> child)
expression_vector expressionsToAggregate, shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::AGGREGATE, std::move(child)},
expressionsToGroupBy{std::move(expressionsToGroupBy)}, expressionsToAggregate{std::move(
expressionsToAggregate)},
schemaBeforeAggregate{std::move(schemaBeforeAggregate)} {}
expressionsToAggregate)} {}

void computeSchema() override;

string getExpressionsForPrinting() const override;

inline bool hasExpressionsToGroupBy() const { return !expressionsToGroupBy.empty(); }
inline expression_vector getExpressionsToGroupBy() const { return expressionsToGroupBy; }
inline expression_vector getExpressionsToAggregate() const { return expressionsToAggregate; }
inline Schema* getSchemaBeforeAggregate() const { return schemaBeforeAggregate.get(); }
inline Schema* getSchemaBeforeAggregate() const { return children[0]->getSchema(); }

unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalAggregate>(expressionsToGroupBy, expressionsToAggregate,
schemaBeforeAggregate->copy(), children[0]->copy());
inline unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalAggregate>(
expressionsToGroupBy, expressionsToAggregate, children[0]->copy());
}

private:
expression_vector expressionsToGroupBy;
expression_vector expressionsToAggregate;
unique_ptr<Schema> schemaBeforeAggregate;
};

} // namespace planner
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ class LogicalCopyCSV : public LogicalOperator {
: LogicalOperator{LogicalOperatorType::COPY_CSV}, csvDescription{std::move(csvDescription)},
tableSchema{std::move(tableSchema)} {}

inline void computeSchema() override { createEmptySchema(); }

inline string getExpressionsForPrinting() const override { return tableSchema.tableName; }

inline CSVDescription getCSVDescription() const { return csvDescription; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ class LogicalCreateNode : public LogicalCreateOrDeleteNode {
: LogicalCreateOrDeleteNode{
LogicalOperatorType::CREATE_NODE, std::move(nodeAndPrimaryKeys), std::move(child)} {}

unique_ptr<LogicalOperator> copy() override {
void computeSchema() override;

inline unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalCreateNode>(nodeAndPrimaryKeys, children[0]->copy());
}
};
Expand All @@ -28,7 +30,7 @@ class LogicalCreateRel : public LogicalCreateOrDeleteRel {

inline vector<expression_pair> getSetItems(uint32_t idx) const { return setItemsPerRel[idx]; }

unique_ptr<LogicalOperator> copy() override {
inline unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalCreateRel>(rels, setItemsPerRel, children[0]->copy());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class LogicalCreateOrDeleteRel : public LogicalOperator {
vector<shared_ptr<RelExpression>> rels, shared_ptr<LogicalOperator> child)
: LogicalOperator{operatorType, std::move(child)}, rels{std::move(rels)} {}

inline void computeSchema() override { copyChildSchema(0); }

inline string getExpressionsForPrinting() const override {
expression_vector expressions;
for (auto& rel : rels) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#pragma once

#include "base_logical_operator.h"
#include "logical_ddl.h"

namespace kuzu {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,29 +1,27 @@
#pragma once

#include "base_logical_operator.h"
#include "sink_util.h"

namespace kuzu {
namespace planner {

class LogicalCrossProduct : public LogicalOperator {
public:
LogicalCrossProduct(unique_ptr<Schema> buildSideSchema,
LogicalCrossProduct(
shared_ptr<LogicalOperator> probeSideChild, shared_ptr<LogicalOperator> buildSideChild)
: LogicalOperator{LogicalOperatorType::CROSS_PRODUCT, std::move(probeSideChild),
std::move(buildSideChild)},
buildSideSchema{std::move(buildSideSchema)} {}
std::move(buildSideChild)} {}

void computeSchema() override;

inline string getExpressionsForPrinting() const override { return string(); }

inline Schema* getBuildSideSchema() const { return buildSideSchema.get(); }
inline Schema* getBuildSideSchema() const { return children[1]->getSchema(); }

inline unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalCrossProduct>(
buildSideSchema->copy(), children[0]->copy(), children[1]->copy());
return make_unique<LogicalCrossProduct>(children[0]->copy(), children[1]->copy());
}

private:
unique_ptr<Schema> buildSideSchema;
};

} // namespace planner
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
#pragma once

#include <vector>

#include "base_logical_operator.h"

namespace kuzu {
Expand All @@ -16,6 +14,8 @@ class LogicalDDL : public LogicalOperator {

inline string getExpressionsForPrinting() const override { return tableName; }

inline void computeSchema() override { schema = make_unique<Schema>(); }

inline string getTableName() const { return tableName; }

inline vector<PropertyNameDataType> getPropertyNameDataTypes() const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ class LogicalDeleteNode : public LogicalCreateOrDeleteNode {
: LogicalCreateOrDeleteNode{
LogicalOperatorType::DELETE_NODE, std::move(nodeAndPrimaryKeys), std::move(child)} {}

inline void computeSchema() override { copyChildSchema(0); }

inline unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalDeleteNode>(nodeAndPrimaryKeys, children[0]->copy());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,23 @@ namespace planner {

class LogicalDistinct : public LogicalOperator {
public:
LogicalDistinct(expression_vector expressionsToDistinct,
unique_ptr<Schema> schemaBeforeDistinct, shared_ptr<LogicalOperator> child)
LogicalDistinct(expression_vector expressionsToDistinct, shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::DISTINCT, std::move(child)},
expressionsToDistinct{std::move(expressionsToDistinct)}, schemaBeforeDistinct{std::move(
schemaBeforeDistinct)} {}
expressionsToDistinct{std::move(expressionsToDistinct)} {}

void computeSchema() override;

string getExpressionsForPrinting() const override;

inline expression_vector getExpressionsToDistinct() const { return expressionsToDistinct; }
inline Schema* getSchemaBeforeDistinct() const { return schemaBeforeDistinct.get(); }
inline Schema* getSchemaBeforeDistinct() const { return children[0]->getSchema(); }

unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalDistinct>(
expressionsToDistinct, schemaBeforeDistinct->copy(), children[0]->copy());
return make_unique<LogicalDistinct>(expressionsToDistinct, children[0]->copy());
}

private:
expression_vector expressionsToDistinct;
unique_ptr<Schema> schemaBeforeDistinct;
};

} // namespace planner
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ class LogicalDropTable : public LogicalOperator {

inline TableSchema* getTableSchema() const { return tableSchema; }

void computeSchema() override { createEmptySchema(); }

inline string getExpressionsForPrinting() const override { return tableSchema->tableName; }

inline unique_ptr<LogicalOperator> copy() override {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@ class LogicalExpressionsScan : public LogicalOperator {
: LogicalOperator{LogicalOperatorType::EXPRESSIONS_SCAN}, expressions{
std::move(expressions)} {}

void computeSchema() override;

inline string getExpressionsForPrinting() const override {
return ExpressionUtil::toString(expressions);
}

inline expression_vector getExpressions() const { return expressions; }

unique_ptr<LogicalOperator> copy() override {
inline unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalExpressionsScan>(expressions);
}

Expand Down
19 changes: 3 additions & 16 deletions src/include/planner/logical_plan/logical_operator/logical_extend.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,33 +17,20 @@ class LogicalExtend : public LogicalOperator {
nbrNode{std::move(nbrNode)}, rel{std::move(rel)}, direction{direction},
properties{std::move(properties)}, extendToNewGroup{extendToNewGroup} {}

void computeSchema() override;

inline string getExpressionsForPrinting() const override {
return boundNode->getRawName() + (direction == RelDirection::FWD ? "->" : "<-") +
nbrNode->getRawName();
}

inline void computeSchema(Schema& schema) {
auto boundGroupPos = schema.getGroupPos(boundNode->getInternalIDPropertyName());
uint32_t nbrGroupPos = 0u;
if (!extendToNewGroup) {
nbrGroupPos = boundGroupPos;
} else {
assert(schema.getGroup(boundGroupPos)->isFlat());
nbrGroupPos = schema.createGroup();
}
schema.insertToGroupAndScope(nbrNode->getInternalIDProperty(), nbrGroupPos);
for (auto& property : properties) {
schema.insertToGroupAndScope(property, nbrGroupPos);
}
}

inline shared_ptr<NodeExpression> getBoundNode() const { return boundNode; }
inline shared_ptr<NodeExpression> getNbrNode() const { return nbrNode; }
inline shared_ptr<RelExpression> getRel() const { return rel; }
inline RelDirection getDirection() const { return direction; }
inline expression_vector getProperties() const { return properties; }

unique_ptr<LogicalOperator> copy() override {
inline unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalExtend>(
boundNode, nbrNode, rel, direction, properties, extendToNewGroup, children[0]->copy());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@ class LogicalFilter : public LogicalOperator {
: LogicalOperator{LogicalOperatorType::FILTER, std::move(child)},
expression{std::move(expression)}, groupPosToSelect{groupPosToSelect} {}

string getExpressionsForPrinting() const override { return expression->getUniqueName(); }
inline void computeSchema() override { copyChildSchema(0); }

unique_ptr<LogicalOperator> copy() override {
inline string getExpressionsForPrinting() const override { return expression->getUniqueName(); }

inline unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalFilter>(expression, groupPosToSelect, children[0]->copy());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
#include "base_logical_operator.h"
#include "binder/expression/expression.h"

using namespace kuzu::binder;

namespace kuzu {
namespace planner {

Expand All @@ -14,16 +12,13 @@ class LogicalFlatten : public LogicalOperator {
: LogicalOperator{LogicalOperatorType::FLATTEN, std::move(child)}, expression{std::move(
expression)} {}

void computeSchema() override;

inline string getExpressionsForPrinting() const override { return expression->getUniqueName(); }

inline shared_ptr<Expression> getExpression() const { return expression; }

inline void computeSchema(Schema& schema) {
auto groupPos = schema.getGroupPos(expression->getUniqueName());
schema.flattenGroup(groupPos);
}

unique_ptr<LogicalOperator> copy() override {
inline unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalFlatten>(expression, children[0]->copy());
}

Expand Down
Loading