diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java index e920036247c782..d09204403f9edd 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java @@ -190,10 +190,14 @@ private LogicalPlan normalizeAgg(LogicalAggregate aggregate, Optional allPushDownExprs = - Sets.union(groupingByExprs, Sets.union(needPushSelf, needPushInputSlots)); - NormalizeToSlotContext bottomSlotContext = - NormalizeToSlotContext.buildContext(existsAlias, allPushDownExprs); + // We need to distinguish between expressions in aggregate function arguments and group by expressions. + NormalizeToSlotContext groupByExprContext = NormalizeToSlotContext.buildContext(existsAlias, groupingByExprs); + Set existsAliasAndGroupByAlias = getExistsAlias(existsAlias, groupByExprContext.getNormalizeToSlotMap()); + Set argsOfAggFuncNeedPushDown = Sets.union(needPushSelf, needPushInputSlots); + NormalizeToSlotContext argsOfAggFuncNeedPushDownContext = NormalizeToSlotContext + .buildContext(existsAliasAndGroupByAlias, argsOfAggFuncNeedPushDown); + NormalizeToSlotContext bottomSlotContext = argsOfAggFuncNeedPushDownContext.mergeContext(groupByExprContext); + Set pushedGroupByExprs = bottomSlotContext.pushDownToNamedExpression(groupingByExprs); Set pushedTrivialAggChildren = @@ -256,8 +260,12 @@ private LogicalPlan normalizeAgg(LogicalAggregate aggregate, Optional upperProjects = normalizeOutput(aggregateOutput, - bottomSlotContext, normalizedAggFuncsToSlotContext); + groupByExprContext, argsOfAggFuncNeedPushDownContext, normalizedAggFuncsToSlotContext); // create a parent project node LogicalProject project = new LogicalProject<>(upperProjects, newAggregate); @@ -302,11 +310,18 @@ private LogicalPlan normalizeAgg(LogicalAggregate aggregate, Optional normalizeOutput(List aggregateOutput, - NormalizeToSlotContext groupByToSlotContext, NormalizeToSlotContext normalizedAggFuncsToSlotContext) { + NormalizeToSlotContext groupByToSlotContext, NormalizeToSlotContext argsOfAggFuncNeedPushDownContext, + NormalizeToSlotContext normalizedAggFuncsToSlotContext) { // build upper project, use two context to do pop up, because agg output maybe contain two part: - // group by keys and agg expressions - List upperProjects = groupByToSlotContext - .normalizeToUseSlotRefWithoutWindowFunction(aggregateOutput); + // group by keys and agg expressions + List upperProjects = new ArrayList<>(); + for (Expression expr : aggregateOutput) { + Expression rewrittenExpr = expr.rewriteDownShortCircuit( + e -> normalizeAggFuncChildren( + argsOfAggFuncNeedPushDownContext, e)); + upperProjects.add((NamedExpression) rewrittenExpr); + } + upperProjects = groupByToSlotContext.normalizeToUseSlotRefWithoutWindowFunction(upperProjects); upperProjects = normalizedAggFuncsToSlotContext.normalizeToUseSlotRefWithoutWindowFunction(upperProjects); Builder builder = new ImmutableList.Builder<>(); @@ -338,4 +353,28 @@ private List collectAllUsedSlots(List expressions) { slots.addAll(ExpressionUtils.getInputSlotSet(expressions)); return slots; } + + private Set getExistsAlias(Set originAliases, + Map groupingExprMap) { + Set existsAlias = Sets.newHashSet(); + existsAlias.addAll(originAliases); + for (NormalizeToSlotTriplet triplet : groupingExprMap.values()) { + if (triplet.pushedExpr instanceof Alias) { + Alias alias = (Alias) triplet.pushedExpr; + existsAlias.add(alias); + } + } + return existsAlias; + } + + private Expression normalizeAggFuncChildren(NormalizeToSlotContext context, Expression expr) { + if (expr instanceof AggregateFunction) { + AggregateFunction function = (AggregateFunction) expr; + List normalizedRealExpressions = context.normalizeToUseSlotRef(function.getArguments()); + function = function.withChildren(normalizedRealExpressions); + return function; + } else { + return expr; + } + } } diff --git a/regression-test/data/nereids_rules_p0/normalize_aggregate/normalize_aggregate_test.out b/regression-test/data/nereids_rules_p0/normalize_aggregate/normalize_aggregate_test.out new file mode 100644 index 00000000000000..860f5ff8f630f2 --- /dev/null +++ b/regression-test/data/nereids_rules_p0/normalize_aggregate/normalize_aggregate_test.out @@ -0,0 +1,6 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !test_upper_project_projections_rewrite -- +4094 \N + +-- !test_upper_project_projections_rewrite2 -- + diff --git a/regression-test/suites/nereids_rules_p0/normalize_aggregate/normalize_aggregate_test.groovy b/regression-test/suites/nereids_rules_p0/normalize_aggregate/normalize_aggregate_test.groovy new file mode 100644 index 00000000000000..1bf5e07c9698c0 --- /dev/null +++ b/regression-test/suites/nereids_rules_p0/normalize_aggregate/normalize_aggregate_test.groovy @@ -0,0 +1,32 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +suite("normalize_aggregate") { + sql "SET enable_nereids_planner=true" + sql "SET enable_fallback_to_original_planner=false" + qt_test_upper_project_projections_rewrite """ + SELECT DISTINCT + + ( ( + + 46 ) ) * 89 AS col0, COUNT( * ) + + - 72 + - - 87 - AVG ( ALL - 56 ) * COUNT( * ) + - CASE + 49 WHEN 6 * + 76 + - + + CAST( NULL AS SIGNED ) THEN NULL WHEN - COUNT( DISTINCT + + CAST( NULL AS SIGNED ) ) + 23 THEN NULL ELSE - + 43 * 32 - + 97 + - ( + 65 ) * + + + CASE - 77 WHEN 5 THEN - 56 * + 26 ELSE NULL END / + COUNT( * ) + 20 + + 78 END * COALESCE ( COUNT( * ), - 60 - 90, + 42 * 27 - 98 * ( - 83 + 47 / 7 ), + - ( NULLIF ( 61, 83 + 88 ) ) ) * 94; + """ + sql "drop table if exists normalize_aggregate_tab" + sql """CREATE TABLE normalize_aggregate_tab(col0 INTEGER, col1 INTEGER, col2 INTEGER) distributed by hash(col0) buckets 10 + properties('replication_num' = '1'); """ + qt_test_upper_project_projections_rewrite2 """ + SELECT - + AVG ( DISTINCT - col0 ) * - col0 FROM + normalize_aggregate_tab WHERE + - col0 IS NULL GROUP BY col0 HAVING NULL IS NULL;""" +} \ No newline at end of file