Skip to content

Commit

Permalink
[refactor](Nereids): cascades refactor (#9470)
Browse files Browse the repository at this point in the history
Describe the overview of changes.

- rename GroupExpression
- use `HashSet<GroupExpression> groupExpressions` in `memo`
- add label of `Nereids` for CI
- remove `GroupExpr` from Plan
  • Loading branch information
jackwener authored May 11, 2022
1 parent ad88eb7 commit 74352c8
Show file tree
Hide file tree
Showing 14 changed files with 107 additions and 114 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/labeler/scope-label-conf.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ kind/test:
area/vectorization:
- be/src/vec/**/*

area/nereids:
- fe/fe-core/src/main/java/org/apache/doris/nereids/**/*

area/planner:
- fe/fe-core/src/main/java/org/apache/doris/planner/**/*
- fe/fe-core/src/main/java/org/apache/doris/analysis/**/*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import org.apache.doris.common.AnalysisException;
import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.memo.PlanReference;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleSet;
import org.apache.doris.nereids.trees.plans.Plan;
Expand All @@ -46,7 +46,7 @@ public RuleSet getRuleSet() {
return context.getOptimizerContext().getRuleSet();
}

public void prunedInvalidRules(PlanReference planReference, List<Rule<Plan>> candidateRules) {
public void prunedInvalidRules(GroupExpression groupExpression, List<Rule<Plan>> candidateRules) {

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.jobs.Job;
import org.apache.doris.nereids.jobs.JobType;
import org.apache.doris.nereids.memo.PlanReference;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.pattern.PatternMatching;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.trees.plans.Plan;
Expand All @@ -30,30 +30,30 @@
import java.util.List;

/**
* Job to apply rule on {@link PlanReference}.
* Job to apply rule on {@link GroupExpression}.
*/
public class ApplyRuleJob extends Job {
private final PlanReference planReference;
private final GroupExpression groupExpression;
private final Rule<Plan> rule;
private final boolean exploredOnly;

/**
* Constructor of ApplyRuleJob.
*
* @param planReference apply rule on this {@link PlanReference}
* @param groupExpression apply rule on this {@link GroupExpression}
* @param rule rule to be applied
* @param context context of optimization
*/
public ApplyRuleJob(PlanReference planReference, Rule<Plan> rule, PlannerContext context) {
public ApplyRuleJob(GroupExpression groupExpression, Rule<Plan> rule, PlannerContext context) {
super(JobType.APPLY_RULE, context);
this.planReference = planReference;
this.groupExpression = groupExpression;
this.rule = rule;
this.exploredOnly = false;
}

@Override
public void execute() throws AnalysisException {
if (planReference.hasExplored(rule)) {
if (groupExpression.hasExplored(rule)) {
return;
}

Expand All @@ -65,20 +65,20 @@ public void execute() throws AnalysisException {
}
List<Plan> newPlanList = rule.transform(plan, context);
for (Plan newPlan : newPlanList) {
PlanReference newReference = context.getOptimizerContext().getMemo()
.newPlanReference(newPlan, planReference.getParent());
GroupExpression newGroupExpression = context.getOptimizerContext().getMemo()
.newGroupExpression(newPlan, groupExpression.getParent());
// TODO need to check return is a new Reference, other wise will be into a dead loop
if (newPlan instanceof LogicalPlan) {
pushTask(new DeriveStatsJob(newReference, context));
pushTask(new DeriveStatsJob(newGroupExpression, context));
if (exploredOnly) {
pushTask(new ExplorePlanJob(newReference, context));
pushTask(new ExplorePlanJob(newGroupExpression, context));
}
pushTask(new OptimizePlanJob(newReference, context));
pushTask(new OptimizePlanJob(newGroupExpression, context));
} else {
pushTask(new CostAndEnforcerJob(newReference, context));
pushTask(new CostAndEnforcerJob(newGroupExpression, context));
}
}
}
planReference.setExplored(rule);
groupExpression.setExplored(rule);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,17 @@
import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.jobs.Job;
import org.apache.doris.nereids.jobs.JobType;
import org.apache.doris.nereids.memo.PlanReference;
import org.apache.doris.nereids.memo.GroupExpression;

/**
* Job to compute cost and add enforcer.
*/
public class CostAndEnforcerJob extends Job {
private final PlanReference planReference;
private final GroupExpression groupExpression;

public CostAndEnforcerJob(PlanReference planReference, PlannerContext context) {
public CostAndEnforcerJob(GroupExpression groupExpression, PlannerContext context) {
super(JobType.OPTIMIZE_CHILDREN, context);
this.planReference = planReference;
this.groupExpression = groupExpression;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,24 @@
import org.apache.doris.nereids.jobs.Job;
import org.apache.doris.nereids.jobs.JobType;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.PlanReference;
import org.apache.doris.nereids.memo.GroupExpression;

/**
* Job to derive stats for {@link PlanReference} in {@link org.apache.doris.nereids.memo.Memo}.
* Job to derive stats for {@link GroupExpression} in {@link org.apache.doris.nereids.memo.Memo}.
*/
public class DeriveStatsJob extends Job {
private final PlanReference planReference;
private final GroupExpression groupExpression;
private boolean deriveChildren;

/**
* Constructor for DeriveStatsJob.
*
* @param planReference Derive stats on this {@link PlanReference}
* @param groupExpression Derive stats on this {@link GroupExpression}
* @param context context of optimization
*/
public DeriveStatsJob(PlanReference planReference, PlannerContext context) {
public DeriveStatsJob(GroupExpression groupExpression, PlannerContext context) {
super(JobType.DERIVE_STATS, context);
this.planReference = planReference;
this.groupExpression = groupExpression;
this.deriveChildren = false;
}

Expand All @@ -49,7 +49,7 @@ public DeriveStatsJob(PlanReference planReference, PlannerContext context) {
*/
public DeriveStatsJob(DeriveStatsJob other) {
super(JobType.DERIVE_STATS, other.context);
this.planReference = other.planReference;
this.groupExpression = other.groupExpression;
this.deriveChildren = other.deriveChildren;
}

Expand All @@ -58,14 +58,14 @@ public void execute() {
if (!deriveChildren) {
deriveChildren = true;
pushTask(new DeriveStatsJob(this));
for (Group childSet : planReference.getChildren()) {
for (Group childSet : groupExpression.getChildren()) {
if (!childSet.getLogicalPlanList().isEmpty()) {
pushTask(new DeriveStatsJob(childSet.getLogicalPlanList().get(0), context));
}
}
} else {
// TODO: derive stat here
planReference.setStatDerived(true);
groupExpression.setStatDerived(true);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import org.apache.doris.nereids.jobs.Job;
import org.apache.doris.nereids.jobs.JobType;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.PlanReference;
import org.apache.doris.nereids.memo.GroupExpression;

/**
* Job to explore {@link Group} in {@link org.apache.doris.nereids.memo.Memo}.
Expand All @@ -45,8 +45,8 @@ public void execute() {
if (group.isExplored()) {
return;
}
for (PlanReference planReference : group.getLogicalPlanList()) {
pushTask(new ExplorePlanJob(planReference, context));
for (GroupExpression groupExpression : group.getLogicalPlanList()) {
pushTask(new ExplorePlanJob(groupExpression, context));
}
group.setExplored(true);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import org.apache.doris.nereids.jobs.Job;
import org.apache.doris.nereids.jobs.JobType;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.PlanReference;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.pattern.Pattern;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.trees.plans.Plan;
Expand All @@ -30,34 +30,34 @@
import java.util.List;

/**
* Job to explore {@link PlanReference} in {@link org.apache.doris.nereids.memo.Memo}.
* Job to explore {@link GroupExpression} in {@link org.apache.doris.nereids.memo.Memo}.
*/
public class ExplorePlanJob extends Job {
private final PlanReference planReference;
private final GroupExpression groupExpression;

/**
* Constructor for ExplorePlanJob.
*
* @param planReference {@link PlanReference} to be explored
* @param groupExpression {@link GroupExpression} to be explored
* @param context context of optimization
*/
public ExplorePlanJob(PlanReference planReference, PlannerContext context) {
public ExplorePlanJob(GroupExpression groupExpression, PlannerContext context) {
super(JobType.EXPLORE_PLAN, context);
this.planReference = planReference;
this.groupExpression = groupExpression;
}

@Override
public void execute() {
List<Rule<Plan>> explorationRules = getRuleSet().getExplorationRules();
prunedInvalidRules(planReference, explorationRules);
prunedInvalidRules(groupExpression, explorationRules);
explorationRules.sort(Comparator.comparingInt(o -> o.getRulePromise().promise()));

for (Rule rule : explorationRules) {
pushTask(new ApplyRuleJob(planReference, rule, context));
pushTask(new ApplyRuleJob(groupExpression, rule, context));
for (int i = 0; i < rule.getPattern().children().size(); ++i) {
Pattern childPattern = rule.getPattern().child(i);
if (childPattern.arity() > 0) {
Group childSet = planReference.getChildren().get(i);
Group childSet = groupExpression.getChildren().get(i);
pushTask(new ExploreGroupJob(childSet, context));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import org.apache.doris.nereids.jobs.Job;
import org.apache.doris.nereids.jobs.JobType;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.PlanReference;
import org.apache.doris.nereids.memo.GroupExpression;

/**
* Job to optimize {@link Group} in {@link org.apache.doris.nereids.memo.Memo}.
Expand All @@ -41,12 +41,12 @@ public void execute() {
return;
}
if (!group.isExplored()) {
for (PlanReference logicalPlanReference : group.getLogicalPlanList()) {
context.getOptimizerContext().pushTask(new OptimizePlanJob(logicalPlanReference, context));
for (GroupExpression logicalGroupExpression : group.getLogicalPlanList()) {
context.getOptimizerContext().pushTask(new OptimizePlanJob(logicalGroupExpression, context));
}
}
for (PlanReference physicalPlanReference : group.getPhysicalPlanList()) {
context.getOptimizerContext().pushTask(new CostAndEnforcerJob(physicalPlanReference, context));
for (GroupExpression physicalGroupExpression : group.getPhysicalPlanList()) {
context.getOptimizerContext().pushTask(new CostAndEnforcerJob(physicalGroupExpression, context));
}
group.setExplored(true);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import org.apache.doris.nereids.jobs.Job;
import org.apache.doris.nereids.jobs.JobType;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.PlanReference;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.pattern.Pattern;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.trees.plans.Plan;
Expand All @@ -34,33 +34,33 @@
* Job to optimize {@link org.apache.doris.nereids.trees.plans.Plan} in {@link org.apache.doris.nereids.memo.Memo}.
*/
public class OptimizePlanJob extends Job {
private final PlanReference planReference;
private final GroupExpression groupExpression;

public OptimizePlanJob(PlanReference planReference, PlannerContext context) {
public OptimizePlanJob(GroupExpression groupExpression, PlannerContext context) {
super(JobType.OPTIMIZE_PLAN, context);
this.planReference = planReference;
this.groupExpression = groupExpression;
}

@Override
public void execute() {
List<Rule<Plan>> validRules = new ArrayList<>();
List<Rule<Plan>> explorationRules = getRuleSet().getExplorationRules();
List<Rule<Plan>> implementationRules = getRuleSet().getImplementationRules();
prunedInvalidRules(planReference, explorationRules);
prunedInvalidRules(planReference, implementationRules);
prunedInvalidRules(groupExpression, explorationRules);
prunedInvalidRules(groupExpression, implementationRules);
validRules.addAll(explorationRules);
validRules.addAll(implementationRules);
validRules.sort(Comparator.comparingInt(o -> o.getRulePromise().promise()));

for (Rule rule : validRules) {
pushTask(new ApplyRuleJob(planReference, rule, context));
pushTask(new ApplyRuleJob(groupExpression, rule, context));

// If child_pattern has any more children (i.e non-leaf), then we will explore the
// child before applying the rule. (assumes task pool is effectively a stack)
for (int i = 0; i < rule.getPattern().children().size(); ++i) {
Pattern childPattern = rule.getPattern().child(i);
if (childPattern.arity() > 0) {
Group childSet = planReference.getChildren().get(i);
Group childSet = groupExpression.getChildren().get(i);
pushTask(new ExploreGroupJob(childSet, context));
}
}
Expand Down
Loading

0 comments on commit 74352c8

Please sign in to comment.