diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/insert/OlapInsertExecutor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/insert/OlapInsertExecutor.java index 0153700863d3be..43a3327a378884 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/insert/OlapInsertExecutor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/insert/OlapInsertExecutor.java @@ -40,6 +40,7 @@ import org.apache.doris.planner.DataSink; import org.apache.doris.planner.DataStreamSink; import org.apache.doris.planner.ExchangeNode; +import org.apache.doris.planner.MultiCastDataSink; import org.apache.doris.planner.OlapTableSink; import org.apache.doris.planner.PlanFragment; import org.apache.doris.qe.ConnectContext; @@ -139,7 +140,28 @@ public void finalizeSink(PlanFragment fragment, DataSink sink, PhysicalSink phys // set schema and partition info for tablet id shuffle exchange if (fragment.getPlanRoot() instanceof ExchangeNode && fragment.getDataPartition().getType() == TPartitionType.TABLET_SINK_SHUFFLE_PARTITIONED) { - DataStreamSink dataStreamSink = (DataStreamSink) (fragment.getChild(0).getSink()); + DataSink childFragmentSink = fragment.getChild(0).getSink(); + DataStreamSink dataStreamSink = null; + if (childFragmentSink instanceof MultiCastDataSink) { + MultiCastDataSink multiCastDataSink = (MultiCastDataSink) childFragmentSink; + int outputExchangeId = (fragment.getPlanRoot()).getId().asInt(); + // which DataStreamSink link to the output exchangeNode? + for (DataStreamSink currentDataStreamSink : multiCastDataSink.getDataStreamSinks()) { + int sinkExchangeId = currentDataStreamSink.getExchNodeId().asInt(); + if (outputExchangeId == sinkExchangeId) { + dataStreamSink = currentDataStreamSink; + break; + } + } + if (dataStreamSink == null) { + throw new IllegalStateException("Can not find DataStreamSink in the MultiCastDataSink"); + } + } else if (childFragmentSink instanceof DataStreamSink) { + dataStreamSink = (DataStreamSink) childFragmentSink; + } else { + throw new IllegalStateException("Unsupported DataSink: " + childFragmentSink); + } + Analyzer analyzer = new Analyzer(Env.getCurrentEnv(), ConnectContext.get()); dataStreamSink.setTabletSinkSchemaParam(olapTableSink.createSchema( database.getId(), olapTableSink.getDstTable(), analyzer)); diff --git a/regression-test/suites/insert_p0/insert.groovy b/regression-test/suites/insert_p0/insert.groovy index 573d5d8366c6d6..83a1a472781c3d 100644 --- a/regression-test/suites/insert_p0/insert.groovy +++ b/regression-test/suites/insert_p0/insert.groovy @@ -83,4 +83,33 @@ suite("insert") { sql "sync" qt_insert """ select * from mutable_datatype order by c_bigint, c_double, c_string, c_date, c_timestamp, c_boolean, c_short_decimal""" + + multi_sql """ + drop table if exists table_select_test1; + CREATE TABLE table_select_test1 ( + `id` int + ) + distributed by hash(id) + properties('replication_num'='1'); + + insert into table_select_test1 values(2); + + drop table if exists table_test_insert1; + create table table_test_insert1 (id int) + partition by range(id) + ( + partition p1 values[('1'), ('50')), + partition p2 values[('50'), ('100')) + ) + distributed by hash(id) buckets 100 + properties('replication_num'='1') + + insert into table_test_insert1 values(1), (10); + + insert into table_test_insert1 + with + a as (select * from table_select_test1 where id > 10), + b as (select * from a union all select * from a) + select id from b; + """ }