Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix plan deep copy #1345

Merged
merged 1 commit into from
Mar 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ class LogicalOperator {
inline LogicalOperatorType getOperatorType() const { return operatorType; }

inline Schema* getSchema() const { return schema.get(); }
void computeSchemaRecursive();
virtual void computeSchema() = 0;

virtual std::string getExpressionsForPrinting() const = 0;
Expand Down
2 changes: 1 addition & 1 deletion src/optimizer/remove_factorization_rewriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ std::shared_ptr<planner::LogicalOperator> RemoveFactorizationRewriter::visitOper
for (auto i = 0; i < op->getNumChildren(); ++i) {
op->setChild(i, visitOperator(op->getChild(i)));
}
op->getSchema()->clear();
assert(op->getSchema() == nullptr);
return visitOperatorReplaceSwitch(op);
}

Expand Down
7 changes: 0 additions & 7 deletions src/planner/operator/base_logical_operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,6 @@ LogicalOperator::LogicalOperator(
}
}

void LogicalOperator::computeSchemaRecursive() {
for (auto& child : children) {
child->computeSchemaRecursive();
}
computeSchema();
}

std::string LogicalOperator::toString(uint64_t depth) const {
auto padding = std::string(depth * 4, ' ');
std::string result = padding;
Expand Down
1 change: 0 additions & 1 deletion src/planner/operator/logical_plan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ std::unique_ptr<LogicalPlan> LogicalPlan::deepCopy() const {
assert(!isEmpty());
auto plan = std::make_unique<LogicalPlan>();
plan->lastOperator = lastOperator->copy(); // deep copy sub-plan
plan->lastOperator->computeSchemaRecursive();
plan->estCardinality = estCardinality;
plan->cost = cost;
return plan;
Expand Down
57 changes: 29 additions & 28 deletions src/planner/planner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,37 +27,40 @@ namespace planner {
std::unique_ptr<LogicalPlan> Planner::getBestPlan(const Catalog& catalog,
const NodesStatisticsAndDeletedIDs& nodesStatistics, const RelsStatistics& relsStatistics,
const BoundStatement& statement) {
std::unique_ptr<LogicalPlan> plan;
switch (statement.getStatementType()) {
case StatementType::QUERY: {
return QueryPlanner(catalog, nodesStatistics, relsStatistics).getBestPlan(statement);
}
plan = QueryPlanner(catalog, nodesStatistics, relsStatistics).getBestPlan(statement);
} break;
case StatementType::CREATE_NODE_CLAUSE: {
return planCreateNodeTable(statement);
}
plan = planCreateNodeTable(statement);
} break;
case StatementType::CREATE_REL_CLAUSE: {
return planCreateRelTable(statement);
}
plan = planCreateRelTable(statement);
} break;
case StatementType::COPY_CSV: {
return planCopy(statement);
}
plan = planCopy(statement);
} break;
case StatementType::DROP_TABLE: {
return planDropTable(statement);
}
plan = planDropTable(statement);
} break;
case StatementType::RENAME_TABLE: {
return planRenameTable(statement);
}
plan = planRenameTable(statement);
} break;
case StatementType::ADD_PROPERTY: {
return planAddProperty(statement);
}
plan = planAddProperty(statement);
} break;
case StatementType::DROP_PROPERTY: {
return planDropProperty(statement);
}
plan = planDropProperty(statement);
} break;
case StatementType::RENAME_PROPERTY: {
return planRenameProperty(statement);
}
plan = planRenameProperty(statement);
} break;
default:
assert(false);
throw common::NotImplementedException("getBestPlan()");
}
// Avoid sharing operator across plans.
return plan->deepCopy();
}

std::vector<std::unique_ptr<LogicalPlan>> Planner::getAllPlans(const Catalog& catalog,
Expand All @@ -66,7 +69,13 @@ std::vector<std::unique_ptr<LogicalPlan>> Planner::getAllPlans(const Catalog& ca
// We enumerate all plans for our testing framework. This API should only be used for QUERY
// but not DDL or COPY_CSV.
assert(statement.getStatementType() == StatementType::QUERY);
return QueryPlanner(catalog, nodesStatistics, relsStatistics).getAllPlans(statement);
auto planner = QueryPlanner(catalog, nodesStatistics, relsStatistics);
std::vector<std::unique_ptr<LogicalPlan>> plans;
for (auto& plan : planner.getAllPlans(statement)) {
// Avoid sharing operator across plans.
plans.push_back(plan->deepCopy());
}
return plans;
}

std::unique_ptr<LogicalPlan> Planner::planCreateNodeTable(const BoundStatement& statement) {
Expand All @@ -75,7 +84,6 @@ std::unique_ptr<LogicalPlan> Planner::planCreateNodeTable(const BoundStatement&
auto createNodeTable = make_shared<LogicalCreateNodeTable>(createNodeClause.getTableName(),
createNodeClause.getPropertyNameDataTypes(), createNodeClause.getPrimaryKeyIdx(),
statement.getStatementResult()->getSingleExpressionToCollect());
createNodeTable->computeSchema();
plan->setLastOperator(std::move(createNodeTable));
return plan;
}
Expand All @@ -87,7 +95,6 @@ std::unique_ptr<LogicalPlan> Planner::planCreateRelTable(const BoundStatement& s
createRelClause.getPropertyNameDataTypes(), createRelClause.getRelMultiplicity(),
createRelClause.getSrcTableID(), createRelClause.getDstTableID(),
statement.getStatementResult()->getSingleExpressionToCollect());
createRelTable->computeSchema();
plan->setLastOperator(std::move(createRelTable));
return plan;
}
Expand All @@ -98,7 +105,6 @@ std::unique_ptr<LogicalPlan> Planner::planDropTable(const BoundStatement& statem
auto dropTable =
make_shared<LogicalDropTable>(dropTableClause.getTableID(), dropTableClause.getTableName(),
statement.getStatementResult()->getSingleExpressionToCollect());
dropTable->computeSchema();
plan->setLastOperator(std::move(dropTable));
return plan;
}
Expand All @@ -109,7 +115,6 @@ std::unique_ptr<LogicalPlan> Planner::planRenameTable(const BoundStatement& stat
auto renameTable = make_shared<LogicalRenameTable>(renameTableClause.getTableID(),
renameTableClause.getTableName(), renameTableClause.getNewName(),
statement.getStatementResult()->getSingleExpressionToCollect());
renameTable->computeSchema();
plan->setLastOperator(std::move(renameTable));
return plan;
}
Expand All @@ -121,7 +126,6 @@ std::unique_ptr<LogicalPlan> Planner::planAddProperty(const BoundStatement& stat
addPropertyClause.getPropertyName(), addPropertyClause.getDataType(),
addPropertyClause.getDefaultValue(), addPropertyClause.getTableName(),
statement.getStatementResult()->getSingleExpressionToCollect());
addProperty->computeSchema();
plan->setLastOperator(std::move(addProperty));
return plan;
}
Expand All @@ -132,7 +136,6 @@ std::unique_ptr<LogicalPlan> Planner::planDropProperty(const BoundStatement& sta
auto dropProperty = make_shared<LogicalDropProperty>(dropPropertyClause.getTableID(),
dropPropertyClause.getPropertyID(), dropPropertyClause.getTableName(),
statement.getStatementResult()->getSingleExpressionToCollect());
dropProperty->computeSchema();
plan->setLastOperator(std::move(dropProperty));
return plan;
}
Expand All @@ -144,7 +147,6 @@ std::unique_ptr<LogicalPlan> Planner::planRenameProperty(const BoundStatement& s
renamePropertyClause.getTableName(), renamePropertyClause.getPropertyID(),
renamePropertyClause.getNewName(),
statement.getStatementResult()->getSingleExpressionToCollect());
renameProperty->computeSchema();
plan->setLastOperator(std::move(renameProperty));
return plan;
}
Expand All @@ -154,7 +156,6 @@ std::unique_ptr<LogicalPlan> Planner::planCopy(const BoundStatement& statement)
auto plan = std::make_unique<LogicalPlan>();
auto copyCSV = make_shared<LogicalCopy>(copyCSVClause.getCopyDescription(),
copyCSVClause.getTableID(), copyCSVClause.getTableName());
copyCSV->computeSchema();
plan->setLastOperator(std::move(copyCSV));
return plan;
}
Expand Down
14 changes: 4 additions & 10 deletions src/planner/query_planner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,7 @@ std::vector<std::unique_ptr<LogicalPlan>> QueryPlanner::planSingleQuery(
for (auto i = 0u; i < singleQuery.getNumQueryParts(); ++i) {
plans = planQueryPart(*singleQuery.getQueryPart(i), std::move(plans));
}
std::vector<std::unique_ptr<LogicalPlan>> result;
for (auto& plan : plans) {
// This is copy is to avoid sharing operator across plans. Later optimization requires
// each plan to be independent.
result.push_back(plan->deepCopy());
}
return result;
return plans;
}

std::vector<std::unique_ptr<LogicalPlan>> QueryPlanner::planQueryPart(
Expand Down Expand Up @@ -405,15 +399,15 @@ std::vector<std::vector<std::unique_ptr<LogicalPlan>>> QueryPlanner::cartesianPr
for (auto& childLogicalPlan : childLogicalPlans) {
if (resultChildrenPlans.empty()) {
std::vector<std::unique_ptr<LogicalPlan>> logicalPlans;
logicalPlans.push_back(childLogicalPlan->deepCopy());
logicalPlans.push_back(childLogicalPlan->shallowCopy());
curChildResultLogicalPlans.push_back(std::move(logicalPlans));
} else {
for (auto& resultChildPlans : resultChildrenPlans) {
std::vector<std::unique_ptr<LogicalPlan>> logicalPlans;
for (auto& resultChildPlan : resultChildPlans) {
logicalPlans.push_back(resultChildPlan->deepCopy());
logicalPlans.push_back(resultChildPlan->shallowCopy());
}
logicalPlans.push_back(childLogicalPlan->deepCopy());
logicalPlans.push_back(childLogicalPlan->shallowCopy());
curChildResultLogicalPlans.push_back(std::move(logicalPlans));
}
}
Expand Down