From da86521f92073883886580d5788ffa6152fb7058 Mon Sep 17 00:00:00 2001 From: Gian Merlino Date: Mon, 7 Oct 2024 23:34:50 -0700 Subject: [PATCH] MSQ: Allow for worker gaps. In a Dart query, all Historicals are given worker IDs, but not all of them are going to actually be started or receive work orders. This can create gaps in the set of workers. For example, workers 1 and 3 could have work assigned while workers 0 and 2 do not. This patch updates ControllerStageTracker and WorkerInputs to handle such gaps, by using the set of actual worker numbers, rather than 0..workerCount, in various places. --- .../msq/input/stage/ReadablePartition.java | 12 ++ .../msq/input/stage/ReadablePartitions.java | 33 +++- .../SparseStripedReadablePartitions.java | 142 ++++++++++++++++++ .../controller/ControllerStageTracker.java | 26 ++-- .../msq/kernel/controller/WorkerInputs.java | 41 +++-- .../CollectedReadablePartitionsTest.java | 12 +- .../stage/CombinedReadablePartitionsTest.java | 2 +- .../SparseStripedReadablePartitionsTest.java | 98 ++++++++++++ .../stage/StripedReadablePartitionsTest.java | 34 ++++- .../kernel/controller/WorkerInputsTest.java | 98 +++++++++--- 10 files changed, 439 insertions(+), 59 deletions(-) create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/SparseStripedReadablePartitions.java create mode 100644 extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/SparseStripedReadablePartitionsTest.java diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/ReadablePartition.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/ReadablePartition.java index 99098d1d4cb3..5f366c60009b 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/ReadablePartition.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/ReadablePartition.java @@ -59,6 +59,18 @@ public static ReadablePartition striped(final int stageNumber, final int numWork return new ReadablePartition(stageNumber, workerNumbers, partitionNumber); } + /** + * Returns an output partition that is striped across a set of {@code workerNumbers}. + */ + public static ReadablePartition striped( + final int stageNumber, + final IntSortedSet workerNumbers, + final int partitionNumber + ) + { + return new ReadablePartition(stageNumber, workerNumbers, partitionNumber); + } + /** * Returns an output partition that has been collected onto a single worker. */ diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/ReadablePartitions.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/ReadablePartitions.java index a71535fbcfce..dcf0042f68b8 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/ReadablePartitions.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/ReadablePartitions.java @@ -24,6 +24,7 @@ import it.unimi.dsi.fastutil.ints.Int2IntAVLTreeMap; import it.unimi.dsi.fastutil.ints.Int2IntSortedMap; import it.unimi.dsi.fastutil.ints.IntAVLTreeSet; +import it.unimi.dsi.fastutil.ints.IntSortedSet; import java.util.Collections; import java.util.List; @@ -39,6 +40,7 @@ @JsonSubTypes(value = { @JsonSubTypes.Type(name = "collected", value = CollectedReadablePartitions.class), @JsonSubTypes.Type(name = "striped", value = StripedReadablePartitions.class), + @JsonSubTypes.Type(name = "sparseStriped", value = SparseStripedReadablePartitions.class), @JsonSubTypes.Type(name = "combined", value = CombinedReadablePartitions.class) }) public interface ReadablePartitions extends Iterable @@ -59,7 +61,7 @@ static ReadablePartitions empty() /** * Combines various sets of partitions into a single set. */ - static CombinedReadablePartitions combine(List readablePartitions) + static ReadablePartitions combine(List readablePartitions) { return new CombinedReadablePartitions(readablePartitions); } @@ -68,7 +70,7 @@ static CombinedReadablePartitions combine(List readableParti * Returns a set of {@code numPartitions} partitions striped across {@code numWorkers} workers: each worker contains * a "stripe" of each partition. */ - static StripedReadablePartitions striped( + static ReadablePartitions striped( final int stageNumber, final int numWorkers, final int numPartitions @@ -82,11 +84,36 @@ static StripedReadablePartitions striped( return new StripedReadablePartitions(stageNumber, numWorkers, partitionNumbers); } + /** + * Returns a set of {@code numPartitions} partitions striped across {@code workers}: each worker contains + * a "stripe" of each partition. + */ + static ReadablePartitions striped( + final int stageNumber, + final IntSortedSet workers, + final int numPartitions + ) + { + final IntAVLTreeSet partitionNumbers = new IntAVLTreeSet(); + for (int i = 0; i < numPartitions; i++) { + partitionNumbers.add(i); + } + + if (workers.lastInt() == workers.size() - 1) { + // Dense worker set. Use StripedReadablePartitions for compactness (send a single number rather than the + // entire worker set) and for backwards compatibility (older workers cannot understand + // SparseStripedReadablePartitions). + return new StripedReadablePartitions(stageNumber, workers.size(), partitionNumbers); + } else { + return new SparseStripedReadablePartitions(stageNumber, workers, partitionNumbers); + } + } + /** * Returns a set of partitions that have been collected onto specific workers: each partition is on exactly * one worker. */ - static CollectedReadablePartitions collected( + static ReadablePartitions collected( final int stageNumber, final Map partitionToWorkerMap ) diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/SparseStripedReadablePartitions.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/SparseStripedReadablePartitions.java new file mode 100644 index 000000000000..e9a02a7d4880 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/SparseStripedReadablePartitions.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.input.stage; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.Iterators; +import it.unimi.dsi.fastutil.ints.IntAVLTreeSet; +import it.unimi.dsi.fastutil.ints.IntSortedSet; +import org.apache.druid.msq.input.SlicerUtils; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Objects; +import java.util.Set; + +/** + * Set of partitions striped across a sparse set of {@code workers}. Each worker contains a "stripe" of each partition. + * + * @see StripedReadablePartitions dense version, where workers from [0..N) are all used. + */ +public class SparseStripedReadablePartitions implements ReadablePartitions +{ + private final int stageNumber; + private final IntSortedSet workers; + private final IntSortedSet partitionNumbers; + + /** + * Constructor. Most callers should use {@link ReadablePartitions#striped(int, int, int)} instead, which takes + * a partition count rather than a set of partition numbers. + */ + public SparseStripedReadablePartitions( + final int stageNumber, + final IntSortedSet workers, + final IntSortedSet partitionNumbers + ) + { + this.stageNumber = stageNumber; + this.workers = workers; + this.partitionNumbers = partitionNumbers; + } + + @JsonCreator + private SparseStripedReadablePartitions( + @JsonProperty("stageNumber") final int stageNumber, + @JsonProperty("workers") final Set workers, + @JsonProperty("partitionNumbers") final Set partitionNumbers + ) + { + this(stageNumber, new IntAVLTreeSet(workers), new IntAVLTreeSet(partitionNumbers)); + } + + @Override + public Iterator iterator() + { + return Iterators.transform( + partitionNumbers.iterator(), + partitionNumber -> ReadablePartition.striped(stageNumber, workers, partitionNumber) + ); + } + + @Override + public List split(final int maxNumSplits) + { + final List retVal = new ArrayList<>(); + + for (List entries : SlicerUtils.makeSlicesStatic(partitionNumbers.iterator(), maxNumSplits)) { + if (!entries.isEmpty()) { + retVal.add(new SparseStripedReadablePartitions(stageNumber, workers, new IntAVLTreeSet(entries))); + } + } + + return retVal; + } + + @JsonProperty + int getStageNumber() + { + return stageNumber; + } + + @JsonProperty + IntSortedSet getWorkers() + { + return workers; + } + + @JsonProperty + IntSortedSet getPartitionNumbers() + { + return partitionNumbers; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + SparseStripedReadablePartitions that = (SparseStripedReadablePartitions) o; + return stageNumber == that.stageNumber + && Objects.equals(workers, that.workers) + && Objects.equals(partitionNumbers, that.partitionNumbers); + } + + @Override + public int hashCode() + { + return Objects.hash(stageNumber, workers, partitionNumbers); + } + + @Override + public String toString() + { + return "StripedReadablePartitions{" + + "stageNumber=" + stageNumber + + ", workers=" + workers + + ", partitionNumbers=" + partitionNumbers + + '}'; + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerStageTracker.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerStageTracker.java index 338a35e0d244..533cb57b97fe 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerStageTracker.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerStageTracker.java @@ -403,7 +403,7 @@ void addPartialKeyInformationForWorker( throw new ISE("Stage does not gather result key statistics"); } - if (workerNumber < 0 || workerNumber >= workerCount) { + if (!workerInputs.workers().contains(workerNumber)) { throw new IAE("Invalid workerNumber [%s]", workerNumber); } @@ -522,7 +522,7 @@ void mergeClusterByStatisticsCollectorForTimeChunk( throw new ISE("Stage does not gather result key statistics"); } - if (workerNumber < 0 || workerNumber >= workerCount) { + if (!workerInputs.workers().contains(workerNumber)) { throw new IAE("Invalid workerNumber [%s]", workerNumber); } @@ -656,7 +656,7 @@ void mergeClusterByStatisticsCollectorForAllTimeChunks( throw new ISE("Stage does not gather result key statistics"); } - if (workerNumber < 0 || workerNumber >= workerCount) { + if (!workerInputs.workers().contains(workerNumber)) { throw new IAE("Invalid workerNumber [%s]", workerNumber); } @@ -763,7 +763,7 @@ void setClusterByPartitionBoundaries(ClusterByPartitions clusterByPartitions) this.resultPartitionBoundaries = clusterByPartitions; this.resultPartitions = ReadablePartitions.striped( stageDef.getStageNumber(), - workerCount, + workerInputs.workers(), clusterByPartitions.size() ); @@ -788,7 +788,7 @@ void setDoneReadingInputForWorker(final int workerNumber) throw DruidException.defensive("Cannot setDoneReadingInput for stage[%s], it is not sorting", stageDef.getId()); } - if (workerNumber < 0 || workerNumber >= workerCount) { + if (!workerInputs.workers().contains(workerNumber)) { throw new IAE("Invalid workerNumber[%s] for stage[%s]", workerNumber, stageDef.getId()); } @@ -830,7 +830,7 @@ void setDoneReadingInputForWorker(final int workerNumber) @SuppressWarnings("unchecked") boolean setResultsCompleteForWorker(final int workerNumber, final Object resultObject) { - if (workerNumber < 0 || workerNumber >= workerCount) { + if (!workerInputs.workers().contains(workerNumber)) { throw new IAE("Invalid workerNumber [%s]", workerNumber); } @@ -947,14 +947,18 @@ private void generateResultPartitionsAndBoundariesWithoutKeyStatistics() resultPartitionBoundaries = maybeResultPartitionBoundaries.valueOrThrow(); resultPartitions = ReadablePartitions.striped( stageNumber, - workerCount, + workerInputs.workers(), resultPartitionBoundaries.size() ); - } else if (shuffleSpec.kind() == ShuffleKind.MIX) { - resultPartitionBoundaries = ClusterByPartitions.oneUniversalPartition(); - resultPartitions = ReadablePartitions.striped(stageNumber, workerCount, shuffleSpec.partitionCount()); } else { - resultPartitions = ReadablePartitions.striped(stageNumber, workerCount, shuffleSpec.partitionCount()); + if (shuffleSpec.kind() == ShuffleKind.MIX) { + resultPartitionBoundaries = ClusterByPartitions.oneUniversalPartition(); + } + resultPartitions = ReadablePartitions.striped( + stageNumber, + workerInputs.workers(), + shuffleSpec.partitionCount() + ); } } else { // No reshuffling: retain partitioning from nonbroadcast inputs. diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/WorkerInputs.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/WorkerInputs.java index 83d7a602bc10..8dcaee9c213a 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/WorkerInputs.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/WorkerInputs.java @@ -24,7 +24,9 @@ import it.unimi.dsi.fastutil.ints.Int2IntMap; import it.unimi.dsi.fastutil.ints.Int2ObjectAVLTreeMap; import it.unimi.dsi.fastutil.ints.Int2ObjectMap; -import it.unimi.dsi.fastutil.ints.IntSet; +import it.unimi.dsi.fastutil.ints.Int2ObjectSortedMap; +import it.unimi.dsi.fastutil.ints.IntAVLTreeSet; +import it.unimi.dsi.fastutil.ints.IntSortedSet; import it.unimi.dsi.fastutil.objects.ObjectIterator; import org.apache.druid.msq.input.InputSlice; import org.apache.druid.msq.input.InputSpec; @@ -45,9 +47,9 @@ public class WorkerInputs { // Worker number -> input number -> input slice. - private final Int2ObjectMap> assignmentsMap; + private final Int2ObjectSortedMap> assignmentsMap; - private WorkerInputs(final Int2ObjectMap> assignmentsMap) + private WorkerInputs(final Int2ObjectSortedMap> assignmentsMap) { this.assignmentsMap = assignmentsMap; } @@ -64,7 +66,7 @@ public static WorkerInputs create( ) { // Split each inputSpec and assign to workers. This list maps worker number -> input number -> input slice. - final Int2ObjectMap> assignmentsMap = new Int2ObjectAVLTreeMap<>(); + final Int2ObjectSortedMap> assignmentsMap = new Int2ObjectAVLTreeMap<>(); final int numInputs = stageDef.getInputSpecs().size(); if (numInputs == 0) { @@ -117,8 +119,8 @@ public static WorkerInputs create( final ObjectIterator>> assignmentsIterator = assignmentsMap.int2ObjectEntrySet().iterator(); + final IntSortedSet nilWorkers = new IntAVLTreeSet(); - boolean first = true; while (assignmentsIterator.hasNext()) { final Int2ObjectMap.Entry> entry = assignmentsIterator.next(); final List slices = entry.getValue(); @@ -130,20 +132,29 @@ public static WorkerInputs create( } } - // Eliminate workers that have no non-nil, non-broadcast inputs. (Except the first one, because if all input - // is nil, *some* worker has to do *something*.) - final boolean hasNonNilNonBroadcastInput = + // Identify nil workers (workers with no non-broadcast inputs). + final boolean isNilWorker = IntStream.range(0, numInputs) - .anyMatch(i -> - !slices.get(i).equals(NilInputSlice.INSTANCE) // Non-nil - && !stageDef.getBroadcastInputNumbers().contains(i) // Non-broadcast + .allMatch(i -> + slices.get(i).equals(NilInputSlice.INSTANCE) // Nil regular input + || stageDef.getBroadcastInputNumbers().contains(i) // Broadcast ); - if (!first && !hasNonNilNonBroadcastInput) { - assignmentsIterator.remove(); + if (isNilWorker) { + nilWorkers.add(entry.getIntKey()); } + } - first = false; + if (nilWorkers.size() == assignmentsMap.size()) { + // All workers have nil regular inputs. Remove all workers exept the first (*some* worker has to do *something*). + final List firstSlices = assignmentsMap.get(nilWorkers.firstInt()); + assignmentsMap.clear(); + assignmentsMap.put(nilWorkers.firstInt(), firstSlices); + } else { + // Remove all nil workers. + for (final int nilWorker : nilWorkers) { + assignmentsMap.remove(nilWorker); + } } return new WorkerInputs(assignmentsMap); @@ -154,7 +165,7 @@ public List inputsForWorker(final int workerNumber) return Preconditions.checkNotNull(assignmentsMap.get(workerNumber), "worker [%s]", workerNumber); } - public IntSet workers() + public IntSortedSet workers() { return assignmentsMap.keySet(); } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/CollectedReadablePartitionsTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/CollectedReadablePartitionsTest.java index 6ed7d2d43d4f..d4db7a0a7c5f 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/CollectedReadablePartitionsTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/CollectedReadablePartitionsTest.java @@ -33,21 +33,24 @@ public class CollectedReadablePartitionsTest @Test public void testPartitionToWorkerMap() { - final CollectedReadablePartitions partitions = ReadablePartitions.collected(1, ImmutableMap.of(0, 1, 1, 2, 2, 1)); + final CollectedReadablePartitions partitions = + (CollectedReadablePartitions) ReadablePartitions.collected(1, ImmutableMap.of(0, 1, 1, 2, 2, 1)); Assert.assertEquals(ImmutableMap.of(0, 1, 1, 2, 2, 1), partitions.getPartitionToWorkerMap()); } @Test public void testStageNumber() { - final CollectedReadablePartitions partitions = ReadablePartitions.collected(1, ImmutableMap.of(0, 1, 1, 2, 2, 1)); + final CollectedReadablePartitions partitions = + (CollectedReadablePartitions) ReadablePartitions.collected(1, ImmutableMap.of(0, 1, 1, 2, 2, 1)); Assert.assertEquals(1, partitions.getStageNumber()); } @Test public void testSplit() { - final CollectedReadablePartitions partitions = ReadablePartitions.collected(1, ImmutableMap.of(0, 1, 1, 2, 2, 1)); + final CollectedReadablePartitions partitions = + (CollectedReadablePartitions) ReadablePartitions.collected(1, ImmutableMap.of(0, 1, 1, 2, 2, 1)); Assert.assertEquals( ImmutableList.of( @@ -64,7 +67,8 @@ public void testSerde() throws Exception final ObjectMapper mapper = TestHelper.makeJsonMapper() .registerModules(new MSQIndexingModule().getJacksonModules()); - final CollectedReadablePartitions partitions = ReadablePartitions.collected(1, ImmutableMap.of(0, 1, 1, 2, 2, 1)); + final CollectedReadablePartitions partitions = + (CollectedReadablePartitions) ReadablePartitions.collected(1, ImmutableMap.of(0, 1, 1, 2, 2, 1)); Assert.assertEquals( partitions, diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/CombinedReadablePartitionsTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/CombinedReadablePartitionsTest.java index 685f4ff7a8ab..16bd047b6240 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/CombinedReadablePartitionsTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/CombinedReadablePartitionsTest.java @@ -31,7 +31,7 @@ public class CombinedReadablePartitionsTest { - private static final CombinedReadablePartitions PARTITIONS = ReadablePartitions.combine( + private static final ReadablePartitions PARTITIONS = ReadablePartitions.combine( ImmutableList.of( ReadablePartitions.striped(0, 2, 2), ReadablePartitions.striped(1, 2, 4) diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/SparseStripedReadablePartitionsTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/SparseStripedReadablePartitionsTest.java new file mode 100644 index 000000000000..5268fd601809 --- /dev/null +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/SparseStripedReadablePartitionsTest.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.input.stage; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import it.unimi.dsi.fastutil.ints.IntAVLTreeSet; +import it.unimi.dsi.fastutil.ints.IntSet; +import nl.jqno.equalsverifier.EqualsVerifier; +import org.apache.druid.msq.guice.MSQIndexingModule; +import org.apache.druid.segment.TestHelper; +import org.junit.Assert; +import org.junit.Test; + +public class SparseStripedReadablePartitionsTest +{ + @Test + public void testPartitionNumbers() + { + final SparseStripedReadablePartitions partitions = + (SparseStripedReadablePartitions) ReadablePartitions.striped(1, new IntAVLTreeSet(new int[]{1, 3}), 3); + Assert.assertEquals(ImmutableSet.of(0, 1, 2), partitions.getPartitionNumbers()); + } + + @Test + public void testWorkers() + { + final SparseStripedReadablePartitions partitions = + (SparseStripedReadablePartitions) ReadablePartitions.striped(1, new IntAVLTreeSet(new int[]{1, 3}), 3); + Assert.assertEquals(IntSet.of(1, 3), partitions.getWorkers()); + } + + @Test + public void testStageNumber() + { + final SparseStripedReadablePartitions partitions = + (SparseStripedReadablePartitions) ReadablePartitions.striped(1, new IntAVLTreeSet(new int[]{1, 3}), 3); + Assert.assertEquals(1, partitions.getStageNumber()); + } + + @Test + public void testSplit() + { + final IntAVLTreeSet workers = new IntAVLTreeSet(new int[]{1, 3}); + final SparseStripedReadablePartitions partitions = + (SparseStripedReadablePartitions) ReadablePartitions.striped(1, workers, 3); + + Assert.assertEquals( + ImmutableList.of( + new SparseStripedReadablePartitions(1, workers, new IntAVLTreeSet(new int[]{0, 2})), + new SparseStripedReadablePartitions(1, workers, new IntAVLTreeSet(new int[]{1})) + ), + partitions.split(2) + ); + } + + @Test + public void testSerde() throws Exception + { + final ObjectMapper mapper = TestHelper.makeJsonMapper() + .registerModules(new MSQIndexingModule().getJacksonModules()); + + final IntAVLTreeSet workers = new IntAVLTreeSet(new int[]{1, 3}); + final ReadablePartitions partitions = ReadablePartitions.striped(1, workers, 3); + + Assert.assertEquals( + partitions, + mapper.readValue( + mapper.writeValueAsString(partitions), + ReadablePartitions.class + ) + ); + } + + @Test + public void testEquals() + { + EqualsVerifier.forClass(SparseStripedReadablePartitions.class).usingGetClass().verify(); + } +} diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/StripedReadablePartitionsTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/StripedReadablePartitionsTest.java index 38e0707f5d0e..05b42b332502 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/StripedReadablePartitionsTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/StripedReadablePartitionsTest.java @@ -26,36 +26,60 @@ import nl.jqno.equalsverifier.EqualsVerifier; import org.apache.druid.msq.guice.MSQIndexingModule; import org.apache.druid.segment.TestHelper; +import org.hamcrest.CoreMatchers; +import org.hamcrest.MatcherAssert; import org.junit.Assert; import org.junit.Test; public class StripedReadablePartitionsTest { + @Test + public void testFromDenseSet() + { + // Tests that when ReadablePartitions.striped is called with a dense set, we get StripedReadablePartitions. + + final IntAVLTreeSet workers = new IntAVLTreeSet(); + workers.add(0); + workers.add(1); + + final ReadablePartitions readablePartitionsFromSet = ReadablePartitions.striped(1, workers, 3); + + MatcherAssert.assertThat( + readablePartitionsFromSet, + CoreMatchers.instanceOf(StripedReadablePartitions.class) + ); + + Assert.assertEquals( + ReadablePartitions.striped(1, 2, 3), + readablePartitionsFromSet + ); + } + @Test public void testPartitionNumbers() { - final StripedReadablePartitions partitions = ReadablePartitions.striped(1, 2, 3); + final StripedReadablePartitions partitions = (StripedReadablePartitions) ReadablePartitions.striped(1, 2, 3); Assert.assertEquals(ImmutableSet.of(0, 1, 2), partitions.getPartitionNumbers()); } @Test public void testNumWorkers() { - final StripedReadablePartitions partitions = ReadablePartitions.striped(1, 2, 3); + final StripedReadablePartitions partitions = (StripedReadablePartitions) ReadablePartitions.striped(1, 2, 3); Assert.assertEquals(2, partitions.getNumWorkers()); } @Test public void testStageNumber() { - final StripedReadablePartitions partitions = ReadablePartitions.striped(1, 2, 3); + final StripedReadablePartitions partitions = (StripedReadablePartitions) ReadablePartitions.striped(1, 2, 3); Assert.assertEquals(1, partitions.getStageNumber()); } @Test public void testSplit() { - final StripedReadablePartitions partitions = ReadablePartitions.striped(1, 2, 3); + final ReadablePartitions partitions = ReadablePartitions.striped(1, 2, 3); Assert.assertEquals( ImmutableList.of( @@ -72,7 +96,7 @@ public void testSerde() throws Exception final ObjectMapper mapper = TestHelper.makeJsonMapper() .registerModules(new MSQIndexingModule().getJacksonModules()); - final StripedReadablePartitions partitions = ReadablePartitions.striped(1, 2, 3); + final ReadablePartitions partitions = ReadablePartitions.striped(1, 2, 3); Assert.assertEquals( partitions, diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/WorkerInputsTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/WorkerInputsTest.java index 605e0bf2de74..e74125b08301 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/WorkerInputsTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/WorkerInputsTest.java @@ -25,9 +25,11 @@ import it.unimi.dsi.fastutil.ints.Int2ObjectAVLTreeMap; import it.unimi.dsi.fastutil.ints.IntAVLTreeSet; import it.unimi.dsi.fastutil.ints.IntSet; +import it.unimi.dsi.fastutil.ints.IntSortedSet; import it.unimi.dsi.fastutil.longs.LongArrayList; import it.unimi.dsi.fastutil.longs.LongList; import nl.jqno.equalsverifier.EqualsVerifier; +import org.apache.druid.error.DruidException; import org.apache.druid.msq.exec.Limits; import org.apache.druid.msq.exec.OutputChannelMode; import org.apache.druid.msq.input.InputSlice; @@ -75,7 +77,7 @@ public void test_max_threeInputs_fourWorkers() final WorkerInputs inputs = WorkerInputs.create( stageDef, Int2IntMaps.EMPTY_MAP, - new TestInputSpecSlicer(true), + new TestInputSpecSlicer(denseWorkers(4), true), WorkerAssignmentStrategy.MAX, Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER ); @@ -91,6 +93,35 @@ public void test_max_threeInputs_fourWorkers() ); } + @Test + public void test_max_threeInputs_fourWorkers_withGaps() + { + final StageDefinition stageDef = + StageDefinition.builder(0) + .inputs(new TestInputSpec(1, 2, 3)) + .maxWorkerCount(4) + .processorFactory(new OffsetLimitFrameProcessorFactory(0, 0L)) + .build(QUERY_ID); + + final WorkerInputs inputs = WorkerInputs.create( + stageDef, + Int2IntMaps.EMPTY_MAP, + new TestInputSpecSlicer(new IntAVLTreeSet(new int[]{1, 3, 4, 5}), true), + WorkerAssignmentStrategy.MAX, + Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER + ); + + Assert.assertEquals( + ImmutableMap.>builder() + .put(1, Collections.singletonList(new TestInputSlice(1))) + .put(3, Collections.singletonList(new TestInputSlice(2))) + .put(4, Collections.singletonList(new TestInputSlice(3))) + .put(5, Collections.singletonList(new TestInputSlice())) + .build(), + inputs.assignmentsMap() + ); + } + @Test public void test_max_zeroInputs_fourWorkers() { @@ -104,7 +135,7 @@ public void test_max_zeroInputs_fourWorkers() final WorkerInputs inputs = WorkerInputs.create( stageDef, Int2IntMaps.EMPTY_MAP, - new TestInputSpecSlicer(true), + new TestInputSpecSlicer(denseWorkers(4), true), WorkerAssignmentStrategy.MAX, Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER ); @@ -133,7 +164,7 @@ public void test_auto_zeroInputSpecs_fourWorkers() final WorkerInputs inputs = WorkerInputs.create( stageDef, Int2IntMaps.EMPTY_MAP, - new TestInputSpecSlicer(true), + new TestInputSpecSlicer(denseWorkers(4), true), WorkerAssignmentStrategy.AUTO, Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER ); @@ -159,7 +190,7 @@ public void test_auto_zeroInputSlices_fourWorkers() final WorkerInputs inputs = WorkerInputs.create( stageDef, Int2IntMaps.EMPTY_MAP, - new TestInputSpecSlicer(true), + new TestInputSpecSlicer(denseWorkers(4), true), WorkerAssignmentStrategy.AUTO, Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER ); @@ -186,7 +217,7 @@ public void test_auto_zeroInputSlices_broadcast_fourWorkers() final WorkerInputs inputs = WorkerInputs.create( stageDef, Int2IntMaps.EMPTY_MAP, - new TestInputSpecSlicer(true), + new TestInputSpecSlicer(denseWorkers(4), true), WorkerAssignmentStrategy.AUTO, Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER ); @@ -212,7 +243,7 @@ public void test_auto_threeInputs_fourWorkers() final WorkerInputs inputs = WorkerInputs.create( stageDef, Int2IntMaps.EMPTY_MAP, - new TestInputSpecSlicer(true), + new TestInputSpecSlicer(denseWorkers(4), true), WorkerAssignmentStrategy.AUTO, Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER ); @@ -324,7 +355,7 @@ public void test_auto_threeBigInputs_fourWorkers() final WorkerInputs inputs = WorkerInputs.create( stageDef, Int2IntMaps.EMPTY_MAP, - new TestInputSpecSlicer(true), + new TestInputSpecSlicer(denseWorkers(4), true), WorkerAssignmentStrategy.AUTO, Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER ); @@ -351,7 +382,7 @@ public void test_auto_tenSmallAndOneBigInputs_twoWorkers() final WorkerInputs inputs = WorkerInputs.create( stageDef, Int2IntMaps.EMPTY_MAP, - new TestInputSpecSlicer(true), + new TestInputSpecSlicer(denseWorkers(2), true), WorkerAssignmentStrategy.AUTO, Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER ); @@ -384,7 +415,7 @@ public void test_auto_threeBigInputs_oneWorker() final WorkerInputs inputs = WorkerInputs.create( stageDef, Int2IntMaps.EMPTY_MAP, - new TestInputSpecSlicer(true), + new TestInputSpecSlicer(denseWorkers(1), true), WorkerAssignmentStrategy.AUTO, Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER ); @@ -411,7 +442,7 @@ public void test_max_shouldAlwaysSplitStatic() .processorFactory(new OffsetLimitFrameProcessorFactory(0, 0L)) .build(QUERY_ID); - TestInputSpecSlicer testInputSpecSlicer = spy(new TestInputSpecSlicer(true)); + TestInputSpecSlicer testInputSpecSlicer = spy(new TestInputSpecSlicer(denseWorkers(3), true)); final WorkerInputs inputs = WorkerInputs.create( stageDef, @@ -455,7 +486,7 @@ public void test_auto_shouldSplitDynamicIfPossible() .processorFactory(new OffsetLimitFrameProcessorFactory(0, 0L)) .build(QUERY_ID); - TestInputSpecSlicer testInputSpecSlicer = spy(new TestInputSpecSlicer(true)); + TestInputSpecSlicer testInputSpecSlicer = spy(new TestInputSpecSlicer(denseWorkers(3), true)); final WorkerInputs inputs = WorkerInputs.create( stageDef, @@ -498,7 +529,7 @@ public void test_auto_shouldUseLeastWorkersPossible() .processorFactory(new OffsetLimitFrameProcessorFactory(0, 0L)) .build(QUERY_ID); - TestInputSpecSlicer testInputSpecSlicer = spy(new TestInputSpecSlicer(true)); + TestInputSpecSlicer testInputSpecSlicer = spy(new TestInputSpecSlicer(denseWorkers(3), true)); final WorkerInputs inputs = WorkerInputs.create( stageDef, @@ -585,11 +616,23 @@ public String toString() private static class TestInputSpecSlicer implements InputSpecSlicer { + private final IntSortedSet workers; private final boolean canSliceDynamic; - public TestInputSpecSlicer(boolean canSliceDynamic) + /** + * Create a test slicer. + * + * @param workers Set of workers to consider assigning work to. + * @param canSliceDynamic Whether this slicer can slice dynamically. + */ + public TestInputSpecSlicer(final IntSortedSet workers, final boolean canSliceDynamic) { + this.workers = workers; this.canSliceDynamic = canSliceDynamic; + + if (workers.isEmpty()) { + throw DruidException.defensive("Need more than one worker in workers[%s]", workers); + } } @Override @@ -606,9 +649,9 @@ public List sliceStatic(InputSpec inputSpec, int maxNumSlices) SlicerUtils.makeSlicesStatic( testInputSpec.values.iterator(), i -> i, - maxNumSlices + Math.min(maxNumSlices, workers.size()) ); - return makeSlices(assignments); + return makeSlices(workers, assignments); } @Override @@ -624,24 +667,39 @@ public List sliceDynamic( SlicerUtils.makeSlicesDynamic( testInputSpec.values.iterator(), i -> i, - maxNumSlices, + Math.min(maxNumSlices, workers.size()), maxFilesPerSlice, maxBytesPerSlice ); - return makeSlices(assignments); + return makeSlices(workers, assignments); } private static List makeSlices( + final IntSortedSet workers, final List> assignments ) { final List retVal = new ArrayList<>(assignments.size()); - - for (final List assignment : assignments) { - retVal.add(new TestInputSlice(new LongArrayList(assignment))); + for (int assignment = 0, workerNumber = 0; + workerNumber <= workers.lastInt() && assignment < assignments.size(); + workerNumber++) { + if (workers.contains(workerNumber)) { + retVal.add(new TestInputSlice(new LongArrayList(assignments.get(assignment++)))); + } else { + retVal.add(NilInputSlice.INSTANCE); + } } return retVal; } } + + private static IntSortedSet denseWorkers(final int numWorkers) + { + final IntAVLTreeSet workers = new IntAVLTreeSet(); + for (int i = 0; i < numWorkers; i++) { + workers.add(i); + } + return workers; + } }