Skip to content

Commit

Permalink
[improve](Nereids) Add all slots used by onClause to project when reo…
Browse files Browse the repository at this point in the history
…rder and fix reorder mark (apache#12701)

1. Add all slots used by onClause in project

```
(A & B) & C like
join(hash conjuncts: C.t2 = A.t2)
|---project(A.t2)
|   +---join(hash conjuncts: A.t1 = B.t1)
|       +---A
|       +---B
+---C

transform to (A & C) & B
join(hash conjuncts: A.t1 = B.t1)
|---project(A.t2)
|   +---join(hash conjuncts: C.t2 = A.t2)
|       +---A
|       +---C
+---B
```

But projection just include `A.t2`, can't find `A.t1`, we should add slots used by onClause when projection exist.

2. fix join reorder mark

Add mark `LAsscom` when apply `LAsscom`

3. remove slotReference

use `Slot` instead of `SlotReference` to avoid cast.
  • Loading branch information
jackwener authored and Yijia Su committed Oct 8, 2022
1 parent 0d91056 commit 8052beb
Show file tree
Hide file tree
Showing 10 changed files with 103 additions and 103 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;

/**
* Rule for change inner join LAsscom (associative and commutive).
Expand All @@ -37,7 +40,7 @@ public class InnerJoinLAsscom extends OneExplorationRuleFactory {
@Override
public Rule build() {
return innerLogicalJoin(innerLogicalJoin(), group())
.when(topJoin -> JoinLAsscomHelper.checkInner(topJoin, topJoin.left()))
.when(topJoin -> check(topJoin, topJoin.left()))
.then(topJoin -> {
JoinLAsscomHelper helper = new JoinLAsscomHelper(topJoin, topJoin.left());
if (!helper.initJoinOnCondition()) {
Expand All @@ -46,4 +49,10 @@ public Rule build() {
return helper.newTopJoin();
}).toRule(RuleType.LOGICAL_INNER_JOIN_LASSCOM);
}

public static boolean check(LogicalJoin<? extends Plan, GroupPlan> topJoin,
LogicalJoin<GroupPlan, GroupPlan> bottomJoin) {
return !bottomJoin.getJoinReorderContext().hasCommuteZigZag()
&& !topJoin.getJoinReorderContext().hasLAsscom();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ public class InnerJoinLAsscomProject extends OneExplorationRuleFactory {
@Override
public Rule build() {
return innerLogicalJoin(logicalProject(innerLogicalJoin()), group())
.when(topJoin -> JoinLAsscomHelper.checkInner(topJoin, topJoin.left().child()))
.when(topJoin -> InnerJoinLAsscom.check(topJoin, topJoin.left().child()))
.then(topJoin -> {
JoinLAsscomHelper helper = new JoinLAsscomHelper(topJoin, topJoin.left().child());
helper.initAllProject(topJoin.left());
helper.initProject(topJoin.left());
if (!helper.initJoinOnCondition()) {
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@

package org.apache.doris.nereids.rules.exploration.join;

import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.Utils;

import java.util.List;

Expand All @@ -47,7 +46,7 @@ public static boolean isNotBottomJoin(LogicalJoin<GroupPlan, GroupPlan> join) {

private static boolean containJoin(GroupPlan groupPlan) {
// TODO: tmp way to judge containJoin
List<SlotReference> output = Utils.getOutputSlotReference(groupPlan);
return !output.stream().map(SlotReference::getQualifier).allMatch(output.get(0).getQualifier()::equals);
List<Slot> output = groupPlan.getOutput();
return !output.stream().map(Slot::getQualifier).allMatch(output.get(0).getQualifier()::equals);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,14 @@

import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.PlanUtils;
import org.apache.doris.nereids.util.Utils;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
Expand Down Expand Up @@ -57,44 +56,39 @@ public JoinLAsscomHelper(LogicalJoin<? extends Plan, GroupPlan> topJoin,
* Create newTopJoin.
*/
public Plan newTopJoin() {
// Split inside-project into two part.
Map<Boolean, List<NamedExpression>> projectExprsMap = allProjects.stream()
// Split bottomJoinProject into two part.
Map<Boolean, List<NamedExpression>> projectExprsMap = bottomProjects.stream()
.collect(Collectors.partitioningBy(projectExpr -> {
Set<Slot> usedSlots = projectExpr.collect(Slot.class::isInstance);
return bOutput.containsAll(usedSlots);
return bOutputSet.containsAll(usedSlots);
}));
List<NamedExpression> newLeftProjects = projectExprsMap.get(Boolean.FALSE);
List<NamedExpression> newRightProjects = projectExprsMap.get(Boolean.TRUE);

List<NamedExpression> newLeftProjectExpr = projectExprsMap.get(Boolean.FALSE);
List<NamedExpression> newRightProjectExprs = projectExprsMap.get(Boolean.TRUE);

// If add project to B, we should add all slotReference used by hashOnCondition.
// Add all slots used by hashOnCondition when projects not empty.
// TODO: Does nonHashOnCondition also need to be considered.
Set<SlotReference> onUsedSlotRef = bottomJoin.getHashJoinConjuncts().stream()
.flatMap(expr -> {
Set<SlotReference> usedSlotRefs = expr.collect(SlotReference.class::isInstance);
Map<Boolean, List<Slot>> onUsedSlots = bottomJoin.getHashJoinConjuncts().stream()
.flatMap(onExpr -> {
Set<Slot> usedSlotRefs = onExpr.collect(Slot.class::isInstance);
return usedSlotRefs.stream();
}).filter(Utils.getOutputSlotReference(bottomJoin)::contains).collect(Collectors.toSet());
boolean existRightProject = !newRightProjectExprs.isEmpty();
boolean existLeftProject = !newLeftProjectExpr.isEmpty();
onUsedSlotRef.forEach(slotRef -> {
if (existRightProject && bOutput.contains(slotRef) && !newRightProjectExprs.contains(slotRef)) {
newRightProjectExprs.add(slotRef);
} else if (existLeftProject && aOutput.contains(slotRef) && !newLeftProjectExpr.contains(slotRef)) {
newLeftProjectExpr.add(slotRef);
}
});
}).collect(Collectors.partitioningBy(bOutputSet::contains));
List<Slot> leftUsedSlots = onUsedSlots.get(Boolean.FALSE);
List<Slot> rightUsedSlots = onUsedSlots.get(Boolean.TRUE);

if (existLeftProject) {
newLeftProjectExpr.addAll(cOutput);
addSlotsUsedByOn(rightUsedSlots, newRightProjects);
addSlotsUsedByOn(leftUsedSlots, newLeftProjects);

if (!newLeftProjects.isEmpty()) {
newLeftProjects.addAll(cOutputSet);
}
LogicalJoin<GroupPlan, GroupPlan> newBottomJoin = new LogicalJoin<>(topJoin.getJoinType(),
newBottomHashJoinConjuncts, ExpressionUtils.optionalAnd(newBottomNonHashJoinConjuncts), a, c,
bottomJoin.getJoinReorderContext());
newBottomJoin.getJoinReorderContext().setHasLAsscom(false);
newBottomJoin.getJoinReorderContext().setHasCommute(false);

Plan left = PlanUtils.projectOrSelf(newLeftProjectExpr, newBottomJoin);
Plan right = PlanUtils.projectOrSelf(newRightProjectExprs, b);
Plan left = PlanUtils.projectOrSelf(newLeftProjects, newBottomJoin);
Plan right = PlanUtils.projectOrSelf(newRightProjects, b);

LogicalJoin<Plan, Plan> newTopJoin = new LogicalJoin<>(bottomJoin.getJoinType(),
newTopHashJoinConjuncts,
Expand All @@ -105,18 +99,16 @@ public Plan newTopJoin() {
return PlanUtils.projectOrSelf(new ArrayList<>(topJoin.getOutput()), newTopJoin);
}

public static boolean checkInner(LogicalJoin<? extends Plan, GroupPlan> topJoin,
LogicalJoin<GroupPlan, GroupPlan> bottomJoin) {
return !bottomJoin.getJoinReorderContext().hasCommuteZigZag()
&& !topJoin.getJoinReorderContext().hasLAsscom();
}

public static boolean checkOuter(LogicalJoin<? extends Plan, GroupPlan> topJoin,
LogicalJoin<GroupPlan, GroupPlan> bottomJoin) {
// hasCommute will cause to lack of OuterJoinAssocRule:Left
return !topJoin.getJoinReorderContext().hasLeftAssociate()
&& !topJoin.getJoinReorderContext().hasRightAssociate()
&& !topJoin.getJoinReorderContext().hasExchange()
&& !bottomJoin.getJoinReorderContext().hasCommute();
// When project not empty, we add all slots used by hashOnCondition into projects.
private void addSlotsUsedByOn(List<Slot> usedSlots, List<NamedExpression> projects) {
if (projects.isEmpty()) {
return;
}
Set<NamedExpression> projectsSet = new HashSet<>(projects);
usedSlots.forEach(slot -> {
if (!projectsSet.contains(slot)) {
projects.add(slot);
}
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;

import com.google.common.collect.ImmutableSet;

Expand Down Expand Up @@ -50,14 +53,24 @@ public class OuterJoinLAsscom extends OneExplorationRuleFactory {
@Override
public Rule build() {
return logicalJoin(logicalJoin(), group())
.when(topJoin -> JoinLAsscomHelper.checkOuter(topJoin, topJoin.left()))
.when(topJoin -> check(topJoin, topJoin.left()))
.when(join -> VALID_TYPE_PAIR_SET.contains(Pair.of(join.left().getJoinType(), join.getJoinType())))
.then(topJoin -> {
JoinLAsscomHelper helper = new JoinLAsscomHelper(topJoin, topJoin.left());
if (!helper.initJoinOnCondition()) {
return null;
}
return helper.newTopJoin();
}).toRule(RuleType.LOGICAL_OUTER_JOIN_LASSCOM);
}

/**
* check.
*/
public static boolean check(LogicalJoin<? extends Plan, GroupPlan> topJoin,
LogicalJoin<GroupPlan, GroupPlan> bottomJoin) {
// hasCommute will cause to lack of OuterJoinAssocRule:Left
return !topJoin.getJoinReorderContext().hasLAsscom()
&& !topJoin.getJoinReorderContext().hasLeftAssociate()
&& !topJoin.getJoinReorderContext().hasRightAssociate()
&& !topJoin.getJoinReorderContext().hasExchange()
&& !bottomJoin.getJoinReorderContext().hasCommute();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ public class OuterJoinLAsscomProject extends OneExplorationRuleFactory {
@Override
public Rule build() {
return logicalJoin(logicalProject(logicalJoin()), group())
.when(topJoin -> JoinLAsscomHelper.checkOuter(topJoin, topJoin.left().child()))
.when(topJoin -> OuterJoinLAsscom.check(topJoin, topJoin.left().child()))
.when(join -> OuterJoinLAsscom.VALID_TYPE_PAIR_SET.contains(
Pair.of(join.left().child().getJoinType(), join.getJoinType())))
.then(topJoin -> {
JoinLAsscomHelper helper = new JoinLAsscomHelper(topJoin, topJoin.left().child());
helper.initAllProject(topJoin.left());
helper.initProject(topJoin.left());
if (!helper.initJoinOnCondition()) {
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,12 @@ abstract class ThreeJoinHelper {
protected final GroupPlan b;
protected final GroupPlan c;

protected final Set<Slot> aOutput;
protected final Set<Slot> bOutput;
protected final Set<Slot> cOutput;
protected final Set<Slot> aOutputSet;
protected final Set<Slot> bOutputSet;
protected final Set<Slot> cOutputSet;
protected final Set<Slot> bottomJoinOutputSet;

protected final List<NamedExpression> allProjects = Lists.newArrayList();
protected final List<NamedExpression> bottomProjects = Lists.newArrayList();

protected final List<Expression> allHashJoinConjuncts = Lists.newArrayList();
protected final List<Expression> allNonHashJoinConjuncts = Lists.newArrayList();
Expand All @@ -70,9 +71,10 @@ public ThreeJoinHelper(LogicalJoin<? extends Plan, ? extends Plan> topJoin,
this.b = b;
this.c = c;

aOutput = a.getOutputSet();
bOutput = b.getOutputSet();
cOutput = c.getOutputSet();
aOutputSet = a.getOutputSet();
bOutputSet = b.getOutputSet();
cOutputSet = c.getOutputSet();
bottomJoinOutputSet = bottomJoin.getOutputSet();

Preconditions.checkArgument(!topJoin.getHashJoinConjuncts().isEmpty(), "topJoin hashJoinConjuncts must exist.");
Preconditions.checkArgument(!bottomJoin.getHashJoinConjuncts().isEmpty(),
Expand All @@ -86,11 +88,8 @@ public ThreeJoinHelper(LogicalJoin<? extends Plan, ? extends Plan> topJoin,
ExpressionUtils.extractConjunction(otherJoinCondition)));
}

@SafeVarargs
public final void initAllProject(LogicalProject<? extends Plan>... projects) {
for (LogicalProject<? extends Plan> project : projects) {
allProjects.addAll(project.getProjects());
}
public final void initProject(LogicalProject<? extends Plan> project) {
bottomProjects.addAll(project.getProjects());
}

/**
Expand All @@ -102,14 +101,14 @@ public boolean initJoinOnCondition() {
// TODO: also need for otherJoinCondition
for (Expression topJoinOnClauseConjunct : topJoin.getHashJoinConjuncts()) {
Set<Slot> topJoinUsedSlot = topJoinOnClauseConjunct.collect(SlotReference.class::isInstance);
if (ExpressionUtils.isIntersecting(topJoinUsedSlot, aOutput) && ExpressionUtils.isIntersecting(
topJoinUsedSlot, bOutput) && ExpressionUtils.isIntersecting(topJoinUsedSlot, cOutput)) {
if (ExpressionUtils.isIntersecting(topJoinUsedSlot, aOutputSet) && ExpressionUtils.isIntersecting(
topJoinUsedSlot, bOutputSet) && ExpressionUtils.isIntersecting(topJoinUsedSlot, cOutputSet)) {
return false;
}
}

Set<Slot> newBottomJoinSlots = new HashSet<>(aOutput);
newBottomJoinSlots.addAll(cOutput);
Set<Slot> newBottomJoinSlots = new HashSet<>(aOutputSet);
newBottomJoinSlots.addAll(cOutputSet);
for (Expression hashConjunct : allHashJoinConjuncts) {
Set<SlotReference> slots = hashConjunct.collect(SlotReference.class::isInstance);
if (newBottomJoinSlots.containsAll(slots)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,9 @@
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalOlapScan;
import org.apache.doris.nereids.util.Utils;

import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
Expand Down Expand Up @@ -63,7 +62,7 @@ private DistributionSpec convertDistribution(LogicalOlapScan olapScan) {
if (distributionInfo instanceof HashDistributionInfo) {
HashDistributionInfo hashDistributionInfo = (HashDistributionInfo) distributionInfo;

List<SlotReference> output = Utils.getOutputSlotReference(olapScan);
List<Slot> output = olapScan.getOutput();
List<ExprId> hashColumns = Lists.newArrayList();
List<Column> schemaColumns = olapScan.getTable().getFullSchema();
for (int i = 0; i < schemaColumns.size(); i++) {
Expand Down
Loading

0 comments on commit 8052beb

Please sign in to comment.