Skip to content

Commit

Permalink
[feat](Nereids) support outer join and aggregate bitmap rewrite by mv (
Browse files Browse the repository at this point in the history
…apache#28596)

- Support left outer join rewrite by materialized view
- Support bitmap_union roll up to imp count(distinct)
- Support partition materialized view rewrite
  • Loading branch information
seawinde authored and HappenLee committed Jan 12, 2024
1 parent 1511cc7 commit 9bab5b2
Show file tree
Hide file tree
Showing 42 changed files with 3,666 additions and 440 deletions.
10 changes: 4 additions & 6 deletions fe/fe-core/src/main/java/org/apache/doris/catalog/MTMV.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.apache.doris.mtmv.MTMVRelation;
import org.apache.doris.mtmv.MTMVStatus;
import org.apache.doris.persist.gson.GsonUtils;
import org.apache.doris.qe.ConnectContext;

import com.google.common.collect.Sets;
import com.google.gson.annotations.SerializedName;
Expand Down Expand Up @@ -128,10 +129,6 @@ public MTMVRelation getRelation() {
return relation;
}

public MTMVCache getCache() {
return cache;
}

public void setCache(MTMVCache cache) {
this.cache = cache;
}
Expand Down Expand Up @@ -202,12 +199,13 @@ public Set<String> getExcludedTriggerTables() {
return Sets.newHashSet(split);
}

public MTMVCache getOrGenerateCache() throws AnalysisException {
// this should use the same connectContext with query, to use the same session variable
public MTMVCache getOrGenerateCache(ConnectContext parent) throws AnalysisException {
if (cache == null) {
writeMvLock();
try {
if (cache == null) {
this.cache = MTMVCache.from(this, MTMVPlanUtil.createMTMVContext(this));
this.cache = MTMVCache.from(this, parent);
}
} finally {
writeMvUnlock();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ public static MTMVCache from(MTMV mtmv, ConnectContext connectContext) {
? (Plan) ((LogicalResultSink) mvRewrittenPlan).child() : mvRewrittenPlan;
// use rewritten plan output expression currently, if expression rewrite fail,
// consider to use the analyzed plan for output expressions only
List<NamedExpression> mvOutputExpressions = mvRewrittenPlan.getExpressions().stream()
List<NamedExpression> mvOutputExpressions = mvPlan.getExpressions().stream()
.map(NamedExpression.class::cast)
.collect(Collectors.toList());
return new MTMVCache(mvPlan, mvOutputExpressions);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.apache.doris.nereids.trees.plans.commands.info.TableNameInfo;
import org.apache.doris.persist.AlterMTMV;

import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import org.apache.commons.collections.CollectionUtils;
Expand All @@ -49,7 +50,7 @@ public class MTMVRelationManager implements MTMVHookService {
private Map<BaseTableInfo, Set<BaseTableInfo>> tableMTMVs = Maps.newConcurrentMap();

public Set<BaseTableInfo> getMtmvsByBaseTable(BaseTableInfo table) {
return tableMTMVs.get(table);
return tableMTMVs.getOrDefault(table, ImmutableSet.of());
}

public Set<MTMV> getAvailableMTMVs(List<BaseTableInfo> tableInfos) {
Expand Down
4 changes: 2 additions & 2 deletions fe/fe-core/src/main/java/org/apache/doris/mtmv/MTMVUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ public static Collection<Partition> getMTMVCanRewritePartitions(MTMV mtmv, Conne
List<Partition> res = Lists.newArrayList();
Collection<Partition> allPartitions = mtmv.getPartitions();
// check session variable if enable rewrite
if (!ctx.getSessionVariable().isEnableMvRewrite()) {
if (!ctx.getSessionVariable().isEnableMaterializedViewRewrite()) {
return res;
}
MTMVRelation mtmvRelation = mtmv.getRelation();
Expand Down Expand Up @@ -438,7 +438,7 @@ private static long getTableMinVisibleVersionTime(OlapTable table) {
* @param relatedTable
* @return mv.partitionId ==> relatedTable.partitionId
*/
private static Map<Long, Set<Long>> getMvToBasePartitions(MTMV mtmv, OlapTable relatedTable)
public static Map<Long, Set<Long>> getMvToBasePartitions(MTMV mtmv, OlapTable relatedTable)
throws AnalysisException {
HashMap<Long, Set<Long>> res = Maps.newHashMap();
Map<Long, PartitionItem> relatedTableItems = relatedTable.getPartitionInfo().getIdToItem(false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,9 @@ public void setOuterScope(@Nullable Scope outerScope) {
}

public List<MaterializationContext> getMaterializationContexts() {
return materializationContexts;
return materializationContexts.stream()
.filter(MaterializationContext::isAvailable)
.collect(Collectors.toList());
}

public void addMaterializationContext(MaterializationContext materializationContext) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
* @see PlaceholderCollector
*/
public class PlaceholderExpression extends Expression implements AlwaysNotNullable {

private final Class<? extends Expression> delegateClazz;
/**
* 1 based
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@
import org.apache.doris.nereids.rules.exploration.join.SemiJoinSemiJoinTranspose;
import org.apache.doris.nereids.rules.exploration.join.SemiJoinSemiJoinTransposeProject;
import org.apache.doris.nereids.rules.exploration.mv.MaterializedViewAggregateRule;
import org.apache.doris.nereids.rules.exploration.mv.MaterializedViewFilterJoinRule;
import org.apache.doris.nereids.rules.exploration.mv.MaterializedViewFilterProjectJoinRule;
import org.apache.doris.nereids.rules.exploration.mv.MaterializedViewOnlyJoinRule;
import org.apache.doris.nereids.rules.exploration.mv.MaterializedViewProjectAggregateRule;
import org.apache.doris.nereids.rules.exploration.mv.MaterializedViewProjectFilterJoinRule;
import org.apache.doris.nereids.rules.exploration.mv.MaterializedViewProjectJoinRule;
import org.apache.doris.nereids.rules.implementation.AggregateStrategies;
import org.apache.doris.nereids.rules.implementation.LogicalAssertNumRowsToPhysicalAssertNumRows;
Expand Down Expand Up @@ -222,7 +226,11 @@ public class RuleSet {
.build();

public static final List<Rule> MATERIALIZED_VIEW_RULES = planRuleFactories()
.add(MaterializedViewOnlyJoinRule.INSTANCE)
.add(MaterializedViewProjectJoinRule.INSTANCE)
.add(MaterializedViewFilterJoinRule.INSTANCE)
.add(MaterializedViewFilterProjectJoinRule.INSTANCE)
.add(MaterializedViewProjectFilterJoinRule.INSTANCE)
.add(MaterializedViewAggregateRule.INSTANCE)
.add(MaterializedViewProjectAggregateRule.INSTANCE)
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,16 @@
import org.apache.doris.nereids.rules.exploration.mv.StructInfo.PlanSplitContext;
import org.apache.doris.nereids.rules.exploration.mv.mapping.ExpressionMapping;
import org.apache.doris.nereids.rules.exploration.mv.mapping.SlotMapping;
import org.apache.doris.nereids.trees.expressions.Any;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.Function;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnion;
import org.apache.doris.nereids.trees.expressions.functions.agg.CouldRollUp;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
Expand All @@ -39,8 +43,11 @@
import com.google.common.collect.HashMultimap;
import com.google.common.collect.Multimap;
import com.google.common.collect.Sets;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
Expand All @@ -53,6 +60,17 @@
*/
public abstract class AbstractMaterializedViewAggregateRule extends AbstractMaterializedViewRule {

protected static final Map<Expression, Expression>
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP = new HashMap<>();
protected final String currentClassName = this.getClass().getSimpleName();

private final Logger logger = LogManager.getLogger(this.getClass());

static {
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new Count(true, Any.INSTANCE),
new BitmapUnion(Any.INSTANCE));
}

@Override
protected Plan rewriteQueryByView(MatchMode matchMode,
StructInfo queryStructInfo,
Expand All @@ -63,10 +81,12 @@ protected Plan rewriteQueryByView(MatchMode matchMode,
// get view and query aggregate and top plan correspondingly
Pair<Plan, LogicalAggregate<Plan>> viewTopPlanAndAggPair = splitToTopPlanAndAggregate(viewStructInfo);
if (viewTopPlanAndAggPair == null) {
logger.warn(currentClassName + " split to view to top plan and agg fail so return null");
return null;
}
Pair<Plan, LogicalAggregate<Plan>> queryTopPlanAndAggPair = splitToTopPlanAndAggregate(queryStructInfo);
if (queryTopPlanAndAggPair == null) {
logger.warn(currentClassName + " split to query to top plan and agg fail so return null");
return null;
}
// Firstly, handle query group by expression rewrite
Expand All @@ -88,13 +108,14 @@ protected Plan rewriteQueryByView(MatchMode matchMode,
needRollUp = !queryGroupShuttledExpression.equals(viewGroupShuttledExpression);
}
if (!needRollUp) {
List<Expression> rewrittenQueryGroupExpr = rewriteExpression(queryTopPlan.getOutput(),
List<Expression> rewrittenQueryGroupExpr = rewriteExpression(queryTopPlan.getExpressions(),
queryTopPlan,
materializationContext.getMvExprToMvScanExprMapping(),
queryToViewSlotMapping,
true);
if (rewrittenQueryGroupExpr == null) {
if (rewrittenQueryGroupExpr.isEmpty()) {
// can not rewrite, bail out.
logger.debug(currentClassName + " can not rewrite expression when not need roll up");
return null;
}
return new LogicalProject<>(
Expand All @@ -109,12 +130,14 @@ protected Plan rewriteQueryByView(MatchMode matchMode,
viewExpr -> viewExpr.anyMatch(expr -> expr instanceof AggregateFunction
&& ((AggregateFunction) expr).isDistinct()))) {
// if mv aggregate function contains distinct, can not roll up, bail out.
logger.debug(currentClassName + " view contains distinct function so can not roll up");
return null;
}
// split the query top plan expressions to group expressions and functions, if can not, bail out.
Pair<Set<? extends Expression>, Set<? extends Expression>> queryGroupAndFunctionPair
= topPlanSplitToGroupAndFunction(queryTopPlanAndAggPair);
if (queryGroupAndFunctionPair == null) {
logger.warn(currentClassName + " query top plan split to group by and function fail so return null");
return null;
}
// Secondly, try to roll up the agg functions
Expand All @@ -132,30 +155,27 @@ protected Plan rewriteQueryByView(MatchMode matchMode,
for (Expression topExpression : queryTopPlan.getExpressions()) {
// is agg function, try to roll up and rewrite
if (queryTopPlanFunctionSet.contains(topExpression)) {
Expression needRollupShuttledExpr = ExpressionUtils.shuttleExpressionWithLineage(
Expression queryFunctionShuttled = ExpressionUtils.shuttleExpressionWithLineage(
topExpression,
queryTopPlan);
if (!mvExprToMvScanExprQueryBased.containsKey(needRollupShuttledExpr)) {
// function can not rewrite by view
return null;
}
// try to roll up
AggregateFunction needRollupAggFunction = (AggregateFunction) topExpression.firstMatch(
AggregateFunction queryFunction = (AggregateFunction) topExpression.firstMatch(
expr -> expr instanceof AggregateFunction);
AggregateFunction rollupAggregateFunction = rollup(needRollupAggFunction,
mvExprToMvScanExprQueryBased.get(needRollupShuttledExpr));
Function rollupAggregateFunction = rollup(queryFunction, queryFunctionShuttled,
mvExprToMvScanExprQueryBased);
if (rollupAggregateFunction == null) {
return null;
}
// key is query need roll up expr, value is mv scan based roll up expr
needRollupExprMap.put(needRollupShuttledExpr, rollupAggregateFunction);
needRollupExprMap.put(queryFunctionShuttled, rollupAggregateFunction);
// rewrite query function expression by mv expression
Expression rewrittenFunctionExpression = rewriteExpression(topExpression,
queryTopPlan,
new ExpressionMapping(needRollupExprMap),
queryToViewSlotMapping,
false);
if (rewrittenFunctionExpression == null) {
logger.debug(currentClassName + " roll up expression can not rewrite by view so return null");
return null;
}
finalAggregateExpressions.add((NamedExpression) rewrittenFunctionExpression);
Expand All @@ -165,6 +185,8 @@ protected Plan rewriteQueryByView(MatchMode matchMode,
ExpressionUtils.shuttleExpressionWithLineage(topExpression, queryTopPlan);
if (!mvExprToMvScanExprQueryBased.containsKey(queryGroupShuttledExpr)) {
// group expr can not rewrite by view
logger.debug(currentClassName
+ " view group expressions can not contains the query group by expression so return null");
return null;
}
groupRewrittenExprMap.put(queryGroupShuttledExpr,
Expand All @@ -177,6 +199,8 @@ protected Plan rewriteQueryByView(MatchMode matchMode,
queryToViewSlotMapping,
true);
if (rewrittenGroupExpression == null) {
logger.debug(currentClassName
+ " query top expression can not be rewritten by view so return null");
return null;
}
finalAggregateExpressions.add((NamedExpression) rewrittenGroupExpression);
Expand Down Expand Up @@ -226,17 +250,33 @@ protected Plan rewriteQueryByView(MatchMode matchMode,
}

// only support sum roll up, support other agg functions later.
private AggregateFunction rollup(AggregateFunction originFunction,
Expression mappedExpression) {
Class<? extends AggregateFunction> rollupAggregateFunction = originFunction.getRollup();
if (rollupAggregateFunction == null) {
private Function rollup(AggregateFunction queryFunction,
Expression queryFunctionShuttled,
Map<Expression, Expression> mvExprToMvScanExprQueryBased) {
if (!(queryFunction instanceof CouldRollUp)) {
return null;
}
if (Sum.class.isAssignableFrom(rollupAggregateFunction)) {
return new Sum(originFunction.isDistinct(), mappedExpression);
Expression rollupParam = null;
if (mvExprToMvScanExprQueryBased.containsKey(queryFunctionShuttled)) {
// function can rewrite by view
rollupParam = mvExprToMvScanExprQueryBased.get(queryFunctionShuttled);
} else {
// function can not rewrite by view, try to use complex roll up param
// eg: query is count(distinct param), mv sql is bitmap_union(to_bitmap(param))
for (Expression mvExprShuttled : mvExprToMvScanExprQueryBased.keySet()) {
if (!(mvExprShuttled instanceof Function)) {
continue;
}
if (isAggregateFunctionEquivalent(queryFunction, (Function) mvExprShuttled)) {
rollupParam = mvExprToMvScanExprQueryBased.get(mvExprShuttled);
}
}
}
// can rollup return null
return null;
if (rollupParam == null) {
return null;
}
// do roll up
return ((CouldRollUp) queryFunction).constructRollUp(rollupParam);
}

private Pair<Set<? extends Expression>, Set<? extends Expression>> topPlanSplitToGroupAndFunction(
Expand Down Expand Up @@ -306,4 +346,23 @@ protected boolean checkPattern(StructInfo structInfo) {
}
return true;
}

private boolean isAggregateFunctionEquivalent(Function queryFunction, Function viewFunction) {
if (queryFunction.equals(viewFunction)) {
return true;
}
// get query equivalent function
Expression equivalentFunction = null;
for (Map.Entry<Expression, Expression> entry : AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.entrySet()) {
if (entry.getKey().equals(queryFunction)) {
equivalentFunction = entry.getValue();
}
}
// check is have equivalent function or not
if (equivalentFunction == null) {
return false;
}
// current compare
return equivalentFunction.equals(viewFunction);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.util.List;
import java.util.stream.Collectors;

Expand All @@ -35,6 +38,10 @@
* This is responsible for common join rewriting
*/
public abstract class AbstractMaterializedViewJoinRule extends AbstractMaterializedViewRule {

protected final String currentClassName = this.getClass().getSimpleName();
private final Logger logger = LogManager.getLogger(this.getClass());

@Override
protected Plan rewriteQueryByView(MatchMode matchMode,
StructInfo queryStructInfo,
Expand All @@ -53,6 +60,7 @@ protected Plan rewriteQueryByView(MatchMode matchMode,
// Can not rewrite, bail out
if (expressionsRewritten.isEmpty()
|| expressionsRewritten.stream().anyMatch(expr -> !(expr instanceof NamedExpression))) {
logger.warn(currentClassName + " expression to rewrite is not named expr so return null");
return null;
}
// record the group id in materializationContext, and when rewrite again in
Expand Down
Loading

0 comments on commit 9bab5b2

Please sign in to comment.