Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix](nereids) refine window child's local shuffle dist-expr #39384

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,8 @@ public Expr visitNot(Not not, PlanTranslatorContext context) {

@Override
public Expr visitSlotReference(SlotReference slotReference, PlanTranslatorContext context) {
return context.findSlotRef(slotReference.getExprId());
return context.getCloneExprIdToSlot() == null ? context.findSlotRef(slotReference.getExprId())
: context.findCloneSlotRef(slotReference.getExprId());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2122,6 +2122,8 @@ public PlanFragment visitPhysicalQuickSort(PhysicalQuickSort<? extends Plan> sor
PlanTranslatorContext context) {
PlanFragment inputFragment = sort.child(0).accept(this, context);
List<List<Expr>> distributeExprLists = getDistributeExprs(sort.child(0));
// 1. Backup current plan to exprIdToSlotRef map
context.addPlanToExprIdSlotRefMap(sort);

// 2. According to the type of sort, generate physical plan
if (!sort.getSortPhase().isMerge()) {
Expand Down Expand Up @@ -2279,8 +2281,9 @@ public PlanFragment visitPhysicalRepeat(PhysicalRepeat<? extends Plan> repeat, P
@Override
public PlanFragment visitPhysicalWindow(PhysicalWindow<? extends Plan> physicalWindow,
PlanTranslatorContext context) {
PlanFragment inputPlanFragment = physicalWindow.child(0).accept(this, context);
List<List<Expr>> distributeExprLists = getDistributeExprs(physicalWindow.child(0));
Plan childPlan = physicalWindow.child(0);
PlanFragment inputPlanFragment = childPlan.accept(this, context);
List<List<Expr>> distributeExprLists = getDistributeExprs(childPlan);

// 1. translate to old optimizer variable
// variable in Nereids
Expand Down Expand Up @@ -2322,6 +2325,26 @@ public PlanFragment visitPhysicalWindow(PhysicalWindow<? extends Plan> physicalW
// analytic window
AnalyticWindow analyticWindow = physicalWindow.translateWindowFrame(windowFrame, context);

// refresh inputPlanFragment's distributeExprLists of window operator
// to obtain better local shuffle distribution.
List<List<Expr>> newChildDistributeExprLists = Lists.newArrayList();
if (!partitionExprs.isEmpty() && inputPlanFragment.getPlanRoot() != null
&& inputPlanFragment.getPlanRoot() instanceof SortNode
&& !inputPlanFragment.getPlanRoot().getChildrenDistributeExprLists().isEmpty()) {
// safety consideration for those already has valid children distribute expr lists setting only
// current op tree only has two patterns, one is the window with sort child, and another is two phase
// global partition topn child, and the latter is no need to refresh its distribution expr list since
// it's expected to be the same as window's, for the former pattern, it is the real candidate.
if (context.findExprIdToSlotRefFromMap(childPlan)) {
List<Expr> newPartitionExprs = partitionKeyList.stream()
.map(e -> ExpressionTranslator.translate(e, context))
.collect(Collectors.toList());
newChildDistributeExprLists.add(newPartitionExprs);
inputPlanFragment.getPlanRoot().setChildrenDistributeExprLists(newChildDistributeExprLists);
context.resetCloneExprIdToSlot();
}
}

// 2. get bufferedTupleDesc from SortNode and compute isNullableMatched
Map<ExprId, SlotRef> bufferedSlotRefForWindow = getBufferedSlotRefForWindow(windowFrameGroup, context);
TupleDescriptor bufferedTupleDesc = context.getBufferedTupleForWindow();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.VirtualSlotReference;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.RelationId;
import org.apache.doris.nereids.trees.plans.physical.PhysicalCTEConsumer;
import org.apache.doris.nereids.trees.plans.physical.PhysicalCTEProducer;
Expand Down Expand Up @@ -78,6 +79,10 @@ public class PlanTranslatorContext {
*/
private final Map<ExprId, SlotRef> exprIdToSlotRef = Maps.newHashMap();

private final Map<Plan, Map<ExprId, SlotRef>> clonePlanToExprIdToSlotRefMap = Maps.newHashMap();

private Map<ExprId, SlotRef> cloneExprIdToSlot = null;

/**
* Inverted index from legacy slot to Nereids' slot.
*/
Expand Down Expand Up @@ -216,6 +221,14 @@ public void addExprIdSlotRefPair(ExprId exprId, SlotRef slotRef) {
slotIdToExprId.put(slotRef.getDesc().getId(), exprId);
}

public void addPlanToExprIdSlotRefMap(Plan plan) {
Map<ExprId, SlotRef> cloneExprIdToSlotRef = Maps.newHashMap();
for (Map.Entry<ExprId, SlotRef> entry : exprIdToSlotRef.entrySet()) {
cloneExprIdToSlotRef.put(entry.getKey(), (SlotRef) entry.getValue().clone());
}
clonePlanToExprIdToSlotRefMap.put(plan, cloneExprIdToSlotRef);
}

public void addExprIdColumnRefPair(ExprId exprId, ColumnRefExpr columnRefExpr) {
exprIdToColumnRef.put(exprId, columnRefExpr);
}
Expand All @@ -236,6 +249,22 @@ public SlotRef findSlotRef(ExprId exprId) {
return exprIdToSlotRef.get(exprId);
}

public SlotRef findCloneSlotRef(ExprId exprId) {
return cloneExprIdToSlot.get(exprId);
}

public boolean findExprIdToSlotRefFromMap(Plan plan) {
return (cloneExprIdToSlot = clonePlanToExprIdToSlotRefMap.get(plan)) != null;
}

public Map<ExprId, SlotRef> getCloneExprIdToSlot() {
return cloneExprIdToSlot;
}

public void resetCloneExprIdToSlot() {
cloneExprIdToSlot = null;
}

public ColumnRefExpr findColumnRef(ExprId exprId) {
return exprIdToColumnRef.get(exprId);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1231,6 +1231,10 @@ public void setPushDownAggNoGrouping(TPushAggOp pushDownAggNoGroupingOp) {
this.pushDownAggNoGroupingOp = pushDownAggNoGroupingOp;
}

public List<List<Expr>> getChildrenDistributeExprLists() {
return this.childrenDistributeExprLists;
}

public void setChildrenDistributeExprLists(List<List<Expr>> childrenDistributeExprLists) {
this.childrenDistributeExprLists = childrenDistributeExprLists;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

suite("window_child_distribution_expr") {
multi_sql """
drop table if exists baseall;
drop table if exists test;
CREATE TABLE IF NOT EXISTS `baseall` (
`k1` tinyint(4) null comment "",
`k2` smallint(6) null comment "",
`k3` int(11) null comment "",
`k4` bigint(20) null comment ""
) engine=olap
DISTRIBUTED BY HASH(`k1`) BUCKETS 3 properties("replication_num" = "1");

CREATE TABLE IF NOT EXISTS `test` (
`k1` tinyint(4) null comment "",
`k2` smallint(6) null comment "",
`k3` int(11) null comment ""
) engine=olap
DISTRIBUTED BY HASH(`k1`) BUCKETS 3 properties("replication_num" = "1");

insert into baseall values (1,1,1,1);
insert into baseall values (2,2,2,2);
insert into baseall values (3,3,3,3);
insert into test values (1,1,1);
insert into test values (2,2,2);
insert into test values (3,3,3);

set enable_nereids_distribute_planner=true;
set enable_pipeline_x_engine=true;
set disable_join_reorder=true;
set enable_local_shuffle=true;
set force_to_local_shuffle=true;
"""
explain {
sql """
select * from (select t2.k2, row_number() over (partition by t1.k1, t1.k2 order by t1.k3) rn from baseall t1 join test t2 on t1.k1=t2.k1) tmp;
"""
contains "distribute expr lists: k1[#13], k2[#14]"
}
explain {
sql """
select * from (select t2.k2, row_number() over (partition by t1.k1, t1.k2 order by t1.k3) rn from baseall t1 join test t2 on t1.k1=t2.k1) tmp where rn <=1;
"""
contains "distribute expr lists: k1[#17], k2[#18]"
}
explain {
sql """
select * from (select t2.k2, row_number() over (partition by t1.k2, t1.k3 order by t1.k4) rn from baseall t1 join test t2 on t1.k1=t2.k1) tmp;
"""
contains "distribute expr lists: k2[#14], k3[#15]"
}
explain {
sql """
select * from (select t2.k2, row_number() over (partition by t1.k2, t1.k3 order by t1.k4) rn from baseall t1 join test t2 on t1.k1=t2.k1) tmp where rn <=1;
"""
contains "distribute expr lists: k2[#18], k3[#19]"
}
explain {
sql """
select * from (select t2.k2, row_number() over (partition by t1.k2, t1.k3 order by t1.k4) rn from baseall t1 join test t2 on t1.k2=t2.k2) tmp;
"""
contains "distribute expr lists: k2[#11], k3[#12]"
}
explain {
sql """
select * from (select t2.k2, row_number() over (partition by t1.k2, t1.k3 order by t1.k4) rn from baseall t1 join test t2 on t1.k2=t2.k2) tmp where rn <=1;
"""
contains "distribute expr lists: k2[#15], k3[#16]"
}
explain {
sql """
select * from (select t2.k2, row_number() over (partition by t1.k1 order by t1.k3) rn from baseall t1 join test t2 on t1.k2=t2.k2) tmp;
"""
contains "distribute expr lists: k1[#12]"
}
explain {
sql """
select * from (select t2.k2, row_number() over (partition by t1.k1 order by t1.k3) rn from baseall t1 join test t2 on t1.k2=t2.k2) tmp where rn <=1;
"""
contains "distribute expr lists: k1[#15]"
}
}
Loading