Skip to content

Commit

Permalink
[fix](Nereids) producer to consumer should be multimap in cte (#39850) (
Browse files Browse the repository at this point in the history
#40544)

pick from master #39850

because consumer could refer multi times for one producer' slot, so
producer to consumer slot map should be multimap
  • Loading branch information
morrySnow authored Sep 9, 2024
1 parent f5f74dd commit 770e1fd
Show file tree
Hide file tree
Showing 9 changed files with 120 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Maps;
import com.google.common.collect.Multimap;
import org.apache.hadoop.util.Lists;

import java.util.ArrayList;
Expand Down Expand Up @@ -600,8 +601,8 @@ public boolean couldPruneColumnOnProducer(CTEId cteId) {
return consumerIds.size() == this.statementContext.getCteIdToConsumers().get(cteId).size();
}

public void addCTEConsumerGroup(CTEId cteId, Group g, Map<Slot, Slot> producerSlotToConsumerSlot) {
List<Pair<Map<Slot, Slot>, Group>> consumerGroups =
public void addCTEConsumerGroup(CTEId cteId, Group g, Multimap<Slot, Slot> producerSlotToConsumerSlot) {
List<Pair<Multimap<Slot, Slot>, Group>> consumerGroups =
this.statementContext.getCteIdToConsumerGroup().computeIfAbsent(cteId, k -> new ArrayList<>());
consumerGroups.add(Pair.of(producerSlotToConsumerSlot, g));
}
Expand All @@ -610,12 +611,18 @@ public void addCTEConsumerGroup(CTEId cteId, Group g, Map<Slot, Slot> producerSl
* Update CTE consumer group as producer's stats update
*/
public void updateConsumerStats(CTEId cteId, Statistics statistics) {
List<Pair<Map<Slot, Slot>, Group>> consumerGroups = this.statementContext.getCteIdToConsumerGroup().get(cteId);
for (Pair<Map<Slot, Slot>, Group> p : consumerGroups) {
Map<Slot, Slot> producerSlotToConsumerSlot = p.first;
List<Pair<Multimap<Slot, Slot>, Group>> consumerGroups
= this.statementContext.getCteIdToConsumerGroup().get(cteId);
for (Pair<Multimap<Slot, Slot>, Group> p : consumerGroups) {
Multimap<Slot, Slot> producerSlotToConsumerSlot = p.first;
Statistics updatedConsumerStats = new Statistics(statistics);
for (Entry<Expression, ColumnStatistic> entry : statistics.columnStatistics().entrySet()) {
updatedConsumerStats.addColumnStats(producerSlotToConsumerSlot.get(entry.getKey()), entry.getValue());
if (!(entry.getKey() instanceof Slot)) {
continue;
}
for (Slot consumer : producerSlotToConsumerSlot.get((Slot) entry.getKey())) {
updatedConsumerStats.addColumnStats(consumer, entry.getValue());
}
}
p.value().setStatistics(updatedConsumerStats);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Maps;
import com.google.common.collect.Multimap;
import com.google.common.collect.Sets;

import java.util.ArrayList;
Expand Down Expand Up @@ -92,7 +93,7 @@ public class StatementContext {
private final Map<RelationId, Set<Expression>> consumerIdToFilters = new HashMap<>();
private final Map<CTEId, Set<RelationId>> cteIdToConsumerUnderProjects = new HashMap<>();
// Used to update consumer's stats
private final Map<CTEId, List<Pair<Map<Slot, Slot>, Group>>> cteIdToConsumerGroup = new HashMap<>();
private final Map<CTEId, List<Pair<Multimap<Slot, Slot>, Group>>> cteIdToConsumerGroup = new HashMap<>();

private final Map<CTEId, LogicalPlan> rewrittenCteProducer = new HashMap<>();
private final Map<CTEId, LogicalPlan> rewrittenCteConsumer = new HashMap<>();
Expand Down Expand Up @@ -229,7 +230,7 @@ public Map<CTEId, Set<RelationId>> getCteIdToConsumerUnderProjects() {
return cteIdToConsumerUnderProjects;
}

public Map<CTEId, List<Pair<Map<Slot, Slot>, Group>>> getCteIdToConsumerGroup() {
public Map<CTEId, List<Pair<Multimap<Slot, Slot>, Group>>> getCteIdToConsumerGroup() {
return cteIdToConsumerGroup;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -987,9 +987,10 @@ public PlanFragment visitPhysicalCTEConsumer(PhysicalCTEConsumer cteConsumer,

// update expr to slot mapping
for (Slot producerSlot : cteProducer.getOutput()) {
Slot consumerSlot = cteConsumer.getProducerToConsumerSlotMap().get(producerSlot);
SlotRef slotRef = context.findSlotRef(producerSlot.getExprId());
context.addExprIdSlotRefPair(consumerSlot.getExprId(), slotRef);
for (Slot consumerSlot : cteConsumer.getProducerToConsumerSlotMap().get(producerSlot)) {
SlotRef slotRef = context.findSlotRef(producerSlot.getExprId());
context.addExprIdSlotRefPair(consumerSlot.getExprId(), slotRef);
}
}
return multiCastFragment;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,10 @@

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.LinkedHashMultimap;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Multimap;

import java.util.LinkedHashMap;
import java.util.List;
Expand Down Expand Up @@ -238,14 +240,15 @@ public Plan visitLogicalPartitionTopN(LogicalPartitionTopN<? extends Plan> parti
@Override
public Plan visitLogicalCTEConsumer(LogicalCTEConsumer cteConsumer, Map<ExprId, Slot> replaceMap) {
Map<Slot, Slot> consumerToProducerOutputMap = new LinkedHashMap<>();
Map<Slot, Slot> producerToConsumerOutputMap = new LinkedHashMap<>();
Multimap<Slot, Slot> producerToConsumerOutputMap = LinkedHashMultimap.create();
for (Slot producerOutputSlot : cteConsumer.getConsumerToProducerOutputMap().values()) {
Slot newProducerOutputSlot = updateExpression(producerOutputSlot, replaceMap);
Slot newConsumerOutputSlot = cteConsumer.getProducerToConsumerOutputMap().get(producerOutputSlot)
.withNullable(newProducerOutputSlot.nullable());
producerToConsumerOutputMap.put(newProducerOutputSlot, newConsumerOutputSlot);
consumerToProducerOutputMap.put(newConsumerOutputSlot, newProducerOutputSlot);
replaceMap.put(newConsumerOutputSlot.getExprId(), newConsumerOutputSlot);
for (Slot consumerOutputSlot : cteConsumer.getProducerToConsumerOutputMap().get(producerOutputSlot)) {
Slot newConsumerOutputSlot = consumerOutputSlot.withNullable(newProducerOutputSlot.nullable());
producerToConsumerOutputMap.put(newProducerOutputSlot, newConsumerOutputSlot);
consumerToProducerOutputMap.put(newConsumerOutputSlot, newProducerOutputSlot);
replaceMap.put(newConsumerOutputSlot.getExprId(), newConsumerOutputSlot);
}
}
return cteConsumer.withTwoMaps(consumerToProducerOutputMap, producerToConsumerOutputMap);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,13 @@ private Plan expandLeftAntiJoin(CascadesContext ctx,
ctx.putCTEIdToConsumer(left);
ctx.putCTEIdToConsumer(right);

Map<Slot, Slot> replaced = new HashMap<>(left.getProducerToConsumerOutputMap());
replaced.putAll(right.getProducerToConsumerOutputMap());
Map<Slot, Slot> replaced = new HashMap<>();
for (Map.Entry<Slot, Slot> entry : left.getConsumerToProducerOutputMap().entrySet()) {
replaced.put(entry.getValue(), entry.getKey());
}
for (Map.Entry<Slot, Slot> entry : right.getConsumerToProducerOutputMap().entrySet()) {
replaced.put(entry.getValue(), entry.getKey());
}
List<Expression> disjunctions = hashOtherConditions.first;
List<Expression> otherConditions = hashOtherConditions.second;
List<Expression> newOtherConditions = otherConditions.stream()
Expand All @@ -189,8 +194,13 @@ private Plan expandLeftAntiJoin(CascadesContext ctx,
LogicalCTEConsumer newRight = new LogicalCTEConsumer(
ctx.getStatementContext().getNextRelationId(), rightProducer.getCteId(), "", rightProducer);
ctx.putCTEIdToConsumer(newRight);
Map<Slot, Slot> newReplaced = new HashMap<>(left.getProducerToConsumerOutputMap());
newReplaced.putAll(newRight.getProducerToConsumerOutputMap());
Map<Slot, Slot> newReplaced = new HashMap<>();
for (Map.Entry<Slot, Slot> entry : left.getConsumerToProducerOutputMap().entrySet()) {
newReplaced.put(entry.getValue(), entry.getKey());
}
for (Map.Entry<Slot, Slot> entry : newRight.getConsumerToProducerOutputMap().entrySet()) {
newReplaced.put(entry.getValue(), entry.getKey());
}
newOtherConditions = otherConditions.stream()
.map(e -> e.rewriteUp(s -> newReplaced.containsKey(s) ? newReplaced.get(s) : s))
.collect(Collectors.toList());
Expand Down Expand Up @@ -246,8 +256,13 @@ private List<Plan> expandInnerJoin(CascadesContext ctx, Pair<List<Expression>,
ctx.putCTEIdToConsumer(right);

//rewrite conjuncts to replace the old slots with CTE slots
Map<Slot, Slot> replaced = new HashMap<>(left.getProducerToConsumerOutputMap());
replaced.putAll(right.getProducerToConsumerOutputMap());
Map<Slot, Slot> replaced = new HashMap<>();
for (Map.Entry<Slot, Slot> entry : left.getConsumerToProducerOutputMap().entrySet()) {
replaced.put(entry.getValue(), entry.getKey());
}
for (Map.Entry<Slot, Slot> entry : right.getConsumerToProducerOutputMap().entrySet()) {
replaced.put(entry.getValue(), entry.getKey());
}
List<Expression> hashCond = pair.first.stream()
.map(e -> e.rewriteUp(s -> replaced.containsKey(s) ? replaced.get(s) : s))
.collect(Collectors.toList());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.LinkedHashMultimap;
import com.google.common.collect.Multimap;

import java.util.LinkedHashMap;
import java.util.List;
Expand Down Expand Up @@ -448,7 +450,7 @@ public Plan visitLogicalCTEConsumer(LogicalCTEConsumer cteConsumer, DeepCopierCo
return context.getRelationReplaceMap().get(cteConsumer.getRelationId());
}
Map<Slot, Slot> consumerToProducerOutputMap = new LinkedHashMap<>();
Map<Slot, Slot> producerToConsumerOutputMap = new LinkedHashMap<>();
Multimap<Slot, Slot> producerToConsumerOutputMap = LinkedHashMultimap.create();
for (Slot consumerOutput : cteConsumer.getOutput()) {
Slot newOutput = (Slot) ExpressionDeepCopier.INSTANCE.deepCopy(consumerOutput, context);
consumerToProducerOutputMap.put(newOutput, cteConsumer.getProducerSlot(consumerOutput));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.doris.nereids.trees.expressions.CTEId;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.PlanType;
import org.apache.doris.nereids.trees.plans.RelationId;
Expand All @@ -30,8 +31,10 @@

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.Multimap;

import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
Expand All @@ -45,20 +48,15 @@ public class LogicalCTEConsumer extends LogicalRelation {
private final String name;
private final CTEId cteId;
private final Map<Slot, Slot> consumerToProducerOutputMap;
private final Map<Slot, Slot> producerToConsumerOutputMap;
private final Multimap<Slot, Slot> producerToConsumerOutputMap;

/**
* Logical CTE consumer.
*/
public LogicalCTEConsumer(RelationId relationId, CTEId cteId, String name,
Map<Slot, Slot> consumerToProducerOutputMap, Map<Slot, Slot> producerToConsumerOutputMap) {
super(relationId, PlanType.LOGICAL_CTE_CONSUMER, Optional.empty(), Optional.empty());
this.cteId = Objects.requireNonNull(cteId, "cteId should not null");
this.name = Objects.requireNonNull(name, "name should not null");
this.consumerToProducerOutputMap = Objects.requireNonNull(consumerToProducerOutputMap,
"consumerToProducerOutputMap should not null");
this.producerToConsumerOutputMap = Objects.requireNonNull(producerToConsumerOutputMap,
"producerToConsumerOutputMap should not null");
Map<Slot, Slot> consumerToProducerOutputMap, Multimap<Slot, Slot> producerToConsumerOutputMap) {
this(relationId, cteId, name, consumerToProducerOutputMap, producerToConsumerOutputMap,
Optional.empty(), Optional.empty());
}

/**
Expand All @@ -68,16 +66,23 @@ public LogicalCTEConsumer(RelationId relationId, CTEId cteId, String name, Logic
super(relationId, PlanType.LOGICAL_CTE_CONSUMER, Optional.empty(), Optional.empty());
this.cteId = Objects.requireNonNull(cteId, "cteId should not null");
this.name = Objects.requireNonNull(name, "name should not null");
this.consumerToProducerOutputMap = new LinkedHashMap<>();
this.producerToConsumerOutputMap = new LinkedHashMap<>();
initOutputMaps(producerPlan);
ImmutableMap.Builder<Slot, Slot> cToPBuilder = ImmutableMap.builder();
ImmutableMultimap.Builder<Slot, Slot> pToCBuilder = ImmutableMultimap.builder();
List<Slot> producerOutput = producerPlan.getOutput();
for (Slot producerOutputSlot : producerOutput) {
Slot consumerSlot = generateConsumerSlot(this.name, producerOutputSlot);
cToPBuilder.put(consumerSlot, producerOutputSlot);
pToCBuilder.put(producerOutputSlot, consumerSlot);
}
consumerToProducerOutputMap = cToPBuilder.build();
producerToConsumerOutputMap = pToCBuilder.build();
}

/**
* Logical CTE consumer.
*/
public LogicalCTEConsumer(RelationId relationId, CTEId cteId, String name,
Map<Slot, Slot> consumerToProducerOutputMap, Map<Slot, Slot> producerToConsumerOutputMap,
Map<Slot, Slot> consumerToProducerOutputMap, Multimap<Slot, Slot> producerToConsumerOutputMap,
Optional<GroupExpression> groupExpression, Optional<LogicalProperties> logicalProperties) {
super(relationId, PlanType.LOGICAL_CTE_CONSUMER, groupExpression, logicalProperties);
this.cteId = Objects.requireNonNull(cteId, "cteId should not null");
Expand All @@ -88,21 +93,24 @@ public LogicalCTEConsumer(RelationId relationId, CTEId cteId, String name,
"producerToConsumerOutputMap should not null");
}

private void initOutputMaps(LogicalPlan childPlan) {
List<Slot> producerOutput = childPlan.getOutput();
for (Slot producerOutputSlot : producerOutput) {
Slot consumerSlot = new SlotReference(producerOutputSlot.getName(),
producerOutputSlot.getDataType(), producerOutputSlot.nullable(), ImmutableList.of(name));
producerToConsumerOutputMap.put(producerOutputSlot, consumerSlot);
consumerToProducerOutputMap.put(consumerSlot, producerOutputSlot);
}
/**
* generate a consumer slot mapping from producer slot.
*/
public static SlotReference generateConsumerSlot(String cteName, Slot producerOutputSlot) {
SlotReference slotRef =
producerOutputSlot instanceof SlotReference ? (SlotReference) producerOutputSlot : null;
return new SlotReference(StatementScopeIdGenerator.newExprId(),
producerOutputSlot.getName(), producerOutputSlot.getDataType(),
producerOutputSlot.nullable(), ImmutableList.of(cteName),
slotRef != null ? (slotRef.getColumn().isPresent() ? slotRef.getColumn().get() : null) : null,
slotRef != null ? Optional.of(slotRef.getInternalName()) : Optional.empty());
}

public Map<Slot, Slot> getConsumerToProducerOutputMap() {
return consumerToProducerOutputMap;
}

public Map<Slot, Slot> getProducerToConsumerOutputMap() {
public Multimap<Slot, Slot> getProducerToConsumerOutputMap() {
return producerToConsumerOutputMap;
}

Expand All @@ -111,7 +119,8 @@ public <R, C> R accept(PlanVisitor<R, C> visitor, C context) {
return visitor.visitLogicalCTEConsumer(this, context);
}

public Plan withTwoMaps(Map<Slot, Slot> consumerToProducerOutputMap, Map<Slot, Slot> producerToConsumerOutputMap) {
public Plan withTwoMaps(Map<Slot, Slot> consumerToProducerOutputMap,
Multimap<Slot, Slot> producerToConsumerOutputMap) {
return new LogicalCTEConsumer(relationId, cteId, name,
consumerToProducerOutputMap, producerToConsumerOutputMap,
Optional.empty(), Optional.empty());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.Multimap;

import java.util.List;
import java.util.Map;
Expand All @@ -43,14 +45,14 @@
public class PhysicalCTEConsumer extends PhysicalRelation {

private final CTEId cteId;
private final Map<Slot, Slot> producerToConsumerSlotMap;
private final Map<Slot, Slot> consumerToProducerSlotMap;
private final Multimap<Slot, Slot> producerToConsumerSlotMap;

/**
* Constructor
*/
public PhysicalCTEConsumer(RelationId relationId, CTEId cteId, Map<Slot, Slot> consumerToProducerSlotMap,
Map<Slot, Slot> producerToConsumerSlotMap, LogicalProperties logicalProperties) {
Multimap<Slot, Slot> producerToConsumerSlotMap, LogicalProperties logicalProperties) {
this(relationId, cteId, consumerToProducerSlotMap, producerToConsumerSlotMap,
Optional.empty(), logicalProperties);
}
Expand All @@ -59,7 +61,7 @@ public PhysicalCTEConsumer(RelationId relationId, CTEId cteId, Map<Slot, Slot> c
* Constructor
*/
public PhysicalCTEConsumer(RelationId relationId, CTEId cteId,
Map<Slot, Slot> consumerToProducerSlotMap, Map<Slot, Slot> producerToConsumerSlotMap,
Map<Slot, Slot> consumerToProducerSlotMap, Multimap<Slot, Slot> producerToConsumerSlotMap,
Optional<GroupExpression> groupExpression, LogicalProperties logicalProperties) {
this(relationId, cteId, consumerToProducerSlotMap, producerToConsumerSlotMap,
groupExpression, logicalProperties, null, null);
Expand All @@ -69,22 +71,22 @@ public PhysicalCTEConsumer(RelationId relationId, CTEId cteId,
* Constructor
*/
public PhysicalCTEConsumer(RelationId relationId, CTEId cteId, Map<Slot, Slot> consumerToProducerSlotMap,
Map<Slot, Slot> producerToConsumerSlotMap, Optional<GroupExpression> groupExpression,
Multimap<Slot, Slot> producerToConsumerSlotMap, Optional<GroupExpression> groupExpression,
LogicalProperties logicalProperties, PhysicalProperties physicalProperties, Statistics statistics) {
super(relationId, PlanType.PHYSICAL_CTE_CONSUMER, groupExpression,
logicalProperties, physicalProperties, statistics);
this.cteId = cteId;
this.consumerToProducerSlotMap = ImmutableMap.copyOf(Objects.requireNonNull(
consumerToProducerSlotMap, "consumerToProducerSlotMap should not null"));
this.producerToConsumerSlotMap = ImmutableMap.copyOf(Objects.requireNonNull(
this.producerToConsumerSlotMap = ImmutableMultimap.copyOf(Objects.requireNonNull(
producerToConsumerSlotMap, "consumerToProducerSlotMap should not null"));
}

public CTEId getCteId() {
return cteId;
}

public Map<Slot, Slot> getProducerToConsumerSlotMap() {
public Multimap<Slot, Slot> getProducerToConsumerSlotMap() {
return producerToConsumerSlotMap;
}

Expand Down
Loading

0 comments on commit 770e1fd

Please sign in to comment.