Skip to content

Commit

Permalink
[feature](nereids) add rewrite rule PushCountIntoUnionAll (apache#33530)
Browse files Browse the repository at this point in the history
  • Loading branch information
feiniaofeiafei committed Oct 24, 2024
1 parent 7d872f9 commit 65dd75a
Show file tree
Hide file tree
Showing 7 changed files with 796 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@
import org.apache.doris.nereids.rules.rewrite.PushConjunctsIntoEsScan;
import org.apache.doris.nereids.rules.rewrite.PushConjunctsIntoJdbcScan;
import org.apache.doris.nereids.rules.rewrite.PushConjunctsIntoOdbcScan;
import org.apache.doris.nereids.rules.rewrite.PushCountIntoUnionAll;
import org.apache.doris.nereids.rules.rewrite.PushDownAggThroughJoin;
import org.apache.doris.nereids.rules.rewrite.PushDownAggThroughJoinOnPkFk;
import org.apache.doris.nereids.rules.rewrite.PushDownAggThroughJoinOneSide;
Expand Down Expand Up @@ -344,7 +345,8 @@ public class Rewriter extends AbstractBatchJobExecutor {
new PushDownAggThroughJoinOneSide(),
new PushDownAggThroughJoin()
)),
costBased(custom(RuleType.PUSH_DOWN_DISTINCT_THROUGH_JOIN, PushDownDistinctThroughJoin::new))
costBased(custom(RuleType.PUSH_DOWN_DISTINCT_THROUGH_JOIN, PushDownDistinctThroughJoin::new)),
topDown(new PushCountIntoUnionAll())
),

// this rule should invoke after infer predicate and push down distinct, and before push down limit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ public enum RuleType {
PUSH_CONJUNCTS_INTO_ES_SCAN(RuleTypeClass.REWRITE),
OLAP_SCAN_TABLET_PRUNE(RuleTypeClass.REWRITE),
PUSH_AGGREGATE_TO_OLAP_SCAN(RuleTypeClass.REWRITE),
PUSH_COUNT_INTO_UNION_ALL(RuleTypeClass.REWRITE),
EXTRACT_SINGLE_TABLE_EXPRESSION_FROM_DISJUNCTION(RuleTypeClass.REWRITE),
HIDE_ONE_ROW_RELATION_UNDER_UNION(RuleTypeClass.REWRITE),
PUSH_PROJECT_THROUGH_UNION(RuleTypeClass.REWRITE),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ private void checkAllSlotReferenceFromChildren(Plan plan) {
if (notFromChildren.isEmpty()) {
return;
}

notFromChildren = removeValidSlotsNotFromChildren(notFromChildren, childrenOutput);
if (!notFromChildren.isEmpty()) {
if (plan.arity() != 0 && plan.child(0) instanceof LogicalAggregate) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum0;
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalSetOperation;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
import org.apache.doris.nereids.util.ExpressionUtils;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.Lists;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
* LogicalAggregate (groupByExpr=[c1#13], outputExpr=[c1#13, count(c1#13) AS `count(c1)`#15])
* +--LogicalUnion (outputs=[c1#13], regularChildrenOutputs=[[c1#9], [a#4], [a#7]])
* |--child1 (output = [[c1#9]])
* |--child2 (output = [[a#4]])
* +--child3 (output = [[a#7]])
* transform to:
* LogicalAggregate (groupByExpr=[c1#13], outputExpr=[c1#13, sum0(count(c1)#19) AS `count(c1)`#15])
* +--LogicalUnion (outputs=[c1#13, count(c1)#19], regularChildrenOutputs=[[c1#9, count(c1)#16],
* [a#4, count(a)#17], [a#7, count(a)#18]])
* |--LogicalAggregate (groupByExpr=[c1#9], outputExpr=[c1#9, count(c1#9) AS `count(c1)`#16])
* | +--child1
* |--LogicalAggregate (groupByExpr=[a#4], outputExpr=[a#4, count(a#4) AS `count(a)`#17])
* | +--child2
* +--LogicalAggregate (groupByExpr=[a#7], outputExpr=[a#7, count(a#7) AS `count(a)`#18]]
* +--child3
*/
public class PushCountIntoUnionAll implements RewriteRuleFactory {
@Override
public List<Rule> buildRules() {
return ImmutableList.of(logicalAggregate(logicalUnion().when(this::checkUnion))
.when(this::checkAgg)
.then(this::doPush)
.toRule(RuleType.PUSH_COUNT_INTO_UNION_ALL),
logicalAggregate(logicalProject(logicalUnion().when(this::checkUnion)))
.when(this::checkAgg)
.when(this::checkProjectUseless)
.then(this::removeProjectAndPush)
.toRule(RuleType.PUSH_COUNT_INTO_UNION_ALL)
);
}

private Plan doPush(LogicalAggregate<LogicalUnion> agg) {
LogicalUnion logicalUnion = agg.child();
List<Slot> outputs = logicalUnion.getOutput();
Map<Slot, Integer> replaceMap = new HashMap<>();
for (int i = 0; i < outputs.size(); i++) {
replaceMap.put(outputs.get(i), i);
}
int childSize = logicalUnion.children().size();
List<Expression> upperGroupByExpressions = agg.getGroupByExpressions();
List<NamedExpression> upperOutputExpressions = agg.getOutputExpressions();
Builder<Plan> newChildren = ImmutableList.builderWithExpectedSize(childSize);
Builder<List<SlotReference>> childrenOutputs = ImmutableList.builderWithExpectedSize(childSize);
// create the pushed down LogicalAggregate
List<List<SlotReference>> childSlots = logicalUnion.getRegularChildrenOutputs();
for (int i = 0; i < childSize; i++) {
List<SlotReference> childOutputs = childSlots.get(i);
List<Expression> groupByExpressions = replaceExpressionByUnionAll(upperGroupByExpressions, replaceMap,
childOutputs);
List<NamedExpression> outputExpressions = replaceExpressionByUnionAll(upperOutputExpressions, replaceMap,
childOutputs);
Plan child = logicalUnion.children().get(i);
LogicalAggregate<Plan> logicalAggregate = new LogicalAggregate<>(groupByExpressions, outputExpressions,
child);
newChildren.add(logicalAggregate);
childrenOutputs.add((List<SlotReference>) (List) logicalAggregate.getOutput());
}

// create the new LogicalUnion
LogicalSetOperation newLogicalUnion = logicalUnion.withChildrenAndTheirOutputs(newChildren.build(),
childrenOutputs.build());
List<NamedExpression> newLogicalUnionOutputs = Lists.newArrayList();
for (NamedExpression ce : upperOutputExpressions) {
if (ce instanceof Alias) {
newLogicalUnionOutputs.add(new SlotReference(ce.getName(), ce.getDataType(), ce.nullable()));
} else if (ce instanceof SlotReference) {
newLogicalUnionOutputs.add(ce);
} else {
return logicalUnion;
}
}
newLogicalUnion = newLogicalUnion.withNewOutputs(newLogicalUnionOutputs);

// The count in the upper agg is converted to sum0, and the alias id and name of the count remain unchanged.
Builder<NamedExpression> newUpperOutputExpressions = ImmutableList.builderWithExpectedSize(
upperOutputExpressions.size());
for (int i = 0; i < upperOutputExpressions.size(); i++) {
NamedExpression sum0Child = newLogicalUnionOutputs.get(i);
Expression rewrittenExpression = upperOutputExpressions.get(i).rewriteDownShortCircuit(expr -> {
if (expr instanceof Alias && ((Alias) expr).child() instanceof Count) {
Alias alias = ((Alias) expr);
return new Alias(alias.getExprId(), new Sum0(sum0Child), alias.getName());
}
return expr;
});
newUpperOutputExpressions.add((NamedExpression) rewrittenExpression);
}
return agg.withAggOutputChild(newUpperOutputExpressions.build(), newLogicalUnion);
}

private <E extends Expression> List<E> replaceExpressionByUnionAll(List<E> expressions,
Map<Slot, Integer> replaceMap, List<? extends Slot> childOutputs) {
// Traverse expressions. If a slot in replaceMap appears, replace it with childOutputs[replaceMap[slot]]
return ExpressionUtils.rewriteDownShortCircuit(expressions, expr -> {
if (expr instanceof Alias && ((Alias) expr).child() instanceof Count) {
Count cnt = (Count) ((Alias) expr).child();
if (cnt.isCountStar()) {
return new Alias(new Count());
} else {
Expression newCntChild = cnt.child(0).rewriteDownShortCircuit(e -> {
if (e instanceof SlotReference && replaceMap.containsKey(e)) {
return childOutputs.get(replaceMap.get(e));
}
return e;
});
return new Alias(new Count(newCntChild));
}
} else if (expr instanceof SlotReference && replaceMap.containsKey(expr)) {
return childOutputs.get(replaceMap.get(expr));
}
return expr;
});
}

private boolean checkAgg(LogicalAggregate aggregate) {
Set<Count> res = ExpressionUtils.collect(aggregate.getOutputExpressions(), expr -> expr instanceof Count);
if (res.isEmpty()) {
return false;
}
return !hasUnsuportedAggFunc(aggregate);
}

private boolean checkProjectUseless(LogicalAggregate<LogicalProject<LogicalUnion>> agg) {
LogicalProject<LogicalUnion> project = agg.child();
if (project.getProjects().size() != 1) {
return false;
}
if (!(project.getProjects().get(0) instanceof Alias)) {
return false;
}
Alias alias = (Alias) project.getProjects().get(0);
if (!alias.child(0).equals(new TinyIntLiteral((byte) 1))) {
return false;
}
List<NamedExpression> aggOutputs = agg.getOutputExpressions();
Slot slot = project.getOutput().get(0);
if (ExpressionUtils.anyMatch(aggOutputs, expr -> expr.equals(slot))) {
return false;
}
return true;
}

private Plan removeProjectAndPush(LogicalAggregate<LogicalProject<LogicalUnion>> agg) {
Plan afterRemove = agg.withChildren(agg.child().child());
return doPush((LogicalAggregate<LogicalUnion>) afterRemove);
}

private boolean hasUnsuportedAggFunc(LogicalAggregate aggregate) {
// only support count, not suport sum,min... and not support count(distinct)
return ExpressionUtils.deapAnyMatch(aggregate.getOutputExpressions(), expr -> {
if (expr instanceof AggregateFunction) {
if (!(expr instanceof Count)) {
return true;
} else {
return ((Count) expr).isDistinct();
}
} else {
return false;
}
});
}

private boolean checkUnion(LogicalUnion union) {
if (union.getQualifier() != Qualifier.ALL) {
return false;
}
if (union.children() == null || union.children().isEmpty()) {
return false;
}
if (!union.getConstantExprsList().isEmpty()) {
return false;
}
return true;
}
}
Loading

0 comments on commit 65dd75a

Please sign in to comment.