diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Controller.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Controller.java index f04286dd7c420..d2370b057935b 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Controller.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Controller.java @@ -117,9 +117,9 @@ void resultsComplete( ); /** - * Returns the current list of task ids, ordered by worker number. The Nth task has worker number N. + * Returns the current list of worker IDs, ordered by worker number. The Nth worker has worker number N. */ - List getTaskIds(); + List getWorkerIds(); @Nullable TaskReport.ReportMap liveReports(); diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerClient.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerClient.java index 405ff4fb9026f..428ce59cd8fac 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerClient.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerClient.java @@ -23,20 +23,25 @@ import org.apache.druid.msq.indexing.error.MSQErrorReport; import org.apache.druid.msq.kernel.StageDefinition; import org.apache.druid.msq.kernel.StageId; +import org.apache.druid.msq.kernel.WorkOrder; import org.apache.druid.msq.statistics.PartialKeyStatisticsInformation; import javax.annotation.Nullable; +import java.io.Closeable; import java.io.IOException; import java.util.List; /** - * Client for the multi-stage query controller. Used by a Worker task. + * Client for the multi-stage query controller. Used by a {@link Worker}. Each instance is specific to a single query, + * meaning it communicates with a single controller. */ -public interface ControllerClient extends AutoCloseable +public interface ControllerClient extends Closeable { /** - * Client side method to update the controller with partial key statistics information for a particular stage and worker. - * Controller's implementation collates all the information for a stage to fetch key statistics from workers. + * Client side method to update the controller with partial key statistics information for a particular stage + * and worker. The controller collates all the information for a stage to fetch key statistics from workers. + * + * Only used when {@link StageDefinition#mustGatherResultKeyStatistics()}. */ void postPartialKeyStatistics( StageId stageId, @@ -86,11 +91,16 @@ void postWorkerError( /** * Client side method to inform the controller about the warnings generated by the given worker. */ - void postWorkerWarning( - List MSQErrorReports - ) throws IOException; + void postWorkerWarning(List MSQErrorReports) throws IOException; - List getTaskList() throws IOException; + /** + * Client side method for retrieving the list of worker IDs from the controller. These IDs can be passed to + * {@link WorkerClient} methods to communicate with other workers. Not necessary when the {@link WorkOrder} has + * {@link WorkOrder#getWorkerIds()} set. + * + * @see Controller#getWorkerIds() for the controller side + */ + List getWorkerIds() throws IOException; /** * Close this client. Idempotent. diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java index d2d5cc657e6ae..839839db4e42b 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java @@ -1171,7 +1171,7 @@ private List generateSegmentIdsWithShardSpecsForReplace( } @Override - public List getTaskIds() + public List getWorkerIds() { if (workerManager == null) { return Collections.emptyList(); @@ -1260,7 +1260,7 @@ private void contactWorkersForStage( { // Sorted copy of target worker numbers to ensure consistent iteration order. final List workersCopy = Ordering.natural().sortedCopy(workers); - final List workerIds = getTaskIds(); + final List workerIds = getWorkerIds(); final List> workerFutures = new ArrayList<>(workersCopy.size()); try { @@ -1488,7 +1488,7 @@ private List findIntervalsToDrop(final Set publishedSegme private CounterSnapshotsTree getCountersFromAllTasks() { final CounterSnapshotsTree retVal = new CounterSnapshotsTree(); - final List taskList = getTaskIds(); + final List taskList = getWorkerIds(); final List> futures = new ArrayList<>(); @@ -1508,7 +1508,7 @@ private CounterSnapshotsTree getCountersFromAllTasks() private void postFinishToAllTasks() { - final List taskList = getTaskIds(); + final List taskList = getWorkerIds(); final List> futures = new ArrayList<>(); @@ -2963,7 +2963,7 @@ private void startQueryResultsReader() } final StageId finalStageId = queryKernel.getStageId(queryDef.getFinalStageDefinition().getStageNumber()); - final List taskIds = getTaskIds(); + final List taskIds = getWorkerIds(); final InputChannelFactory inputChannelFactory; diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerMemoryParameters.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerMemoryParameters.java index 8e6fc72b6aa72..2ab016e10e486 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerMemoryParameters.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerMemoryParameters.java @@ -91,7 +91,8 @@ public static ControllerMemoryParameters createProductionInstance( memoryIntrospector.totalMemoryInJvm(), usableMemoryInJvm, numControllersInJvm, - memoryIntrospector.numProcessorsInJvm() + memoryIntrospector.numProcessorsInJvm(), + 0 ) ); } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ListeningOutputChannelFactory.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ListeningOutputChannelFactory.java new file mode 100644 index 0000000000000..ebaad07638723 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ListeningOutputChannelFactory.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.exec; + +import org.apache.druid.frame.processor.OutputChannel; +import org.apache.druid.frame.processor.OutputChannelFactory; +import org.apache.druid.frame.processor.PartitionedOutputChannel; + +import java.io.IOException; + +/** + * Decorator for {@link OutputChannelFactory} that notifies a {@link Listener} whenever a channel is opened. + */ +public class ListeningOutputChannelFactory implements OutputChannelFactory +{ + private final OutputChannelFactory delegate; + private final Listener listener; + + public ListeningOutputChannelFactory(final OutputChannelFactory delegate, final Listener listener) + { + this.delegate = delegate; + this.listener = listener; + } + + @Override + public OutputChannel openChannel(final int partitionNumber) throws IOException + { + return notifyListener(delegate.openChannel(partitionNumber)); + } + + + @Override + public OutputChannel openNilChannel(final int partitionNumber) + { + return notifyListener(delegate.openNilChannel(partitionNumber)); + } + + @Override + public PartitionedOutputChannel openPartitionedChannel( + final String name, + final boolean deleteAfterRead + ) + { + throw new UnsupportedOperationException("Listening to partitioned channels is not supported"); + } + + private OutputChannel notifyListener(OutputChannel channel) + { + listener.channelOpened(channel); + return channel; + } + + public interface Listener + { + void channelOpened(OutputChannel channel); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/OutputChannelMode.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/OutputChannelMode.java index 7e7fc3d3d6f3d..f42d558a76ce6 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/OutputChannelMode.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/OutputChannelMode.java @@ -32,9 +32,12 @@ public enum OutputChannelMode { /** - * In-memory output channels. Stage shuffle data does not hit disk. This mode requires a consumer stage to run - * at the same time as its corresponding producer stage. See {@link ControllerQueryKernelUtils#computeStageGroups} for the - * logic that determines when we can use in-memory channels. + * In-memory output channels. Stage shuffle data does not hit disk. In-memory channels do not fully buffer stage + * output. They use a blocking queue; see {@link RunWorkOrder#makeStageOutputChannelFactory()}. + * + * Because stage output is not fully buffered, this mode requires a consumer stage to run at the same time as its + * corresponding producer stage. See {@link ControllerQueryKernelUtils#computeStageGroups} for the logic that + * determines when we can use in-memory channels. */ MEMORY("memory"), diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/RunWorkOrder.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/RunWorkOrder.java new file mode 100644 index 0000000000000..0173979efeed3 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/RunWorkOrder.java @@ -0,0 +1,1051 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.exec; + +import com.google.common.base.Function; +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterables; +import com.google.common.util.concurrent.FutureCallback; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.MoreExecutors; +import com.google.common.util.concurrent.SettableFuture; +import org.apache.druid.common.guava.FutureUtils; +import org.apache.druid.error.DruidException; +import org.apache.druid.frame.allocation.ArenaMemoryAllocator; +import org.apache.druid.frame.allocation.ArenaMemoryAllocatorFactory; +import org.apache.druid.frame.channel.BlockingQueueFrameChannel; +import org.apache.druid.frame.channel.ByteTracker; +import org.apache.druid.frame.channel.FrameWithPartition; +import org.apache.druid.frame.key.ClusterByPartitions; +import org.apache.druid.frame.processor.BlockingQueueOutputChannelFactory; +import org.apache.druid.frame.processor.Bouncer; +import org.apache.druid.frame.processor.ComposingOutputChannelFactory; +import org.apache.druid.frame.processor.FileOutputChannelFactory; +import org.apache.druid.frame.processor.FrameChannelHashPartitioner; +import org.apache.druid.frame.processor.FrameChannelMixer; +import org.apache.druid.frame.processor.FrameProcessor; +import org.apache.druid.frame.processor.FrameProcessorExecutor; +import org.apache.druid.frame.processor.OutputChannel; +import org.apache.druid.frame.processor.OutputChannelFactory; +import org.apache.druid.frame.processor.OutputChannels; +import org.apache.druid.frame.processor.PartitionedOutputChannel; +import org.apache.druid.frame.processor.SuperSorter; +import org.apache.druid.frame.processor.SuperSorterProgressTracker; +import org.apache.druid.frame.processor.manager.ProcessorManager; +import org.apache.druid.frame.processor.manager.ProcessorManagers; +import org.apache.druid.frame.util.DurableStorageUtils; +import org.apache.druid.frame.write.FrameWriters; +import org.apache.druid.java.util.common.FileUtils; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.java.util.common.UOE; +import org.apache.druid.java.util.common.concurrent.Execs; +import org.apache.druid.msq.counters.CounterNames; +import org.apache.druid.msq.counters.CounterTracker; +import org.apache.druid.msq.indexing.CountingOutputChannelFactory; +import org.apache.druid.msq.indexing.InputChannelFactory; +import org.apache.druid.msq.indexing.InputChannelsImpl; +import org.apache.druid.msq.indexing.processor.KeyStatisticsCollectionProcessor; +import org.apache.druid.msq.input.InputSlice; +import org.apache.druid.msq.input.InputSliceReader; +import org.apache.druid.msq.input.InputSlices; +import org.apache.druid.msq.input.MapInputSliceReader; +import org.apache.druid.msq.input.NilInputSlice; +import org.apache.druid.msq.input.NilInputSliceReader; +import org.apache.druid.msq.input.external.ExternalInputSlice; +import org.apache.druid.msq.input.external.ExternalInputSliceReader; +import org.apache.druid.msq.input.inline.InlineInputSlice; +import org.apache.druid.msq.input.inline.InlineInputSliceReader; +import org.apache.druid.msq.input.lookup.LookupInputSlice; +import org.apache.druid.msq.input.lookup.LookupInputSliceReader; +import org.apache.druid.msq.input.stage.InputChannels; +import org.apache.druid.msq.input.stage.StageInputSlice; +import org.apache.druid.msq.input.stage.StageInputSliceReader; +import org.apache.druid.msq.input.table.SegmentsInputSlice; +import org.apache.druid.msq.input.table.SegmentsInputSliceReader; +import org.apache.druid.msq.kernel.FrameContext; +import org.apache.druid.msq.kernel.FrameProcessorFactory; +import org.apache.druid.msq.kernel.ProcessorsAndChannels; +import org.apache.druid.msq.kernel.ShuffleSpec; +import org.apache.druid.msq.kernel.StageDefinition; +import org.apache.druid.msq.kernel.WorkOrder; +import org.apache.druid.msq.shuffle.output.DurableStorageOutputChannelFactory; +import org.apache.druid.msq.statistics.ClusterByStatisticsCollector; +import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot; +import org.apache.druid.utils.CloseableUtils; +import org.checkerframework.checker.nullness.qual.MonotonicNonNull; + +import javax.annotation.Nullable; +import java.io.File; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.stream.Collectors; + +/** + * Main worker logic for executing a {@link WorkOrder} in a {@link FrameProcessorExecutor}. + */ +public class RunWorkOrder +{ + private final WorkOrder workOrder; + private final InputChannelFactory inputChannelFactory; + private final CounterTracker counterTracker; + private final FrameProcessorExecutor exec; + private final String cancellationId; + private final int parallelism; + private final WorkerContext workerContext; + private final FrameContext frameContext; + private final RunWorkOrderListener listener; + private final boolean reindex; + private final boolean removeNullBytes; + private final ByteTracker intermediateSuperSorterLocalStorageTracker; + private final AtomicBoolean started = new AtomicBoolean(); + + @MonotonicNonNull + private InputSliceReader inputSliceReader; + @MonotonicNonNull + private OutputChannelFactory workOutputChannelFactory; + @MonotonicNonNull + private OutputChannelFactory shuffleOutputChannelFactory; + @MonotonicNonNull + private ResultAndChannels workResultAndOutputChannels; + @MonotonicNonNull + private SettableFuture stagePartitionBoundariesFuture; + @MonotonicNonNull + private ListenableFuture stageOutputChannelsFuture; + + public RunWorkOrder( + final WorkOrder workOrder, + final InputChannelFactory inputChannelFactory, + final CounterTracker counterTracker, + final FrameProcessorExecutor exec, + final String cancellationId, + final WorkerContext workerContext, + final FrameContext frameContext, + final RunWorkOrderListener listener, + final boolean reindex, + final boolean removeNullBytes + ) + { + this.workOrder = workOrder; + this.inputChannelFactory = inputChannelFactory; + this.counterTracker = counterTracker; + this.exec = exec; + this.cancellationId = cancellationId; + this.parallelism = workerContext.threadCount(); + this.workerContext = workerContext; + this.frameContext = frameContext; + this.listener = listener; + this.reindex = reindex; + this.removeNullBytes = removeNullBytes; + this.intermediateSuperSorterLocalStorageTracker = + new ByteTracker( + frameContext.storageParameters().isIntermediateStorageLimitConfigured() + ? frameContext.storageParameters().getIntermediateSuperSorterStorageMaxLocalBytes() + : Long.MAX_VALUE + ); + } + + /** + * Start execution of the provided {@link WorkOrder} in the provided {@link FrameProcessorExecutor}. + * + * Execution proceeds asynchronously after this method returns. The {@link RunWorkOrderListener} passed to the + * constructor of this instance can be used to track progress. + */ + public void start() throws IOException + { + if (started.getAndSet(true)) { + throw new ISE("Already started"); + } + + final StageDefinition stageDef = workOrder.getStageDefinition(); + + try { + makeInputSliceReader(); + makeWorkOutputChannelFactory(); + makeShuffleOutputChannelFactory(); + makeAndRunWorkProcessors(); + + if (stageDef.doesShuffle()) { + makeAndRunShuffleProcessors(); + } else { + // No shuffling: work output _is_ stage output. Retain read-only versions to reduce memory footprint. + stageOutputChannelsFuture = + Futures.immediateFuture(workResultAndOutputChannels.getOutputChannels().readOnly()); + } + + setUpCompletionCallbacks(); + } + catch (Throwable t) { + // If start() has problems, cancel anything that was already kicked off, and close the FrameContext. + try { + exec.cancel(cancellationId); + } + catch (Throwable t2) { + t.addSuppressed(t2); + } + + CloseableUtils.closeAndSuppressExceptions(frameContext, t::addSuppressed); + throw t; + } + } + + /** + * Settable {@link ClusterByPartitions} future for global sort. Necessary because we don't know ahead of time + * what the boundaries will be. The controller decides based on statistics from all workers. Once the controller + * decides, its decision is written to this future, which allows sorting on workers to proceed. + */ + @Nullable + public SettableFuture getStagePartitionBoundariesFuture() + { + return stagePartitionBoundariesFuture; + } + + private void makeInputSliceReader() + { + if (inputSliceReader != null) { + throw new ISE("inputSliceReader already created"); + } + + final String queryId = workOrder.getQueryDefinition().getQueryId(); + + final InputChannels inputChannels = + new InputChannelsImpl( + workOrder.getQueryDefinition(), + InputSlices.allReadablePartitions(workOrder.getInputs()), + inputChannelFactory, + () -> ArenaMemoryAllocator.createOnHeap(frameContext.memoryParameters().getStandardFrameSize()), + exec, + cancellationId, + removeNullBytes + ); + + inputSliceReader = new MapInputSliceReader( + ImmutableMap., InputSliceReader>builder() + .put(NilInputSlice.class, NilInputSliceReader.INSTANCE) + .put(StageInputSlice.class, new StageInputSliceReader(queryId, inputChannels)) + .put(ExternalInputSlice.class, new ExternalInputSliceReader(frameContext.tempDir("external"))) + .put(InlineInputSlice.class, new InlineInputSliceReader(frameContext.segmentWrangler())) + .put(LookupInputSlice.class, new LookupInputSliceReader(frameContext.segmentWrangler())) + .put(SegmentsInputSlice.class, new SegmentsInputSliceReader(frameContext, reindex)) + .build() + ); + } + + private void makeWorkOutputChannelFactory() + { + if (workOutputChannelFactory != null) { + throw new ISE("processorOutputChannelFactory already created"); + } + + final OutputChannelFactory baseOutputChannelFactory; + + if (workOrder.getStageDefinition().doesShuffle()) { + // Writing to a consumer in the same JVM (which will be set up later on in this method). Use the large frame + // size if we're writing to a SuperSorter, since we'll generate fewer temp files if we use larger frames. + // Otherwise, use the standard frame size. + final int frameSize; + + if (workOrder.getStageDefinition().getShuffleSpec().kind().isSort()) { + frameSize = frameContext.memoryParameters().getLargeFrameSize(); + } else { + frameSize = frameContext.memoryParameters().getStandardFrameSize(); + } + + baseOutputChannelFactory = new BlockingQueueOutputChannelFactory(frameSize); + } else { + // Writing stage output. + baseOutputChannelFactory = makeStageOutputChannelFactory(); + } + + workOutputChannelFactory = new CountingOutputChannelFactory( + baseOutputChannelFactory, + counterTracker.channel(CounterNames.outputChannel()) + ); + } + + private void makeShuffleOutputChannelFactory() + { + shuffleOutputChannelFactory = + new CountingOutputChannelFactory( + makeStageOutputChannelFactory(), + counterTracker.channel(CounterNames.shuffleChannel()) + ); + } + + /** + * Use {@link FrameProcessorFactory#makeProcessors} to create {@link ProcessorsAndChannels}. Executes the + * processors using {@link #exec} and sets the output channels in {@link #workResultAndOutputChannels}. + * + * @param type of {@link StageDefinition#getProcessorFactory()} + * @param return type of {@link FrameProcessor} created by the manager + * @param result type of {@link ProcessorManager#result()} + * @param type of {@link WorkOrder#getExtraInfo()} + */ + private , ProcessorReturnType, ManagerReturnType, ExtraInfoType> void makeAndRunWorkProcessors() + throws IOException + { + if (workResultAndOutputChannels != null) { + throw new ISE("workResultAndOutputChannels already set"); + } + + @SuppressWarnings("unchecked") + final FactoryType processorFactory = (FactoryType) workOrder.getStageDefinition().getProcessorFactory(); + + @SuppressWarnings("unchecked") + final ProcessorsAndChannels processors = + processorFactory.makeProcessors( + workOrder.getStageDefinition(), + workOrder.getWorkerNumber(), + workOrder.getInputs(), + inputSliceReader, + (ExtraInfoType) workOrder.getExtraInfo(), + workOutputChannelFactory, + frameContext, + parallelism, + counterTracker, + listener::onWarning, + removeNullBytes + ); + + final ProcessorManager processorManager = processors.getProcessorManager(); + + final int maxOutstandingProcessors; + + if (processors.getOutputChannels().getAllChannels().isEmpty()) { + // No output channels: run up to "parallelism" processors at once. + maxOutstandingProcessors = Math.max(1, parallelism); + } else { + // If there are output channels, that acts as a ceiling on the number of processors that can run at once. + maxOutstandingProcessors = + Math.max(1, Math.min(parallelism, processors.getOutputChannels().getAllChannels().size())); + } + + final ListenableFuture workResultFuture = exec.runAllFully( + processorManager, + maxOutstandingProcessors, + frameContext.processorBouncer(), + cancellationId + ); + + workResultAndOutputChannels = new ResultAndChannels<>(workResultFuture, processors.getOutputChannels()); + } + + private void makeAndRunShuffleProcessors() + { + if (stageOutputChannelsFuture != null) { + throw new ISE("stageOutputChannelsFuture already set"); + } + + final ShuffleSpec shuffleSpec = workOrder.getStageDefinition().getShuffleSpec(); + + final ShufflePipelineBuilder shufflePipeline = new ShufflePipelineBuilder( + workOrder, + counterTracker, + exec, + cancellationId, + frameContext + ); + + shufflePipeline.initialize(workResultAndOutputChannels); + shufflePipeline.gatherResultKeyStatisticsAndReportDoneReadingInputIfNeeded(); + + switch (shuffleSpec.kind()) { + case MIX: + shufflePipeline.mix(shuffleOutputChannelFactory); + break; + + case HASH: + shufflePipeline.hashPartition(shuffleOutputChannelFactory); + break; + + case HASH_LOCAL_SORT: + final OutputChannelFactory hashOutputChannelFactory; + + if (shuffleSpec.partitionCount() == 1) { + // Single partition; no need to write temporary files. + hashOutputChannelFactory = + new BlockingQueueOutputChannelFactory(frameContext.memoryParameters().getStandardFrameSize()); + } else { + // Multi-partition; write temporary files and then sort each one file-by-file. + hashOutputChannelFactory = + new FileOutputChannelFactory( + frameContext.tempDir("hash-parts"), + frameContext.memoryParameters().getStandardFrameSize(), + null + ); + } + + shufflePipeline.hashPartition(hashOutputChannelFactory); + shufflePipeline.localSort(shuffleOutputChannelFactory); + break; + + case GLOBAL_SORT: + shufflePipeline.globalSort(shuffleOutputChannelFactory, makeGlobalSortPartitionBoundariesFuture()); + break; + + default: + throw new UOE("Cannot handle shuffle kind [%s]", shuffleSpec.kind()); + } + + stageOutputChannelsFuture = shufflePipeline.build(); + } + + private ListenableFuture makeGlobalSortPartitionBoundariesFuture() + { + if (workOrder.getStageDefinition().mustGatherResultKeyStatistics()) { + if (stagePartitionBoundariesFuture != null) { + throw new ISE("Cannot call 'makeGlobalSortPartitionBoundariesFuture' twice"); + } + + return (stagePartitionBoundariesFuture = SettableFuture.create()); + } else { + // Result key stats aren't needed, so the partition boundaries are knowable ahead of time. Compute them now. + final ClusterByPartitions boundaries = + workOrder.getStageDefinition() + .generatePartitionBoundariesForShuffle(null) + .valueOrThrow(); + + return Futures.immediateFuture(boundaries); + } + } + + private void setUpCompletionCallbacks() + { + Futures.addCallback( + Futures.allAsList( + Arrays.asList( + workResultAndOutputChannels.getResultFuture(), + stageOutputChannelsFuture + ) + ), + new FutureCallback>() + { + @Override + public void onSuccess(final List workerResultAndOutputChannelsResolved) + { + final Object resultObject = workerResultAndOutputChannelsResolved.get(0); + final OutputChannels outputChannels = (OutputChannels) workerResultAndOutputChannelsResolved.get(1); + + if (workOrder.getOutputChannelMode() != OutputChannelMode.MEMORY) { + // In non-MEMORY output channel modes, call onOutputChannelAvailable when all work is done. + // (In MEMORY mode, we would have called onOutputChannelAvailable when the channels were created.) + for (final OutputChannel channel : outputChannels.getAllChannels()) { + listener.onOutputChannelAvailable(channel); + } + } + + if (workOrder.getOutputChannelMode().isDurable()) { + // In DURABLE_STORAGE output channel mode, write a success file once all work is done. + writeDurableStorageSuccessFile(); + } + + listener.onSuccess(resultObject); + } + + @Override + public void onFailure(final Throwable t) + { + listener.onFailure(t); + } + }, + Execs.directExecutor() + ); + } + + /** + * Write {@link DurableStorageUtils#SUCCESS_MARKER_FILENAME} for a particular stage, if durable storage is enabled. + */ + private void writeDurableStorageSuccessFile() + { + final DurableStorageOutputChannelFactory durableStorageOutputChannelFactory = + makeDurableStorageOutputChannelFactory( + frameContext.tempDir("durable"), + frameContext.memoryParameters().getStandardFrameSize(), + workOrder.getOutputChannelMode() == OutputChannelMode.DURABLE_STORAGE_QUERY_RESULTS + ); + + try { + durableStorageOutputChannelFactory.createSuccessFile(workerContext.workerId()); + } + catch (IOException e) { + throw new ISE( + e, + "Unable to create success file at location[%s]", + durableStorageOutputChannelFactory.getSuccessFilePath() + ); + } + } + + private OutputChannelFactory makeStageOutputChannelFactory() + { + // Use the standard frame size, since we assume this size when computing how much is needed to merge output + // files from different workers. + final int frameSize = frameContext.memoryParameters().getStandardFrameSize(); + final OutputChannelMode outputChannelMode = workOrder.getOutputChannelMode(); + + switch (outputChannelMode) { + case MEMORY: + // Use ListeningOutputChannelFactory to capture output channels as they are created, rather than when + // work is complete. + return new ListeningOutputChannelFactory( + new BlockingQueueOutputChannelFactory(frameSize), + listener::onOutputChannelAvailable + ); + + case LOCAL_STORAGE: + final File fileChannelDirectory = + frameContext.tempDir(StringUtils.format("output_stage_%06d", workOrder.getStageNumber())); + return new FileOutputChannelFactory(fileChannelDirectory, frameSize, null); + + case DURABLE_STORAGE_INTERMEDIATE: + case DURABLE_STORAGE_QUERY_RESULTS: + return makeDurableStorageOutputChannelFactory( + frameContext.tempDir("durable"), + frameSize, + outputChannelMode == OutputChannelMode.DURABLE_STORAGE_QUERY_RESULTS + ); + + default: + throw DruidException.defensive("No handling for outputChannelMode[%s]", outputChannelMode); + } + } + + private OutputChannelFactory makeSuperSorterIntermediateOutputChannelFactory(final File tmpDir) + { + final int frameSize = frameContext.memoryParameters().getLargeFrameSize(); + final File fileChannelDirectory = + new File(tmpDir, StringUtils.format("intermediate_output_stage_%06d", workOrder.getStageNumber())); + final FileOutputChannelFactory fileOutputChannelFactory = + new FileOutputChannelFactory(fileChannelDirectory, frameSize, intermediateSuperSorterLocalStorageTracker); + + if (workOrder.getOutputChannelMode().isDurable() + && frameContext.storageParameters().isIntermediateStorageLimitConfigured()) { + final boolean isQueryResults = + workOrder.getOutputChannelMode() == OutputChannelMode.DURABLE_STORAGE_QUERY_RESULTS; + return new ComposingOutputChannelFactory( + ImmutableList.of( + fileOutputChannelFactory, + makeDurableStorageOutputChannelFactory(tmpDir, frameSize, isQueryResults) + ), + frameSize + ); + } else { + return fileOutputChannelFactory; + } + } + + private DurableStorageOutputChannelFactory makeDurableStorageOutputChannelFactory( + final File tmpDir, + final int frameSize, + final boolean isQueryResults + ) + { + return DurableStorageOutputChannelFactory.createStandardImplementation( + workOrder.getQueryDefinition().getQueryId(), + workOrder.getWorkerNumber(), + workOrder.getStageNumber(), + workerContext.workerId(), + frameSize, + MSQTasks.makeStorageConnector(workerContext.injector()), + tmpDir, + isQueryResults + ); + } + + /** + * Helper for {@link RunWorkOrder#makeAndRunShuffleProcessors()}. Builds a {@link FrameProcessor} pipeline to + * handle the shuffle. + */ + private class ShufflePipelineBuilder + { + private final WorkOrder workOrder; + private final CounterTracker counterTracker; + private final FrameProcessorExecutor exec; + private final String cancellationId; + private final FrameContext frameContext; + + // Current state of the pipeline. It's a future to allow pipeline construction to be deferred if necessary. + private ListenableFuture> pipelineFuture; + + public ShufflePipelineBuilder( + final WorkOrder workOrder, + final CounterTracker counterTracker, + final FrameProcessorExecutor exec, + final String cancellationId, + final FrameContext frameContext + ) + { + this.workOrder = workOrder; + this.counterTracker = counterTracker; + this.exec = exec; + this.cancellationId = cancellationId; + this.frameContext = frameContext; + } + + /** + * Start the pipeline with the outputs of the main processor. + */ + public void initialize(final ResultAndChannels resultAndChannels) + { + if (pipelineFuture != null) { + throw new ISE("already initialized"); + } + + pipelineFuture = Futures.immediateFuture(resultAndChannels); + } + + /** + * Add {@link FrameChannelMixer}, which mixes all current outputs into a single channel from the provided factory. + */ + public void mix(final OutputChannelFactory outputChannelFactory) + { + // No sorting or statistics gathering, just combining all outputs into one big partition. Use a mixer to get + // everything into one file. Note: even if there is only one output channel, we'll run it through the mixer + // anyway, to ensure the data gets written to a file. (httpGetChannelData requires files.) + + push( + resultAndChannels -> { + final OutputChannel outputChannel = outputChannelFactory.openChannel(0); + + final FrameChannelMixer mixer = + new FrameChannelMixer( + resultAndChannels.getOutputChannels().getAllReadableChannels(), + outputChannel.getWritableChannel() + ); + + return new ResultAndChannels<>( + exec.runFully(mixer, cancellationId), + OutputChannels.wrap(Collections.singletonList(outputChannel.readOnly())) + ); + } + ); + } + + /** + * Add {@link KeyStatisticsCollectionProcessor} if {@link StageDefinition#mustGatherResultKeyStatistics()}. + * + * Calls {@link RunWorkOrderListener#onDoneReadingInput(ClusterByStatisticsSnapshot)} when statistics are gathered. + * If statistics were not needed, calls the listener immediately. + */ + public void gatherResultKeyStatisticsAndReportDoneReadingInputIfNeeded() + { + push( + resultAndChannels -> { + final StageDefinition stageDefinition = workOrder.getStageDefinition(); + final OutputChannels channels = resultAndChannels.getOutputChannels(); + + if (channels.getAllChannels().isEmpty()) { + // No data coming out of this stage. Report empty statistics, if the kernel is expecting statistics. + if (stageDefinition.mustGatherResultKeyStatistics()) { + listener.onDoneReadingInput(ClusterByStatisticsSnapshot.empty()); + } else { + listener.onDoneReadingInput(null); + } + + // Generate one empty channel so the next part of the pipeline has something to do. + final BlockingQueueFrameChannel channel = BlockingQueueFrameChannel.minimal(); + channel.writable().close(); + + final OutputChannel outputChannel = OutputChannel.readOnly( + channel.readable(), + FrameWithPartition.NO_PARTITION + ); + + return new ResultAndChannels<>( + Futures.immediateFuture(null), + OutputChannels.wrap(Collections.singletonList(outputChannel)) + ); + } else if (stageDefinition.mustGatherResultKeyStatistics()) { + return gatherResultKeyStatistics(channels); + } else { + // Report "done reading input" when the input future resolves. + // No need to add any processors to the pipeline. + resultAndChannels.resultFuture.addListener( + () -> listener.onDoneReadingInput(null), + Execs.directExecutor() + ); + return resultAndChannels; + } + } + ); + } + + /** + * Add a {@link SuperSorter} using {@link StageDefinition#getSortKey()} and partition boundaries + * from {@code partitionBoundariesFuture}. + */ + public void globalSort( + final OutputChannelFactory outputChannelFactory, + final ListenableFuture partitionBoundariesFuture + ) + { + pushAsync( + resultAndChannels -> { + final StageDefinition stageDefinition = workOrder.getStageDefinition(); + + final File sorterTmpDir = frameContext.tempDir("super-sort"); + FileUtils.mkdirp(sorterTmpDir); + if (!sorterTmpDir.isDirectory()) { + throw new IOException("Cannot create directory: " + sorterTmpDir); + } + + final WorkerMemoryParameters memoryParameters = frameContext.memoryParameters(); + final SuperSorter sorter = new SuperSorter( + resultAndChannels.getOutputChannels().getAllReadableChannels(), + stageDefinition.getFrameReader(), + stageDefinition.getSortKey(), + partitionBoundariesFuture, + exec, + outputChannelFactory, + makeSuperSorterIntermediateOutputChannelFactory(sorterTmpDir), + memoryParameters.getSuperSorterMaxActiveProcessors(), + memoryParameters.getSuperSorterMaxChannelsPerProcessor(), + -1, + cancellationId, + counterTracker.sortProgress(), + removeNullBytes + ); + + return FutureUtils.transform( + sorter.run(), + sortedChannels -> new ResultAndChannels<>(Futures.immediateFuture(null), sortedChannels) + ); + } + ); + } + + /** + * Add a {@link FrameChannelHashPartitioner} using {@link StageDefinition#getSortKey()}. + */ + public void hashPartition(final OutputChannelFactory outputChannelFactory) + { + pushAsync( + resultAndChannels -> { + final ShuffleSpec shuffleSpec = workOrder.getStageDefinition().getShuffleSpec(); + final int partitions = shuffleSpec.partitionCount(); + + final List outputChannels = new ArrayList<>(); + + for (int i = 0; i < partitions; i++) { + outputChannels.add(outputChannelFactory.openChannel(i)); + } + + final FrameChannelHashPartitioner partitioner = new FrameChannelHashPartitioner( + resultAndChannels.getOutputChannels().getAllReadableChannels(), + outputChannels.stream().map(OutputChannel::getWritableChannel).collect(Collectors.toList()), + workOrder.getStageDefinition().getFrameReader(), + workOrder.getStageDefinition().getClusterBy().getColumns().size(), + FrameWriters.makeRowBasedFrameWriterFactory( + new ArenaMemoryAllocatorFactory(frameContext.memoryParameters().getStandardFrameSize()), + workOrder.getStageDefinition().getSignature(), + workOrder.getStageDefinition().getSortKey(), + removeNullBytes + ) + ); + + final ListenableFuture partitionerFuture = exec.runFully(partitioner, cancellationId); + + final ResultAndChannels retVal = + new ResultAndChannels<>(partitionerFuture, OutputChannels.wrap(outputChannels)); + + if (retVal.getOutputChannels().areReadableChannelsReady()) { + return Futures.immediateFuture(retVal); + } else { + return FutureUtils.transform(partitionerFuture, ignored -> retVal); + } + } + ); + } + + /** + * Add a sequence of {@link SuperSorter}, operating on each current output channel in order, one at a time. + */ + public void localSort(final OutputChannelFactory outputChannelFactory) + { + pushAsync( + resultAndChannels -> { + final StageDefinition stageDefinition = workOrder.getStageDefinition(); + final OutputChannels channels = resultAndChannels.getOutputChannels(); + final List> sortedChannelFutures = new ArrayList<>(); + + ListenableFuture nextFuture = Futures.immediateFuture(null); + + for (final OutputChannel channel : channels.getAllChannels()) { + final File sorterTmpDir = frameContext.tempDir( + StringUtils.format("hash-parts-super-sort-%06d", channel.getPartitionNumber()) + ); + + FileUtils.mkdirp(sorterTmpDir); + + // SuperSorter will try to write to output partition zero; we remap it to the correct partition number. + final OutputChannelFactory partitionOverrideOutputChannelFactory = new OutputChannelFactory() + { + @Override + public OutputChannel openChannel(int expectedZero) throws IOException + { + if (expectedZero != 0) { + throw new ISE("Unexpected part [%s]", expectedZero); + } + + return outputChannelFactory.openChannel(channel.getPartitionNumber()); + } + + @Override + public PartitionedOutputChannel openPartitionedChannel(String name, boolean deleteAfterRead) + { + throw new UnsupportedOperationException(); + } + + @Override + public OutputChannel openNilChannel(int expectedZero) + { + if (expectedZero != 0) { + throw new ISE("Unexpected part [%s]", expectedZero); + } + + return outputChannelFactory.openNilChannel(channel.getPartitionNumber()); + } + }; + + // Chain futures so we only sort one partition at a time. + nextFuture = Futures.transformAsync( + nextFuture, + ignored -> { + final SuperSorter sorter = new SuperSorter( + Collections.singletonList(channel.getReadableChannel()), + stageDefinition.getFrameReader(), + stageDefinition.getSortKey(), + Futures.immediateFuture(ClusterByPartitions.oneUniversalPartition()), + exec, + partitionOverrideOutputChannelFactory, + makeSuperSorterIntermediateOutputChannelFactory(sorterTmpDir), + 1, + 2, + -1, + cancellationId, + + // Tracker is not actually tracked, since it doesn't quite fit into the way we report counters. + // There's a single SuperSorterProgressTrackerCounter per worker, but workers that do local + // sorting have a SuperSorter per partition. + new SuperSorterProgressTracker(), + removeNullBytes + ); + + return FutureUtils.transform(sorter.run(), r -> Iterables.getOnlyElement(r.getAllChannels())); + }, + MoreExecutors.directExecutor() + ); + + sortedChannelFutures.add(nextFuture); + } + + return FutureUtils.transform( + Futures.allAsList(sortedChannelFutures), + sortedChannels -> new ResultAndChannels<>( + Futures.immediateFuture(null), + OutputChannels.wrap(sortedChannels) + ) + ); + } + ); + } + + /** + * Return the (future) output channels for this pipeline. + */ + public ListenableFuture build() + { + if (pipelineFuture == null) { + throw new ISE("Not initialized"); + } + + return Futures.transformAsync( + pipelineFuture, + resultAndChannels -> + Futures.transform( + resultAndChannels.getResultFuture(), + (Function) input -> { + sanityCheckOutputChannels(resultAndChannels.getOutputChannels()); + return resultAndChannels.getOutputChannels(); + }, + Execs.directExecutor() + ), + Execs.directExecutor() + ); + } + + /** + * Adds {@link KeyStatisticsCollectionProcessor}. Called by {@link #gatherResultKeyStatisticsAndReportDoneReadingInputIfNeeded()}. + */ + private ResultAndChannels gatherResultKeyStatistics(final OutputChannels channels) + { + final StageDefinition stageDefinition = workOrder.getStageDefinition(); + final List retVal = new ArrayList<>(); + final List processors = new ArrayList<>(); + + for (final OutputChannel outputChannel : channels.getAllChannels()) { + final BlockingQueueFrameChannel channel = BlockingQueueFrameChannel.minimal(); + retVal.add(OutputChannel.readOnly(channel.readable(), outputChannel.getPartitionNumber())); + + processors.add( + new KeyStatisticsCollectionProcessor( + outputChannel.getReadableChannel(), + channel.writable(), + stageDefinition.getFrameReader(), + stageDefinition.getClusterBy(), + stageDefinition.createResultKeyStatisticsCollector( + frameContext.memoryParameters().getPartitionStatisticsMaxRetainedBytes() + ) + ) + ); + } + + final ListenableFuture clusterByStatisticsCollectorFuture = + exec.runAllFully( + ProcessorManagers.of(processors) + .withAccumulation( + stageDefinition.createResultKeyStatisticsCollector( + frameContext.memoryParameters().getPartitionStatisticsMaxRetainedBytes() + ), + ClusterByStatisticsCollector::addAll + ), + // Run all processors simultaneously. They are lightweight and this keeps things moving. + processors.size(), + Bouncer.unlimited(), + cancellationId + ); + + Futures.addCallback( + clusterByStatisticsCollectorFuture, + new FutureCallback() + { + @Override + public void onSuccess(final ClusterByStatisticsCollector result) + { + listener.onDoneReadingInput(result.snapshot()); + } + + @Override + public void onFailure(Throwable t) + { + listener.onFailure( + new ISE(t, "Failed to gather clusterBy statistics for stage[%s]", stageDefinition.getId()) + ); + } + }, + Execs.directExecutor() + ); + + return new ResultAndChannels<>( + clusterByStatisticsCollectorFuture, + OutputChannels.wrap(retVal) + ); + } + + /** + * Update the {@link #pipelineFuture}. + */ + private void push(final ExceptionalFunction, ResultAndChannels> fn) + { + pushAsync( + channels -> + Futures.immediateFuture(fn.apply(channels)) + ); + } + + /** + * Update the {@link #pipelineFuture} asynchronously. + */ + private void pushAsync(final ExceptionalFunction, ListenableFuture>> fn) + { + if (pipelineFuture == null) { + throw new ISE("Not initialized"); + } + + pipelineFuture = FutureUtils.transform( + Futures.transformAsync( + pipelineFuture, + fn::apply, + Execs.directExecutor() + ), + resultAndChannels -> new ResultAndChannels<>( + resultAndChannels.getResultFuture(), + resultAndChannels.getOutputChannels().readOnly() + ) + ); + } + + /** + * Verifies there is exactly one channel per partition. + */ + private void sanityCheckOutputChannels(final OutputChannels outputChannels) + { + for (int partitionNumber : outputChannels.getPartitionNumbers()) { + final List outputChannelsForPartition = + outputChannels.getChannelsForPartition(partitionNumber); + + Preconditions.checkState(partitionNumber >= 0, "Expected partitionNumber >= 0, but got [%s]", partitionNumber); + Preconditions.checkState( + outputChannelsForPartition.size() == 1, + "Expected one channel for partition [%s], but got [%s]", + partitionNumber, + outputChannelsForPartition.size() + ); + } + } + } + + private static class ResultAndChannels + { + private final ListenableFuture resultFuture; + private final OutputChannels outputChannels; + + public ResultAndChannels( + ListenableFuture resultFuture, + OutputChannels outputChannels + ) + { + this.resultFuture = resultFuture; + this.outputChannels = outputChannels; + } + + public ListenableFuture getResultFuture() + { + return resultFuture; + } + + public OutputChannels getOutputChannels() + { + return outputChannels; + } + } + + private interface ExceptionalFunction + { + R apply(T t) throws Exception; + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/RunWorkOrderListener.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/RunWorkOrderListener.java new file mode 100644 index 0000000000000..19c3c6570fe9f --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/RunWorkOrderListener.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.exec; + +import org.apache.druid.frame.processor.OutputChannel; +import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot; + +import javax.annotation.Nullable; + +/** + * Listener for various things that may happen during execution of {@link RunWorkOrder#start()}. Listener methods are + * fired in processing threads, so they must be thread-safe, and it is important that they run quickly. + */ +public interface RunWorkOrderListener +{ + /** + * Called when done reading input. If key statistics were gathered, they are provided. + */ + void onDoneReadingInput(@Nullable ClusterByStatisticsSnapshot snapshot); + + /** + * Called when an output channel becomes available for reading by downstream stages. + */ + void onOutputChannelAvailable(OutputChannel outputChannel); + + /** + * Called when the work order has succeeded. + */ + void onSuccess(Object resultObject); + + /** + * Called when a non-fatal exception is encountered. Work continues after this listener fires. + */ + void onWarning(Throwable t); + + /** + * Called when the work order has failed. + */ + void onFailure(Throwable t); +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Worker.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Worker.java index cc5f0fae17322..a90068060d81e 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Worker.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Worker.java @@ -19,40 +19,44 @@ package org.apache.druid.msq.exec; +import com.google.common.util.concurrent.ListenableFuture; import org.apache.druid.frame.key.ClusterByPartitions; -import org.apache.druid.indexer.TaskStatus; import org.apache.druid.msq.counters.CounterSnapshotsTree; -import org.apache.druid.msq.indexing.MSQWorkerTask; import org.apache.druid.msq.kernel.StageId; import org.apache.druid.msq.kernel.WorkOrder; import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot; -import javax.annotation.Nullable; -import java.io.IOException; import java.io.InputStream; +/** + * Interface for a multi-stage query (MSQ) worker. Workers are long-lived and are able to run multiple {@link WorkOrder} + * prior to exiting. + * + * @see WorkerImpl the production implementation + */ public interface Worker { /** - * Unique ID for this worker. + * Identifier for this worker. Same as {@link WorkerContext#workerId()}. */ String id(); /** - * The task which this worker runs. + * Runs the worker in the current thread. Surrounding classes provide the execution thread. */ - MSQWorkerTask task(); + void run(); /** - * Runs the worker in the current thread. Surrounding classes provide - * the execution thread. + * Terminate the worker upon a cancellation request. Causes a concurrently-running {@link #run()} method in + * a separate thread to cancel all outstanding work and exit. Does not block. Use {@link #awaitStop()} if you + * would like to wait for {@link #run()} to finish. */ - TaskStatus run() throws Exception; + void stop(); /** - * Terminate the worker upon a cancellation request. + * Wait for {@link #run()} to finish. */ - void stopGracefully(); + void awaitStop(); /** * Report that the controller has failed. The worker must cease work immediately. Cleanup then exit. @@ -63,20 +67,20 @@ public interface Worker // Controller-to-worker, and worker-to-worker messages /** - * Called when the worker chat handler receives a request for a work order. Accepts the work order and schedules it for - * execution + * Called when the worker receives a new work order. Accepts the work order and schedules it for execution. */ void postWorkOrder(WorkOrder workOrder); /** * Returns the statistics snapshot for the given stageId. This is called from {@link WorkerSketchFetcher} under - * PARALLEL OR AUTO modes. + * {@link ClusterStatisticsMergeMode#PARALLEL} OR {@link ClusterStatisticsMergeMode#AUTO} modes. */ ClusterByStatisticsSnapshot fetchStatisticsSnapshot(StageId stageId); /** * Returns the statistics snapshot for the given stageId which contains only the sketch for the specified timeChunk. - * This is called from {@link WorkerSketchFetcher} under SEQUENTIAL OR AUTO modes. + * This is called from {@link WorkerSketchFetcher} under {@link ClusterStatisticsMergeMode#SEQUENTIAL} or + * {@link ClusterStatisticsMergeMode#AUTO} modes. */ ClusterByStatisticsSnapshot fetchStatisticsSnapshotForTimeChunk(StageId stageId, long timeChunk); @@ -84,26 +88,30 @@ public interface Worker * Called when the worker chat handler recieves the result partition boundaries for a particular stageNumber * and queryId */ - boolean postResultPartitionBoundaries( - ClusterByPartitions stagePartitionBoundaries, - String queryId, - int stageNumber - ); + boolean postResultPartitionBoundaries(StageId stageId, ClusterByPartitions stagePartitionBoundaries); /** * Returns an InputStream of the worker output for a particular queryId, stageNumber and partitionNumber. * Offset indicates the number of bytes to skip the channel data, and is used to prevent re-reading the same data - * during retry in case of a connection error + * during retry in case of a connection error. + * + * The returned future resolves when at least one byte of data is available, or when the channel is finished. + * If the channel is finished, an empty {@link InputStream} is returned. + * + * With {@link OutputChannelMode#MEMORY}, once this method is called with a certain offset, workers are free to + * delete data prior to that offset. (Already-requested offsets will not be re-requested, because + * {@link OutputChannelMode#MEMORY} requires a single reader.) In this mode, if an already-requested offset is + * re-requested for some reason, an error future is returned. * - * Returns a null if the workerOutput for a particular queryId, stageNumber, and partitionNumber is not found. + * The returned future resolves to null if stage output for a particular queryId, stageNumber, and + * partitionNumber is not found. * - * @throws IOException when the worker output is found but there is an error while reading it. + * Throws an exception when worker output is found, but there is an error while reading it. */ - @Nullable - InputStream readChannel(String queryId, int stageNumber, int partitionNumber, long offset) throws IOException; + ListenableFuture readStageOutput(StageId stageId, int partitionNumber, long offset); /** - * Returns the snapshot of the worker counters + * Returns a snapshot of counters. */ CounterSnapshotsTree getCounters(); @@ -115,7 +123,7 @@ boolean postResultPartitionBoundaries( void postCleanupStage(StageId stageId); /** - * Called when the work required for the query has been finished + * Called when the worker is no longer needed, and should shut down. */ void postFinish(); } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerContext.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerContext.java index f5e86039c23f1..666115d774cf1 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerContext.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerContext.java @@ -21,11 +21,12 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.inject.Injector; -import org.apache.druid.frame.processor.Bouncer; -import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.io.Closer; +import org.apache.druid.msq.indexing.MSQWorkerTask; import org.apache.druid.msq.kernel.FrameContext; +import org.apache.druid.msq.kernel.FrameProcessorFactory; import org.apache.druid.msq.kernel.QueryDefinition; +import org.apache.druid.msq.kernel.WorkOrder; import org.apache.druid.server.DruidNode; import java.io.File; @@ -33,10 +34,21 @@ /** * Context used by multi-stage query workers. * - * Useful because it allows test fixtures to provide their own implementations. + * Each context is scoped to a {@link Worker} and is shared across all {@link WorkOrder} run by that worker. */ public interface WorkerContext { + /** + * Query ID for this context. + */ + String queryId(); + + /** + * Identifier for this worker that enables the controller, and other workers, to find it. For tasks this is the + * task ID from {@link MSQWorkerTask#getId()}. For persistent servers, this is the server URI. + */ + String workerId(); + ObjectMapper jsonMapper(); // Using an Injector directly because tasks do not have a way to provide their own Guice modules. @@ -49,9 +61,15 @@ public interface WorkerContext void registerWorker(Worker worker, Closer closer); /** - * Creates and fetches the controller client for the provided controller ID. + * Maximum number of {@link WorkOrder} that a {@link Worker} with this context will be asked to execute + * simultaneously. + */ + int maxConcurrentStages(); + + /** + * Creates a controller client. */ - ControllerClient makeControllerClient(String controllerId); + ControllerClient makeControllerClient(); /** * Creates and fetches a {@link WorkerClient}. It is independent of the workerId because the workerId is passed @@ -60,24 +78,24 @@ public interface WorkerContext WorkerClient makeWorkerClient(); /** - * Fetch a directory for temporary outputs + * Directory for temporary outputs. */ File tempDir(); - FrameContext frameContext(QueryDefinition queryDef, int stageNumber); + /** + * Create a context with useful objects required by {@link FrameProcessorFactory#makeProcessors}. + */ + FrameContext frameContext(QueryDefinition queryDef, int stageNumber, OutputChannelMode outputChannelMode); + /** + * Number of available processing threads. + */ int threadCount(); /** - * Fetch node info about self + * Fetch node info about self. */ DruidNode selfNode(); - Bouncer processorBouncer(); DataServerQueryHandlerFactory dataServerQueryHandlerFactory(); - - default File tempDir(int stageNumber, String id) - { - return new File(StringUtils.format("%s/stage_%02d/%s", tempDir(), stageNumber, id)); - } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java index 61939d8237319..7d2964eb2f8c7 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java @@ -19,111 +19,58 @@ package org.apache.druid.msq.exec; -import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Function; -import com.google.common.base.Preconditions; import com.google.common.base.Suppliers; +import com.google.common.base.Throwables; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; -import com.google.common.util.concurrent.AsyncFunction; -import com.google.common.util.concurrent.FutureCallback; -import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListeningExecutorService; -import com.google.common.util.concurrent.MoreExecutors; import com.google.common.util.concurrent.SettableFuture; -import it.unimi.dsi.fastutil.bytes.ByteArrays; +import it.unimi.dsi.fastutil.ints.Int2ObjectMap; +import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap; +import it.unimi.dsi.fastutil.ints.IntObjectPair; import org.apache.druid.common.guava.FutureUtils; -import org.apache.druid.frame.allocation.ArenaMemoryAllocator; -import org.apache.druid.frame.allocation.ArenaMemoryAllocatorFactory; -import org.apache.druid.frame.channel.BlockingQueueFrameChannel; -import org.apache.druid.frame.channel.ByteTracker; -import org.apache.druid.frame.channel.FrameWithPartition; -import org.apache.druid.frame.channel.ReadableFileFrameChannel; +import org.apache.druid.error.DruidException; import org.apache.druid.frame.channel.ReadableFrameChannel; -import org.apache.druid.frame.channel.ReadableNilFrameChannel; -import org.apache.druid.frame.file.FrameFile; -import org.apache.druid.frame.file.FrameFileWriter; import org.apache.druid.frame.key.ClusterByPartitions; -import org.apache.druid.frame.processor.BlockingQueueOutputChannelFactory; -import org.apache.druid.frame.processor.Bouncer; -import org.apache.druid.frame.processor.ComposingOutputChannelFactory; -import org.apache.druid.frame.processor.FileOutputChannelFactory; -import org.apache.druid.frame.processor.FrameChannelHashPartitioner; -import org.apache.druid.frame.processor.FrameChannelMixer; -import org.apache.druid.frame.processor.FrameProcessor; import org.apache.druid.frame.processor.FrameProcessorExecutor; import org.apache.druid.frame.processor.OutputChannel; -import org.apache.druid.frame.processor.OutputChannelFactory; -import org.apache.druid.frame.processor.OutputChannels; -import org.apache.druid.frame.processor.PartitionedOutputChannel; -import org.apache.druid.frame.processor.SuperSorter; -import org.apache.druid.frame.processor.SuperSorterProgressTracker; -import org.apache.druid.frame.processor.manager.ProcessorManager; -import org.apache.druid.frame.processor.manager.ProcessorManagers; import org.apache.druid.frame.util.DurableStorageUtils; -import org.apache.druid.frame.write.FrameWriters; -import org.apache.druid.indexer.TaskStatus; import org.apache.druid.java.util.common.FileUtils; import org.apache.druid.java.util.common.ISE; -import org.apache.druid.java.util.common.Pair; import org.apache.druid.java.util.common.RE; import org.apache.druid.java.util.common.StringUtils; -import org.apache.druid.java.util.common.UOE; import org.apache.druid.java.util.common.io.Closer; import org.apache.druid.java.util.common.logger.Logger; -import org.apache.druid.msq.counters.CounterNames; import org.apache.druid.msq.counters.CounterSnapshotsTree; import org.apache.druid.msq.counters.CounterTracker; -import org.apache.druid.msq.indexing.CountingOutputChannelFactory; import org.apache.druid.msq.indexing.InputChannelFactory; -import org.apache.druid.msq.indexing.InputChannelsImpl; import org.apache.druid.msq.indexing.MSQWorkerTask; import org.apache.druid.msq.indexing.destination.MSQSelectDestination; import org.apache.druid.msq.indexing.error.CanceledFault; import org.apache.druid.msq.indexing.error.CannotParseExternalDataFault; import org.apache.druid.msq.indexing.error.MSQErrorReport; import org.apache.druid.msq.indexing.error.MSQException; -import org.apache.druid.msq.indexing.error.MSQFaultUtils; import org.apache.druid.msq.indexing.error.MSQWarningReportLimiterPublisher; import org.apache.druid.msq.indexing.error.MSQWarningReportPublisher; import org.apache.druid.msq.indexing.error.MSQWarningReportSimplePublisher; import org.apache.druid.msq.indexing.error.MSQWarnings; -import org.apache.druid.msq.indexing.processor.KeyStatisticsCollectionProcessor; -import org.apache.druid.msq.input.InputSlice; -import org.apache.druid.msq.input.InputSliceReader; import org.apache.druid.msq.input.InputSlices; -import org.apache.druid.msq.input.MapInputSliceReader; -import org.apache.druid.msq.input.NilInputSlice; -import org.apache.druid.msq.input.NilInputSliceReader; -import org.apache.druid.msq.input.external.ExternalInputSlice; -import org.apache.druid.msq.input.external.ExternalInputSliceReader; -import org.apache.druid.msq.input.inline.InlineInputSlice; -import org.apache.druid.msq.input.inline.InlineInputSliceReader; -import org.apache.druid.msq.input.lookup.LookupInputSlice; -import org.apache.druid.msq.input.lookup.LookupInputSliceReader; -import org.apache.druid.msq.input.stage.InputChannels; import org.apache.druid.msq.input.stage.ReadablePartition; -import org.apache.druid.msq.input.stage.StageInputSlice; -import org.apache.druid.msq.input.stage.StageInputSliceReader; -import org.apache.druid.msq.input.table.SegmentsInputSlice; -import org.apache.druid.msq.input.table.SegmentsInputSliceReader; import org.apache.druid.msq.kernel.FrameContext; -import org.apache.druid.msq.kernel.FrameProcessorFactory; -import org.apache.druid.msq.kernel.ProcessorsAndChannels; -import org.apache.druid.msq.kernel.ShuffleSpec; import org.apache.druid.msq.kernel.StageDefinition; import org.apache.druid.msq.kernel.StageId; -import org.apache.druid.msq.kernel.StagePartition; import org.apache.druid.msq.kernel.WorkOrder; +import org.apache.druid.msq.kernel.controller.ControllerQueryKernelUtils; import org.apache.druid.msq.kernel.worker.WorkerStageKernel; import org.apache.druid.msq.kernel.worker.WorkerStagePhase; import org.apache.druid.msq.shuffle.input.DurableStorageInputChannelFactory; +import org.apache.druid.msq.shuffle.input.MetaInputChannelFactory; import org.apache.druid.msq.shuffle.input.WorkerInputChannelFactory; -import org.apache.druid.msq.shuffle.output.DurableStorageOutputChannelFactory; -import org.apache.druid.msq.statistics.ClusterByStatisticsCollector; +import org.apache.druid.msq.shuffle.input.WorkerOrLocalInputChannelFactory; +import org.apache.druid.msq.shuffle.output.StageOutputHolder; import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot; import org.apache.druid.msq.statistics.PartialKeyStatisticsInformation; import org.apache.druid.msq.util.DecoratedExecutorService; @@ -132,23 +79,14 @@ import org.apache.druid.query.PrioritizedRunnable; import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryProcessingPool; -import org.apache.druid.rpc.ServiceClosedException; import org.apache.druid.server.DruidNode; import javax.annotation.Nullable; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.File; +import java.io.Closeable; import java.io.IOException; import java.io.InputStream; -import java.io.RandomAccessFile; -import java.nio.channels.Channels; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; import java.util.Comparator; -import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; @@ -156,9 +94,9 @@ import java.util.concurrent.BlockingQueue; import java.util.concurrent.Callable; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; import java.util.concurrent.LinkedBlockingDeque; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Consumer; import java.util.function.Supplier; import java.util.stream.Collectors; @@ -166,105 +104,93 @@ /** * Interface for a worker of a multi-stage query. + * + * Not scoped to any particular query. There is one of these per {@link MSQWorkerTask}, and one per server for + * long-lived workers. */ public class WorkerImpl implements Worker { private static final Logger log = new Logger(WorkerImpl.class); + /** + * Task object, if this {@link WorkerImpl} was launched from a task. Ideally, this would not be needed, and we + * would be able to get everything we need from {@link WorkerContext}. + */ + @Nullable private final MSQWorkerTask task; private final WorkerContext context; private final DruidNode selfDruidNode; - private final Bouncer processorBouncer; - private final BlockingQueue> kernelManipulationQueue = new LinkedBlockingDeque<>(); - private final ConcurrentHashMap> stageOutputs = new ConcurrentHashMap<>(); - private final ConcurrentHashMap stageCounters = new ConcurrentHashMap<>(); - private final ConcurrentHashMap stageKernelMap = new ConcurrentHashMap<>(); - private final ByteTracker intermediateSuperSorterLocalStorageTracker; - private final boolean durableStageStorageEnabled; - private final WorkerStorageParameters workerStorageParameters; - private final boolean isRemoveNullBytes; + private final BlockingQueue> kernelManipulationQueue = new LinkedBlockingDeque<>(); + private final ConcurrentHashMap> stageOutputs = new ConcurrentHashMap<>(); /** - * Only set for select jobs. + * Pair of {workerNumber, stageId} -> counters. */ - @Nullable - private final MSQSelectDestination selectDestination; + private final ConcurrentHashMap, CounterTracker> stageCounters = new ConcurrentHashMap<>(); + + /** + * Atomic that is set to true when {@link #run()} starts (or when {@link #stop()} is called before {@link #run()}). + */ + private final AtomicBoolean didRun = new AtomicBoolean(); /** - * Set once in {@link #runTask} and never reassigned. + * Future that resolves when {@link #run()} completes. + */ + private final SettableFuture runFuture = SettableFuture.create(); + + /** + * Set once in {@link #run} and never reassigned. This is in a field so {@link #doCancel()} can close it. */ private volatile ControllerClient controllerClient; /** - * Set once in {@link #runTask} and never reassigned. Used by processing threads so we can contact other workers + * Set once in {@link #runInternal} and never reassigned. Used by processing threads so we can contact other workers * during a shuffle. */ private volatile WorkerClient workerClient; /** - * Set to false by {@link #controllerFailed()} as a way of enticing the {@link #runTask} method to exit promptly. + * Set to false by {@link #controllerFailed()} as a way of enticing the {@link #runInternal} method to exit promptly. */ private volatile boolean controllerAlive = true; - public WorkerImpl(MSQWorkerTask task, WorkerContext context) - { - this( - task, - context, - WorkerStorageParameters.createProductionInstance( - context.injector(), - MultiStageQueryContext.isDurableStorageEnabled(QueryContext.of(task.getContext())) - // If Durable Storage is enabled, then super sorter intermediate storage can be enabled. - ) - ); - } - - @VisibleForTesting - public WorkerImpl(MSQWorkerTask task, WorkerContext context, WorkerStorageParameters workerStorageParameters) + public WorkerImpl(@Nullable final MSQWorkerTask task, final WorkerContext context) { this.task = task; this.context = context; this.selfDruidNode = context.selfNode(); - this.processorBouncer = context.processorBouncer(); - QueryContext queryContext = QueryContext.of(task.getContext()); - this.durableStageStorageEnabled = MultiStageQueryContext.isDurableStorageEnabled(queryContext); - this.selectDestination = MultiStageQueryContext.getSelectDestinationOrNull(queryContext); - this.isRemoveNullBytes = MultiStageQueryContext.removeNullBytes(queryContext); - this.workerStorageParameters = workerStorageParameters; - - long maxBytes = workerStorageParameters.isIntermediateStorageLimitConfigured() - ? workerStorageParameters.getIntermediateSuperSorterStorageMaxLocalBytes() - : Long.MAX_VALUE; - this.intermediateSuperSorterLocalStorageTracker = new ByteTracker(maxBytes); } @Override public String id() { - return task.getId(); + return context.workerId(); } @Override - public MSQWorkerTask task() + public void run() { - return task; - } + if (!didRun.compareAndSet(false, true)) { + throw new ISE("already run"); + } - @Override - public TaskStatus run() throws Exception - { try (final Closer closer = Closer.create()) { + final KernelHolders kernelHolders = KernelHolders.create(context, closer); + controllerClient = kernelHolders.getControllerClient(); + + Throwable t = null; Optional maybeErrorReport; try { - maybeErrorReport = runTask(closer); + maybeErrorReport = runInternal(kernelHolders, closer); } catch (Throwable e) { + t = e; maybeErrorReport = Optional.of( MSQErrorReport.fromException( - id(), - MSQTasks.getHostFromSelfNode(selfDruidNode), + context.workerId(), + MSQTasks.getHostFromSelfNode(context.selfNode()), null, e ) @@ -273,203 +199,112 @@ public TaskStatus run() throws Exception if (maybeErrorReport.isPresent()) { final MSQErrorReport errorReport = maybeErrorReport.get(); - final String errorLogMessage = MSQTasks.errorReportToLogMessage(errorReport); - log.warn(errorLogMessage); + final String logMessage = MSQTasks.errorReportToLogMessage(errorReport); + log.warn("%s", logMessage); - closer.register(() -> { - if (controllerAlive && controllerClient != null && selfDruidNode != null) { - controllerClient.postWorkerError(id(), errorReport); - } - }); + if (controllerAlive) { + controllerClient.postWorkerError(context.workerId(), errorReport); + } - return TaskStatus.failure(id(), MSQFaultUtils.generateMessageWithErrorCode(errorReport.getFault())); - } else { - return TaskStatus.success(id()); + if (t != null) { + Throwables.throwIfInstanceOf(t, MSQException.class); + throw new MSQException(t, maybeErrorReport.get().getFault()); + } else { + throw new MSQException(maybeErrorReport.get().getFault()); + } } } + catch (IOException e) { + throw new RuntimeException(e); + } + finally { + runFuture.set(null); + } } /** * Runs worker logic. Returns an empty Optional on success. On failure, returns an error report for errors that * happened in other threads; throws exceptions for errors that happened in the main worker loop. */ - public Optional runTask(final Closer closer) throws Exception + private Optional runInternal(final KernelHolders kernelHolders, final Closer workerCloser) + throws Exception { - this.controllerClient = context.makeControllerClient(task.getControllerTaskId()); - closer.register(controllerClient::close); - closer.register(context.dataServerQueryHandlerFactory()); - context.registerWorker(this, closer); // Uses controllerClient, so must be called after that is initialized - - this.workerClient = new ExceptionWrappingWorkerClient(context.makeWorkerClient()); - closer.register(workerClient::close); - - final KernelHolder kernelHolder = new KernelHolder(); - final String cancellationId = id(); - + context.registerWorker(this, workerCloser); + workerCloser.register(context.dataServerQueryHandlerFactory()); + this.workerClient = workerCloser.register(new ExceptionWrappingWorkerClient(context.makeWorkerClient())); final FrameProcessorExecutor workerExec = new FrameProcessorExecutor(makeProcessingPool()); - // Delete all the stage outputs - closer.register(() -> { - for (final StageId stageId : stageOutputs.keySet()) { - cleanStageOutput(stageId, false); - } - }); - - // Close stage output processors and running futures (if present) - closer.register(() -> { - try { - workerExec.cancel(cancellationId); - } - catch (InterruptedException e) { - // Strange that cancellation would itself be interrupted. Throw an exception, since this is unexpected. - throw new RuntimeException(e); - } - }); + final long maxAllowedParseExceptions; - long maxAllowedParseExceptions = Long.parseLong(task.getContext().getOrDefault( - MSQWarnings.CTX_MAX_PARSE_EXCEPTIONS_ALLOWED, - Long.MAX_VALUE - ).toString()); + if (task != null) { + maxAllowedParseExceptions = + Long.parseLong(task.getContext() + .getOrDefault(MSQWarnings.CTX_MAX_PARSE_EXCEPTIONS_ALLOWED, Long.MAX_VALUE) + .toString()); + } else { + maxAllowedParseExceptions = 0; + } - long maxVerboseParseExceptions; + final long maxVerboseParseExceptions; if (maxAllowedParseExceptions == -1L) { maxVerboseParseExceptions = Limits.MAX_VERBOSE_PARSE_EXCEPTIONS; } else { maxVerboseParseExceptions = Math.min(maxAllowedParseExceptions, Limits.MAX_VERBOSE_PARSE_EXCEPTIONS); } - Set criticalWarningCodes; + final Set criticalWarningCodes; if (maxAllowedParseExceptions == 0) { criticalWarningCodes = ImmutableSet.of(CannotParseExternalDataFault.CODE); } else { criticalWarningCodes = ImmutableSet.of(); } - final MSQWarningReportPublisher msqWarningReportPublisher = new MSQWarningReportLimiterPublisher( - new MSQWarningReportSimplePublisher( - id(), - controllerClient, - id(), - MSQTasks.getHostFromSelfNode(selfDruidNode) - ), - Limits.MAX_VERBOSE_WARNINGS, - ImmutableMap.of(CannotParseExternalDataFault.CODE, maxVerboseParseExceptions), - criticalWarningCodes, - controllerClient, - id(), - MSQTasks.getHostFromSelfNode(selfDruidNode) - ); - - closer.register(msqWarningReportPublisher); - - final Map> partitionBoundariesFutureMap = new HashMap<>(); + // Delay removal of kernels so we don't interfere with iteration of kernelHolders.getAllKernelHolders(). + final Set kernelsToRemove = new HashSet<>(); - final Map stageFrameContexts = new HashMap<>(); - - while (!kernelHolder.isDone()) { + while (!kernelHolders.isDone()) { boolean didSomething = false; - for (final WorkerStageKernel kernel : kernelHolder.getStageKernelMap().values()) { + for (final KernelHolder kernelHolder : kernelHolders.getAllKernelHolders()) { + final WorkerStageKernel kernel = kernelHolder.kernel; final StageDefinition stageDefinition = kernel.getStageDefinition(); - if (kernel.getPhase() == WorkerStagePhase.NEW) { - - log.info("Processing work order for stage [%d]" + - (log.isDebugEnabled() - ? StringUtils.format( - " with payload [%s]", - context.jsonMapper().writeValueAsString(kernel.getWorkOrder()) - ) : ""), stageDefinition.getId().getStageNumber()); - - // Create separate inputChannelFactory per stage, because the list of tasks can grow between stages, and - // so we need to avoid the memoization in baseInputChannelFactory. - final InputChannelFactory inputChannelFactory = makeBaseInputChannelFactory(closer); - - // Compute memory parameters for all stages, even ones that haven't been assigned yet, so we can fail-fast - // if some won't work. (We expect that all stages will get assigned to the same pool of workers.) - for (final StageDefinition stageDef : kernel.getWorkOrder().getQueryDefinition().getStageDefinitions()) { - stageFrameContexts.computeIfAbsent( - stageDef.getId(), - stageId -> context.frameContext( - kernel.getWorkOrder().getQueryDefinition(), - stageId.getStageNumber() - ) - ); - } - - // Start working on this stage immediately. - kernel.startReading(); - - final RunWorkOrder runWorkOrder = new RunWorkOrder( - kernel, - inputChannelFactory, - stageCounters.computeIfAbsent(stageDefinition.getId(), ignored -> new CounterTracker()), + // Workers run all work orders they get. There is not (currently) any limit on the number of concurrent work + // orders; we rely on the controller to avoid overloading workers. + if (kernel.getPhase() == WorkerStagePhase.NEW + && kernelHolders.runningKernelCount() < context.maxConcurrentStages()) { + handleNewWorkOrder( + kernelHolder, + controllerClient, workerExec, - cancellationId, - context.threadCount(), - stageFrameContexts.get(stageDefinition.getId()), - msqWarningReportPublisher + criticalWarningCodes, + maxVerboseParseExceptions ); - - runWorkOrder.start(); - - final SettableFuture partitionBoundariesFuture = - runWorkOrder.getStagePartitionBoundariesFuture(); - - if (partitionBoundariesFuture != null) { - if (partitionBoundariesFutureMap.put(stageDefinition.getId(), partitionBoundariesFuture) != null) { - throw new ISE("Work order collision for stage [%s]", stageDefinition.getId()); - } - } - + logKernelStatus(kernelHolders.getAllKernels()); didSomething = true; - logKernelStatus(kernelHolder.getStageKernelMap().values()); } - if (kernel.getPhase() == WorkerStagePhase.READING_INPUT && kernel.hasResultKeyStatisticsSnapshot()) { - if (controllerAlive) { - PartialKeyStatisticsInformation partialKeyStatisticsInformation = - kernel.getResultKeyStatisticsSnapshot() - .partialKeyStatistics(); - - controllerClient.postPartialKeyStatistics( - stageDefinition.getId(), - kernel.getWorkOrder().getWorkerNumber(), - partialKeyStatisticsInformation - ); - } - kernel.startPreshuffleWaitingForResultPartitionBoundaries(); - + if (kernel.getPhase() == WorkerStagePhase.READING_INPUT + && handleReadingInput(kernelHolder, controllerClient)) { didSomething = true; - logKernelStatus(kernelHolder.getStageKernelMap().values()); + logKernelStatus(kernelHolders.getAllKernels()); } - logKernelStatus(kernelHolder.getStageKernelMap().values()); if (kernel.getPhase() == WorkerStagePhase.PRESHUFFLE_WAITING_FOR_RESULT_PARTITION_BOUNDARIES - && kernel.hasResultPartitionBoundaries()) { - partitionBoundariesFutureMap.get(stageDefinition.getId()).set(kernel.getResultPartitionBoundaries()); - kernel.startPreshuffleWritingOutput(); - + && handleWaitingForResultPartitionBoundaries(kernelHolder)) { didSomething = true; - logKernelStatus(kernelHolder.getStageKernelMap().values()); + logKernelStatus(kernelHolders.getAllKernels()); } - if (kernel.getPhase() == WorkerStagePhase.RESULTS_READY - && kernel.addPostedResultsComplete(Pair.of( - stageDefinition.getId(), - kernel.getWorkOrder().getWorkerNumber() - ))) { - if (controllerAlive) { - controllerClient.postResultsComplete( - stageDefinition.getId(), - kernel.getWorkOrder().getWorkerNumber(), - kernel.getResultObject() - ); - } + if (kernel.getPhase() == WorkerStagePhase.RESULTS_COMPLETE + && handleResultsReady(kernelHolder, controllerClient)) { + didSomething = true; + logKernelStatus(kernelHolders.getAllKernels()); } if (kernel.getPhase() == WorkerStagePhase.FAILED) { - // Better than throwing an exception, because we can include the stage number. + // Return an error report when a work order fails. This is better than throwing an exception, because we can + // include the stage number. return Optional.of( MSQErrorReport.fromException( id(), @@ -479,17 +314,37 @@ public Optional runTask(final Closer closer) throws Exception ) ); } + + if (kernel.getPhase().isTerminal()) { + handleTerminated(kernelHolder); + kernelsToRemove.add(stageDefinition.getId()); + } + } + + for (final StageId stageId : kernelsToRemove) { + kernelHolders.removeKernel(stageId); } - if (!didSomething && !kernelHolder.isDone()) { - Consumer nextCommand; + kernelsToRemove.clear(); + if (!didSomething && !kernelHolders.isDone()) { + Consumer nextCommand; + + // Run the next command, waiting for it if necessary. Post counters to the controller every 5 seconds + // while waiting. do { - postCountersToController(); + postCountersToController(kernelHolders.getControllerClient()); } while ((nextCommand = kernelManipulationQueue.poll(5, TimeUnit.SECONDS)) == null); - nextCommand.accept(kernelHolder); - logKernelStatus(kernelHolder.getStageKernelMap().values()); + nextCommand.accept(kernelHolders); + + // Run all pending commands after that one. Helps avoid deep queues. + // After draining the command queue, move on to the next iteration of the worker loop. + while ((nextCommand = kernelManipulationQueue.poll()) != null) { + nextCommand.accept(kernelHolders); + } + + logKernelStatus(kernelHolders.getAllKernels()); } } @@ -497,123 +352,288 @@ public Optional runTask(final Closer closer) throws Exception return Optional.empty(); } - @Override - public void stopGracefully() + /** + * Handle a kernel in state {@link WorkerStagePhase#NEW}. The kernel is transitioned to + * {@link WorkerStagePhase#READING_INPUT} and a {@link RunWorkOrder} instance is created to start executing work. + */ + private void handleNewWorkOrder( + final KernelHolder kernelHolder, + final ControllerClient controllerClient, + final FrameProcessorExecutor workerExec, + final Set criticalWarningCodes, + final long maxVerboseParseExceptions + ) throws IOException { - // stopGracefully() is called when the containing process is terminated, or when the task is canceled. - log.info("Worker task[%s] canceled.", task.getId()); - doCancel(); - } + final WorkerStageKernel kernel = kernelHolder.kernel; + final WorkOrder workOrder = kernel.getWorkOrder(); + final StageDefinition stageDefinition = workOrder.getStageDefinition(); + final String cancellationId = cancellationIdFor(stageDefinition.getId()); + + log.info( + "Processing work order for stage[%s]%s", + stageDefinition.getId(), + (log.isDebugEnabled() + ? StringUtils.format(", payload[%s]", context.jsonMapper().writeValueAsString(workOrder)) : "") + ); - @Override - public void controllerFailed() - { - log.info("Controller task[%s] for worker task[%s] failed. Canceling.", task.getControllerTaskId(), task.getId()); - doCancel(); + final FrameContext frameContext = kernelHolder.processorCloser.register( + context.frameContext( + workOrder.getQueryDefinition(), + stageDefinition.getStageNumber(), + workOrder.getOutputChannelMode() + ) + ); + kernelHolder.processorCloser.register(() -> { + try { + workerExec.cancel(cancellationId); + } + catch (InterruptedException e) { + // Strange that cancellation would itself be interrupted. Log and suppress. + log.warn(e, "Cancellation interrupted for stage[%s]", stageDefinition.getId()); + Thread.currentThread().interrupt(); + } + }); + + // Set up cleanup functions for this work order. + kernelHolder.resultsCloser.register(() -> FileUtils.deleteDirectory(frameContext.tempDir())); + kernelHolder.resultsCloser.register(() -> removeStageOutputChannels(stageDefinition.getId())); + + // Create separate inputChannelFactory per stage, because the list of tasks can grow between stages, and + // so we need to avoid the memoization of controllerClient.getWorkerIds() in baseInputChannelFactory. + final InputChannelFactory inputChannelFactory = + makeBaseInputChannelFactory(workOrder, controllerClient, kernelHolder.processorCloser); + + // Start working on this stage immediately. + kernel.startReading(); + + final QueryContext queryContext = task != null ? QueryContext.of(task.getContext()) : QueryContext.empty(); + final RunWorkOrder runWorkOrder = new RunWorkOrder( + workOrder, + inputChannelFactory, + stageCounters.computeIfAbsent( + IntObjectPair.of(workOrder.getWorkerNumber(), stageDefinition.getId()), + ignored -> new CounterTracker() + ), + workerExec, + cancellationId, + context, + frameContext, + makeRunWorkOrderListener(workOrder, controllerClient, criticalWarningCodes, maxVerboseParseExceptions), + MultiStageQueryContext.isReindex(queryContext), + MultiStageQueryContext.removeNullBytes(queryContext) + ); + + runWorkOrder.start(); + kernelHolder.partitionBoundariesFuture = runWorkOrder.getStagePartitionBoundariesFuture(); } - @Override - public InputStream readChannel( - final String queryId, - final int stageNumber, - final int partitionNumber, - final long offset + /** + * Handle a kernel in state {@link WorkerStagePhase#READING_INPUT}. + * + * If the worker has finished generating result key statistics, they are posted to the controller and the kernel is + * transitioned to {@link WorkerStagePhase#PRESHUFFLE_WAITING_FOR_RESULT_PARTITION_BOUNDARIES}. + * + * @return whether kernel state changed + */ + private boolean handleReadingInput( + final KernelHolder kernelHolder, + final ControllerClient controllerClient ) throws IOException { - final StageId stageId = new StageId(queryId, stageNumber); - final StagePartition stagePartition = new StagePartition(stageId, partitionNumber); - final ConcurrentHashMap partitionOutputsForStage = stageOutputs.get(stageId); + final WorkerStageKernel kernel = kernelHolder.kernel; + if (kernel.hasResultKeyStatisticsSnapshot()) { + if (controllerAlive) { + PartialKeyStatisticsInformation partialKeyStatisticsInformation = + kernel.getResultKeyStatisticsSnapshot() + .partialKeyStatistics(); + + controllerClient.postPartialKeyStatistics( + kernel.getStageDefinition().getId(), + kernel.getWorkOrder().getWorkerNumber(), + partialKeyStatisticsInformation + ); + } - if (partitionOutputsForStage == null) { - return null; + kernel.startPreshuffleWaitingForResultPartitionBoundaries(); + return true; + } else if (kernel.isDoneReadingInput() + && kernel.getStageDefinition().doesSortDuringShuffle() + && !kernel.getStageDefinition().mustGatherResultKeyStatistics()) { + // Skip postDoneReadingInput when context.maxConcurrentStages() == 1, for backwards compatibility. + // See Javadoc comment on ControllerClient#postDoneReadingInput. + if (controllerAlive && context.maxConcurrentStages() > 1) { + controllerClient.postDoneReadingInput( + kernel.getStageDefinition().getId(), + kernel.getWorkOrder().getWorkerNumber() + ); + } + + kernel.startPreshuffleWritingOutput(); + return true; + } else { + return false; + } + } + + /** + * Handle a kernel in state {@link WorkerStagePhase#PRESHUFFLE_WAITING_FOR_RESULT_PARTITION_BOUNDARIES}. + * + * If partition boundaries have become available, the {@link KernelHolder#partitionBoundariesFuture} is updated and + * the kernel is transitioned to state {@link WorkerStagePhase#PRESHUFFLE_WRITING_OUTPUT}. + * + * @return whether kernel state changed + */ + private boolean handleWaitingForResultPartitionBoundaries(final KernelHolder kernelHolder) + { + if (kernelHolder.kernel.hasResultPartitionBoundaries()) { + kernelHolder.partitionBoundariesFuture.set(kernelHolder.kernel.getResultPartitionBoundaries()); + kernelHolder.kernel.startPreshuffleWritingOutput(); + return true; + } else { + return false; } - final ReadableFrameChannel channel = partitionOutputsForStage.get(partitionNumber); + } - if (channel == null) { - return null; + /** + * Handle a kernel in state {@link WorkerStagePhase#RESULTS_COMPLETE}. If {@link ControllerClient#postResultsComplete} + * has not yet been posted to the controller, it is posted at this time. Otherwise nothing happens. + * + * @return whether kernel state changed + */ + private boolean handleResultsReady(final KernelHolder kernelHolder, final ControllerClient controllerClient) + throws IOException + { + final WorkerStageKernel kernel = kernelHolder.kernel; + final boolean didNotPostYet = + kernel.addPostedResultsComplete(kernel.getStageDefinition().getId(), kernel.getWorkOrder().getWorkerNumber()); + + if (controllerAlive && didNotPostYet) { + controllerClient.postResultsComplete( + kernel.getStageDefinition().getId(), + kernel.getWorkOrder().getWorkerNumber(), + kernel.getResultObject() + ); } - if (channel instanceof ReadableNilFrameChannel) { - // Build an empty frame file. - final ByteArrayOutputStream baos = new ByteArrayOutputStream(); - FrameFileWriter.open(Channels.newChannel(baos), null, ByteTracker.unboundedTracker()).close(); + return didNotPostYet; + } - final ByteArrayInputStream in = new ByteArrayInputStream(baos.toByteArray()); + /** + * Handle a kernel in state where {@link WorkerStagePhase#isTerminal()} is true. + */ + private void handleTerminated(final KernelHolder kernelHolder) + { + final WorkerStageKernel kernel = kernelHolder.kernel; + removeStageOutputChannels(kernel.getStageDefinition().getId()); - //noinspection ResultOfMethodCallIgnored: OK to ignore since "skip" always works for ByteArrayInputStream. - in.skip(offset); + if (kernelHolder.kernel.getWorkOrder().getOutputChannelMode().isDurable()) { + removeStageDurableStorageOutput(kernel.getStageDefinition().getId()); + } + } - return in; - } else if (channel instanceof ReadableFileFrameChannel) { - // Close frameFile once we've returned an input stream: no need to retain a reference to the mmap after that, - // since we aren't using it. - try (final FrameFile frameFile = ((ReadableFileFrameChannel) channel).newFrameFileReference()) { - final RandomAccessFile randomAccessFile = new RandomAccessFile(frameFile.file(), "r"); + @Override + public void stop() + { + // stopGracefully() is called when the containing process is terminated, or when the task is canceled. + log.info("Worker id[%s] canceled.", context.workerId()); - if (offset >= randomAccessFile.length()) { - randomAccessFile.close(); - return new ByteArrayInputStream(ByteArrays.EMPTY_ARRAY); - } else { - randomAccessFile.seek(offset); - return Channels.newInputStream(randomAccessFile.getChannel()); - } - } + if (didRun.compareAndSet(false, true)) { + // run() hasn't been called yet. Set runFuture so awaitStop() still works. + runFuture.set(null); } else { - String errorMsg = StringUtils.format( - "Returned server error to client because channel for [%s] is not nil or file-based (class = %s)", - stagePartition, - channel.getClass().getName() - ); - log.error(StringUtils.encodeForFormat(errorMsg)); - - throw new IOException(errorMsg); + doCancel(); } } + @Override + public void awaitStop() + { + FutureUtils.getUnchecked(runFuture, false); + } + + @Override + public void controllerFailed() + { + log.info( + "Controller task[%s] for worker[%s] failed. Canceling.", + task != null ? task.getControllerTaskId() : null, + id() + ); + doCancel(); + } + + @Override + public ListenableFuture readStageOutput( + final StageId stageId, + final int partitionNumber, + final long offset + ) + { + return getOrCreateStageOutputHolder(stageId, partitionNumber).readRemotelyFrom(offset); + } + @Override public void postWorkOrder(final WorkOrder workOrder) { - log.info("Got work order for stage [%d]", workOrder.getStageNumber()); - if (task.getWorkerNumber() != workOrder.getWorkerNumber()) { - throw new ISE("Worker number mismatch: expected [%d]", task.getWorkerNumber()); + log.info( + "Got work order for stage[%s], workerNumber[%s]", + workOrder.getStageDefinition().getId(), + workOrder.getWorkerNumber() + ); + + if (task != null && task.getWorkerNumber() != workOrder.getWorkerNumber()) { + throw new ISE( + "Worker number mismatch: expected workerNumber[%d], got[%d]", + task.getWorkerNumber(), + workOrder.getWorkerNumber() + ); + } + + final OutputChannelMode outputChannelMode; + + // This stack of conditions can be removed once we can rely on OutputChannelMode always being in the WorkOrder. + // (It will be there for newer controllers; this is a backwards-compatibility thing.) + if (workOrder.hasOutputChannelMode()) { + outputChannelMode = workOrder.getOutputChannelMode(); + } else { + final MSQSelectDestination selectDestination = + task != null + ? MultiStageQueryContext.getSelectDestination(QueryContext.of(task.getContext())) + : MSQSelectDestination.TASKREPORT; + + outputChannelMode = ControllerQueryKernelUtils.getOutputChannelMode( + workOrder.getQueryDefinition(), + workOrder.getStageNumber(), + selectDestination, + task != null && MultiStageQueryContext.isDurableStorageEnabled(QueryContext.of(task.getContext())), + false + ); } - // Do not add to queue if workerOrder already present. + final WorkOrder workOrderToUse = workOrder.withOutputChannelMode(outputChannelMode); kernelManipulationQueue.add( - kernelHolder -> - kernelHolder.getStageKernelMap().putIfAbsent( - workOrder.getStageDefinition().getId(), - WorkerStageKernel.create(workOrder) - ) + kernelHolders -> + kernelHolders.addKernel(WorkerStageKernel.create(workOrderToUse)) ); } @Override public boolean postResultPartitionBoundaries( - final ClusterByPartitions stagePartitionBoundaries, - final String queryId, - final int stageNumber + final StageId stageId, + final ClusterByPartitions stagePartitionBoundaries ) { - final StageId stageId = new StageId(queryId, stageNumber); - kernelManipulationQueue.add( - kernelHolder -> { - final WorkerStageKernel stageKernel = kernelHolder.getStageKernelMap().get(stageId); + kernelHolders -> { + final WorkerStageKernel stageKernel = kernelHolders.getKernelFor(stageId); if (stageKernel != null) { if (!stageKernel.hasResultPartitionBoundaries()) { stageKernel.setResultPartitionBoundaries(stagePartitionBoundaries); } else { // Ignore if partition boundaries are already set. - log.warn( - "Stage[%s] already has result partition boundaries set. Ignoring the latest partition boundaries recieved.", - stageId - ); + log.warn("Stage[%s] already has result partition boundaries set. Ignoring new ones.", stageId); } - } else { - // Ignore the update if we don't have a kernel for this stage. - log.warn("Ignored result partition boundaries call for unknown stage [%s]", stageId); } } ); @@ -623,167 +643,230 @@ public boolean postResultPartitionBoundaries( @Override public void postCleanupStage(final StageId stageId) { - log.info("Cleanup order for stage [%s] received", stageId); - kernelManipulationQueue.add( - holder -> { - cleanStageOutput(stageId, true); - // Mark the stage as FINISHED - WorkerStageKernel stageKernel = holder.getStageKernelMap().get(stageId); - if (stageKernel == null) { - log.warn("Stage id [%s] non existent. Unable to mark the stage kernel for it as FINISHED", stageId); - } else { - stageKernel.setStageFinished(); - } - } - ); + log.debug("Received cleanup order for stage[%s].", stageId); + kernelManipulationQueue.add(holder -> { + holder.finishProcessing(stageId); + final WorkerStageKernel kernel = holder.getKernelFor(stageId); + if (kernel != null) { + kernel.setStageFinished(); + } + }); } @Override public void postFinish() { - log.info("Finish received for task [%s]", task.getId()); - kernelManipulationQueue.add(KernelHolder::setDone); + log.debug("Received finish call."); + kernelManipulationQueue.add(KernelHolders::setDone); } @Override public ClusterByStatisticsSnapshot fetchStatisticsSnapshot(StageId stageId) { - log.info("Fetching statistics for stage [%d]", stageId.getStageNumber()); - if (stageKernelMap.get(stageId) == null) { - throw new ISE("Requested statistics snapshot for non-existent stageId %s.", stageId); - } else if (stageKernelMap.get(stageId).getResultKeyStatisticsSnapshot() == null) { - throw new ISE( - "Requested statistics snapshot is not generated yet for stageId [%s]", - stageId - ); - } else { - return stageKernelMap.get(stageId).getResultKeyStatisticsSnapshot(); - } + log.debug("Fetching statistics for stage[%s]", stageId); + final SettableFuture snapshotFuture = SettableFuture.create(); + kernelManipulationQueue.add( + holder -> { + try { + final WorkerStageKernel kernel = holder.getKernelFor(stageId); + if (kernel != null) { + final ClusterByStatisticsSnapshot snapshot = kernel.getResultKeyStatisticsSnapshot(); + if (snapshot == null) { + throw new ISE("Requested statistics snapshot is not generated yet for stage [%s]", stageId); + } + + snapshotFuture.set(snapshot); + } else { + snapshotFuture.setException(new ISE("Stage[%s] has terminated", stageId)); + } + } + catch (Throwable t) { + snapshotFuture.setException(t); + } + } + ); + return FutureUtils.getUnchecked(snapshotFuture, true); } @Override public ClusterByStatisticsSnapshot fetchStatisticsSnapshotForTimeChunk(StageId stageId, long timeChunk) { - log.debug( - "Fetching statistics for stage [%d] with time chunk [%d] ", - stageId.getStageNumber(), - timeChunk - ); - if (stageKernelMap.get(stageId) == null) { - throw new ISE("Requested statistics snapshot for non-existent stageId [%s].", stageId); - } else if (stageKernelMap.get(stageId).getResultKeyStatisticsSnapshot() == null) { - throw new ISE( - "Requested statistics snapshot is not generated yet for stageId [%s]", - stageId - ); - } else { - return stageKernelMap.get(stageId) - .getResultKeyStatisticsSnapshot() - .getSnapshotForTimeChunk(timeChunk); - } - + return fetchStatisticsSnapshot(stageId).getSnapshotForTimeChunk(timeChunk); } - @Override public CounterSnapshotsTree getCounters() { final CounterSnapshotsTree retVal = new CounterSnapshotsTree(); - for (final Map.Entry entry : stageCounters.entrySet()) { - retVal.put(entry.getKey().getStageNumber(), task().getWorkerNumber(), entry.getValue().snapshot()); + for (final Map.Entry, CounterTracker> entry : stageCounters.entrySet()) { + retVal.put( + entry.getKey().right().getStageNumber(), + entry.getKey().leftInt(), + entry.getValue().snapshot() + ); } return retVal; } - private InputChannelFactory makeBaseInputChannelFactory(final Closer closer) + /** + * Create a {@link RunWorkOrderListener} for {@link RunWorkOrder} that hooks back into the {@link KernelHolders} + * in the main loop. + */ + private RunWorkOrderListener makeRunWorkOrderListener( + final WorkOrder workOrder, + final ControllerClient controllerClient, + final Set criticalWarningCodes, + final long maxVerboseParseExceptions + ) { - final Supplier> workerTaskList = Suppliers.memoize( - () -> { - try { - return controllerClient.getTaskList(); - } - catch (IOException e) { - throw new RuntimeException(e); - } - } - )::get; + final StageId stageId = workOrder.getStageDefinition().getId(); + final MSQWarningReportPublisher msqWarningReportPublisher = new MSQWarningReportLimiterPublisher( + new MSQWarningReportSimplePublisher( + id(), + controllerClient, + id(), + MSQTasks.getHostFromSelfNode(selfDruidNode) + ), + Limits.MAX_VERBOSE_WARNINGS, + ImmutableMap.of(CannotParseExternalDataFault.CODE, maxVerboseParseExceptions), + criticalWarningCodes, + controllerClient, + id(), + MSQTasks.getHostFromSelfNode(selfDruidNode) + ); - if (durableStageStorageEnabled) { - return DurableStorageInputChannelFactory.createStandardImplementation( - task.getControllerTaskId(), - MSQTasks.makeStorageConnector(context.injector()), - closer, - false - ); - } else { - return new WorkerOrLocalInputChannelFactory(workerTaskList); - } - } + return new RunWorkOrderListener() + { + @Override + public void onDoneReadingInput(@Nullable ClusterByStatisticsSnapshot snapshot) + { + kernelManipulationQueue.add( + holder -> { + final WorkerStageKernel kernel = holder.getKernelFor(stageId); + if (kernel != null) { + kernel.setResultKeyStatisticsSnapshot(snapshot); + } + } + ); + } - private OutputChannelFactory makeStageOutputChannelFactory( - final FrameContext frameContext, - final int stageNumber, - boolean isFinalStage - ) - { - // Use the standard frame size, since we assume this size when computing how much is needed to merge output - // files from different workers. - final int frameSize = frameContext.memoryParameters().getStandardFrameSize(); - - if (durableStageStorageEnabled || (isFinalStage - && MSQSelectDestination.DURABLESTORAGE.equals(selectDestination))) { - return DurableStorageOutputChannelFactory.createStandardImplementation( - task.getControllerTaskId(), - task().getWorkerNumber(), - stageNumber, - task().getId(), - frameSize, - MSQTasks.makeStorageConnector(context.injector()), - context.tempDir(), - (isFinalStage && MSQSelectDestination.DURABLESTORAGE.equals(selectDestination)) - ); - } else { - final File fileChannelDirectory = - new File(context.tempDir(), StringUtils.format("output_stage_%06d", stageNumber)); + @Override + public void onOutputChannelAvailable(OutputChannel channel) + { + ReadableFrameChannel readableChannel = null; - return new FileOutputChannelFactory(fileChannelDirectory, frameSize, null); - } - } + try { + readableChannel = channel.getReadableChannel(); + getOrCreateStageOutputHolder(stageId, channel.getPartitionNumber()) + .setChannel(readableChannel); + } + catch (Exception e) { + if (readableChannel != null) { + try { + readableChannel.close(); + } + catch (Throwable e2) { + e.addSuppressed(e2); + } + } - private OutputChannelFactory makeSuperSorterIntermediateOutputChannelFactory( - final FrameContext frameContext, - final int stageNumber, - final File tmpDir - ) - { - final int frameSize = frameContext.memoryParameters().getLargeFrameSize(); - final File fileChannelDirectory = - new File(tmpDir, StringUtils.format("intermediate_output_stage_%06d", stageNumber)); - final FileOutputChannelFactory fileOutputChannelFactory = - new FileOutputChannelFactory(fileChannelDirectory, frameSize, intermediateSuperSorterLocalStorageTracker); - - if (durableStageStorageEnabled && workerStorageParameters.isIntermediateStorageLimitConfigured()) { - return new ComposingOutputChannelFactory( - ImmutableList.of( - fileOutputChannelFactory, - DurableStorageOutputChannelFactory.createStandardImplementation( - task.getControllerTaskId(), - task().getWorkerNumber(), - stageNumber, - task().getId(), - frameSize, - MSQTasks.makeStorageConnector(context.injector()), - tmpDir, - false - ) - ), - frameSize - ); - } else { - return fileOutputChannelFactory; - } + kernelManipulationQueue.add(holder -> { + throw new RE(e, "Worker completion callback error for stage [%s]", stageId); + }); + } + } + + @Override + public void onSuccess(Object resultObject) + { + kernelManipulationQueue.add( + holder -> { + // Call finishProcessing prior to transitioning to RESULTS_COMPLETE, so the FrameContext is closed + // and resources are released. + holder.finishProcessing(stageId); + + final WorkerStageKernel kernel = holder.getKernelFor(stageId); + if (kernel != null) { + kernel.setResultsComplete(resultObject); + } + } + ); + } + + @Override + public void onWarning(Throwable t) + { + msqWarningReportPublisher.publishException(stageId.getStageNumber(), t); + } + + @Override + public void onFailure(Throwable t) + { + kernelManipulationQueue.add( + holder -> { + final WorkerStageKernel kernel = holder.getKernelFor(stageId); + if (kernel != null) { + kernel.fail(t); + } + } + ); + } + }; + } + + private InputChannelFactory makeBaseInputChannelFactory( + final WorkOrder workOrder, + final ControllerClient controllerClient, + final Closer closer + ) + { + return MetaInputChannelFactory.create( + InputSlices.allStageSlices(workOrder.getInputs()), + workOrder.getOutputChannelMode(), + outputChannelMode -> { + switch (outputChannelMode) { + case MEMORY: + case LOCAL_STORAGE: + final Supplier> workerIds; + + if (workOrder.getWorkerIds() != null) { + workerIds = workOrder::getWorkerIds; + } else { + workerIds = Suppliers.memoize( + () -> { + try { + return controllerClient.getWorkerIds(); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + ); + } + + return new WorkerOrLocalInputChannelFactory( + id(), + workerIds, + new WorkerInputChannelFactory(workerClient, workerIds), + this::getOrCreateStageOutputHolder + ); + + case DURABLE_STORAGE_INTERMEDIATE: + case DURABLE_STORAGE_QUERY_RESULTS: + return DurableStorageInputChannelFactory.createStandardImplementation( + task.getControllerTaskId(), + MSQTasks.makeStorageConnector(context.injector()), + closer, + outputChannelMode == OutputChannelMode.DURABLE_STORAGE_QUERY_RESULTS + ); + + default: + throw DruidException.defensive("No handling for output channel mode[%s]", outputChannelMode); + } + } + ); } /** @@ -846,69 +929,75 @@ public void run() /** * Posts all counters for this worker to the controller. */ - private void postCountersToController() throws IOException + private void postCountersToController(final ControllerClient controllerClient) throws IOException { final CounterSnapshotsTree snapshotsTree = getCounters(); if (controllerAlive && !snapshotsTree.isEmpty()) { - try { - controllerClient.postCounters(id(), snapshotsTree); - } - catch (IOException e) { - if (e.getCause() instanceof ServiceClosedException) { - // Suppress. This can happen if the controller goes away while a postCounters call is in flight. - log.debug(e, "Ignoring failure on postCounters, because controller has gone away."); - } else { - throw e; - } - } + controllerClient.postCounters(id(), snapshotsTree); } } /** - * Cleans up the stage outputs corresponding to the provided stage id. It essentially calls {@code doneReading()} on - * the readable channels corresponding to all the partitions for that stage, and removes it from the {@code stageOutputs} - * map + * Removes and closes all output channels for a stage from {@link #stageOutputs}. */ - private void cleanStageOutput(final StageId stageId, boolean removeDurableStorageFiles) + private void removeStageOutputChannels(final StageId stageId) { // This code is thread-safe because remove() on ConcurrentHashMap will remove and return the removed channel only for // one thread. For the other threads it will return null, therefore we will call doneReading for a channel only once - final ConcurrentHashMap partitionOutputsForStage = stageOutputs.remove(stageId); + final ConcurrentHashMap partitionOutputsForStage = stageOutputs.remove(stageId); // Check for null, this can be the case if this method is called simultaneously from multiple threads. if (partitionOutputsForStage == null) { return; } for (final int partition : partitionOutputsForStage.keySet()) { - final ReadableFrameChannel output = partitionOutputsForStage.remove(partition); - if (output == null) { - continue; + final StageOutputHolder output = partitionOutputsForStage.remove(partition); + if (output != null) { + output.close(); } - output.close(); } + } + + /** + * Remove outputs from durable storage for a particular stage. + */ + private void removeStageDurableStorageOutput(final StageId stageId) + { // One caveat with this approach is that in case of a worker crash, while the MM/Indexer systems will delete their // temp directories where intermediate results were stored, it won't be the case for the external storage. // Therefore, the logic for cleaning the stage output in case of a worker/machine crash has to be external. // We currently take care of this in the controller. - if (durableStageStorageEnabled && removeDurableStorageFiles) { - final String folderName = DurableStorageUtils.getTaskIdOutputsFolderName( - task.getControllerTaskId(), - stageId.getStageNumber(), - task.getWorkerNumber(), - task.getId() - ); - try { - MSQTasks.makeStorageConnector(context.injector()).deleteRecursively(folderName); - } - catch (Exception e) { - // If an error is thrown while cleaning up a file, log it and try to continue with the cleanup - log.warn(e, "Error while cleaning up folder at path " + folderName); - } + final String folderName = DurableStorageUtils.getTaskIdOutputsFolderName( + task.getControllerTaskId(), + stageId.getStageNumber(), + task.getWorkerNumber(), + context.workerId() + ); + try { + MSQTasks.makeStorageConnector(context.injector()).deleteRecursively(folderName); + } + catch (Exception e) { + // If an error is thrown while cleaning up a file, log it and try to continue with the cleanup + log.warn(e, "Error while cleaning up durable storage path[%s].", folderName); } } + private StageOutputHolder getOrCreateStageOutputHolder(final StageId stageId, final int partitionNumber) + { + return stageOutputs.computeIfAbsent(stageId, ignored1 -> new ConcurrentHashMap<>()) + .computeIfAbsent(partitionNumber, ignored -> new StageOutputHolder()); + } + /** - * Called by {@link #stopGracefully()} (task canceled, or containing process shut down) and + * Returns cancellation ID for a particular stage, to be used in {@link FrameProcessorExecutor#cancel(String)}. + */ + private static String cancellationIdFor(final StageId stageId) + { + return stageId.toString(); + } + + /** + * Called by {@link #stop()} (task canceled, or containing process shut down) and * {@link #controllerFailed()}. */ private void doCancel() @@ -935,15 +1024,15 @@ private void doCancel() /** * Log (at DEBUG level) a string explaining the status of all work assigned to this worker. */ - private static void logKernelStatus(final Collection kernels) + private static void logKernelStatus(final Iterable kernels) { if (log.isDebugEnabled()) { log.debug( "Stages: %s", - kernels.stream() - .sorted(Comparator.comparing(k -> k.getStageDefinition().getStageNumber())) - .map(WorkerImpl::makeKernelStageStatusString) - .collect(Collectors.joining("; ")) + StreamSupport.stream(kernels.spliterator(), false) + .sorted(Comparator.comparing(k -> k.getStageDefinition().getStageNumber())) + .map(WorkerImpl::makeKernelStageStatusString) + .collect(Collectors.joining("; ")) ); } } @@ -978,936 +1067,205 @@ private static String makeKernelStageStatusString(final WorkerStageKernel kernel } /** - * An {@link InputChannelFactory} that loads data locally when possible, and otherwise connects directly to other - * workers. Used when durable shuffle storage is off. + * Holds {@link WorkerStageKernel} and {@link Closer}, one per {@link WorkOrder}. Also holds {@link ControllerClient}. + * Only manipulated by the main loop. Other threads that need to manipulate kernels must do so through + * {@link #kernelManipulationQueue}. */ - private class WorkerOrLocalInputChannelFactory implements InputChannelFactory + private static class KernelHolders implements Closeable { - private final Supplier> taskList; - private final WorkerInputChannelFactory workerInputChannelFactory; - - public WorkerOrLocalInputChannelFactory(final Supplier> taskList) - { - this.workerInputChannelFactory = new WorkerInputChannelFactory(workerClient, taskList); - this.taskList = taskList; - } + private final WorkerContext workerContext; + private final ControllerClient controllerClient; - @Override - public ReadableFrameChannel openChannel(StageId stageId, int workerNumber, int partitionNumber) - { - final String taskId = taskList.get().get(workerNumber); - if (taskId.equals(id())) { - final ConcurrentMap partitionOutputsForStage = stageOutputs.get(stageId); - if (partitionOutputsForStage == null) { - throw new ISE("Unable to find outputs for stage [%s]", stageId); - } - - final ReadableFrameChannel myChannel = partitionOutputsForStage.get(partitionNumber); + /** + * Stage number -> kernel holder. + */ + private final Int2ObjectMap holderMap = new Int2ObjectOpenHashMap<>(); - if (myChannel instanceof ReadableFileFrameChannel) { - // Must duplicate the channel to avoid double-closure upon task cleanup. - final FrameFile frameFile = ((ReadableFileFrameChannel) myChannel).newFrameFileReference(); - return new ReadableFileFrameChannel(frameFile); - } else if (myChannel instanceof ReadableNilFrameChannel) { - return myChannel; - } else { - throw new ISE("Output for stage [%s] are stored in an instance of %s which is not " - + "supported", stageId, myChannel.getClass()); - } - } else { - return workerInputChannelFactory.openChannel(stageId, workerNumber, partitionNumber); - } - } - } + private boolean done = false; - /** - * Main worker logic for executing a {@link WorkOrder}. - */ - private class RunWorkOrder - { - private final WorkerStageKernel kernel; - private final InputChannelFactory inputChannelFactory; - private final CounterTracker counterTracker; - private final FrameProcessorExecutor exec; - private final String cancellationId; - private final int parallelism; - private final FrameContext frameContext; - private final MSQWarningReportPublisher warningPublisher; - - private InputSliceReader inputSliceReader; - private OutputChannelFactory workOutputChannelFactory; - private OutputChannelFactory shuffleOutputChannelFactory; - private ResultAndChannels workResultAndOutputChannels; - private SettableFuture stagePartitionBoundariesFuture; - private ListenableFuture shuffleOutputChannelsFuture; - - public RunWorkOrder( - final WorkerStageKernel kernel, - final InputChannelFactory inputChannelFactory, - final CounterTracker counterTracker, - final FrameProcessorExecutor exec, - final String cancellationId, - final int parallelism, - final FrameContext frameContext, - final MSQWarningReportPublisher warningPublisher - ) + private KernelHolders(final WorkerContext workerContext, final ControllerClient controllerClient) { - this.kernel = kernel; - this.inputChannelFactory = inputChannelFactory; - this.counterTracker = counterTracker; - this.exec = exec; - this.cancellationId = cancellationId; - this.parallelism = parallelism; - this.frameContext = frameContext; - this.warningPublisher = warningPublisher; + this.workerContext = workerContext; + this.controllerClient = controllerClient; } - private void start() throws IOException + public static KernelHolders create(final WorkerContext workerContext, final Closer closer) { - final WorkOrder workOrder = kernel.getWorkOrder(); - final StageDefinition stageDef = workOrder.getStageDefinition(); - - final boolean isFinalStage = stageDef.getStageNumber() == workOrder.getQueryDefinition() - .getFinalStageDefinition() - .getStageNumber(); - - makeInputSliceReader(); - makeWorkOutputChannelFactory(isFinalStage); - makeShuffleOutputChannelFactory(isFinalStage); - makeAndRunWorkProcessors(); - - if (stageDef.doesShuffle()) { - makeAndRunShuffleProcessors(); - } else { - // No shuffling: work output _is_ shuffle output. Retain read-only versions to reduce memory footprint. - shuffleOutputChannelsFuture = - Futures.immediateFuture(workResultAndOutputChannels.getOutputChannels().readOnly()); - } - - setUpCompletionCallbacks(isFinalStage); + return closer.register(new KernelHolders(workerContext, closer.register(workerContext.makeControllerClient()))); } /** - * Settable {@link ClusterByPartitions} future for global sort. Necessary because we don't know ahead of time - * what the boundaries will be. The controller decides based on statistics from all workers. Once the controller - * decides, its decision is written to this future, which allows sorting on workers to proceed. + * Add a {@link WorkerStageKernel} to this holder. Also creates a {@link ControllerClient} for the query ID + * if one does not yet exist. Does nothing if a kernel with the same {@link StageId} is already being tracked. */ - @Nullable - public SettableFuture getStagePartitionBoundariesFuture() + public void addKernel(final WorkerStageKernel kernel) { - return stagePartitionBoundariesFuture; - } + final StageId stageId = verifyQueryId(kernel.getWorkOrder().getStageDefinition().getId()); - private void makeInputSliceReader() - { - if (inputSliceReader != null) { - throw new ISE("inputSliceReader already created"); + if (holderMap.putIfAbsent(stageId.getStageNumber(), new KernelHolder(kernel)) != null) { + // Already added. Do nothing. } - - final WorkOrder workOrder = kernel.getWorkOrder(); - final String queryId = workOrder.getQueryDefinition().getQueryId(); - - final InputChannels inputChannels = - new InputChannelsImpl( - workOrder.getQueryDefinition(), - InputSlices.allReadablePartitions(workOrder.getInputs()), - inputChannelFactory, - () -> ArenaMemoryAllocator.createOnHeap(frameContext.memoryParameters().getStandardFrameSize()), - exec, - cancellationId, - MultiStageQueryContext.removeNullBytes(QueryContext.of(task.getContext())) - ); - - inputSliceReader = new MapInputSliceReader( - ImmutableMap., InputSliceReader>builder() - .put(NilInputSlice.class, NilInputSliceReader.INSTANCE) - .put(StageInputSlice.class, new StageInputSliceReader(queryId, inputChannels)) - .put(ExternalInputSlice.class, new ExternalInputSliceReader(frameContext.tempDir())) - .put(InlineInputSlice.class, new InlineInputSliceReader(frameContext.segmentWrangler())) - .put(LookupInputSlice.class, new LookupInputSliceReader(frameContext.segmentWrangler())) - .put( - SegmentsInputSlice.class, - new SegmentsInputSliceReader( - frameContext, - MultiStageQueryContext.isReindex(QueryContext.of(task().getContext())) - ) - ) - .build() - ); - } - - private void makeWorkOutputChannelFactory(boolean isFinalStage) - { - if (workOutputChannelFactory != null) { - throw new ISE("processorOutputChannelFactory already created"); - } - - final OutputChannelFactory baseOutputChannelFactory; - - if (kernel.getStageDefinition().doesShuffle()) { - // Writing to a consumer in the same JVM (which will be set up later on in this method). Use the large frame - // size if we're writing to a SuperSorter, since we'll generate fewer temp files if we use larger frames. - // Otherwise, use the standard frame size. - final int frameSize; - - if (kernel.getStageDefinition().getShuffleSpec().kind().isSort()) { - frameSize = frameContext.memoryParameters().getLargeFrameSize(); - } else { - frameSize = frameContext.memoryParameters().getStandardFrameSize(); - } - - baseOutputChannelFactory = new BlockingQueueOutputChannelFactory(frameSize); - } else { - // Writing stage output. - baseOutputChannelFactory = - makeStageOutputChannelFactory(frameContext, kernel.getStageDefinition().getStageNumber(), isFinalStage); - } - - workOutputChannelFactory = new CountingOutputChannelFactory( - baseOutputChannelFactory, - counterTracker.channel(CounterNames.outputChannel()) - ); - } - - private void makeShuffleOutputChannelFactory(boolean isFinalStage) - { - shuffleOutputChannelFactory = - new CountingOutputChannelFactory( - makeStageOutputChannelFactory(frameContext, kernel.getStageDefinition().getStageNumber(), isFinalStage), - counterTracker.channel(CounterNames.shuffleChannel()) - ); } /** - * Use {@link FrameProcessorFactory#makeProcessors} to create {@link ProcessorsAndChannels}. Executes the - * processors using {@link #exec} and sets the output channels in {@link #workResultAndOutputChannels}. + * Called when processing for a stage is complete. Releases processing resources associated with the stage, i.e., + * those that are part of {@link KernelHolder#processorCloser}. * - * @param type of {@link StageDefinition#getProcessorFactory()} - * @param return type of {@link FrameProcessor} created by the manager - * @param result type of {@link ProcessorManager#result()} - * @param type of {@link WorkOrder#getExtraInfo()} + * Does not release results-fetching resources, i.e., does not release {@link KernelHolder#resultsCloser}. Those + * resources are released on {@link #removeKernel(StageId)} only. */ - private , ProcessorReturnType, ManagerReturnType, ExtraInfoType> void makeAndRunWorkProcessors() - throws IOException - { - if (workResultAndOutputChannels != null) { - throw new ISE("workResultAndOutputChannels already set"); - } - - @SuppressWarnings("unchecked") - final FactoryType processorFactory = (FactoryType) kernel.getStageDefinition().getProcessorFactory(); - - @SuppressWarnings("unchecked") - final ProcessorsAndChannels processors = - processorFactory.makeProcessors( - kernel.getStageDefinition(), - kernel.getWorkOrder().getWorkerNumber(), - kernel.getWorkOrder().getInputs(), - inputSliceReader, - (ExtraInfoType) kernel.getWorkOrder().getExtraInfo(), - workOutputChannelFactory, - frameContext, - parallelism, - counterTracker, - e -> warningPublisher.publishException(kernel.getStageDefinition().getStageNumber(), e), - isRemoveNullBytes - ); - - final ProcessorManager processorManager = processors.getProcessorManager(); - - final int maxOutstandingProcessors; - - if (processors.getOutputChannels().getAllChannels().isEmpty()) { - // No output channels: run up to "parallelism" processors at once. - maxOutstandingProcessors = Math.max(1, parallelism); - } else { - // If there are output channels, that acts as a ceiling on the number of processors that can run at once. - maxOutstandingProcessors = - Math.max(1, Math.min(parallelism, processors.getOutputChannels().getAllChannels().size())); - } - - final ListenableFuture workResultFuture = exec.runAllFully( - processorManager, - maxOutstandingProcessors, - processorBouncer, - cancellationId - ); - - workResultAndOutputChannels = new ResultAndChannels<>(workResultFuture, processors.getOutputChannels()); - } - - private void makeAndRunShuffleProcessors() + public void finishProcessing(final StageId stageId) { - if (shuffleOutputChannelsFuture != null) { - throw new ISE("shuffleOutputChannelsFuture already set"); - } + final KernelHolder kernel = holderMap.get(verifyQueryId(stageId).getStageNumber()); - final ShuffleSpec shuffleSpec = kernel.getWorkOrder().getStageDefinition().getShuffleSpec(); - - final ShufflePipelineBuilder shufflePipeline = new ShufflePipelineBuilder( - kernel, - counterTracker, - exec, - cancellationId, - frameContext - ); - - shufflePipeline.initialize(workResultAndOutputChannels); - - switch (shuffleSpec.kind()) { - case MIX: - shufflePipeline.mix(shuffleOutputChannelFactory); - break; - - case HASH: - shufflePipeline.hashPartition(shuffleOutputChannelFactory); - break; - - case HASH_LOCAL_SORT: - final OutputChannelFactory hashOutputChannelFactory; - - if (shuffleSpec.partitionCount() == 1) { - // Single partition; no need to write temporary files. - hashOutputChannelFactory = - new BlockingQueueOutputChannelFactory(frameContext.memoryParameters().getStandardFrameSize()); - } else { - // Multi-partition; write temporary files and then sort each one file-by-file. - hashOutputChannelFactory = - new FileOutputChannelFactory( - context.tempDir(kernel.getStageDefinition().getStageNumber(), "hash-parts"), - frameContext.memoryParameters().getStandardFrameSize(), - null - ); - } - - shufflePipeline.hashPartition(hashOutputChannelFactory); - shufflePipeline.localSort(shuffleOutputChannelFactory); - break; - - case GLOBAL_SORT: - shufflePipeline.gatherResultKeyStatisticsIfNeeded(); - shufflePipeline.globalSort(shuffleOutputChannelFactory, makeGlobalSortPartitionBoundariesFuture()); - break; - - default: - throw new UOE("Cannot handle shuffle kind [%s]", shuffleSpec.kind()); - } - - shuffleOutputChannelsFuture = shufflePipeline.build(); - } - - private ListenableFuture makeGlobalSortPartitionBoundariesFuture() - { - if (kernel.getStageDefinition().mustGatherResultKeyStatistics()) { - if (stagePartitionBoundariesFuture != null) { - throw new ISE("Cannot call 'makeGlobalSortPartitionBoundariesFuture' twice"); + if (kernel != null) { + try { + kernel.processorCloser.close(); + } + catch (IOException e) { + throw new RuntimeException(e); } - - return (stagePartitionBoundariesFuture = SettableFuture.create()); - } else { - return Futures.immediateFuture(kernel.getResultPartitionBoundaries()); } } - private void setUpCompletionCallbacks(boolean isFinalStage) - { - final StageDefinition stageDef = kernel.getStageDefinition(); - - Futures.addCallback( - Futures.allAsList( - Arrays.asList( - workResultAndOutputChannels.getResultFuture(), - shuffleOutputChannelsFuture - ) - ), - new FutureCallback>() - { - @Override - public void onSuccess(final List workerResultAndOutputChannelsResolved) - { - final Object resultObject = workerResultAndOutputChannelsResolved.get(0); - final OutputChannels outputChannels = (OutputChannels) workerResultAndOutputChannelsResolved.get(1); - - for (OutputChannel channel : outputChannels.getAllChannels()) { - try { - stageOutputs.computeIfAbsent(stageDef.getId(), ignored1 -> new ConcurrentHashMap<>()) - .computeIfAbsent(channel.getPartitionNumber(), ignored2 -> channel.getReadableChannel()); - } - catch (Exception e) { - kernelManipulationQueue.add(holder -> { - throw new RE(e, "Worker completion callback error for stage [%s]", stageDef.getId()); - }); - - // Don't make the "setResultsComplete" call below. - return; - } - } - - // Once the outputs channels have been resolved and are ready for reading, write success file, if - // using durable storage. - writeDurableStorageSuccessFileIfNeeded(stageDef.getStageNumber(), isFinalStage); - - kernelManipulationQueue.add(holder -> holder.getStageKernelMap() - .get(stageDef.getId()) - .setResultsComplete(resultObject)); - } - - @Override - public void onFailure(final Throwable t) - { - kernelManipulationQueue.add( - kernelHolder -> - kernelHolder.getStageKernelMap().get(stageDef.getId()).fail(t) - ); - } - }, - MoreExecutors.directExecutor() - ); - } - /** - * Write {@link DurableStorageUtils#SUCCESS_MARKER_FILENAME} for a particular stage, if durable storage is enabled. + * Remove the {@link WorkerStageKernel} for a given {@link StageId} from this holder. Closes all the associated + * {@link Closeable}. Removes and closes the {@link ControllerClient} for this query ID, if there are no longer + * any active work orders for that query ID + * + * @throws IllegalStateException if there is no active kernel for this stage */ - private void writeDurableStorageSuccessFileIfNeeded(final int stageNumber, boolean isFinalStage) + public void removeKernel(final StageId stageId) { - final DurableStorageOutputChannelFactory durableStorageOutputChannelFactory; - if (durableStageStorageEnabled || (isFinalStage - && MSQSelectDestination.DURABLESTORAGE.equals(selectDestination))) { - durableStorageOutputChannelFactory = DurableStorageOutputChannelFactory.createStandardImplementation( - task.getControllerTaskId(), - task().getWorkerNumber(), - stageNumber, - task().getId(), - frameContext.memoryParameters().getStandardFrameSize(), - MSQTasks.makeStorageConnector(context.injector()), - context.tempDir(), - (isFinalStage && MSQSelectDestination.DURABLESTORAGE.equals(selectDestination)) - ); - } else { - return; + final KernelHolder removed = holderMap.remove(verifyQueryId(stageId).getStageNumber()); + + if (removed == null) { + throw new ISE("No kernel for stage[%s]", stageId); } + try { - durableStorageOutputChannelFactory.createSuccessFile(task.getId()); + removed.processorCloser.close(); + removed.resultsCloser.close(); } catch (IOException e) { - throw new ISE( - e, - "Unable to create the success file [%s] at the location [%s]", - DurableStorageUtils.SUCCESS_MARKER_FILENAME, - durableStorageOutputChannelFactory.getSuccessFilePath() - ); + throw new RuntimeException(e); } } - } - - /** - * Helper for {@link RunWorkOrder#makeAndRunShuffleProcessors()}. Builds a {@link FrameProcessor} pipeline to - * handle the shuffle. - */ - private class ShufflePipelineBuilder - { - private final WorkerStageKernel kernel; - private final CounterTracker counterTracker; - private final FrameProcessorExecutor exec; - private final String cancellationId; - private final FrameContext frameContext; - - // Current state of the pipeline. It's a future to allow pipeline construction to be deferred if necessary. - private ListenableFuture> pipelineFuture; - - public ShufflePipelineBuilder( - final WorkerStageKernel kernel, - final CounterTracker counterTracker, - final FrameProcessorExecutor exec, - final String cancellationId, - final FrameContext frameContext - ) - { - this.kernel = kernel; - this.counterTracker = counterTracker; - this.exec = exec; - this.cancellationId = cancellationId; - this.frameContext = frameContext; - } /** - * Start the pipeline with the outputs of the main processor. + * Returns all currently-active kernel holders. */ - public void initialize(final ResultAndChannels resultAndChannels) + public Iterable getAllKernelHolders() { - if (pipelineFuture != null) { - throw new ISE("already initialized"); - } - - pipelineFuture = Futures.immediateFuture(resultAndChannels); + return holderMap.values(); } /** - * Add {@link FrameChannelMixer}, which mixes all current outputs into a single channel from the provided factory. + * Returns all currently-active kernels. */ - public void mix(final OutputChannelFactory outputChannelFactory) + public Iterable getAllKernels() { - // No sorting or statistics gathering, just combining all outputs into one big partition. Use a mixer to get - // everything into one file. Note: even if there is only one output channel, we'll run it through the mixer - // anyway, to ensure the data gets written to a file. (httpGetChannelData requires files.) - - push( - resultAndChannels -> { - final OutputChannel outputChannel = outputChannelFactory.openChannel(0); - - final FrameChannelMixer mixer = - new FrameChannelMixer( - resultAndChannels.getOutputChannels().getAllReadableChannels(), - outputChannel.getWritableChannel() - ); - - return new ResultAndChannels<>( - exec.runFully(mixer, cancellationId), - OutputChannels.wrap(Collections.singletonList(outputChannel.readOnly())) - ); - } - ); + return Iterables.transform(holderMap.values(), holder -> holder.kernel); } /** - * Add {@link KeyStatisticsCollectionProcessor} if {@link StageDefinition#mustGatherResultKeyStatistics()}. + * Returns the number of kernels that are in running states, where {@link WorkerStagePhase#isRunning()}. */ - public void gatherResultKeyStatisticsIfNeeded() + public int runningKernelCount() { - push( - resultAndChannels -> { - final StageDefinition stageDefinition = kernel.getStageDefinition(); - final OutputChannels channels = resultAndChannels.getOutputChannels(); - - if (channels.getAllChannels().isEmpty()) { - // No data coming out of this processor. Report empty statistics, if the kernel is expecting statistics. - if (stageDefinition.mustGatherResultKeyStatistics()) { - kernelManipulationQueue.add( - holder -> - holder.getStageKernelMap().get(stageDefinition.getId()) - .setResultKeyStatisticsSnapshot(ClusterByStatisticsSnapshot.empty()) - ); - } - - // Generate one empty channel so the SuperSorter has something to do. - final BlockingQueueFrameChannel channel = BlockingQueueFrameChannel.minimal(); - channel.writable().close(); - - final OutputChannel outputChannel = OutputChannel.readOnly( - channel.readable(), - FrameWithPartition.NO_PARTITION - ); - - return new ResultAndChannels<>( - Futures.immediateFuture(null), - OutputChannels.wrap(Collections.singletonList(outputChannel)) - ); - } else if (stageDefinition.mustGatherResultKeyStatistics()) { - return gatherResultKeyStatistics(channels); - } else { - return resultAndChannels; - } - } - ); - } - - /** - * Add a {@link SuperSorter} using {@link StageDefinition#getSortKey()} and partition boundaries - * from {@code partitionBoundariesFuture}. - */ - public void globalSort( - final OutputChannelFactory outputChannelFactory, - final ListenableFuture partitionBoundariesFuture - ) - { - pushAsync( - resultAndChannels -> { - final StageDefinition stageDefinition = kernel.getStageDefinition(); - - final File sorterTmpDir = context.tempDir(stageDefinition.getStageNumber(), "super-sort"); - FileUtils.mkdirp(sorterTmpDir); - if (!sorterTmpDir.isDirectory()) { - throw new IOException("Cannot create directory: " + sorterTmpDir); - } + int retVal = 0; + for (final KernelHolder holder : holderMap.values()) { + if (holder.kernel.getPhase().isRunning()) { + retVal++; + } + } - final WorkerMemoryParameters memoryParameters = frameContext.memoryParameters(); - final SuperSorter sorter = new SuperSorter( - resultAndChannels.getOutputChannels().getAllReadableChannels(), - stageDefinition.getFrameReader(), - stageDefinition.getSortKey(), - partitionBoundariesFuture, - exec, - outputChannelFactory, - makeSuperSorterIntermediateOutputChannelFactory( - frameContext, - stageDefinition.getStageNumber(), - sorterTmpDir - ), - memoryParameters.getSuperSorterMaxActiveProcessors(), - memoryParameters.getSuperSorterMaxChannelsPerProcessor(), - -1, - cancellationId, - counterTracker.sortProgress(), - isRemoveNullBytes - ); - - return FutureUtils.transform( - sorter.run(), - sortedChannels -> new ResultAndChannels<>(Futures.immediateFuture(null), sortedChannels) - ); - } - ); + return retVal; } /** - * Add a {@link FrameChannelHashPartitioner} using {@link StageDefinition#getSortKey()}. + * Return the kernel for a particular {@link StageId}. + * + * @return kernel, or null if there is no active kernel for this stage */ - public void hashPartition(final OutputChannelFactory outputChannelFactory) + @Nullable + public WorkerStageKernel getKernelFor(final StageId stageId) { - pushAsync( - resultAndChannels -> { - final ShuffleSpec shuffleSpec = kernel.getStageDefinition().getShuffleSpec(); - final int partitions = shuffleSpec.partitionCount(); - - final List outputChannels = new ArrayList<>(); - - for (int i = 0; i < partitions; i++) { - outputChannels.add(outputChannelFactory.openChannel(i)); - } - - final FrameChannelHashPartitioner partitioner = new FrameChannelHashPartitioner( - resultAndChannels.getOutputChannels().getAllReadableChannels(), - outputChannels.stream().map(OutputChannel::getWritableChannel).collect(Collectors.toList()), - kernel.getStageDefinition().getFrameReader(), - kernel.getStageDefinition().getClusterBy().getColumns().size(), - FrameWriters.makeRowBasedFrameWriterFactory( - new ArenaMemoryAllocatorFactory(frameContext.memoryParameters().getStandardFrameSize()), - kernel.getStageDefinition().getSignature(), - kernel.getStageDefinition().getSortKey(), - isRemoveNullBytes - ) - ); - - final ListenableFuture partitionerFuture = exec.runFully(partitioner, cancellationId); - - final ResultAndChannels retVal = - new ResultAndChannels<>(partitionerFuture, OutputChannels.wrap(outputChannels)); - - if (retVal.getOutputChannels().areReadableChannelsReady()) { - return Futures.immediateFuture(retVal); - } else { - return FutureUtils.transform(partitionerFuture, ignored -> retVal); - } - } - ); + final KernelHolder holder = holderMap.get(verifyQueryId(stageId).getStageNumber()); + if (holder != null) { + return holder.kernel; + } else { + return null; + } } /** - * Add a sequence of {@link SuperSorter}, operating on each current output channel in order, one at a time. + * Retrieves the {@link ControllerClient}, which is shared across all {@link WorkOrder} for this worker. */ - public void localSort(final OutputChannelFactory outputChannelFactory) + public ControllerClient getControllerClient() { - pushAsync( - resultAndChannels -> { - final StageDefinition stageDefinition = kernel.getStageDefinition(); - final OutputChannels channels = resultAndChannels.getOutputChannels(); - final List> sortedChannelFutures = new ArrayList<>(); - - ListenableFuture nextFuture = Futures.immediateFuture(null); - - for (final OutputChannel channel : channels.getAllChannels()) { - final File sorterTmpDir = context.tempDir( - stageDefinition.getStageNumber(), - StringUtils.format("hash-parts-super-sort-%06d", channel.getPartitionNumber()) - ); - - FileUtils.mkdirp(sorterTmpDir); - - // SuperSorter will try to write to output partition zero; we remap it to the correct partition number. - final OutputChannelFactory partitionOverrideOutputChannelFactory = new OutputChannelFactory() - { - @Override - public OutputChannel openChannel(int expectedZero) throws IOException - { - if (expectedZero != 0) { - throw new ISE("Unexpected part [%s]", expectedZero); - } - - return outputChannelFactory.openChannel(channel.getPartitionNumber()); - } - - @Override - public PartitionedOutputChannel openPartitionedChannel(String name, boolean deleteAfterRead) - { - throw new UnsupportedOperationException(); - } - - @Override - public OutputChannel openNilChannel(int expectedZero) - { - if (expectedZero != 0) { - throw new ISE("Unexpected part [%s]", expectedZero); - } - - return outputChannelFactory.openNilChannel(channel.getPartitionNumber()); - } - }; - - // Chain futures so we only sort one partition at a time. - nextFuture = Futures.transformAsync( - nextFuture, - (AsyncFunction) ignored -> { - final SuperSorter sorter = new SuperSorter( - Collections.singletonList(channel.getReadableChannel()), - stageDefinition.getFrameReader(), - stageDefinition.getSortKey(), - Futures.immediateFuture(ClusterByPartitions.oneUniversalPartition()), - exec, - partitionOverrideOutputChannelFactory, - makeSuperSorterIntermediateOutputChannelFactory( - frameContext, - stageDefinition.getStageNumber(), - sorterTmpDir - ), - 1, - 2, - -1, - cancellationId, - - // Tracker is not actually tracked, since it doesn't quite fit into the way we report counters. - // There's a single SuperSorterProgressTrackerCounter per worker, but workers that do local - // sorting have a SuperSorter per partition. - new SuperSorterProgressTracker(), - isRemoveNullBytes - ); - - return FutureUtils.transform(sorter.run(), r -> Iterables.getOnlyElement(r.getAllChannels())); - }, - MoreExecutors.directExecutor() - ); - - sortedChannelFutures.add(nextFuture); - } - - return FutureUtils.transform( - Futures.allAsList(sortedChannelFutures), - sortedChannels -> new ResultAndChannels<>( - Futures.immediateFuture(null), - OutputChannels.wrap(sortedChannels) - ) - ); - } - ); + return controllerClient; } /** - * Return the (future) output channels for this pipeline. + * Remove all {@link WorkerStageKernel} and close all {@link ControllerClient}. */ - public ListenableFuture build() + @Override + public void close() { - if (pipelineFuture == null) { - throw new ISE("Not initialized"); - } + for (final int stageNumber : ImmutableList.copyOf(holderMap.keySet())) { + final StageId stageId = new StageId(workerContext.queryId(), stageNumber); - return Futures.transformAsync( - pipelineFuture, - (AsyncFunction, OutputChannels>) resultAndChannels -> - Futures.transform( - resultAndChannels.getResultFuture(), - (Function) input -> { - sanityCheckOutputChannels(resultAndChannels.getOutputChannels()); - return resultAndChannels.getOutputChannels(); - }, - MoreExecutors.directExecutor() - ), - MoreExecutors.directExecutor() - ); - } - - /** - * Adds {@link KeyStatisticsCollectionProcessor}. Called by {@link #gatherResultKeyStatisticsIfNeeded()}. - */ - private ResultAndChannels gatherResultKeyStatistics(final OutputChannels channels) - { - final StageDefinition stageDefinition = kernel.getStageDefinition(); - final List retVal = new ArrayList<>(); - final List processors = new ArrayList<>(); - - for (final OutputChannel outputChannel : channels.getAllChannels()) { - final BlockingQueueFrameChannel channel = BlockingQueueFrameChannel.minimal(); - retVal.add(OutputChannel.readOnly(channel.readable(), outputChannel.getPartitionNumber())); - - processors.add( - new KeyStatisticsCollectionProcessor( - outputChannel.getReadableChannel(), - channel.writable(), - stageDefinition.getFrameReader(), - stageDefinition.getClusterBy(), - stageDefinition.createResultKeyStatisticsCollector( - frameContext.memoryParameters().getPartitionStatisticsMaxRetainedBytes() - ) - ) - ); + try { + removeKernel(stageId); + } + catch (Exception e) { + log.warn(e, "Failed to remove kernel for stage[%s].", stageId); + } } - - final ListenableFuture clusterByStatisticsCollectorFuture = - exec.runAllFully( - ProcessorManagers.of(processors) - .withAccumulation( - stageDefinition.createResultKeyStatisticsCollector( - frameContext.memoryParameters().getPartitionStatisticsMaxRetainedBytes() - ), - ClusterByStatisticsCollector::addAll - ), - // Run all processors simultaneously. They are lightweight and this keeps things moving. - processors.size(), - Bouncer.unlimited(), - cancellationId - ); - - Futures.addCallback( - clusterByStatisticsCollectorFuture, - new FutureCallback() - { - @Override - public void onSuccess(final ClusterByStatisticsCollector result) - { - result.logSketches(); - kernelManipulationQueue.add( - holder -> - holder.getStageKernelMap().get(stageDefinition.getId()) - .setResultKeyStatisticsSnapshot(result.snapshot()) - ); - } - - @Override - public void onFailure(Throwable t) - { - kernelManipulationQueue.add( - holder -> { - log.noStackTrace() - .warn(t, "Failed to gather clusterBy statistics for stage [%s]", stageDefinition.getId()); - holder.getStageKernelMap().get(stageDefinition.getId()).fail(t); - } - ); - } - }, - MoreExecutors.directExecutor() - ); - - return new ResultAndChannels<>( - clusterByStatisticsCollectorFuture, - OutputChannels.wrap(retVal) - ); } /** - * Update the {@link #pipelineFuture}. + * Check whether {@link #setDone()} has been called. */ - private void push(final ExceptionalFunction, ResultAndChannels> fn) + public boolean isDone() { - pushAsync( - channels -> - Futures.immediateFuture(fn.apply(channels)) - ); + return done; } /** - * Update the {@link #pipelineFuture} asynchronously. + * Mark the holder as "done", signaling to the main loop that it should clean up and exit as soon as possible. */ - private void pushAsync(final ExceptionalFunction, ListenableFuture>> fn) + public void setDone() { - if (pipelineFuture == null) { - throw new ISE("Not initialized"); - } - - pipelineFuture = FutureUtils.transform( - Futures.transformAsync( - pipelineFuture, - new AsyncFunction, ResultAndChannels>() - { - @Override - public ListenableFuture> apply(ResultAndChannels t) throws Exception - { - return fn.apply(t); - } - }, - MoreExecutors.directExecutor() - ), - resultAndChannels -> new ResultAndChannels<>( - resultAndChannels.getResultFuture(), - resultAndChannels.getOutputChannels().readOnly() - ) - ); + this.done = true; } - /** - * Verifies there is exactly one channel per partition. - */ - private void sanityCheckOutputChannels(final OutputChannels outputChannels) + private StageId verifyQueryId(final StageId stageId) { - for (int partitionNumber : outputChannels.getPartitionNumbers()) { - final List outputChannelsForPartition = - outputChannels.getChannelsForPartition(partitionNumber); - - Preconditions.checkState(partitionNumber >= 0, "Expected partitionNumber >= 0, but got [%s]", partitionNumber); - Preconditions.checkState( - outputChannelsForPartition.size() == 1, - "Expected one channel for partition [%s], but got [%s]", - partitionNumber, - outputChannelsForPartition.size() - ); + if (!stageId.getQueryId().equals(workerContext.queryId())) { + throw new ISE("Unexpected queryId[%s], expected queryId[%s]", stageId.getQueryId(), workerContext.queryId()); } - } - } - - private class KernelHolder - { - private boolean done = false; - - public Map getStageKernelMap() - { - return stageKernelMap; - } - - public boolean isDone() - { - return done; - } - public void setDone() - { - this.done = true; + return stageId; } } - private static class ResultAndChannels + /** + * Holder for a single {@link WorkerStageKernel} and associated items, contained within {@link KernelHolders}. + */ + private static class KernelHolder { - private final ListenableFuture resultFuture; - private final OutputChannels outputChannels; - - public ResultAndChannels( - ListenableFuture resultFuture, - OutputChannels outputChannels - ) - { - this.resultFuture = resultFuture; - this.outputChannels = outputChannels; - } - - public ListenableFuture getResultFuture() - { - return resultFuture; - } + private final WorkerStageKernel kernel; + private final Closer processorCloser; + private final Closer resultsCloser; + private SettableFuture partitionBoundariesFuture; - public OutputChannels getOutputChannels() + public KernelHolder(WorkerStageKernel kernel) { - return outputChannels; + this.kernel = kernel; + this.processorCloser = Closer.create(); + this.resultsCloser = Closer.create(); } } - - private interface ExceptionalFunction - { - R apply(T t) throws Exception; - } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerMemoryParameters.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerMemoryParameters.java index b36b1b4155a83..aeaae030e613e 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerMemoryParameters.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerMemoryParameters.java @@ -168,29 +168,14 @@ public class WorkerMemoryParameters this.partitionStatisticsMaxRetainedBytes = partitionStatisticsMaxRetainedBytes; } - /** - * Create a production instance for {@link org.apache.druid.msq.indexing.MSQControllerTask}. - */ - public static WorkerMemoryParameters createProductionInstanceForController(final Injector injector) - { - long totalLookupFootprint = computeTotalLookupFootprint(injector); - return createInstance( - Runtime.getRuntime().maxMemory(), - computeNumWorkersInJvm(injector), - computeNumProcessorsInJvm(injector), - 0, - 0, - totalLookupFootprint - ); - } - /** * Create a production instance for {@link org.apache.druid.msq.indexing.MSQWorkerTask}. */ public static WorkerMemoryParameters createProductionInstanceForWorker( final Injector injector, final QueryDefinition queryDef, - final int stageNumber + final int stageNumber, + final int maxConcurrentStages ) { final StageDefinition stageDef = queryDef.getStageDefinition(stageNumber); @@ -212,6 +197,7 @@ public static WorkerMemoryParameters createProductionInstanceForWorker( Runtime.getRuntime().maxMemory(), computeNumWorkersInJvm(injector), computeNumProcessorsInJvm(injector), + maxConcurrentStages, numInputWorkers, numHashOutputPartitions, totalLookupFootprint @@ -228,6 +214,7 @@ public static WorkerMemoryParameters createProductionInstanceForWorker( * @param numWorkersInJvm number of workers that can run concurrently in this JVM. Generally equal to * the task capacity. * @param numProcessingThreadsInJvm size of the processing thread pool in the JVM. + * @param maxConcurrentStages maximum number of concurrent stages per worker. * @param numInputWorkers total number of workers across all input stages. * @param numHashOutputPartitions total number of output partitions, if using hash partitioning; zero if not using * hash partitioning. @@ -237,6 +224,7 @@ public static WorkerMemoryParameters createInstance( final long maxMemoryInJvm, final int numWorkersInJvm, final int numProcessingThreadsInJvm, + final int maxConcurrentStages, final int numInputWorkers, final int numHashOutputPartitions, final long totalLookupFootprint @@ -257,7 +245,8 @@ public static WorkerMemoryParameters createInstance( ); final long usableMemoryInJvm = computeUsableMemoryInJvm(maxMemoryInJvm, totalLookupFootprint); final long workerMemory = memoryPerWorker(usableMemoryInJvm, numWorkersInJvm); - final long bundleMemory = memoryPerBundle(usableMemoryInJvm, numWorkersInJvm, numProcessingThreadsInJvm); + final long bundleMemory = + memoryPerBundle(usableMemoryInJvm, numWorkersInJvm, numProcessingThreadsInJvm) / maxConcurrentStages; final long bundleMemoryForInputChannels = memoryNeededForInputChannels(numInputWorkers); final long bundleMemoryForHashPartitioning = memoryNeededForHashPartitioning(numHashOutputPartitions); final long bundleMemoryForProcessing = @@ -268,6 +257,7 @@ public static WorkerMemoryParameters createInstance( usableMemoryInJvm, numWorkersInJvm, numProcessingThreadsInJvm, + maxConcurrentStages, numHashOutputPartitions ); @@ -281,12 +271,14 @@ public static WorkerMemoryParameters createInstance( estimateUsableMemory( numWorkersInJvm, numProcessingThreadsInJvm, - PROCESSING_MINIMUM_BYTES + BUFFER_BYTES_FOR_ESTIMATION + bundleMemoryForInputChannels + PROCESSING_MINIMUM_BYTES + BUFFER_BYTES_FOR_ESTIMATION + bundleMemoryForInputChannels, + maxConcurrentStages ), totalLookupFootprint), maxMemoryInJvm, usableMemoryInJvm, numWorkersInJvm, - numProcessingThreadsInJvm + numProcessingThreadsInJvm, + maxConcurrentStages ) ); } @@ -301,14 +293,16 @@ public static WorkerMemoryParameters createInstance( calculateSuggestedMinMemoryFromUsableMemory( estimateUsableMemory( numWorkersInJvm, - (MIN_SUPER_SORTER_FRAMES + BUFFER_BYTES_FOR_ESTIMATION) * LARGE_FRAME_SIZE + (MIN_SUPER_SORTER_FRAMES + BUFFER_BYTES_FOR_ESTIMATION) * LARGE_FRAME_SIZE, + maxConcurrentStages ), totalLookupFootprint ), maxMemoryInJvm, usableMemoryInJvm, numWorkersInJvm, - numProcessingThreadsInJvm + numProcessingThreadsInJvm, + maxConcurrentStages ) ); } @@ -338,12 +332,14 @@ public static WorkerMemoryParameters createInstance( estimateUsableMemory( numWorkersInJvm, numProcessingThreadsInJvm, - PROCESSING_MINIMUM_BYTES + BUFFER_BYTES_FOR_ESTIMATION + bundleMemoryForInputChannels + PROCESSING_MINIMUM_BYTES + BUFFER_BYTES_FOR_ESTIMATION + bundleMemoryForInputChannels, + maxConcurrentStages ), totalLookupFootprint), maxMemoryInJvm, usableMemoryInJvm, numWorkersInJvm, - numProcessingThreadsInJvm + numProcessingThreadsInJvm, + maxConcurrentStages ) ); } @@ -352,7 +348,9 @@ public static WorkerMemoryParameters createInstance( bundleMemoryForProcessing, superSorterMaxActiveProcessors, superSorterMaxChannelsPerProcessor, - Ints.checkedCast(workerMemory) // 100% of worker memory is devoted to partition statistics + + // 100% of worker memory is devoted to partition statistics + Ints.checkedCast(workerMemory / maxConcurrentStages) ); } @@ -459,18 +457,19 @@ static int computeMaxWorkers( final long usableMemoryInJvm, final int numWorkersInJvm, final int numProcessingThreadsInJvm, + final int maxConcurrentStages, final int numHashOutputPartitions ) { final long bundleMemory = memoryPerBundle(usableMemoryInJvm, numWorkersInJvm, numProcessingThreadsInJvm); - // Compute number of workers that gives us PROCESSING_MINIMUM_BYTES of memory per bundle, while accounting for - // memoryNeededForInputChannels + memoryNeededForHashPartitioning. + // Compute number of workers that gives us PROCESSING_MINIMUM_BYTES of memory per bundle per concurrent stage, while + // accounting for memoryNeededForInputChannels + memoryNeededForHashPartitioning. final int isHashing = numHashOutputPartitions > 0 ? 1 : 0; - return Math.max( - 0, - Ints.checkedCast((bundleMemory - PROCESSING_MINIMUM_BYTES) / ((long) STANDARD_FRAME_SIZE * (1 + isHashing)) - 1) - ); + final long bundleMemoryPerStage = bundleMemory / maxConcurrentStages; + final long maxWorkers = + (bundleMemoryPerStage - PROCESSING_MINIMUM_BYTES) / ((long) STANDARD_FRAME_SIZE * (1 + isHashing)) - 1; + return Math.max(0, Ints.checkedCast(maxWorkers)); } /** @@ -528,7 +527,8 @@ private static long memoryPerWorker( } /** - * Compute the memory allocated to each processing bundle. Any computation changes done to this method should also be done in its corresponding method {@link WorkerMemoryParameters#estimateUsableMemory(int, int, long)} + * Compute the memory allocated to each processing bundle. Any computation changes done to this method should also be + * done in its corresponding method {@link WorkerMemoryParameters#estimateUsableMemory} */ private static long memoryPerBundle( final long usableMemoryInJvm, @@ -536,6 +536,8 @@ private static long memoryPerBundle( final int numProcessingThreadsInJvm ) { + // One bundle per worker + one per processor. The worker bundles are used for sorting (SuperSorter) and the + // processing bundles are used for reading input and doing per-partition processing. final int bundleCount = numWorkersInJvm + numProcessingThreadsInJvm; // Need to subtract memoryForWorkers off the top of usableMemoryInJvm, since this is reserved for @@ -553,24 +555,28 @@ private static long memoryPerBundle( private static long estimateUsableMemory( final int numWorkersInJvm, final int numProcessingThreadsInJvm, - final long estimatedEachBundleMemory + final long estimatedEachBundleMemory, + final int maxConcurrentStages ) { final int bundleCount = numWorkersInJvm + numProcessingThreadsInJvm; - return estimateUsableMemory(numWorkersInJvm, estimatedEachBundleMemory * bundleCount); - + return estimateUsableMemory(numWorkersInJvm, estimatedEachBundleMemory * bundleCount, maxConcurrentStages); } /** * Add overheads to the estimated bundle memoery for all the workers. Checkout {@link WorkerMemoryParameters#memoryPerWorker(long, int)} * for the overhead calculation outside the processing bundles. */ - private static long estimateUsableMemory(final int numWorkersInJvm, final long estimatedTotalBundleMemory) + private static long estimateUsableMemory( + final int numWorkersInJvm, + final long estimatedTotalBundleMemory, + final int maxConcurrentStages + ) { - // Currently, we only add the partition stats overhead since it will be the single largest overhead per worker. final long estimateStatOverHeadPerWorker = PARTITION_STATS_MEMORY_MAX_BYTES; - return estimatedTotalBundleMemory + (estimateStatOverHeadPerWorker * numWorkersInJvm); + final long requiredUsableMemory = estimatedTotalBundleMemory + (estimateStatOverHeadPerWorker * numWorkersInJvm); + return requiredUsableMemory * maxConcurrentStages; } private static long memoryNeededForHashPartitioning(final int numOutputPartitions) diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerStorageParameters.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerStorageParameters.java index 59576ec90bfba..53e12dd2ab4d8 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerStorageParameters.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerStorageParameters.java @@ -70,11 +70,13 @@ private WorkerStorageParameters(final long intermediateSuperSorterStorageMaxLoca public static WorkerStorageParameters createProductionInstance( final Injector injector, - final boolean isIntermediateSuperSorterStorageEnabled + final OutputChannelMode outputChannelMode ) { long tmpStorageBytesPerTask = injector.getInstance(TaskConfig.class).getTmpStorageBytesPerTask(); - return createInstance(tmpStorageBytesPerTask, isIntermediateSuperSorterStorageEnabled); + + // If durable storage is enabled, then super sorter intermediate storage should be enabled as well. + return createInstance(tmpStorageBytesPerTask, outputChannelMode.isDurable()); } @VisibleForTesting diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerFrameContext.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerFrameContext.java index e0de5bdc27e25..fb6e4a0079f1e 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerFrameContext.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerFrameContext.java @@ -20,9 +20,13 @@ package org.apache.druid.msq.indexing; import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.druid.frame.processor.Bouncer; +import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.msq.exec.DataServerQueryHandlerFactory; import org.apache.druid.msq.exec.WorkerMemoryParameters; +import org.apache.druid.msq.exec.WorkerStorageParameters; import org.apache.druid.msq.kernel.FrameContext; +import org.apache.druid.msq.kernel.StageId; import org.apache.druid.msq.querykit.DataSegmentProvider; import org.apache.druid.query.groupby.GroupingEngine; import org.apache.druid.segment.IndexIO; @@ -35,25 +39,31 @@ public class IndexerFrameContext implements FrameContext { + private final StageId stageId; private final IndexerWorkerContext context; private final IndexIO indexIO; private final DataSegmentProvider dataSegmentProvider; private final WorkerMemoryParameters memoryParameters; + private final WorkerStorageParameters storageParameters; private final DataServerQueryHandlerFactory dataServerQueryHandlerFactory; public IndexerFrameContext( + StageId stageId, IndexerWorkerContext context, IndexIO indexIO, DataSegmentProvider dataSegmentProvider, DataServerQueryHandlerFactory dataServerQueryHandlerFactory, - WorkerMemoryParameters memoryParameters + WorkerMemoryParameters memoryParameters, + WorkerStorageParameters storageParameters ) { + this.stageId = stageId; this.context = context; this.indexIO = indexIO; this.dataSegmentProvider = dataSegmentProvider; - this.dataServerQueryHandlerFactory = dataServerQueryHandlerFactory; this.memoryParameters = memoryParameters; + this.storageParameters = storageParameters; + this.dataServerQueryHandlerFactory = dataServerQueryHandlerFactory; } @Override @@ -90,7 +100,8 @@ public DataServerQueryHandlerFactory dataServerQueryHandlerFactory() @Override public File tempDir() { - return context.tempDir(); + // No need to include query ID; each task handles a single query, so there is no ambiguity. + return new File(context.tempDir(), StringUtils.format("stage_%06d", stageId.getStageNumber())); } @Override @@ -128,4 +139,22 @@ public WorkerMemoryParameters memoryParameters() { return memoryParameters; } + + @Override + public Bouncer processorBouncer() + { + return context.injector().getInstance(Bouncer.class); + } + + @Override + public WorkerStorageParameters storageParameters() + { + return storageParameters; + } + + @Override + public void close() + { + // Nothing to close. + } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerResourcePermissionMapper.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerResourcePermissionMapper.java index 30bc75282fa46..2dedaf204ec7c 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerResourcePermissionMapper.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerResourcePermissionMapper.java @@ -52,4 +52,10 @@ public List getAdminPermissions() ) ); } + + @Override + public List getQueryPermissions(String queryId) + { + return getAdminPermissions(); + } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerWorkerContext.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerWorkerContext.java index 1bd789df76905..63358467489be 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerWorkerContext.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerWorkerContext.java @@ -24,9 +24,7 @@ import com.google.errorprone.annotations.concurrent.GuardedBy; import com.google.inject.Injector; import com.google.inject.Key; -import org.apache.druid.frame.processor.Bouncer; import org.apache.druid.guice.annotations.EscalatedGlobal; -import org.apache.druid.guice.annotations.Self; import org.apache.druid.guice.annotations.Smile; import org.apache.druid.indexing.common.SegmentCacheManagerFactory; import org.apache.druid.indexing.common.TaskToolbox; @@ -35,16 +33,21 @@ import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.msq.exec.ControllerClient; import org.apache.druid.msq.exec.DataServerQueryHandlerFactory; +import org.apache.druid.msq.exec.MemoryIntrospector; +import org.apache.druid.msq.exec.OutputChannelMode; import org.apache.druid.msq.exec.TaskDataSegmentProvider; import org.apache.druid.msq.exec.Worker; import org.apache.druid.msq.exec.WorkerClient; import org.apache.druid.msq.exec.WorkerContext; import org.apache.druid.msq.exec.WorkerMemoryParameters; +import org.apache.druid.msq.exec.WorkerStorageParameters; import org.apache.druid.msq.indexing.client.IndexerControllerClient; import org.apache.druid.msq.indexing.client.IndexerWorkerClient; import org.apache.druid.msq.indexing.client.WorkerChatHandler; import org.apache.druid.msq.kernel.FrameContext; import org.apache.druid.msq.kernel.QueryDefinition; +import org.apache.druid.msq.util.MultiStageQueryContext; +import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryToolChestWarehouse; import org.apache.druid.rpc.ServiceClientFactory; import org.apache.druid.rpc.ServiceLocations; @@ -67,37 +70,49 @@ public class IndexerWorkerContext implements WorkerContext private static final long FREQUENCY_CHECK_MILLIS = 1000; private static final long FREQUENCY_CHECK_JITTER = 30; + private final MSQWorkerTask task; private final TaskToolbox toolbox; private final Injector injector; + private final OverlordClient overlordClient; private final IndexIO indexIO; private final TaskDataSegmentProvider dataSegmentProvider; private final DataServerQueryHandlerFactory dataServerQueryHandlerFactory; private final ServiceClientFactory clientFactory; - - @GuardedBy("this") - private OverlordClient overlordClient; + private final MemoryIntrospector memoryIntrospector; + private final int maxConcurrentStages; @GuardedBy("this") private ServiceLocator controllerLocator; public IndexerWorkerContext( + final MSQWorkerTask task, final TaskToolbox toolbox, final Injector injector, + final OverlordClient overlordClient, final IndexIO indexIO, final TaskDataSegmentProvider dataSegmentProvider, - final DataServerQueryHandlerFactory dataServerQueryHandlerFactory, - final ServiceClientFactory clientFactory + final ServiceClientFactory clientFactory, + final MemoryIntrospector memoryIntrospector, + final DataServerQueryHandlerFactory dataServerQueryHandlerFactory ) { + this.task = task; this.toolbox = toolbox; this.injector = injector; + this.overlordClient = overlordClient; this.indexIO = indexIO; this.dataSegmentProvider = dataSegmentProvider; - this.dataServerQueryHandlerFactory = dataServerQueryHandlerFactory; this.clientFactory = clientFactory; + this.memoryIntrospector = memoryIntrospector; + this.dataServerQueryHandlerFactory = dataServerQueryHandlerFactory; + this.maxConcurrentStages = MultiStageQueryContext.getMaxConcurrentStages(QueryContext.of(task.getContext())); } - public static IndexerWorkerContext createProductionInstance(final TaskToolbox toolbox, final Injector injector) + public static IndexerWorkerContext createProductionInstance( + final MSQWorkerTask task, + final TaskToolbox toolbox, + final Injector injector + ) { final IndexIO indexIO = injector.getInstance(IndexIO.class); final SegmentCacheManager segmentCacheManager = @@ -105,28 +120,42 @@ public static IndexerWorkerContext createProductionInstance(final TaskToolbox to .manufacturate(new File(toolbox.getIndexingTmpDir(), "segment-fetch")); final ServiceClientFactory serviceClientFactory = injector.getInstance(Key.get(ServiceClientFactory.class, EscalatedGlobal.class)); + final MemoryIntrospector memoryIntrospector = injector.getInstance(MemoryIntrospector.class); + final OverlordClient overlordClient = + injector.getInstance(OverlordClient.class).withRetryPolicy(StandardRetryPolicy.unlimited()); final ObjectMapper smileMapper = injector.getInstance(Key.get(ObjectMapper.class, Smile.class)); final QueryToolChestWarehouse warehouse = injector.getInstance(QueryToolChestWarehouse.class); return new IndexerWorkerContext( + task, toolbox, injector, + overlordClient, indexIO, - new TaskDataSegmentProvider( - toolbox.getCoordinatorClient(), - segmentCacheManager, - indexIO - ), + new TaskDataSegmentProvider(toolbox.getCoordinatorClient(), segmentCacheManager, indexIO), + serviceClientFactory, + memoryIntrospector, new DataServerQueryHandlerFactory( toolbox.getCoordinatorClient(), serviceClientFactory, smileMapper, warehouse - ), - serviceClientFactory + ) ); } + @Override + public String queryId() + { + return task.getControllerTaskId(); + } + + @Override + public String workerId() + { + return task.getId(); + } + public TaskToolbox toolbox() { return toolbox; @@ -147,7 +176,8 @@ public Injector injector() @Override public void registerWorker(Worker worker, Closer closer) { - WorkerChatHandler chatHandler = new WorkerChatHandler(toolbox, worker); + final WorkerChatHandler chatHandler = + new WorkerChatHandler(worker, toolbox.getAuthorizerMapper(), task.getDataSource()); toolbox.getChatHandlerProvider().register(worker.id(), chatHandler, false); closer.register(() -> toolbox.getChatHandlerProvider().unregister(worker.id())); closer.register(() -> { @@ -161,7 +191,7 @@ public void registerWorker(Worker worker, Closer closer) // Register the periodic controller checker final ExecutorService periodicControllerCheckerExec = Execs.singleThreaded("controller-status-checker-%s"); closer.register(periodicControllerCheckerExec::shutdownNow); - final ServiceLocator controllerLocator = makeControllerLocator(worker.task().getControllerTaskId()); + final ServiceLocator controllerLocator = makeControllerLocator(task.getControllerTaskId()); periodicControllerCheckerExec.submit(() -> controllerCheckerRunnable(controllerLocator, worker)); } @@ -218,15 +248,21 @@ public File tempDir() } @Override - public ControllerClient makeControllerClient(String controllerId) + public int maxConcurrentStages() + { + return maxConcurrentStages; + } + + @Override + public ControllerClient makeControllerClient() { - final ServiceLocator locator = makeControllerLocator(controllerId); + final ServiceLocator locator = makeControllerLocator(task.getControllerTaskId()); return new IndexerControllerClient( clientFactory.makeClient( - controllerId, + task.getControllerTaskId(), locator, - new SpecificTaskRetryPolicy(controllerId, StandardRetryPolicy.unlimited()) + new SpecificTaskRetryPolicy(task.getControllerTaskId(), StandardRetryPolicy.unlimited()) ), jsonMapper(), locator @@ -237,37 +273,33 @@ public ControllerClient makeControllerClient(String controllerId) public WorkerClient makeWorkerClient() { // Ignore workerId parameter. The workerId is passed into each method of WorkerClient individually. - return new IndexerWorkerClient(clientFactory, makeOverlordClient(), jsonMapper()); + return new IndexerWorkerClient(clientFactory, overlordClient, jsonMapper()); } @Override - public FrameContext frameContext(QueryDefinition queryDef, int stageNumber) + public FrameContext frameContext(QueryDefinition queryDef, int stageNumber, OutputChannelMode outputChannelMode) { return new IndexerFrameContext( + queryDef.getStageDefinition(stageNumber).getId(), this, indexIO, dataSegmentProvider, dataServerQueryHandlerFactory, - WorkerMemoryParameters.createProductionInstanceForWorker(injector, queryDef, stageNumber) + WorkerMemoryParameters.createProductionInstanceForWorker(injector, queryDef, stageNumber, maxConcurrentStages), + WorkerStorageParameters.createProductionInstance(injector, outputChannelMode) ); } @Override public int threadCount() { - return processorBouncer().getMaxCount(); + return memoryIntrospector.numProcessorsInJvm(); } @Override public DruidNode selfNode() { - return injector.getInstance(Key.get(DruidNode.class, Self.class)); - } - - @Override - public Bouncer processorBouncer() - { - return injector.getInstance(Bouncer.class); + return toolbox.getDruidNode(); } @Override @@ -276,21 +308,13 @@ public DataServerQueryHandlerFactory dataServerQueryHandlerFactory() return dataServerQueryHandlerFactory; } - private synchronized OverlordClient makeOverlordClient() - { - if (overlordClient == null) { - overlordClient = injector.getInstance(OverlordClient.class) - .withRetryPolicy(StandardRetryPolicy.unlimited()); - } - return overlordClient; - } - private synchronized ServiceLocator makeControllerLocator(final String controllerId) { if (controllerLocator == null) { - controllerLocator = new SpecificTaskServiceLocator(controllerId, makeOverlordClient()); + controllerLocator = new SpecificTaskServiceLocator(controllerId, overlordClient); } return controllerLocator; } + } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQWorkerTask.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQWorkerTask.java index b4d18ea390e9e..31b03d63ba6ec 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQWorkerTask.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQWorkerTask.java @@ -33,10 +33,13 @@ import org.apache.druid.indexing.common.config.TaskConfig; import org.apache.druid.indexing.common.task.AbstractTask; import org.apache.druid.indexing.common.task.Tasks; +import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.msq.exec.MSQTasks; import org.apache.druid.msq.exec.Worker; import org.apache.druid.msq.exec.WorkerContext; import org.apache.druid.msq.exec.WorkerImpl; +import org.apache.druid.msq.indexing.error.MSQException; +import org.apache.druid.msq.indexing.error.MSQFaultUtils; import org.apache.druid.server.security.ResourceAction; import javax.annotation.Nonnull; @@ -48,6 +51,7 @@ public class MSQWorkerTask extends AbstractTask { public static final String TYPE = "query_worker"; + private static final Logger log = new Logger(MSQWorkerTask.class); private final String controllerTaskId; private final int workerNumber; @@ -132,18 +136,25 @@ public boolean isReady(final TaskActionClient taskActionClient) } @Override - public TaskStatus runTask(final TaskToolbox toolbox) throws Exception + public TaskStatus runTask(final TaskToolbox toolbox) { - final WorkerContext context = IndexerWorkerContext.createProductionInstance(toolbox, injector); + final WorkerContext context = IndexerWorkerContext.createProductionInstance(this, toolbox, injector); worker = new WorkerImpl(this, context); - return worker.run(); + + try { + worker.run(); + return TaskStatus.success(context.workerId()); + } + catch (MSQException e) { + return TaskStatus.failure(context.workerId(), MSQFaultUtils.generateMessageWithErrorCode(e.getFault())); + } } @Override public void stopGracefully(TaskConfig taskConfig) { if (worker != null) { - worker.stopGracefully(); + worker.stop(); } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/client/IndexerControllerClient.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/client/IndexerControllerClient.java index 81303eb438480..1e31de71a8aca 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/client/IndexerControllerClient.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/client/IndexerControllerClient.java @@ -152,7 +152,7 @@ public void postWorkerWarning(List MSQErrorReports) throws IOExc } @Override - public List getTaskList() throws IOException + public List getWorkerIds() throws IOException { final BytesFullResponseHolder retVal = doRequest( new RequestBuilder(HttpMethod.GET, "/taskList"), diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/client/WorkerChatHandler.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/client/WorkerChatHandler.java index 70d1ab11d380c..7c8b86bb9d641 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/client/WorkerChatHandler.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/client/WorkerChatHandler.java @@ -19,310 +19,25 @@ package org.apache.druid.msq.indexing.client; -import com.google.common.collect.ImmutableMap; -import it.unimi.dsi.fastutil.bytes.ByteArrays; -import org.apache.commons.lang.mutable.MutableLong; -import org.apache.druid.frame.file.FrameFileHttpResponseHandler; -import org.apache.druid.frame.key.ClusterByPartitions; -import org.apache.druid.indexing.common.TaskToolbox; -import org.apache.druid.java.util.common.StringUtils; -import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.msq.exec.Worker; -import org.apache.druid.msq.indexing.MSQWorkerTask; -import org.apache.druid.msq.kernel.StageId; -import org.apache.druid.msq.kernel.WorkOrder; -import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot; -import org.apache.druid.msq.statistics.serde.ClusterByStatisticsSnapshotSerde; +import org.apache.druid.msq.indexing.IndexerResourcePermissionMapper; +import org.apache.druid.msq.rpc.WorkerResource; import org.apache.druid.segment.realtime.ChatHandler; -import org.apache.druid.segment.realtime.ChatHandlers; -import org.apache.druid.server.security.Action; -import org.apache.druid.utils.CloseableUtils; +import org.apache.druid.segment.realtime.ChatHandlerProvider; +import org.apache.druid.server.security.AuthorizerMapper; -import javax.annotation.Nullable; -import javax.servlet.http.HttpServletRequest; -import javax.ws.rs.Consumes; -import javax.ws.rs.GET; -import javax.ws.rs.POST; -import javax.ws.rs.Path; -import javax.ws.rs.PathParam; -import javax.ws.rs.Produces; -import javax.ws.rs.QueryParam; -import javax.ws.rs.core.Context; -import javax.ws.rs.core.MediaType; -import javax.ws.rs.core.Response; -import javax.ws.rs.core.StreamingOutput; -import java.io.IOException; -import java.io.InputStream; - -public class WorkerChatHandler implements ChatHandler +/** + * Subclass of {@link WorkerResource} that implements {@link ChatHandler}, suitable for registration + * with a {@link ChatHandlerProvider}. + */ +public class WorkerChatHandler extends WorkerResource implements ChatHandler { - private static final Logger log = new Logger(WorkerChatHandler.class); - - /** - * Callers must be able to store an entire chunk in memory. It can't be too large. - */ - private static final long CHANNEL_DATA_CHUNK_SIZE = 1_000_000; - - private final Worker worker; - private final MSQWorkerTask task; - private final TaskToolbox toolbox; - - public WorkerChatHandler(TaskToolbox toolbox, Worker worker) - { - this.worker = worker; - this.task = worker.task(); - this.toolbox = toolbox; - } - - /** - * Returns up to {@link #CHANNEL_DATA_CHUNK_SIZE} bytes of stage output data. - *

- * See {@link org.apache.druid.msq.exec.WorkerClient#fetchChannelData} for the client-side code that calls this API. - */ - @GET - @Path("/channels/{queryId}/{stageNumber}/{partitionNumber}") - @Produces(MediaType.APPLICATION_OCTET_STREAM) - public Response httpGetChannelData( - @PathParam("queryId") final String queryId, - @PathParam("stageNumber") final int stageNumber, - @PathParam("partitionNumber") final int partitionNumber, - @QueryParam("offset") final long offset, - @Context final HttpServletRequest req + public WorkerChatHandler( + final Worker worker, + final AuthorizerMapper authorizerMapper, + final String dataSource ) { - ChatHandlers.authorizationCheck(req, Action.WRITE, task.getDataSource(), toolbox.getAuthorizerMapper()); - try { - final InputStream inputStream = worker.readChannel(queryId, stageNumber, partitionNumber, offset); - if (inputStream == null) { - return Response.status(Response.Status.NOT_FOUND).build(); - } - - final Response.ResponseBuilder responseBuilder = Response.ok(); - - final byte[] readBuf = new byte[8192]; - final MutableLong bytesReadTotal = new MutableLong(0L); - final int firstRead = inputStream.read(readBuf); - - if (firstRead == -1) { - // Empty read means we're at the end of the channel. Set the last fetch header so the client knows this. - inputStream.close(); - return responseBuilder - .header( - FrameFileHttpResponseHandler.HEADER_LAST_FETCH_NAME, - FrameFileHttpResponseHandler.HEADER_LAST_FETCH_VALUE - ) - .entity(ByteArrays.EMPTY_ARRAY) - .build(); - } - - return Response.ok((StreamingOutput) output -> { - try { - int bytesReadThisCall = firstRead; - do { - final int bytesToWrite = - (int) Math.min(CHANNEL_DATA_CHUNK_SIZE - bytesReadTotal.longValue(), bytesReadThisCall); - output.write(readBuf, 0, bytesToWrite); - bytesReadTotal.add(bytesReadThisCall); - } while (bytesReadTotal.longValue() < CHANNEL_DATA_CHUNK_SIZE - && (bytesReadThisCall = inputStream.read(readBuf)) != -1); - } - catch (Throwable e) { - // Suppress the exception to ensure nothing gets written over the wire once we've sent a 200. The client - // will resume from where it left off. - log.noStackTrace().warn( - e, - "Error writing channel for query [%s] stage [%s] partition [%s] offset [%,d] to [%s]", - queryId, - stageNumber, - partitionNumber, - offset, - req.getRemoteAddr() - ); - } - finally { - CloseableUtils.closeAll(inputStream, output); - } - }).build(); - } - catch (IOException e) { - return Response.status(Response.Status.INTERNAL_SERVER_ERROR).build(); - } - } - - /** - * See {@link org.apache.druid.msq.exec.WorkerClient#postWorkOrder} for the client-side code that calls this API. - */ - @POST - @Consumes(MediaType.APPLICATION_JSON) - @Produces(MediaType.APPLICATION_JSON) - @Path("/workOrder") - public Response httpPostWorkOrder(final WorkOrder workOrder, @Context final HttpServletRequest req) - { - ChatHandlers.authorizationCheck(req, Action.WRITE, task.getDataSource(), toolbox.getAuthorizerMapper()); - worker.postWorkOrder(workOrder); - return Response.status(Response.Status.ACCEPTED).build(); - } - - /** - * See {@link org.apache.druid.msq.exec.WorkerClient#postResultPartitionBoundaries} for the client-side code that calls this API. - */ - @POST - @Consumes(MediaType.APPLICATION_JSON) - @Produces(MediaType.APPLICATION_JSON) - @Path("/resultPartitionBoundaries/{queryId}/{stageNumber}") - public Response httpPostResultPartitionBoundaries( - final ClusterByPartitions stagePartitionBoundaries, - @PathParam("queryId") final String queryId, - @PathParam("stageNumber") final int stageNumber, - @Context final HttpServletRequest req - ) - { - ChatHandlers.authorizationCheck(req, Action.WRITE, task.getDataSource(), toolbox.getAuthorizerMapper()); - if (worker.postResultPartitionBoundaries(stagePartitionBoundaries, queryId, stageNumber)) { - return Response.status(Response.Status.ACCEPTED).build(); - } else { - return Response.status(Response.Status.BAD_REQUEST).build(); - } - } - - @POST - @Path("/keyStatistics/{queryId}/{stageNumber}") - @Produces({MediaType.APPLICATION_JSON, MediaType.APPLICATION_OCTET_STREAM}) - @Consumes(MediaType.APPLICATION_JSON) - public Response httpFetchKeyStatistics( - @PathParam("queryId") final String queryId, - @PathParam("stageNumber") final int stageNumber, - @QueryParam("sketchEncoding") @Nullable final SketchEncoding sketchEncoding, - @Context final HttpServletRequest req - ) - { - ChatHandlers.authorizationCheck(req, Action.READ, task.getDataSource(), toolbox.getAuthorizerMapper()); - ClusterByStatisticsSnapshot clusterByStatisticsSnapshot; - StageId stageId = new StageId(queryId, stageNumber); - try { - clusterByStatisticsSnapshot = worker.fetchStatisticsSnapshot(stageId); - if (SketchEncoding.OCTET_STREAM.equals(sketchEncoding)) { - return Response.status(Response.Status.ACCEPTED) - .type(MediaType.APPLICATION_OCTET_STREAM) - .entity((StreamingOutput) output -> ClusterByStatisticsSnapshotSerde.serialize(output, clusterByStatisticsSnapshot)) - .build(); - } else { - return Response.status(Response.Status.ACCEPTED) - .type(MediaType.APPLICATION_JSON) - .entity(clusterByStatisticsSnapshot) - .build(); - } - } - catch (Exception e) { - String errorMessage = StringUtils.format( - "Invalid request for key statistics for query[%s] and stage[%d]", - queryId, - stageNumber - ); - log.error(e, errorMessage); - return Response.status(Response.Status.BAD_REQUEST) - .entity(ImmutableMap.of("error", errorMessage)) - .build(); - } - } - - @POST - @Path("/keyStatisticsForTimeChunk/{queryId}/{stageNumber}/{timeChunk}") - @Produces({MediaType.APPLICATION_JSON, MediaType.APPLICATION_OCTET_STREAM}) - @Consumes(MediaType.APPLICATION_JSON) - public Response httpFetchKeyStatisticsWithSnapshot( - @PathParam("queryId") final String queryId, - @PathParam("stageNumber") final int stageNumber, - @PathParam("timeChunk") final long timeChunk, - @QueryParam("sketchEncoding") @Nullable final SketchEncoding sketchEncoding, - @Context final HttpServletRequest req - ) - { - ChatHandlers.authorizationCheck(req, Action.READ, task.getDataSource(), toolbox.getAuthorizerMapper()); - ClusterByStatisticsSnapshot snapshotForTimeChunk; - StageId stageId = new StageId(queryId, stageNumber); - try { - snapshotForTimeChunk = worker.fetchStatisticsSnapshotForTimeChunk(stageId, timeChunk); - if (SketchEncoding.OCTET_STREAM.equals(sketchEncoding)) { - return Response.status(Response.Status.ACCEPTED) - .type(MediaType.APPLICATION_OCTET_STREAM) - .entity((StreamingOutput) output -> ClusterByStatisticsSnapshotSerde.serialize(output, snapshotForTimeChunk)) - .build(); - } else { - return Response.status(Response.Status.ACCEPTED) - .type(MediaType.APPLICATION_JSON) - .entity(snapshotForTimeChunk) - .build(); - } - } - catch (Exception e) { - String errorMessage = StringUtils.format( - "Invalid request for key statistics for query[%s], stage[%d] and timeChunk[%d]", - queryId, - stageNumber, - timeChunk - ); - log.error(e, errorMessage); - return Response.status(Response.Status.BAD_REQUEST) - .entity(ImmutableMap.of("error", errorMessage)) - .build(); - } - } - - /** - * See {@link org.apache.druid.msq.exec.WorkerClient#postCleanupStage} for the client-side code that calls this API. - */ - @POST - @Path("/cleanupStage/{queryId}/{stageNumber}") - public Response httpPostCleanupStage( - @PathParam("queryId") final String queryId, - @PathParam("stageNumber") final int stageNumber, - @Context final HttpServletRequest req - ) - { - ChatHandlers.authorizationCheck(req, Action.WRITE, task.getDataSource(), toolbox.getAuthorizerMapper()); - worker.postCleanupStage(new StageId(queryId, stageNumber)); - return Response.status(Response.Status.ACCEPTED).build(); - } - - /** - * See {@link org.apache.druid.msq.exec.WorkerClient#postFinish} for the client-side code that calls this API. - */ - @POST - @Path("/finish") - public Response httpPostFinish(@Context final HttpServletRequest req) - { - ChatHandlers.authorizationCheck(req, Action.WRITE, task.getDataSource(), toolbox.getAuthorizerMapper()); - worker.postFinish(); - return Response.status(Response.Status.ACCEPTED).build(); - } - - - /** - * See {@link org.apache.druid.msq.exec.WorkerClient#getCounters} for the client-side code that calls this API. - */ - @GET - @Produces(MediaType.APPLICATION_JSON) - @Path("/counters") - public Response httpGetCounters(@Context final HttpServletRequest req) - { - ChatHandlers.authorizationCheck(req, Action.WRITE, task.getDataSource(), toolbox.getAuthorizerMapper()); - return Response.status(Response.Status.OK).entity(worker.getCounters()).build(); - } - - /** - * Determines the encoding of key collectors returned by {@link #httpFetchKeyStatistics} and - * {@link #httpFetchKeyStatisticsWithSnapshot}. - */ - public enum SketchEncoding - { - /** - * The key collector is encoded as a byte stream with {@link ClusterByStatisticsSnapshotSerde}. - */ - OCTET_STREAM, - /** - * The key collector is encoded as json - */ - JSON + super(worker, new IndexerResourcePermissionMapper(dataSource), authorizerMapper); } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/NotEnoughMemoryFault.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/NotEnoughMemoryFault.java index 5c80f065eef3b..6f4b36da1eec6 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/NotEnoughMemoryFault.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/NotEnoughMemoryFault.java @@ -20,6 +20,7 @@ package org.apache.druid.msq.indexing.error; import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonTypeName; @@ -35,6 +36,7 @@ public class NotEnoughMemoryFault extends BaseMSQFault private final long usableMemory; private final int serverWorkers; private final int serverThreads; + private final int maxConcurrentStages; @JsonCreator public NotEnoughMemoryFault( @@ -42,19 +44,23 @@ public NotEnoughMemoryFault( @JsonProperty("serverMemory") final long serverMemory, @JsonProperty("usableMemory") final long usableMemory, @JsonProperty("serverWorkers") final int serverWorkers, - @JsonProperty("serverThreads") final int serverThreads + @JsonProperty("serverThreads") final int serverThreads, + @JsonProperty("maxConcurrentStages") final int maxConcurrentStages ) { super( CODE, "Not enough memory. Required at least %,d bytes. (total = %,d bytes; usable = %,d bytes; " - + "worker capacity = %,d; processing threads = %,d). Increase JVM memory with the -Xmx option" - + (serverWorkers > 1 ? " or reduce worker capacity on this server" : ""), + + "worker capacity = %,d; processing threads = %,d; concurrent stages = %,d). " + + "Increase JVM memory with the -Xmx option" + + (serverWorkers > 1 ? ", or reduce worker capacity on this server" : "") + + (maxConcurrentStages > 1 ? ", or reduce maxConcurrentStages for this query" : ""), suggestedServerMemory, serverMemory, usableMemory, serverWorkers, - serverThreads + serverThreads, + maxConcurrentStages ); this.suggestedServerMemory = suggestedServerMemory; @@ -62,6 +68,7 @@ public NotEnoughMemoryFault( this.usableMemory = usableMemory; this.serverWorkers = serverWorkers; this.serverThreads = serverThreads; + this.maxConcurrentStages = maxConcurrentStages; } @JsonProperty @@ -94,6 +101,13 @@ public int getServerThreads() return serverThreads; } + @JsonProperty + @JsonInclude(JsonInclude.Include.NON_DEFAULT) + public int getMaxConcurrentStages() + { + return maxConcurrentStages; + } + @Override public boolean equals(Object o) { @@ -107,12 +121,12 @@ public boolean equals(Object o) return false; } NotEnoughMemoryFault that = (NotEnoughMemoryFault) o; - return - suggestedServerMemory == that.suggestedServerMemory - && serverMemory == that.serverMemory - && usableMemory == that.usableMemory - && serverWorkers == that.serverWorkers - && serverThreads == that.serverThreads; + return suggestedServerMemory == that.suggestedServerMemory + && serverMemory == that.serverMemory + && usableMemory == that.usableMemory + && serverWorkers == that.serverWorkers + && serverThreads == that.serverThreads + && maxConcurrentStages == that.maxConcurrentStages; } @Override @@ -124,7 +138,8 @@ public int hashCode() serverMemory, usableMemory, serverWorkers, - serverThreads + serverThreads, + maxConcurrentStages ); } @@ -137,6 +152,7 @@ public String toString() " bytes, usableMemory=" + usableMemory + " bytes, serverWorkers=" + serverWorkers + ", serverThreads=" + serverThreads + + ", maxConcurrentStages=" + maxConcurrentStages + '}'; } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/InputSlices.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/InputSlices.java index 028f1b5bd48a4..de01235447aa5 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/InputSlices.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/InputSlices.java @@ -41,6 +41,22 @@ private InputSlices() // No instantiation. } + /** + * Returns all {@link StageInputSlice} from the provided list of input slices. Ignores other types of input slices. + */ + public static List allStageSlices(final List slices) + { + final List retVal = new ArrayList<>(); + + for (final InputSlice slice : slices) { + if (slice instanceof StageInputSlice) { + retVal.add((StageInputSlice) slice); + } + } + + return retVal; + } + /** * Combines all {@link StageInputSlice#getPartitions()} from the input slices that are {@link StageInputSlice}. * Ignores other types of input slices. @@ -49,10 +65,8 @@ public static ReadablePartitions allReadablePartitions(final List sl { final List partitionsList = new ArrayList<>(); - for (final InputSlice slice : slices) { - if (slice instanceof StageInputSlice) { - partitionsList.add(((StageInputSlice) slice).getPartitions()); - } + for (final StageInputSlice slice : allStageSlices(slices)) { + partitionsList.add(slice.getPartitions()); } return ReadablePartitions.combine(partitionsList); diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/external/ExternalInputSliceReader.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/external/ExternalInputSliceReader.java index 03aa7cd0fe4f2..4b68a3bf1b01b 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/external/ExternalInputSliceReader.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/external/ExternalInputSliceReader.java @@ -31,7 +31,7 @@ import org.apache.druid.data.input.impl.InlineInputSource; import org.apache.druid.data.input.impl.TimestampSpec; import org.apache.druid.java.util.common.DateTimes; -import org.apache.druid.java.util.common.ISE; +import org.apache.druid.java.util.common.FileUtils; import org.apache.druid.msq.counters.ChannelCounters; import org.apache.druid.msq.counters.CounterNames; import org.apache.druid.msq.counters.CounterTracker; @@ -53,6 +53,7 @@ import org.apache.druid.timeline.SegmentId; import java.io.File; +import java.io.IOException; import java.util.Iterator; import java.util.List; import java.util.function.Consumer; @@ -94,7 +95,7 @@ public ReadableInputs attach( externalInputSlice.getInputSources(), externalInputSlice.getInputFormat(), externalInputSlice.getSignature(), - temporaryDirectory, + new File(temporaryDirectory, String.valueOf(inputNumber)), counters.channel(CounterNames.inputChannel(inputNumber)).setTotalFiles(slice.fileCount()), counters.warnings(), warningPublisher @@ -128,9 +129,13 @@ private static Iterator inputSourceSegmentIterator( ColumnsFilter.all() ); - if (!temporaryDirectory.exists() && !temporaryDirectory.mkdir()) { - throw new ISE("Cannot create temporary directory at [%s]", temporaryDirectory); + try { + FileUtils.mkdirp(temporaryDirectory); } + catch (IOException e) { + throw new RuntimeException(e); + } + return Iterators.transform( inputSources.iterator(), inputSource -> { diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/FrameContext.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/FrameContext.java index 7db2fa1a9dd9a..da962a9d39314 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/FrameContext.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/FrameContext.java @@ -20,8 +20,11 @@ package org.apache.druid.msq.kernel; import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.druid.frame.processor.Bouncer; import org.apache.druid.msq.exec.DataServerQueryHandlerFactory; +import org.apache.druid.msq.exec.OutputChannelMode; import org.apache.druid.msq.exec.WorkerMemoryParameters; +import org.apache.druid.msq.exec.WorkerStorageParameters; import org.apache.druid.msq.querykit.DataSegmentProvider; import org.apache.druid.query.groupby.GroupingEngine; import org.apache.druid.segment.IndexIO; @@ -30,12 +33,16 @@ import org.apache.druid.segment.incremental.RowIngestionMeters; import org.apache.druid.segment.loading.DataSegmentPusher; +import java.io.Closeable; import java.io.File; /** - * Provides services and objects for the functioning of the frame processors + * Provides services and objects for the functioning of the frame processors. Scoped to a specific stage of a + * specific query, i.e., one {@link WorkOrder}. + * + * Generated by {@link org.apache.druid.msq.exec.WorkerContext#frameContext(QueryDefinition, int, OutputChannelMode)}. */ -public interface FrameContext +public interface FrameContext extends Closeable { SegmentWrangler segmentWrangler(); @@ -59,5 +66,14 @@ public interface FrameContext IndexMergerV9 indexMerger(); + Bouncer processorBouncer(); + WorkerMemoryParameters memoryParameters(); + + WorkerStorageParameters storageParameters(); + + default File tempDir(String name) + { + return new File(tempDir(), name); + } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/WorkOrder.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/WorkOrder.java index 201a1783c05f9..0c85787021033 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/WorkOrder.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/WorkOrder.java @@ -109,7 +109,7 @@ ExtraInfoHolder getExtraInfoHolder() /** * Worker IDs for this query, if known in advance (at the time the work order is created). May be null, in which - * case workers use {@link ControllerClient#getTaskList()} to find worker IDs. + * case workers use {@link ControllerClient#getWorkerIds()} to find worker IDs. */ @Nullable @JsonProperty("workers") diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/worker/WorkerStageKernel.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/worker/WorkerStageKernel.java index 632b8a8106ddb..b838092ca7140 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/worker/WorkerStageKernel.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/worker/WorkerStageKernel.java @@ -42,6 +42,8 @@ * This separation of decision-making from the "real world" allows the decision-making to live in one, * easy-to-follow place. * + * Not thread-safe. + * * @see org.apache.druid.msq.kernel.controller.ControllerQueryKernel state machine on the controller side */ public class WorkerStageKernel @@ -51,9 +53,10 @@ public class WorkerStageKernel private WorkerStagePhase phase = WorkerStagePhase.NEW; - // We read this variable in the main thread and the netty threads @Nullable - private volatile ClusterByStatisticsSnapshot resultKeyStatisticsSnapshot; + private ClusterByStatisticsSnapshot resultKeyStatisticsSnapshot; + + private boolean doneReadingInput; @Nullable private ClusterByPartitions resultPartitionBoundaries; @@ -107,25 +110,25 @@ public void startReading() public void startPreshuffleWaitingForResultPartitionBoundaries() { - assertPreshuffleStatisticsNeeded(); + assertPreshuffleStatisticsNeeded(true); transitionTo(WorkerStagePhase.PRESHUFFLE_WAITING_FOR_RESULT_PARTITION_BOUNDARIES); } public void startPreshuffleWritingOutput() { - assertPreshuffleStatisticsNeeded(); transitionTo(WorkerStagePhase.PRESHUFFLE_WRITING_OUTPUT); } - public void setResultKeyStatisticsSnapshot(final ClusterByStatisticsSnapshot resultKeyStatisticsSnapshot) + public void setResultKeyStatisticsSnapshot(@Nullable final ClusterByStatisticsSnapshot resultKeyStatisticsSnapshot) { - assertPreshuffleStatisticsNeeded(); + assertPreshuffleStatisticsNeeded(resultKeyStatisticsSnapshot != null); this.resultKeyStatisticsSnapshot = resultKeyStatisticsSnapshot; + this.doneReadingInput = true; } public void setResultPartitionBoundaries(final ClusterByPartitions resultPartitionBoundaries) { - assertPreshuffleStatisticsNeeded(); + assertPreshuffleStatisticsNeeded(true); this.resultPartitionBoundaries = resultPartitionBoundaries; } @@ -134,6 +137,11 @@ public boolean hasResultKeyStatisticsSnapshot() return resultKeyStatisticsSnapshot != null; } + public boolean isDoneReadingInput() + { + return doneReadingInput; + } + public boolean hasResultPartitionBoundaries() { return resultPartitionBoundaries != null; @@ -152,10 +160,10 @@ public ClusterByPartitions getResultPartitionBoundaries() @Nullable public Object getResultObject() { - if (phase == WorkerStagePhase.RESULTS_READY || phase == WorkerStagePhase.FINISHED) { + if (phase == WorkerStagePhase.RESULTS_COMPLETE) { return resultObject; } else { - throw new ISE("Results are not ready yet"); + throw new ISE("Results are not ready in phase[%s]", phase); } } @@ -174,7 +182,7 @@ public void setResultsComplete(Object resultObject) throw new NullPointerException("resultObject must not be null"); } - transitionTo(WorkerStagePhase.RESULTS_READY); + transitionTo(WorkerStagePhase.RESULTS_COMPLETE); this.resultObject = resultObject; } @@ -196,16 +204,18 @@ public void fail(Throwable t) } } - public boolean addPostedResultsComplete(Pair stageIdAndWorkerNumber) + public boolean addPostedResultsComplete(StageId stageId, int workerNumber) { - return postedResultsComplete.add(stageIdAndWorkerNumber); + return postedResultsComplete.add(Pair.of(stageId, workerNumber)); } - private void assertPreshuffleStatisticsNeeded() + private void assertPreshuffleStatisticsNeeded(final boolean delivered) { - if (!workOrder.getStageDefinition().mustGatherResultKeyStatistics()) { + if (delivered != workOrder.getStageDefinition().mustGatherResultKeyStatistics()) { throw new ISE( - "Result partitioning is not necessary for stage [%s]", + "Result key statistics %s, but %s, for stage[%s]", + delivered ? "delivered" : "not delivered", + workOrder.getStageDefinition().mustGatherResultKeyStatistics() ? "expected" : "not expected", workOrder.getStageDefinition().getId() ); } @@ -222,7 +232,12 @@ private void transitionTo(final WorkerStagePhase newPhase) ); phase = newPhase; } else { - throw new IAE("Cannot transition from [%s] to [%s]", phase, newPhase); + throw new IAE( + "Cannot transition stage[%s] from[%s] to[%s]", + workOrder.getStageDefinition().getId(), + phase, + newPhase + ); } } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/worker/WorkerStagePhase.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/worker/WorkerStagePhase.java index f54aa52349ea8..4e59e7d17a89c 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/worker/WorkerStagePhase.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/worker/WorkerStagePhase.java @@ -54,11 +54,12 @@ public boolean canTransitionFrom(final WorkerStagePhase priorPhase) @Override public boolean canTransitionFrom(final WorkerStagePhase priorPhase) { - return priorPhase == PRESHUFFLE_WAITING_FOR_RESULT_PARTITION_BOUNDARIES; + return priorPhase == PRESHUFFLE_WAITING_FOR_RESULT_PARTITION_BOUNDARIES /* if globally sorting */ + || priorPhase == READING_INPUT /* if locally sorting */; } }, - RESULTS_READY { + RESULTS_COMPLETE { @Override public boolean canTransitionFrom(final WorkerStagePhase priorPhase) { @@ -70,7 +71,9 @@ public boolean canTransitionFrom(final WorkerStagePhase priorPhase) @Override public boolean canTransitionFrom(final WorkerStagePhase priorPhase) { - return priorPhase == RESULTS_READY; + // Stages can transition to FINISHED even if they haven't generated all output yet. For example, this is + // possible if the downstream stage is applying a limit. + return priorPhase.compareTo(FINISHED) < 0; } }, @@ -84,4 +87,24 @@ public boolean canTransitionFrom(final WorkerStagePhase priorPhase) }; public abstract boolean canTransitionFrom(WorkerStagePhase priorPhase); + + /** + * Whether this phase indicates that the stage is no longer running. + */ + public boolean isTerminal() + { + return this == FINISHED || this == FAILED; + } + + /** + * Whether this phase indicates a stage is running and consuming its full complement of resources. + * + * There are still some resources that can be consumed by stages that are not running. For example, in the + * {@link #FINISHED} state, stages can still have data on disk that has not been cleaned-up yet, some pointers + * to that data that still reside in memory, and some counters in memory available for collection by the controller. + */ + public boolean isRunning() + { + return this != NEW && this != RESULTS_COMPLETE && this != FINISHED && this != FAILED; + } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/BaseWorkerClientImpl.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/BaseWorkerClientImpl.java index 8d0fba72a2169..fd1a0323d0fb0 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/BaseWorkerClientImpl.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/BaseWorkerClientImpl.java @@ -97,7 +97,7 @@ public ListenableFuture fetchClusterByStatisticsSna "/keyStatistics/%s/%d?sketchEncoding=%s", StringUtils.urlEncode(stageId.getQueryId()), stageId.getStageNumber(), - WorkerChatHandler.SketchEncoding.OCTET_STREAM + WorkerResource.SketchEncoding.OCTET_STREAM ); return getClient(workerId).asyncRequest( @@ -118,7 +118,7 @@ public ListenableFuture fetchClusterByStatisticsSna StringUtils.urlEncode(stageId.getQueryId()), stageId.getStageNumber(), timeChunk, - WorkerChatHandler.SketchEncoding.OCTET_STREAM + WorkerResource.SketchEncoding.OCTET_STREAM ); return getClient(workerId).asyncRequest( diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/ControllerResource.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/ControllerResource.java index d3e9eefa86d2b..cc570ec992ad4 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/ControllerResource.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/ControllerResource.java @@ -82,6 +82,27 @@ public Response httpPostPartialKeyStatistics( return Response.status(Response.Status.ACCEPTED).build(); } + /** + * Used by subtasks to inform the controller that they are done reading their input, in cases where they would + * not be calling {@link #httpPostPartialKeyStatistics(Object, String, int, int, HttpServletRequest)}. + * + * See {@link ControllerClient#postDoneReadingInput(StageId, int)} for the client-side code that calls this API. + */ + @POST + @Path("/doneReadingInput/{queryId}/{stageNumber}/{workerNumber}") + @Produces(MediaType.APPLICATION_JSON) + @Consumes(MediaType.APPLICATION_JSON) + public Response httpPostDoneReadingInput( + @PathParam("stageNumber") final int stageNumber, + @PathParam("workerNumber") final int workerNumber, + @Context final HttpServletRequest req + ) + { + MSQResourceUtils.authorizeAdminRequest(permissionMapper, authorizerMapper, req); + controller.doneReadingInput(stageNumber, workerNumber); + return Response.status(Response.Status.ACCEPTED).build(); + } + /** * Used by subtasks to post system errors. Note that the errors are organized by taskId, not by query/stage/worker, * because system errors are associated with a task rather than a specific query/stage/worker execution context. @@ -166,7 +187,7 @@ public Response httpPostResultsComplete( } /** - * See {@link ControllerClient#getTaskList()} for the client-side code that calls this API. + * See {@link ControllerClient#getWorkerIds} for the client-side code that calls this API. */ @GET @Path("/taskList") @@ -174,7 +195,7 @@ public Response httpPostResultsComplete( public Response httpGetTaskList(@Context final HttpServletRequest req) { MSQResourceUtils.authorizeAdminRequest(permissionMapper, authorizerMapper, req); - return Response.ok(new MSQTaskList(controller.getTaskIds())).build(); + return Response.ok(new MSQTaskList(controller.getWorkerIds())).build(); } /** diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/MSQResourceUtils.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/MSQResourceUtils.java index 30a8179fe0f00..8820b4ead5a0c 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/MSQResourceUtils.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/MSQResourceUtils.java @@ -47,4 +47,20 @@ public static void authorizeAdminRequest( throw new ForbiddenException(access.toString()); } } + + public static void authorizeQueryRequest( + final ResourcePermissionMapper permissionMapper, + final AuthorizerMapper authorizerMapper, + final HttpServletRequest request, + final String queryId + ) + { + final List resourceActions = permissionMapper.getQueryPermissions(queryId); + + Access access = AuthorizationUtils.authorizeAllResourceActions(request, resourceActions, authorizerMapper); + + if (!access.isAllowed()) { + throw new ForbiddenException(access.toString()); + } + } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/ResourcePermissionMapper.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/ResourcePermissionMapper.java index 8c79f4fa0e054..0a7fb874f6d1c 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/ResourcePermissionMapper.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/ResourcePermissionMapper.java @@ -23,11 +23,9 @@ import java.util.List; -/** - * Provides HTTP resources such as {@link ControllerResource} with information about which permissions are needed - * for requests. - */ public interface ResourcePermissionMapper { List getAdminPermissions(); + + List getQueryPermissions(String queryId); } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/WorkerResource.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/WorkerResource.java new file mode 100644 index 0000000000000..a0bfecff5427d --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/WorkerResource.java @@ -0,0 +1,391 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.rpc; + +import com.fasterxml.jackson.jaxrs.smile.SmileMediaTypes; +import com.google.common.collect.ImmutableMap; +import com.google.common.util.concurrent.FutureCallback; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import org.apache.druid.frame.file.FrameFileHttpResponseHandler; +import org.apache.druid.frame.key.ClusterByPartitions; +import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.java.util.common.concurrent.Execs; +import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.msq.exec.Worker; +import org.apache.druid.msq.kernel.StageId; +import org.apache.druid.msq.kernel.WorkOrder; +import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot; +import org.apache.druid.msq.statistics.serde.ClusterByStatisticsSnapshotSerde; +import org.apache.druid.server.security.AuthorizerMapper; +import org.apache.druid.utils.CloseableUtils; + +import javax.annotation.Nullable; +import javax.servlet.AsyncContext; +import javax.servlet.AsyncEvent; +import javax.servlet.AsyncListener; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import javax.ws.rs.Consumes; +import javax.ws.rs.GET; +import javax.ws.rs.POST; +import javax.ws.rs.Path; +import javax.ws.rs.PathParam; +import javax.ws.rs.Produces; +import javax.ws.rs.QueryParam; +import javax.ws.rs.core.Context; +import javax.ws.rs.core.MediaType; +import javax.ws.rs.core.Response; +import javax.ws.rs.core.StreamingOutput; +import java.io.InputStream; +import java.io.OutputStream; + +public class WorkerResource +{ + private static final Logger log = new Logger(WorkerResource.class); + + /** + * Callers must be able to store an entire chunk in memory. It can't be too large. + */ + private static final long CHANNEL_DATA_CHUNK_SIZE = 1_000_000; + private static final long GET_CHANNEL_DATA_TIMEOUT = 30_000L; + + protected final Worker worker; + protected final ResourcePermissionMapper permissionMapper; + protected final AuthorizerMapper authorizerMapper; + + public WorkerResource( + final Worker worker, + final ResourcePermissionMapper permissionMapper, + final AuthorizerMapper authorizerMapper + ) + { + this.worker = worker; + this.permissionMapper = permissionMapper; + this.authorizerMapper = authorizerMapper; + } + + /** + * Returns up to {@link #CHANNEL_DATA_CHUNK_SIZE} bytes of stage output data. + *

+ * See {@link org.apache.druid.msq.exec.WorkerClient#fetchChannelData} for the client-side code that calls this API. + */ + @GET + @Path("/channels/{queryId}/{stageNumber}/{partitionNumber}") + @Produces(MediaType.APPLICATION_OCTET_STREAM) + public Response httpGetChannelData( + @PathParam("queryId") final String queryId, + @PathParam("stageNumber") final int stageNumber, + @PathParam("partitionNumber") final int partitionNumber, + @QueryParam("offset") final long offset, + @Context final HttpServletRequest req + ) + { + MSQResourceUtils.authorizeQueryRequest(permissionMapper, authorizerMapper, req, queryId); + + final ListenableFuture dataFuture = + worker.readStageOutput(new StageId(queryId, stageNumber), partitionNumber, offset); + + final AsyncContext asyncContext = req.startAsync(); + asyncContext.setTimeout(GET_CHANNEL_DATA_TIMEOUT); + asyncContext.addListener( + new AsyncListener() + { + @Override + public void onComplete(AsyncEvent event) + { + } + + @Override + public void onTimeout(AsyncEvent event) + { + HttpServletResponse response = (HttpServletResponse) asyncContext.getResponse(); + response.setStatus(HttpServletResponse.SC_OK); + event.getAsyncContext().complete(); + } + + @Override + public void onError(AsyncEvent event) + { + } + + @Override + public void onStartAsync(AsyncEvent event) + { + } + } + ); + + // Save these items, since "req" becomes inaccessible in future exception handlers. + final String remoteAddr = req.getRemoteAddr(); + final String requestURI = req.getRequestURI(); + + Futures.addCallback( + dataFuture, + new FutureCallback() + { + @Override + public void onSuccess(final InputStream inputStream) + { + HttpServletResponse response = (HttpServletResponse) asyncContext.getResponse(); + + try (final OutputStream outputStream = response.getOutputStream()) { + if (inputStream == null) { + response.setStatus(HttpServletResponse.SC_NOT_FOUND); + } else { + response.setStatus(HttpServletResponse.SC_OK); + response.setContentType(MediaType.APPLICATION_OCTET_STREAM); + + final byte[] readBuf = new byte[8192]; + final int firstRead = inputStream.read(readBuf); + + if (firstRead == -1) { + // Empty read means we're at the end of the channel. + // Set the last fetch header so the client knows this. + response.setHeader( + FrameFileHttpResponseHandler.HEADER_LAST_FETCH_NAME, + FrameFileHttpResponseHandler.HEADER_LAST_FETCH_VALUE + ); + } else { + long bytesReadTotal = 0; + int bytesReadThisCall = firstRead; + do { + final int bytesToWrite = + (int) Math.min(CHANNEL_DATA_CHUNK_SIZE - bytesReadTotal, bytesReadThisCall); + outputStream.write(readBuf, 0, bytesToWrite); + bytesReadTotal += bytesReadThisCall; + } while (bytesReadTotal < CHANNEL_DATA_CHUNK_SIZE + && (bytesReadThisCall = inputStream.read(readBuf)) != -1); + } + } + } + catch (Exception e) { + log.noStackTrace().warn(e, "Could not respond to request from[%s] to[%s]", remoteAddr, requestURI); + } + finally { + CloseableUtils.closeAndSuppressExceptions(inputStream, e -> log.warn("Failed to close output channel")); + asyncContext.complete(); + } + } + + @Override + public void onFailure(Throwable e) + { + if (!dataFuture.isCancelled()) { + try { + HttpServletResponse response = (HttpServletResponse) asyncContext.getResponse(); + response.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR); + asyncContext.complete(); + } + catch (Exception e2) { + e.addSuppressed(e2); + } + + log.noStackTrace().warn(e, "Request failed from[%s] to[%s]", remoteAddr, requestURI); + } + } + }, + Execs.directExecutor() + ); + + return null; + } + + /** + * See {@link org.apache.druid.msq.exec.WorkerClient#postWorkOrder} for the client-side code that calls this API. + */ + @POST + @Consumes({MediaType.APPLICATION_JSON, SmileMediaTypes.APPLICATION_JACKSON_SMILE}) + @Path("/workOrder") + public Response httpPostWorkOrder(final WorkOrder workOrder, @Context final HttpServletRequest req) + { + final String queryId = workOrder.getQueryDefinition().getQueryId(); + MSQResourceUtils.authorizeQueryRequest(permissionMapper, authorizerMapper, req, queryId); + worker.postWorkOrder(workOrder); + return Response.status(Response.Status.ACCEPTED).build(); + } + + /** + * See {@link org.apache.druid.msq.exec.WorkerClient#postResultPartitionBoundaries} for the client-side code that calls this API. + */ + @POST + @Consumes({MediaType.APPLICATION_JSON, SmileMediaTypes.APPLICATION_JACKSON_SMILE}) + @Path("/resultPartitionBoundaries/{queryId}/{stageNumber}") + public Response httpPostResultPartitionBoundaries( + final ClusterByPartitions stagePartitionBoundaries, + @PathParam("queryId") final String queryId, + @PathParam("stageNumber") final int stageNumber, + @Context final HttpServletRequest req + ) + { + MSQResourceUtils.authorizeQueryRequest(permissionMapper, authorizerMapper, req, queryId); + if (worker.postResultPartitionBoundaries(new StageId(queryId, stageNumber), stagePartitionBoundaries)) { + return Response.status(Response.Status.ACCEPTED).build(); + } else { + return Response.status(Response.Status.BAD_REQUEST).build(); + } + } + + @POST + @Path("/keyStatistics/{queryId}/{stageNumber}") + @Consumes({MediaType.APPLICATION_JSON, SmileMediaTypes.APPLICATION_JACKSON_SMILE}) + @Produces({MediaType.APPLICATION_JSON, MediaType.APPLICATION_OCTET_STREAM}) + public Response httpFetchKeyStatistics( + @PathParam("queryId") final String queryId, + @PathParam("stageNumber") final int stageNumber, + @QueryParam("sketchEncoding") @Nullable final SketchEncoding sketchEncoding, + @Context final HttpServletRequest req + ) + { + MSQResourceUtils.authorizeQueryRequest(permissionMapper, authorizerMapper, req, queryId); + ClusterByStatisticsSnapshot clusterByStatisticsSnapshot; + StageId stageId = new StageId(queryId, stageNumber); + try { + clusterByStatisticsSnapshot = worker.fetchStatisticsSnapshot(stageId); + if (SketchEncoding.OCTET_STREAM.equals(sketchEncoding)) { + return Response.status(Response.Status.ACCEPTED) + .type(MediaType.APPLICATION_OCTET_STREAM) + .entity( + (StreamingOutput) output -> + ClusterByStatisticsSnapshotSerde.serialize(output, clusterByStatisticsSnapshot) + ) + .build(); + } else { + return Response.status(Response.Status.ACCEPTED) + .type(MediaType.APPLICATION_JSON) + .entity(clusterByStatisticsSnapshot) + .build(); + } + } + catch (Exception e) { + String errorMessage = StringUtils.format( + "Invalid request for key statistics for query[%s] and stage[%d]", + queryId, + stageNumber + ); + log.error(e, errorMessage); + return Response.status(Response.Status.BAD_REQUEST) + .entity(ImmutableMap.of("error", errorMessage)) + .build(); + } + } + + @POST + @Path("/keyStatisticsForTimeChunk/{queryId}/{stageNumber}/{timeChunk}") + @Consumes({MediaType.APPLICATION_JSON, SmileMediaTypes.APPLICATION_JACKSON_SMILE}) + @Produces({MediaType.APPLICATION_JSON, MediaType.APPLICATION_OCTET_STREAM}) + public Response httpFetchKeyStatisticsWithSnapshot( + @PathParam("queryId") final String queryId, + @PathParam("stageNumber") final int stageNumber, + @PathParam("timeChunk") final long timeChunk, + @QueryParam("sketchEncoding") @Nullable final SketchEncoding sketchEncoding, + @Context final HttpServletRequest req + ) + { + MSQResourceUtils.authorizeQueryRequest(permissionMapper, authorizerMapper, req, queryId); + ClusterByStatisticsSnapshot snapshotForTimeChunk; + StageId stageId = new StageId(queryId, stageNumber); + try { + snapshotForTimeChunk = worker.fetchStatisticsSnapshotForTimeChunk(stageId, timeChunk); + if (SketchEncoding.OCTET_STREAM.equals(sketchEncoding)) { + return Response.status(Response.Status.ACCEPTED) + .type(MediaType.APPLICATION_OCTET_STREAM) + .entity( + (StreamingOutput) output -> + ClusterByStatisticsSnapshotSerde.serialize(output, snapshotForTimeChunk) + ) + .build(); + } else { + return Response.status(Response.Status.ACCEPTED) + .type(MediaType.APPLICATION_JSON) + .entity(snapshotForTimeChunk) + .build(); + } + } + catch (Exception e) { + String errorMessage = StringUtils.format( + "Invalid request for key statistics for query[%s], stage[%d] and timeChunk[%d]", + queryId, + stageNumber, + timeChunk + ); + log.error(e, errorMessage); + return Response.status(Response.Status.BAD_REQUEST) + .entity(ImmutableMap.of("error", errorMessage)) + .build(); + } + } + + /** + * See {@link org.apache.druid.msq.exec.WorkerClient#postCleanupStage} for the client-side code that calls this API. + */ + @POST + @Path("/cleanupStage/{queryId}/{stageNumber}") + public Response httpPostCleanupStage( + @PathParam("queryId") final String queryId, + @PathParam("stageNumber") final int stageNumber, + @Context final HttpServletRequest req + ) + { + MSQResourceUtils.authorizeQueryRequest(permissionMapper, authorizerMapper, req, queryId); + worker.postCleanupStage(new StageId(queryId, stageNumber)); + return Response.status(Response.Status.ACCEPTED).build(); + } + + /** + * See {@link org.apache.druid.msq.exec.WorkerClient#postFinish} for the client-side code that calls this API. + */ + @POST + @Path("/finish") + public Response httpPostFinish(@Context final HttpServletRequest req) + { + MSQResourceUtils.authorizeAdminRequest(permissionMapper, authorizerMapper, req); + worker.postFinish(); + return Response.status(Response.Status.ACCEPTED).build(); + } + + /** + * See {@link org.apache.druid.msq.exec.WorkerClient#getCounters} for the client-side code that calls this API. + */ + @GET + @Produces({MediaType.APPLICATION_JSON + "; qs=0.9", SmileMediaTypes.APPLICATION_JACKSON_SMILE + "; qs=0.1"}) + @Path("/counters") + public Response httpGetCounters(@Context final HttpServletRequest req) + { + MSQResourceUtils.authorizeAdminRequest(permissionMapper, authorizerMapper, req); + return Response.status(Response.Status.OK).entity(worker.getCounters()).build(); + } + + /** + * Determines the encoding of key collectors returned by {@link #httpFetchKeyStatistics} and + * {@link #httpFetchKeyStatisticsWithSnapshot}. + */ + public enum SketchEncoding + { + /** + * The key collector is encoded as a byte stream with {@link ClusterByStatisticsSnapshotSerde}. + */ + OCTET_STREAM, + /** + * The key collector is encoded as json + */ + JSON + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/input/MetaInputChannelFactory.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/input/MetaInputChannelFactory.java new file mode 100644 index 0000000000000..37595050c8198 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/input/MetaInputChannelFactory.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.shuffle.input; + +import it.unimi.dsi.fastutil.ints.Int2ObjectMap; +import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap; +import org.apache.druid.frame.channel.ReadableFrameChannel; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.msq.exec.OutputChannelMode; +import org.apache.druid.msq.indexing.InputChannelFactory; +import org.apache.druid.msq.input.stage.StageInputSlice; +import org.apache.druid.msq.kernel.StageId; +import org.apache.druid.msq.kernel.WorkOrder; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Function; + +/** + * Meta-factory that wraps {@link #inputChannelFactoryProvider}, and can create various other kinds of factories. + */ +public class MetaInputChannelFactory implements InputChannelFactory +{ + private final Int2ObjectMap stageOutputModeMap; + private final Function inputChannelFactoryProvider; + private final Map inputChannelFactoryMap = new HashMap<>(); + + public MetaInputChannelFactory( + final Int2ObjectMap stageOutputModeMap, + final Function inputChannelFactoryProvider + ) + { + this.stageOutputModeMap = stageOutputModeMap; + this.inputChannelFactoryProvider = inputChannelFactoryProvider; + } + + /** + * Create a meta-factory. + * + * @param slices stage slices from {@link WorkOrder#getInputs()} + * @param defaultOutputChannelMode mode to use when {@link StageInputSlice#getOutputChannelMode()} is null; i.e., + * when running with an older controller + * @param inputChannelFactoryProvider provider of {@link InputChannelFactory} for various {@link OutputChannelMode} + */ + public static MetaInputChannelFactory create( + final List slices, + final OutputChannelMode defaultOutputChannelMode, + final Function inputChannelFactoryProvider + ) + { + final Int2ObjectMap stageOutputModeMap = new Int2ObjectOpenHashMap<>(); + + for (final StageInputSlice slice : slices) { + final OutputChannelMode newMode; + + if (slice.getOutputChannelMode() != null) { + newMode = slice.getOutputChannelMode(); + } else { + newMode = defaultOutputChannelMode; + } + + final OutputChannelMode prevMode = stageOutputModeMap.putIfAbsent( + slice.getStageNumber(), + newMode + ); + + if (prevMode != null && prevMode != newMode) { + throw new ISE( + "Inconsistent output modes for stage[%s], got[%s] and[%s]", + slice.getStageNumber(), + prevMode, + newMode + ); + } + } + + return new MetaInputChannelFactory(stageOutputModeMap, inputChannelFactoryProvider); + } + + @Override + public ReadableFrameChannel openChannel( + final StageId stageId, + final int workerNumber, + final int partitionNumber + ) throws IOException + { + final OutputChannelMode outputChannelMode = stageOutputModeMap.get(stageId.getStageNumber()); + + if (outputChannelMode == null) { + throw new ISE("No output mode for stageNumber[%s]", stageId.getStageNumber()); + } + + return inputChannelFactoryMap.computeIfAbsent(outputChannelMode, inputChannelFactoryProvider) + .openChannel(stageId, workerNumber, partitionNumber); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/input/WorkerOrLocalInputChannelFactory.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/input/WorkerOrLocalInputChannelFactory.java new file mode 100644 index 0000000000000..08c7176b7c2b5 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/input/WorkerOrLocalInputChannelFactory.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.shuffle.input; + +import org.apache.druid.frame.channel.ReadableFrameChannel; +import org.apache.druid.msq.indexing.InputChannelFactory; +import org.apache.druid.msq.kernel.StageId; +import org.apache.druid.msq.shuffle.output.StageOutputHolder; + +import java.io.IOException; +import java.util.List; +import java.util.function.Supplier; + +/** + * An {@link InputChannelFactory} that loads data locally when possible, and otherwise connects directly to other + * workers. Used when durable shuffle storage is off. + */ +public class WorkerOrLocalInputChannelFactory implements InputChannelFactory +{ + private final String myId; + private final Supplier> workerIdsSupplier; + private final InputChannelFactory workerInputChannelFactory; + private final StageOutputHolderProvider stageOutputHolderProvider; + + public WorkerOrLocalInputChannelFactory( + final String myId, + final Supplier> workerIdsSupplier, + final InputChannelFactory workerInputChannelFactory, + final StageOutputHolderProvider stageOutputHolderProvider + ) + { + this.myId = myId; + this.workerIdsSupplier = workerIdsSupplier; + this.workerInputChannelFactory = workerInputChannelFactory; + this.stageOutputHolderProvider = stageOutputHolderProvider; + } + + @Override + public ReadableFrameChannel openChannel(StageId stageId, int workerNumber, int partitionNumber) throws IOException + { + final String taskId = workerIdsSupplier.get().get(workerNumber); + if (taskId.equals(myId)) { + return stageOutputHolderProvider.getHolder(stageId, partitionNumber).readLocally(); + } else { + return workerInputChannelFactory.openChannel(stageId, workerNumber, partitionNumber); + } + } + + public interface StageOutputHolderProvider + { + StageOutputHolder getHolder(StageId stageId, int partitionNumber); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/ByteChunksInputStream.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/ByteChunksInputStream.java new file mode 100644 index 0000000000000..f623e58f65b31 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/ByteChunksInputStream.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.shuffle.output; + +import org.apache.druid.error.DruidException; + +import java.io.InputStream; +import java.util.List; + +/** + * Input stream based on a list of byte arrays. + */ +public class ByteChunksInputStream extends InputStream +{ + private final List chunks; + private int chunkNum; + private int positionWithinChunk; + + /** + * Create a new stream wrapping a list of chunks. + * + * @param chunks byte arrays + * @param positionWithinFirstChunk starting position within the first byte array + */ + public ByteChunksInputStream(final List chunks, final int positionWithinFirstChunk) + { + this.chunks = chunks; + this.positionWithinChunk = positionWithinFirstChunk; + this.chunkNum = -1; + advanceChunk(); + } + + @Override + public int read() + { + if (chunkNum >= chunks.size()) { + return -1; + } else { + final byte[] currentChunk = chunks.get(chunkNum); + final byte b = currentChunk[positionWithinChunk++]; + + if (positionWithinChunk == currentChunk.length) { + chunkNum++; + positionWithinChunk = 0; + } + + return b & 0xFF; + } + } + + @Override + public int read(byte[] b) + { + return read(b, 0, b.length); + } + + @Override + public int read(byte[] b, int off, int len) + { + if (len == 0) { + return 0; + } else if (chunkNum >= chunks.size()) { + return -1; + } else { + int r = 0; + + while (r < len && chunkNum < chunks.size()) { + final byte[] currentChunk = chunks.get(chunkNum); + int toReadFromCurrentChunk = Math.min(len - r, currentChunk.length - positionWithinChunk); + System.arraycopy(currentChunk, positionWithinChunk, b, off + r, toReadFromCurrentChunk); + r += toReadFromCurrentChunk; + positionWithinChunk += toReadFromCurrentChunk; + if (positionWithinChunk == currentChunk.length) { + chunkNum++; + positionWithinChunk = 0; + } + } + + return r; + } + } + + @Override + public void close() + { + chunkNum = chunks.size(); + positionWithinChunk = 0; + } + + private void advanceChunk() + { + chunkNum++; + + // Verify nonempty + if (chunkNum < chunks.size() && chunks.get(chunkNum).length == 0) { + throw DruidException.defensive("Empty chunk not allowed"); + } + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/ChannelStageOutputReader.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/ChannelStageOutputReader.java new file mode 100644 index 0000000000000..ec95ca7af6a78 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/ChannelStageOutputReader.java @@ -0,0 +1,274 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.shuffle.output; + +import com.google.common.collect.ImmutableList; +import com.google.common.primitives.Ints; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import it.unimi.dsi.fastutil.bytes.ByteArrays; +import org.apache.druid.common.guava.FutureUtils; +import org.apache.druid.error.DruidException; +import org.apache.druid.frame.channel.ByteTracker; +import org.apache.druid.frame.channel.ReadableFrameChannel; +import org.apache.druid.frame.file.FrameFileWriter; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.msq.exec.OutputChannelMode; +import org.apache.druid.msq.kernel.controller.ControllerQueryKernelUtils; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; +import java.util.ArrayDeque; +import java.util.Deque; + +/** + * Reader for the case where stage output is a generic {@link ReadableFrameChannel}. + * + * Because this reader returns an underlying channel directly, it must only be used when it is certain that + * only a single consumer exists, i.e., when using output mode {@link OutputChannelMode#MEMORY}. See + * {@link ControllerQueryKernelUtils#canUseMemoryOutput} for the code that ensures that there is only a single + * consumer in the in-memory case. + */ +public class ChannelStageOutputReader implements StageOutputReader +{ + enum State + { + INIT, + LOCAL, + REMOTE, + CLOSED + } + + private final ReadableFrameChannel channel; + private final FrameFileWriter writer; + + /** + * Pair of chunk size + chunk InputStream. + */ + private final Deque chunks = new ArrayDeque<>(); + + /** + * State of this reader. + */ + @GuardedBy("this") + private State state = State.INIT; + + /** + * Position of {@link #positionWithinFirstChunk} in the first chunk of {@link #chunks}, within the overall stream. + */ + @GuardedBy("this") + private long cursor; + + /** + * Offset of the first chunk in {@link #chunks} which corresponds to {@link #cursor}. + */ + @GuardedBy("this") + private int positionWithinFirstChunk; + + /** + * Whether {@link FrameFileWriter#close()} is called on {@link #writer}. + */ + @GuardedBy("this") + private boolean didCloseWriter; + + public ChannelStageOutputReader(final ReadableFrameChannel channel) + { + this.channel = channel; + this.writer = FrameFileWriter.open(new ChunkAcceptor(), null, ByteTracker.unboundedTracker()); + } + + /** + * Returns an input stream starting at the provided offset. + * + * The returned {@link InputStream} is non-blocking, and is slightly buffered (up to one frame). It does not + * necessarily contain the complete remaining dataset; this means that multiple calls to this method are necessary + * to fetch the complete dataset. + * + * The provided offset must be greater than, or equal to, the offset provided to the prior call. + * + * This class supports either remote or local reads, but not both. Calling both this method and {@link #readLocally()} + * on the same instance of this class is an error. + * + * @param offset offset into the stage output stream + */ + @Override + public synchronized ListenableFuture readRemotelyFrom(final long offset) + { + if (state == State.INIT) { + state = State.REMOTE; + } else if (state == State.LOCAL) { + throw new ISE("Cannot read both remotely and locally"); + } else if (state == State.CLOSED) { + throw new ISE("Closed"); + } + + if (offset < cursor) { + return Futures.immediateFailedFuture( + new ISE("Offset[%,d] no longer available, current cursor is[%,d]", offset, cursor)); + } + + while (chunks.isEmpty() || offset > cursor) { + // Fetch additional chunks if needed. + if (chunks.isEmpty()) { + if (didCloseWriter) { + if (offset == cursor) { + return Futures.immediateFuture(new ByteArrayInputStream(ByteArrays.EMPTY_ARRAY)); + } else { + throw DruidException.defensive( + "Channel finished but cursor[%,d] does not match requested offset[%,d]", + cursor, + offset + ); + } + } else if (channel.isFinished()) { + try { + writer.close(); + } + catch (IOException e) { + throw new RuntimeException(e); + } + + didCloseWriter = true; + continue; + } else if (channel.canRead()) { + try { + writer.writeFrame(channel.read(), FrameFileWriter.NO_PARTITION); + } + catch (Exception e) { + try { + writer.abort(); + } + catch (IOException e2) { + e.addSuppressed(e2); + } + + throw new RuntimeException(e); + } + } else { + return FutureUtils.transformAsync(channel.readabilityFuture(), ignored -> readRemotelyFrom(offset)); + } + } + + // Advance cursor to the provided offset, or the end of the current chunk, whichever is earlier. + final byte[] chunk = chunks.peek(); + final long amountToAdvance = Math.min(offset - cursor, chunk.length - positionWithinFirstChunk); + cursor += amountToAdvance; + positionWithinFirstChunk += Ints.checkedCast(amountToAdvance); + + // Remove first chunk if it is no longer needed. (i.e., if the cursor is at the end of it.) + if (positionWithinFirstChunk == chunk.length) { + chunks.poll(); + positionWithinFirstChunk = 0; + } + } + + if (chunks.isEmpty() || offset != cursor) { + throw DruidException.defensive( + "Expected cursor[%,d] to be caught up to offset[%,d] by this point, and to have nonzero chunks", + cursor, + offset + ); + } + + return Futures.immediateFuture(new ByteChunksInputStream(ImmutableList.copyOf(chunks), positionWithinFirstChunk)); + } + + /** + * Returns the {@link ReadableFrameChannel} that backs this reader. + * + * Callers are responsible for closing the returned channel. Once this method is called, the caller becomes the + * owner of the channel, and this class's {@link #close()} method will no longer close the channel. + * + * Only a single reader is supported. Once this method is called, it cannot be called again. + * + * This class supports either remote or local reads, but not both. Calling both this method and + * {@link #readRemotelyFrom(long)} on the same instance of this class is an error. + */ + @Override + public synchronized ReadableFrameChannel readLocally() + { + if (state == State.INIT) { + state = State.LOCAL; + return channel; + } else if (state == State.REMOTE) { + throw new ISE("Cannot read both remotely and locally"); + } else if (state == State.LOCAL) { + throw new ISE("Cannot read channel multiple times"); + } else { + assert state == State.CLOSED; + throw new ISE("Closed"); + } + } + + /** + * Closes the {@link ReadableFrameChannel} backing this reader, unless {@link #readLocally()} has been called. + * In that case, the caller of {@link #readLocally()} is responsible for closing the channel. + */ + @Override + public synchronized void close() + { + // Call channel.close() unless readLocally() has been called. In that case, we expect the caller to close it. + if (state != State.LOCAL) { + state = State.CLOSED; + channel.close(); + } + } + + /** + * Input stream that can have bytes appended to it, and that can have bytes acknowledged. + */ + private class ChunkAcceptor implements WritableByteChannel + { + private boolean open = true; + + @Override + public int write(final ByteBuffer src) throws IOException + { + if (!open) { + throw new IOException("Closed"); + } + + final int len = src.remaining(); + if (len > 0) { + final byte[] bytes = new byte[len]; + src.get(bytes); + chunks.add(bytes); + } + + return len; + } + + @Override + public boolean isOpen() + { + return open; + } + + @Override + public void close() + { + open = false; + } + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/FileStageOutputReader.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/FileStageOutputReader.java new file mode 100644 index 0000000000000..29fb7b17ee78e --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/FileStageOutputReader.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.shuffle.output; + +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import it.unimi.dsi.fastutil.bytes.ByteArrays; +import org.apache.druid.frame.channel.ReadableFileFrameChannel; +import org.apache.druid.frame.channel.ReadableFrameChannel; +import org.apache.druid.frame.file.FrameFile; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.RandomAccessFile; +import java.nio.channels.Channels; + +/** + * Reader for the case where stage output is stored in a {@link FrameFile} on disk. + */ +public class FileStageOutputReader implements StageOutputReader +{ + private final FrameFile frameFile; + + public FileStageOutputReader(FrameFile frameFile) + { + this.frameFile = frameFile; + } + + /** + * Returns an input stream starting at the provided offset. The file is opened and seeked in-line with this method + * call, so the returned future is always immediately resolved. Callers are responsible for closing the returned + * input stream. + * + * This class supports remote and local reads from the same {@link FrameFile}, which, for example, is useful when + * broadcasting the output of a stage. + * + * @param offset offset into the stage output file + */ + @Override + public ListenableFuture readRemotelyFrom(long offset) + { + try { + final RandomAccessFile randomAccessFile = new RandomAccessFile(frameFile.file(), "r"); + + if (offset >= randomAccessFile.length()) { + randomAccessFile.close(); + return Futures.immediateFuture(new ByteArrayInputStream(ByteArrays.EMPTY_ARRAY)); + } else { + randomAccessFile.seek(offset); + return Futures.immediateFuture(Channels.newInputStream(randomAccessFile.getChannel())); + } + } + catch (Exception e) { + return Futures.immediateFailedFuture(e); + } + } + + /** + * Returns a channel pointing to a fresh {@link FrameFile#newReference()} of the underlying frame file. Callers are + * responsible for closing the returned channel. + * + * This class supports remote and local reads from the same {@link FrameFile}, which, for example, is useful when + * broadcasting the output of a stage. + */ + @Override + public ReadableFrameChannel readLocally() + { + return new ReadableFileFrameChannel(frameFile.newReference()); + } + + /** + * Closes the initial reference to the underlying {@link FrameFile}. Does not close additional references created by + * calls to {@link #readLocally()}; those references are closed when the channel(s) returned by {@link #readLocally()} + * are closed. + */ + @Override + public void close() throws IOException + { + frameFile.close(); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/FutureReadableFrameChannel.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/FutureReadableFrameChannel.java new file mode 100644 index 0000000000000..8dcb8786713ba --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/FutureReadableFrameChannel.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.shuffle.output; + +import com.google.common.util.concurrent.ListenableFuture; +import org.apache.druid.common.guava.FutureUtils; +import org.apache.druid.frame.Frame; +import org.apache.druid.frame.channel.ReadableFrameChannel; +import org.apache.druid.java.util.common.logger.Logger; + +import java.util.NoSuchElementException; + +/** + * Channel that wraps a {@link ListenableFuture} of a {@link ReadableFrameChannel}, but acts like a regular (non-future) + * {@link ReadableFrameChannel}. + */ +public class FutureReadableFrameChannel implements ReadableFrameChannel +{ + private static final Logger log = new Logger(FutureReadableFrameChannel.class); + + private final ListenableFuture channelFuture; + private ReadableFrameChannel channel; + + public FutureReadableFrameChannel(final ListenableFuture channelFuture) + { + this.channelFuture = channelFuture; + } + + @Override + public boolean isFinished() + { + if (populateChannel()) { + return channel.isFinished(); + } else { + return false; + } + } + + @Override + public boolean canRead() + { + if (populateChannel()) { + return channel.canRead(); + } else { + return false; + } + } + + @Override + public Frame read() + { + if (populateChannel()) { + return channel.read(); + } else { + throw new NoSuchElementException(); + } + } + + @Override + public ListenableFuture readabilityFuture() + { + if (populateChannel()) { + return channel.readabilityFuture(); + } else { + return FutureUtils.transformAsync(channelFuture, ignored -> readabilityFuture()); + } + } + + @Override + public void close() + { + if (populateChannel()) { + channel.close(); + } else { + channelFuture.cancel(true); + + // In case of a race where channelFuture resolved between populateChannel() and here, the cancel call above would + // have no effect. Guard against this case by checking if the channelFuture has resolved, and if so, close the + // channel here. + try { + final ReadableFrameChannel theChannel = FutureUtils.getUncheckedImmediately(channelFuture); + + try { + theChannel.close(); + } + catch (Throwable t) { + log.noStackTrace().warn(t, "Failed to close channel"); + } + } + catch (Throwable ignored) { + // Suppress. + } + } + } + + private boolean populateChannel() + { + if (channel != null) { + return true; + } else if (channelFuture.isDone()) { + channel = FutureUtils.getUncheckedImmediately(channelFuture); + return true; + } else { + return false; + } + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/NilStageOutputReader.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/NilStageOutputReader.java new file mode 100644 index 0000000000000..86530dad1d01c --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/NilStageOutputReader.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.shuffle.output; + +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import org.apache.druid.frame.channel.ByteTracker; +import org.apache.druid.frame.channel.ReadableFrameChannel; +import org.apache.druid.frame.channel.ReadableNilFrameChannel; +import org.apache.druid.frame.file.FrameFileWriter; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.channels.Channels; + +/** + * Reader for the case where stage output is known to be empty. + */ +public class NilStageOutputReader implements StageOutputReader +{ + public static final NilStageOutputReader INSTANCE = new NilStageOutputReader(); + + private static final byte[] EMPTY_FRAME_FILE; + + static { + try { + final ByteArrayOutputStream baos = new ByteArrayOutputStream(); + FrameFileWriter.open(Channels.newChannel(baos), null, ByteTracker.unboundedTracker()).close(); + EMPTY_FRAME_FILE = baos.toByteArray(); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public ListenableFuture readRemotelyFrom(final long offset) + { + final ByteArrayInputStream in = new ByteArrayInputStream(EMPTY_FRAME_FILE); + + //noinspection ResultOfMethodCallIgnored: OK to ignore since "skip" always works for ByteArrayInputStream. + in.skip(offset); + + return Futures.immediateFuture(in); + } + + @Override + public ReadableFrameChannel readLocally() + { + return ReadableNilFrameChannel.INSTANCE; + } + + @Override + public void close() + { + // Nothing to do. + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/StageOutputHolder.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/StageOutputHolder.java new file mode 100644 index 0000000000000..c19519dfb7bbc --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/StageOutputHolder.java @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.shuffle.output; + +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.SettableFuture; +import org.apache.druid.common.guava.FutureUtils; +import org.apache.druid.frame.channel.ReadableFileFrameChannel; +import org.apache.druid.frame.channel.ReadableFrameChannel; +import org.apache.druid.frame.channel.ReadableNilFrameChannel; +import org.apache.druid.frame.file.FrameFile; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.msq.exec.Worker; +import org.apache.druid.msq.kernel.StageId; +import org.apache.druid.msq.rpc.WorkerResource; +import org.apache.druid.utils.CloseableUtils; + +import javax.servlet.http.HttpServletRequest; +import java.io.Closeable; +import java.io.InputStream; + +/** + * Container for a {@link StageOutputReader}, which is used to read the output of a stage. + */ +public class StageOutputHolder implements Closeable +{ + private final SettableFuture channelFuture; + private final ListenableFuture readerFuture; + + public StageOutputHolder() + { + this.channelFuture = SettableFuture.create(); + this.readerFuture = FutureUtils.transform(channelFuture, StageOutputHolder::createReader); + } + + /** + * Method for remote reads. + * + * Provides the implementation for {@link Worker#readStageOutput(StageId, int, long)}, which is in turn used by + * {@link WorkerResource#httpGetChannelData(String, int, int, long, HttpServletRequest)}. + * + * @see StageOutputReader#readRemotelyFrom(long) for details on behavior + */ + public ListenableFuture readRemotelyFrom(final long offset) + { + return FutureUtils.transformAsync(readerFuture, reader -> reader.readRemotelyFrom(offset)); + } + + /** + * Method for local reads. + * + * Used instead of {@link #readRemotelyFrom(long)} when a worker is reading a channel from itself, to avoid needless + * HTTP calls to itself. + * + * @see StageOutputReader#readLocally() for details on behavior + */ + public ReadableFrameChannel readLocally() + { + return new FutureReadableFrameChannel(FutureUtils.transform(readerFuture, StageOutputReader::readLocally)); + } + + /** + * Sets the channel that backs {@link #readLocally()} and {@link #readRemotelyFrom(long)}. + */ + public void setChannel(final ReadableFrameChannel channel) + { + if (!channelFuture.set(channel)) { + if (FutureUtils.getUncheckedImmediately(channelFuture) == null) { + throw new ISE("Closed"); + } else { + throw new ISE("Channel already set"); + } + } + } + + @Override + public void close() + { + channelFuture.set(null); + + final StageOutputReader reader; + + try { + reader = FutureUtils.getUnchecked(readerFuture, true); + } + catch (Throwable e) { + // Error creating the reader, nothing to close. Suppress. + return; + } + + if (reader != null) { + CloseableUtils.closeAndWrapExceptions(reader); + } + } + + private static StageOutputReader createReader(final ReadableFrameChannel channel) + { + if (channel == null) { + // Happens if close() was called before the channel resolved. + throw new ISE("Closed"); + } + + if (channel instanceof ReadableNilFrameChannel) { + return NilStageOutputReader.INSTANCE; + } + + if (channel instanceof ReadableFileFrameChannel) { + // Optimized implementation when reading an entire file. + final ReadableFileFrameChannel fileChannel = (ReadableFileFrameChannel) channel; + + if (fileChannel.isEntireFile()) { + final FrameFile frameFile = fileChannel.newFrameFileReference(); + + // Close original channel, so we don't leak a frame file reference. + channel.close(); + + return new FileStageOutputReader(frameFile); + } + } + + // Generic implementation for any other type of channel. + return new ChannelStageOutputReader(channel); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/StageOutputReader.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/StageOutputReader.java new file mode 100644 index 0000000000000..36b993611ca4a --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/shuffle/output/StageOutputReader.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.shuffle.output; + +import com.google.common.util.concurrent.ListenableFuture; +import org.apache.druid.frame.channel.ReadableFrameChannel; +import org.apache.druid.frame.file.FrameFile; +import org.apache.druid.msq.exec.Worker; +import org.apache.druid.msq.kernel.StageId; + +import java.io.Closeable; +import java.io.InputStream; + +/** + * Interface for reading output channels for a particular stage. Each instance of this interface represents a + * stream from a single {@link org.apache.druid.msq.kernel.StagePartition} in {@link FrameFile} format. + * + * @see FileStageOutputReader implementation backed by {@link FrameFile} + * @see ChannelStageOutputReader implementation backed by {@link ReadableFrameChannel} + * @see NilStageOutputReader implementation for an empty channel + */ +public interface StageOutputReader extends Closeable +{ + /** + * Method for remote reads. + * + * This method ultimately backs {@link Worker#readStageOutput(StageId, int, long)}. Refer to that method's + * documentation for details about behavior of the returned future. + * + * Callers are responsible for closing the returned {@link InputStream}. This input stream may encapsulate + * resources that are not closed by this class's {@link #close()} method. + * + * It is implementation-dependent whether calls to this method must have monotonically increasing offsets. + * In particular, {@link ChannelStageOutputReader} requires monotonically increasing offsets, but + * {@link FileStageOutputReader} and {@link NilStageOutputReader} do not. + * + * @param offset offset into the stage output file + * + * @see StageOutputHolder#readRemotelyFrom(long) which uses this method + * @see Worker#readStageOutput(StageId, int, long) for documentation on behavior of the returned future + */ + ListenableFuture readRemotelyFrom(long offset); + + /** + * Method for local reads. + * + * Depending on implementation, this method may or may not be able to be called multiple times, and may or may not + * be able to be mixed with {@link #readRemotelyFrom(long)}. Refer to the specific implementation for more details. + * + * Callers are responsible for closing the returned channel. The returned channel may encapsulate resources that + * are not closed by this class's {@link #close()} method. + * + * It is implementation-dependent whether this method can be called multiple times. In particular, + * {@link ChannelStageOutputReader#readLocally()} can only be called one time, but the implementations in + * {@link FileStageOutputReader} and {@link NilStageOutputReader} can be called multiple times. + * + * @see StageOutputHolder#readLocally() which uses this method + */ + ReadableFrameChannel readLocally(); +} diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQDrillWindowQueryTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQDrillWindowQueryTest.java index d3a67fdd659cb..1b2eebe7742e6 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQDrillWindowQueryTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQDrillWindowQueryTest.java @@ -28,6 +28,7 @@ import org.apache.druid.msq.sql.MSQTaskSqlEngine; import org.apache.druid.msq.test.CalciteMSQTestsHelper; import org.apache.druid.msq.test.ExtractResultsFactory; +import org.apache.druid.msq.test.MSQTestBase; import org.apache.druid.msq.test.MSQTestOverlordServiceClient; import org.apache.druid.msq.test.MSQTestTaskActionClient; import org.apache.druid.msq.test.VerifyMSQSupportedNativeQueriesPredicate; @@ -63,15 +64,7 @@ public SqlEngine createEngine( Injector injector ) { - final WorkerMemoryParameters workerMemoryParameters = - WorkerMemoryParameters.createInstance( - WorkerMemoryParameters.PROCESSING_MINIMUM_BYTES * 50, - 2, - 10, - 2, - 0, - 0 - ); + final WorkerMemoryParameters workerMemoryParameters = MSQTestBase.makeTestWorkerMemoryParameters(); final MSQTestOverlordServiceClient indexingServiceClient = new MSQTestOverlordServiceClient( queryJsonMapper, injector, diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerImplTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerImplTest.java deleted file mode 100644 index 171f476ebfe0b..0000000000000 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerImplTest.java +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.druid.msq.exec; - -import org.apache.druid.java.util.common.ISE; -import org.apache.druid.msq.indexing.MSQWorkerTask; -import org.apache.druid.msq.kernel.StageId; -import org.junit.Assert; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.mockito.Mock; -import org.mockito.junit.MockitoJUnitRunner; - -import java.util.HashMap; - - -@RunWith(MockitoJUnitRunner.class) -public class WorkerImplTest -{ - @Mock - WorkerContext workerContext; - - @Test - public void testFetchStatsThrows() - { - WorkerImpl worker = new WorkerImpl(new MSQWorkerTask("controller", "ds", 1, new HashMap<>(), 0), workerContext, WorkerStorageParameters.createInstanceForTests(Long.MAX_VALUE)); - Assert.assertThrows(ISE.class, () -> worker.fetchStatisticsSnapshot(new StageId("xx", 1))); - } - - @Test - public void testFetchStatsWithTimeChunkThrows() - { - WorkerImpl worker = new WorkerImpl(new MSQWorkerTask("controller", "ds", 1, new HashMap<>(), 0), workerContext, WorkerStorageParameters.createInstanceForTests(Long.MAX_VALUE)); - Assert.assertThrows(ISE.class, () -> worker.fetchStatisticsSnapshotForTimeChunk(new StageId("xx", 1), 1L)); - } - -} diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerMemoryParametersTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerMemoryParametersTest.java index 29614fc073471..1ead2a181fd9c 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerMemoryParametersTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerMemoryParametersTest.java @@ -32,34 +32,54 @@ public class WorkerMemoryParametersTest @Test public void test_oneWorkerInJvm_alone() { - Assert.assertEquals(params(335_500_000, 1, 41, 75_000_000), create(1_000_000_000, 1, 1, 1, 0, 0)); - Assert.assertEquals(params(223_000_000, 2, 13, 75_000_000), create(1_000_000_000, 1, 2, 1, 0, 0)); - Assert.assertEquals(params(133_000_000, 4, 3, 75_000_000), create(1_000_000_000, 1, 4, 1, 0, 0)); - Assert.assertEquals(params(73_000_000, 3, 2, 75_000_000), create(1_000_000_000, 1, 8, 1, 0, 0)); - Assert.assertEquals(params(49_923_076, 2, 2, 75_000_000), create(1_000_000_000, 1, 12, 1, 0, 0)); + Assert.assertEquals(params(335_500_000, 1, 41, 75_000_000), create(1_000_000_000, 1, 1, 1, 1, 0, 0)); + Assert.assertEquals(params(223_000_000, 2, 13, 75_000_000), create(1_000_000_000, 1, 2, 1, 1, 0, 0)); + Assert.assertEquals(params(133_000_000, 4, 3, 75_000_000), create(1_000_000_000, 1, 4, 1, 1, 0, 0)); + Assert.assertEquals(params(73_000_000, 3, 2, 75_000_000), create(1_000_000_000, 1, 8, 1, 1, 0, 0)); + Assert.assertEquals(params(49_923_076, 2, 2, 75_000_000), create(1_000_000_000, 1, 12, 1, 1, 0, 0)); final MSQException e = Assert.assertThrows( MSQException.class, - () -> create(1_000_000_000, 1, 32, 1, 0, 0) + () -> create(1_000_000_000, 1, 32, 1, 1, 0, 0) ); - Assert.assertEquals(new NotEnoughMemoryFault(1_588_044_000, 1_000_000_000, 750_000_000, 1, 32), e.getFault()); + Assert.assertEquals(new NotEnoughMemoryFault(1_588_044_000, 1_000_000_000, 750_000_000, 1, 32, 1), e.getFault()); - final MSQFault fault = Assert.assertThrows(MSQException.class, () -> create(1_000_000_000, 2, 32, 1, 0, 0)) + final MSQFault fault = Assert.assertThrows(MSQException.class, () -> create(1_000_000_000, 2, 32, 1, 1, 0, 0)) .getFault(); - Assert.assertEquals(new NotEnoughMemoryFault(2024045333, 1_000_000_000, 750_000_000, 2, 32), fault); + Assert.assertEquals(new NotEnoughMemoryFault(2024045333, 1_000_000_000, 750_000_000, 2, 32, 1), fault); + } + + @Test + public void test_oneWorkerInJvm_alone_twoConcurrentStages() + { + Assert.assertEquals(params(166_750_000, 1, 20, 37_500_000), create(1_000_000_000, 1, 1, 2, 1, 0, 0)); + Assert.assertEquals(params(110_500_000, 2, 6, 37_500_000), create(1_000_000_000, 1, 2, 2, 1, 0, 0)); + Assert.assertEquals(params(65_500_000, 2, 3, 37_500_000), create(1_000_000_000, 1, 4, 2, 1, 0, 0)); + Assert.assertEquals(params(35_500_000, 1, 3, 37_500_000), create(1_000_000_000, 1, 8, 2, 1, 0, 0)); + + final MSQException e = Assert.assertThrows( + MSQException.class, + () -> create(1_000_000_000, 1, 12, 2, 1, 0, 0) + ); + + Assert.assertEquals(new NotEnoughMemoryFault(1_736_034_666, 1_000_000_000, 750_000_000, 1, 12, 2), e.getFault()); + + final MSQFault fault = Assert.assertThrows(MSQException.class, () -> create(1_000_000_000, 2, 32, 2, 1, 0, 0)) + .getFault(); + Assert.assertEquals(new NotEnoughMemoryFault(4_048_090_666L, 1_000_000_000, 750_000_000, 2, 32, 2), fault); } @Test public void test_oneWorkerInJvm_twoHundredWorkersInCluster() { - Assert.assertEquals(params(474_000_000, 1, 83, 150_000_000), create(2_000_000_000, 1, 1, 200, 0, 0)); - Assert.assertEquals(params(249_000_000, 2, 27, 150_000_000), create(2_000_000_000, 1, 2, 200, 0, 0)); + Assert.assertEquals(params(474_000_000, 1, 83, 150_000_000), create(2_000_000_000, 1, 1, 1, 200, 0, 0)); + Assert.assertEquals(params(249_000_000, 2, 27, 150_000_000), create(2_000_000_000, 1, 2, 1, 200, 0, 0)); final MSQException e = Assert.assertThrows( MSQException.class, - () -> create(1_000_000_000, 1, 4, 200, 0, 0) + () -> create(1_000_000_000, 1, 4, 1, 200, 0, 0) ); Assert.assertEquals(new TooManyWorkersFault(200, 109), e.getFault()); @@ -68,76 +88,102 @@ public void test_oneWorkerInJvm_twoHundredWorkersInCluster() @Test public void test_fourWorkersInJvm_twoHundredWorkersInCluster() { - Assert.assertEquals(params(1_014_000_000, 1, 150, 168_750_000), create(9_000_000_000L, 4, 1, 200, 0, 0)); - Assert.assertEquals(params(811_500_000, 2, 62, 168_750_000), create(9_000_000_000L, 4, 2, 200, 0, 0)); - Assert.assertEquals(params(558_375_000, 4, 22, 168_750_000), create(9_000_000_000L, 4, 4, 200, 0, 0)); - Assert.assertEquals(params(305_250_000, 4, 14, 168_750_000), create(9_000_000_000L, 4, 8, 200, 0, 0)); - Assert.assertEquals(params(102_750_000, 4, 8, 168_750_000), create(9_000_000_000L, 4, 16, 200, 0, 0)); + Assert.assertEquals(params(1_014_000_000, 1, 150, 168_750_000), create(9_000_000_000L, 4, 1, 1, 200, 0, 0)); + Assert.assertEquals(params(811_500_000, 2, 62, 168_750_000), create(9_000_000_000L, 4, 2, 1, 200, 0, 0)); + Assert.assertEquals(params(558_375_000, 4, 22, 168_750_000), create(9_000_000_000L, 4, 4, 1, 200, 0, 0)); + Assert.assertEquals(params(305_250_000, 4, 14, 168_750_000), create(9_000_000_000L, 4, 8, 1, 200, 0, 0)); + Assert.assertEquals(params(102_750_000, 4, 8, 168_750_000), create(9_000_000_000L, 4, 16, 1, 200, 0, 0)); final MSQException e = Assert.assertThrows( MSQException.class, - () -> create(8_000_000_000L, 4, 32, 200, 0, 0) + () -> create(8_000_000_000L, 4, 32, 1, 200, 0, 0) ); Assert.assertEquals(new TooManyWorkersFault(200, 124), e.getFault()); // Make sure 124 actually works, and 125 doesn't. (Verify the error message above.) - Assert.assertEquals(params(25_000_000, 4, 3, 150_000_000), create(8_000_000_000L, 4, 32, 124, 0, 0)); + Assert.assertEquals(params(25_000_000, 4, 3, 150_000_000), create(8_000_000_000L, 4, 32, 1, 124, 0, 0)); final MSQException e2 = Assert.assertThrows( MSQException.class, - () -> create(8_000_000_000L, 4, 32, 125, 0, 0) + () -> create(8_000_000_000L, 4, 32, 1, 125, 0, 0) ); Assert.assertEquals(new TooManyWorkersFault(125, 124), e2.getFault()); } + @Test + public void test_fourWorkersInJvm_twoHundredWorkersInCluster_twoConcurrentStages() + { + Assert.assertEquals(params(406_500_000, 1, 74, 84_375_000), create(9_000_000_000L, 4, 1, 2, 200, 0, 0)); + Assert.assertEquals(params(305_250_000, 2, 30, 84_375_000), create(9_000_000_000L, 4, 2, 2, 200, 0, 0)); + Assert.assertEquals(params(178_687_500, 4, 10, 84_375_000), create(9_000_000_000L, 4, 4, 2, 200, 0, 0)); + Assert.assertEquals(params(52_125_000, 4, 6, 84_375_000), create(9_000_000_000L, 4, 8, 2, 200, 0, 0)); + + final MSQException e = Assert.assertThrows( + MSQException.class, + () -> create(8_000_000_000L, 4, 16, 2, 200, 0, 0) + ); + + Assert.assertEquals(new TooManyWorkersFault(200, 109), e.getFault()); + + // Make sure 109 actually works, and 110 doesn't. (Verify the error message above.) + Assert.assertEquals(params(25_000_000, 4, 3, 75_000_000), create(8_000_000_000L, 4, 16, 2, 109, 0, 0)); + + final MSQException e2 = Assert.assertThrows( + MSQException.class, + () -> create(8_000_000_000L, 4, 16, 2, 110, 0, 0) + ); + + Assert.assertEquals(new TooManyWorkersFault(110, 109), e2.getFault()); + } + @Test public void test_oneWorkerInJvm_smallWorkerCapacity() { // Supersorter max channels per processer are one less than they are usually to account for extra frames that are required while creating composing output channels - Assert.assertEquals(params(41_200_000, 1, 3, 9_600_000), create(128_000_000, 1, 1, 1, 0, 0)); - Assert.assertEquals(params(26_800_000, 1, 1, 9_600_000), create(128_000_000, 1, 2, 1, 0, 0)); + Assert.assertEquals(params(41_200_000, 1, 3, 9_600_000), create(128_000_000, 1, 1, 1, 1, 0, 0)); + Assert.assertEquals(params(26_800_000, 1, 1, 9_600_000), create(128_000_000, 1, 2, 1, 1, 0, 0)); final MSQException e = Assert.assertThrows( MSQException.class, - () -> create(1_000_000_000, 1, 32, 1, 0, 0) + () -> create(1_000_000_000, 1, 32, 1, 1, 0, 0) ); - Assert.assertEquals(new NotEnoughMemoryFault(1_588_044_000, 1_000_000_000, 750_000_000, 1, 32), e.getFault()); + Assert.assertEquals(new NotEnoughMemoryFault(1_588_044_000, 1_000_000_000, 750_000_000, 1, 32, 1), e.getFault()); final MSQException e2 = Assert.assertThrows( MSQException.class, - () -> create(128_000_000, 1, 4, 1, 0, 0) + () -> create(128_000_000, 1, 4, 1, 1, 0, 0) ); - Assert.assertEquals(new NotEnoughMemoryFault(580_006_666, 12_8000_000, 96_000_000, 1, 4), e2.getFault()); + Assert.assertEquals(new NotEnoughMemoryFault(580_006_666, 12_8000_000, 96_000_000, 1, 4, 1), e2.getFault()); - final MSQFault fault = Assert.assertThrows(MSQException.class, () -> create(1_000_000_000, 2, 32, 1, 0, 0)) + final MSQFault fault = Assert.assertThrows(MSQException.class, () -> create(1_000_000_000, 2, 32, 1, 1, 0, 0)) .getFault(); - Assert.assertEquals(new NotEnoughMemoryFault(2024045333, 1_000_000_000, 750_000_000, 2, 32), fault); + Assert.assertEquals(new NotEnoughMemoryFault(2024045333, 1_000_000_000, 750_000_000, 2, 32, 1), fault); } @Test public void test_fourWorkersInJvm_twoHundredWorkersInCluster_hashPartitions() { - Assert.assertEquals(params(814_000_000, 1, 150, 168_750_000), create(9_000_000_000L, 4, 1, 200, 200, 0)); - Assert.assertEquals(params(611_500_000, 2, 62, 168_750_000), create(9_000_000_000L, 4, 2, 200, 200, 0)); - Assert.assertEquals(params(358_375_000, 4, 22, 168_750_000), create(9_000_000_000L, 4, 4, 200, 200, 0)); - Assert.assertEquals(params(105_250_000, 4, 14, 168_750_000), create(9_000_000_000L, 4, 8, 200, 200, 0)); + Assert.assertEquals(params(814_000_000, 1, 150, 168_750_000), create(9_000_000_000L, 4, 1, 1, 200, 200, 0)); + Assert.assertEquals(params(611_500_000, 2, 62, 168_750_000), create(9_000_000_000L, 4, 2, 1, 200, 200, 0)); + Assert.assertEquals(params(358_375_000, 4, 22, 168_750_000), create(9_000_000_000L, 4, 4, 1, 200, 200, 0)); + Assert.assertEquals(params(105_250_000, 4, 14, 168_750_000), create(9_000_000_000L, 4, 8, 1, 200, 200, 0)); final MSQException e = Assert.assertThrows( MSQException.class, - () -> create(9_000_000_000L, 4, 16, 200, 200, 0) + () -> create(9_000_000_000L, 4, 16, 1, 200, 200, 0) ); Assert.assertEquals(new TooManyWorkersFault(200, 138), e.getFault()); // Make sure 138 actually works, and 139 doesn't. (Verify the error message above.) - Assert.assertEquals(params(26_750_000, 4, 8, 168_750_000), create(9_000_000_000L, 4, 16, 138, 138, 0)); + Assert.assertEquals(params(26_750_000, 4, 8, 168_750_000), create(9_000_000_000L, 4, 16, 1, 138, 138, 0)); final MSQException e2 = Assert.assertThrows( MSQException.class, - () -> create(9_000_000_000L, 4, 16, 139, 139, 0) + () -> create(9_000_000_000L, 4, 16, 1, 139, 139, 0) ); Assert.assertEquals(new TooManyWorkersFault(139, 138), e2.getFault()); @@ -148,10 +194,10 @@ public void test_oneWorkerInJvm_oneByteUsableMemory() { final MSQException e = Assert.assertThrows( MSQException.class, - () -> WorkerMemoryParameters.createInstance(1, 1, 1, 32, 1, 1) + () -> WorkerMemoryParameters.createInstance(1, 1, 1, 1, 32, 1, 1) ); - Assert.assertEquals(new NotEnoughMemoryFault(554669334, 1, 1, 1, 1), e.getFault()); + Assert.assertEquals(new NotEnoughMemoryFault(554669334, 1, 1, 1, 1, 1), e.getFault()); } @Test @@ -179,6 +225,7 @@ private static WorkerMemoryParameters create( final long maxMemoryInJvm, final int numWorkersInJvm, final int numProcessingThreadsInJvm, + final int maxConcurrentStages, final int numInputWorkers, final int numHashOutputPartitions, final int totalLookUpFootprint @@ -188,6 +235,7 @@ private static WorkerMemoryParameters create( maxMemoryInJvm, numWorkersInJvm, numProcessingThreadsInJvm, + maxConcurrentStages, numInputWorkers, numHashOutputPartitions, totalLookUpFootprint diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/IndexerWorkerContextTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/IndexerWorkerContextTest.java index 583c21d3407c7..dfb88d17b216a 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/IndexerWorkerContextTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/IndexerWorkerContextTest.java @@ -19,6 +19,7 @@ package org.apache.druid.msq.indexing; +import com.google.common.collect.ImmutableMap; import com.google.common.util.concurrent.Futures; import com.google.inject.Injector; import org.apache.druid.indexing.common.SegmentCacheManagerFactory; @@ -30,6 +31,7 @@ import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; +import org.mockito.quality.Strictness; import java.util.Collections; @@ -44,12 +46,19 @@ public void setup() Mockito.when(injectorMock.getInstance(SegmentCacheManagerFactory.class)) .thenReturn(Mockito.mock(SegmentCacheManagerFactory.class)); + final MSQWorkerTask task = + Mockito.mock(MSQWorkerTask.class, Mockito.withSettings().strictness(Strictness.STRICT_STUBS)); + Mockito.when(task.getContext()).thenReturn(ImmutableMap.of()); + indexerWorkerContext = new IndexerWorkerContext( + task, Mockito.mock(TaskToolbox.class), injectorMock, null, null, null, + null, + null, null ); } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/WorkerChatHandlerTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/WorkerChatHandlerTest.java index 5d86abd129ce9..ccf91acb6667b 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/WorkerChatHandlerTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/WorkerChatHandlerTest.java @@ -19,12 +19,8 @@ package org.apache.druid.msq.indexing; -import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.util.concurrent.ListenableFuture; import org.apache.druid.frame.key.ClusterByPartitions; -import org.apache.druid.indexer.TaskStatus; -import org.apache.druid.indexing.common.TaskToolbox; -import org.apache.druid.indexing.common.task.NoopTestTaskReportFileWriter; -import org.apache.druid.jackson.DefaultObjectMapper; import org.apache.druid.java.util.common.ISE; import org.apache.druid.msq.counters.CounterSnapshotsTree; import org.apache.druid.msq.exec.Worker; @@ -32,12 +28,9 @@ import org.apache.druid.msq.kernel.StageId; import org.apache.druid.msq.kernel.WorkOrder; import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot; -import org.apache.druid.segment.IndexIO; -import org.apache.druid.segment.IndexMergerV9; -import org.apache.druid.segment.column.ColumnConfig; -import org.apache.druid.segment.writeout.OffHeapMemorySegmentWriteOutMediumFactory; import org.apache.druid.server.security.AuthConfig; import org.apache.druid.server.security.AuthenticationResult; +import org.apache.druid.server.security.AuthorizerMapper; import org.apache.druid.sql.calcite.util.CalciteTests; import org.junit.After; import org.junit.Assert; @@ -51,15 +44,16 @@ import javax.servlet.http.HttpServletRequest; import javax.ws.rs.core.Response; import java.io.InputStream; -import java.util.HashMap; public class WorkerChatHandlerTest { private static final StageId TEST_STAGE = new StageId("123", 0); + private static final String DATASOURCE = "foo"; + @Mock private HttpServletRequest req; - private TaskToolbox toolbox; + private AuthorizerMapper authorizerMapper; private AutoCloseable mocks; private final TestWorker worker = new TestWorker(); @@ -67,29 +61,16 @@ public class WorkerChatHandlerTest @Before public void setUp() { - ObjectMapper mapper = new DefaultObjectMapper(); - IndexIO indexIO = new IndexIO(mapper, ColumnConfig.DEFAULT); - IndexMergerV9 indexMerger = new IndexMergerV9( - mapper, - indexIO, - OffHeapMemorySegmentWriteOutMediumFactory.instance() - ); - + authorizerMapper = CalciteTests.TEST_AUTHORIZER_MAPPER; mocks = MockitoAnnotations.openMocks(this); Mockito.when(req.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)) .thenReturn(new AuthenticationResult("druid", "druid", null, null)); - TaskToolbox.Builder builder = new TaskToolbox.Builder(); - toolbox = builder.authorizerMapper(CalciteTests.TEST_AUTHORIZER_MAPPER) - .indexIO(indexIO) - .indexMergerV9(indexMerger) - .taskReportFileWriter(new NoopTestTaskReportFileWriter()) - .build(); } @Test public void testFetchSnapshot() { - WorkerChatHandler chatHandler = new WorkerChatHandler(toolbox, worker); + WorkerChatHandler chatHandler = new WorkerChatHandler(worker, authorizerMapper, DATASOURCE); Assert.assertEquals( ClusterByStatisticsSnapshot.empty(), chatHandler.httpFetchKeyStatistics(TEST_STAGE.getQueryId(), TEST_STAGE.getStageNumber(), null, req) @@ -100,7 +81,7 @@ public void testFetchSnapshot() @Test public void testFetchSnapshot404() { - WorkerChatHandler chatHandler = new WorkerChatHandler(toolbox, worker); + WorkerChatHandler chatHandler = new WorkerChatHandler(worker, authorizerMapper, DATASOURCE); Assert.assertEquals( Response.Status.BAD_REQUEST.getStatusCode(), chatHandler.httpFetchKeyStatistics("123", 2, null, req) @@ -111,7 +92,7 @@ public void testFetchSnapshot404() @Test public void testFetchSnapshotWithTimeChunk() { - WorkerChatHandler chatHandler = new WorkerChatHandler(toolbox, worker); + WorkerChatHandler chatHandler = new WorkerChatHandler(worker, authorizerMapper, DATASOURCE); Assert.assertEquals( ClusterByStatisticsSnapshot.empty(), chatHandler.httpFetchKeyStatisticsWithSnapshot(TEST_STAGE.getQueryId(), TEST_STAGE.getStageNumber(), 1, null, req) @@ -122,7 +103,7 @@ public void testFetchSnapshotWithTimeChunk() @Test public void testFetchSnapshotWithTimeChunk404() { - WorkerChatHandler chatHandler = new WorkerChatHandler(toolbox, worker); + WorkerChatHandler chatHandler = new WorkerChatHandler(worker, authorizerMapper, DATASOURCE); Assert.assertEquals( Response.Status.BAD_REQUEST.getStatusCode(), chatHandler.httpFetchKeyStatisticsWithSnapshot("123", 2, 1, null, req) @@ -133,7 +114,6 @@ public void testFetchSnapshotWithTimeChunk404() private static class TestWorker implements Worker { - @Override public String id() { @@ -141,25 +121,25 @@ public String id() } @Override - public MSQWorkerTask task() + public void run() { - return new MSQWorkerTask("controller", "ds", 1, new HashMap<>(), 0); + } @Override - public TaskStatus run() + public void stop() { - return null; + } @Override - public void stopGracefully() + public void controllerFailed() { } @Override - public void controllerFailed() + public void awaitStop() { } @@ -192,9 +172,8 @@ public ClusterByStatisticsSnapshot fetchStatisticsSnapshotForTimeChunk(StageId s @Override public boolean postResultPartitionBoundaries( - ClusterByPartitions stagePartitionBoundaries, - String queryId, - int stageNumber + StageId stageId, + ClusterByPartitions stagePartitionBoundaries ) { return false; @@ -202,7 +181,7 @@ public boolean postResultPartitionBoundaries( @Nullable @Override - public InputStream readChannel(String queryId, int stageNumber, int partitionNumber, long offset) + public ListenableFuture readStageOutput(StageId stageId, int partitionNumber, long offset) { return null; } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/error/MSQFaultSerdeTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/error/MSQFaultSerdeTest.java index c33faa40c14e0..cffc0f78a497d 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/error/MSQFaultSerdeTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/error/MSQFaultSerdeTest.java @@ -74,7 +74,7 @@ public void testFaultSerde() throws IOException )); assertFaultSerde(new InvalidNullByteFault("the source", 1, "the column", "the value", 2)); assertFaultSerde(new InvalidFieldFault("the source", "the column", 1, "the error", "the log msg")); - assertFaultSerde(new NotEnoughMemoryFault(1000, 1000, 900, 1, 2)); + assertFaultSerde(new NotEnoughMemoryFault(1000, 1000, 900, 1, 2, 2)); assertFaultSerde(QueryNotSupportedFault.INSTANCE); assertFaultSerde(new QueryRuntimeFault("new error", "base error")); assertFaultSerde(new QueryRuntimeFault("new error", null)); diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/shuffle/output/ByteChunksInputStreamTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/shuffle/output/ByteChunksInputStreamTest.java new file mode 100644 index 0000000000000..bc349d56c8fdb --- /dev/null +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/shuffle/output/ByteChunksInputStreamTest.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.shuffle.output; + +import com.google.common.collect.ImmutableList; +import org.hamcrest.MatcherAssert; +import org.hamcrest.Matchers; +import org.junit.Assert; +import org.junit.Test; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.List; + +public class ByteChunksInputStreamTest +{ + private final List chunks = ImmutableList.of( + new byte[]{-128, -127, -1, 0, 1, 126, 127}, + new byte[]{0}, + new byte[]{3, 4, 5} + ); + + @Test + public void test_read_fromStart() throws IOException + { + try (final InputStream in = new ByteChunksInputStream(chunks, 0)) { + final ByteArrayOutputStream baos = new ByteArrayOutputStream(); + + int c; + while ((c = in.read()) != -1) { + MatcherAssert.assertThat("InputStream#read contract", c, Matchers.greaterThanOrEqualTo(0)); + baos.write(c); + } + + Assert.assertArrayEquals(chunksSubset(0), baos.toByteArray()); + } + } + + @Test + public void test_read_fromSecondByte() throws IOException + { + try (final InputStream in = new ByteChunksInputStream(chunks, 1)) { + final ByteArrayOutputStream baos = new ByteArrayOutputStream(); + + int c; + while ((c = in.read()) != -1) { + MatcherAssert.assertThat("InputStream#read contract", c, Matchers.greaterThanOrEqualTo(0)); + baos.write(c); + } + + Assert.assertArrayEquals(chunksSubset(1), baos.toByteArray()); + } + } + + @Test + public void test_read_array1_fromStart() throws IOException + { + try (final InputStream in = new ByteChunksInputStream(chunks, 0)) { + final ByteArrayOutputStream baos = new ByteArrayOutputStream(); + final byte[] buf = new byte[2]; + + int r; + while ((r = in.read(buf, 1, 1)) != -1) { + Assert.assertEquals("InputStream#read bytes read", 1, r); + baos.write(buf, 1, 1); + } + + Assert.assertArrayEquals(chunksSubset(0), baos.toByteArray()); + } + } + + @Test + public void test_read_array1_fromSecondByte() throws IOException + { + try (final InputStream in = new ByteChunksInputStream(chunks, 1)) { + final ByteArrayOutputStream baos = new ByteArrayOutputStream(); + final byte[] buf = new byte[2]; + + int r; + while ((r = in.read(buf, 1, 1)) != -1) { + Assert.assertEquals("InputStream#read bytes read", 1, r); + baos.write(buf, 1, 1); + } + + Assert.assertArrayEquals(chunksSubset(1), baos.toByteArray()); + } + } + + @Test + public void test_read_array3_fromStart() throws IOException + { + try (final InputStream in = new ByteChunksInputStream(chunks, 0)) { + final ByteArrayOutputStream baos = new ByteArrayOutputStream(); + final byte[] buf = new byte[5]; + + int r; + while ((r = in.read(buf, 2, 3)) != -1) { + baos.write(buf, 2, r); + } + + Assert.assertArrayEquals(chunksSubset(0), baos.toByteArray()); + } + } + + @Test + public void test_read_array3_fromSecondByte() throws IOException + { + try (final InputStream in = new ByteChunksInputStream(chunks, 1)) { + final ByteArrayOutputStream baos = new ByteArrayOutputStream(); + final byte[] buf = new byte[6]; + + int r; + while ((r = in.read(buf, 2, 3)) != -1) { + baos.write(buf, 2, r); + } + + Assert.assertArrayEquals(chunksSubset(1), baos.toByteArray()); + } + } + + private byte[] chunksSubset(final int positionInFirstChunk) + { + final ByteArrayOutputStream baos = new ByteArrayOutputStream(); + + for (int chunk = 0, p = positionInFirstChunk; chunk < chunks.size(); chunk++, p = 0) { + baos.write(chunks.get(chunk), p, chunks.get(chunk).length - p); + } + + return baos.toByteArray(); + } +} diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/shuffle/output/ChannelStageOutputReaderTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/shuffle/output/ChannelStageOutputReaderTest.java new file mode 100644 index 0000000000000..927372a3a6ae9 --- /dev/null +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/shuffle/output/ChannelStageOutputReaderTest.java @@ -0,0 +1,255 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.shuffle.output; + +import com.google.common.io.ByteStreams; +import com.google.common.math.IntMath; +import org.apache.druid.common.guava.FutureUtils; +import org.apache.druid.frame.Frame; +import org.apache.druid.frame.FrameType; +import org.apache.druid.frame.channel.BlockingQueueFrameChannel; +import org.apache.druid.frame.channel.ReadableFileFrameChannel; +import org.apache.druid.frame.file.FrameFile; +import org.apache.druid.frame.read.FrameReader; +import org.apache.druid.frame.testutil.FrameSequenceBuilder; +import org.apache.druid.frame.testutil.FrameTestUtil; +import org.apache.druid.segment.TestIndex; +import org.apache.druid.segment.incremental.IncrementalIndex; +import org.apache.druid.segment.incremental.IncrementalIndexStorageAdapter; +import org.apache.druid.testing.InitializedNullHandlingTest; +import org.hamcrest.CoreMatchers; +import org.hamcrest.MatcherAssert; +import org.hamcrest.Matchers; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.internal.matchers.ThrowableCauseMatcher; +import org.junit.internal.matchers.ThrowableMessageMatcher; +import org.junit.rules.TemporaryFolder; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.math.RoundingMode; +import java.util.List; + +public class ChannelStageOutputReaderTest extends InitializedNullHandlingTest +{ + private static final int MAX_FRAMES = 10; + private static final int EXPECTED_NUM_ROWS = 1209; + + private final BlockingQueueFrameChannel channel = new BlockingQueueFrameChannel(MAX_FRAMES); + private final ChannelStageOutputReader reader = new ChannelStageOutputReader(channel.readable()); + + @Rule + public TemporaryFolder temporaryFolder = new TemporaryFolder(); + + private FrameReader frameReader; + private List frameList; + + @Before + public void setUp() + { + final IncrementalIndex index = TestIndex.getIncrementalTestIndex(); + final IncrementalIndexStorageAdapter adapter = new IncrementalIndexStorageAdapter(index); + frameReader = FrameReader.create(adapter.getRowSignature()); + frameList = FrameSequenceBuilder.fromAdapter(adapter) + .frameType(FrameType.ROW_BASED) + .maxRowsPerFrame(IntMath.divide(index.size(), MAX_FRAMES, RoundingMode.CEILING)) + .frames() + .toList(); + } + + @After + public void tearDown() + { + reader.close(); + } + + @Test + public void test_readLocally() throws IOException + { + writeAllFramesToChannel(); + + Assert.assertSame(channel.readable(), reader.readLocally()); + reader.close(); // Won't close the channel, because it's already been returned by readLocally + + final int numRows = FrameTestUtil.readRowsFromFrameChannel(channel.readable(), frameReader).toList().size(); + Assert.assertEquals(EXPECTED_NUM_ROWS, numRows); + } + + @Test + public void test_readLocally_closePriorToRead() throws IOException + { + writeAllFramesToChannel(); + + reader.close(); + + // Can't read the channel after closing the reader + Assert.assertThrows( + IllegalStateException.class, + reader::readLocally + ); + } + + @Test + public void test_readLocally_thenReadRemotely() throws IOException + { + writeAllFramesToChannel(); + + Assert.assertSame(channel.readable(), reader.readLocally()); + + // Can't read remotely after reading locally + Assert.assertThrows( + IllegalStateException.class, + () -> reader.readRemotelyFrom(0) + ); + + // Can still read locally after this error + final int numRows = FrameTestUtil.readRowsFromFrameChannel(channel.readable(), frameReader).toList().size(); + Assert.assertEquals(EXPECTED_NUM_ROWS, numRows); + } + + @Test + public void test_readRemotely_strideBasedOnReturnedChunk() throws IOException + { + // Test that reads entire chunks from readRemotelyFrom. This is a typical usage pattern. + + writeAllFramesToChannel(); + + final File tmpFile = temporaryFolder.newFile(); + + try (final FileOutputStream tmpOut = new FileOutputStream(tmpFile)) { + int numReads = 0; + long offset = 0; + + while (true) { + try (final InputStream in = FutureUtils.getUnchecked(reader.readRemotelyFrom(offset), true)) { + numReads++; + final long bytesWritten = ByteStreams.copy(in, tmpOut); + offset += bytesWritten; + + if (bytesWritten == 0) { + break; + } + } + } + + MatcherAssert.assertThat(numReads, Matchers.greaterThan(1)); + } + + final FrameFile frameFile = FrameFile.open(tmpFile, null); + final int numRows = + FrameTestUtil.readRowsFromFrameChannel(new ReadableFileFrameChannel(frameFile), frameReader).toList().size(); + + Assert.assertEquals(EXPECTED_NUM_ROWS, numRows); + } + + @Test + public void test_readRemotely_strideOneByte() throws IOException + { + // Test that reads one byte at a time from readRemotelyFrom. This helps ensure that there are no edge cases + // in the chunk-reading logic. + + writeAllFramesToChannel(); + + final File tmpFile = temporaryFolder.newFile(); + + try (final FileOutputStream tmpOut = new FileOutputStream(tmpFile)) { + int numReads = 0; + long offset = 0; + + while (true) { + try (final InputStream in = FutureUtils.getUnchecked(reader.readRemotelyFrom(offset), true)) { + numReads++; + final int nextByte = in.read(); + + if (nextByte < 0) { + break; + } + + tmpOut.write(nextByte); + offset++; + } + } + + Assert.assertEquals(numReads, offset + 1); + } + + final FrameFile frameFile = FrameFile.open(tmpFile, null); + final int numRows = + FrameTestUtil.readRowsFromFrameChannel(new ReadableFileFrameChannel(frameFile), frameReader).toList().size(); + + Assert.assertEquals(EXPECTED_NUM_ROWS, numRows); + } + + @Test + public void test_readRemotely_thenLocally() throws IOException + { + writeAllFramesToChannel(); + + // Read remotely + FutureUtils.getUnchecked(reader.readRemotelyFrom(0), true); + + // Then read locally + Assert.assertThrows( + IllegalStateException.class, + reader::readLocally + ); + } + + @Test + public void test_readRemotely_cannotReverse() throws IOException + { + writeAllFramesToChannel(); + + // Read remotely from offset = 1. + final InputStream in = FutureUtils.getUnchecked(reader.readRemotelyFrom(1), true); + final int offset = ByteStreams.toByteArray(in).length; + MatcherAssert.assertThat(offset, Matchers.greaterThan(0)); + + // Then read again from offset = 0; should get an error. + final RuntimeException e = Assert.assertThrows( + RuntimeException.class, + () -> FutureUtils.getUnchecked(reader.readRemotelyFrom(0), true) + ); + + MatcherAssert.assertThat( + e, + ThrowableCauseMatcher.hasCause( + Matchers.allOf( + CoreMatchers.instanceOf(IllegalStateException.class), + ThrowableMessageMatcher.hasMessage(CoreMatchers.startsWith("Offset[0] no longer available")) + ) + ) + ); + } + + private void writeAllFramesToChannel() throws IOException + { + for (Frame frame : frameList) { + channel.writable().write(frame); + } + channel.writable().close(); + } +} diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteArraysQueryMSQTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteArraysQueryMSQTest.java index b60c6c71d2e27..124b4fce25880 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteArraysQueryMSQTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteArraysQueryMSQTest.java @@ -64,15 +64,7 @@ public SqlEngine createEngine( Injector injector ) { - final WorkerMemoryParameters workerMemoryParameters = - WorkerMemoryParameters.createInstance( - WorkerMemoryParameters.PROCESSING_MINIMUM_BYTES * 50, - 2, - 10, - 2, - 0, - 0 - ); + final WorkerMemoryParameters workerMemoryParameters = MSQTestBase.makeTestWorkerMemoryParameters(); final MSQTestOverlordServiceClient indexingServiceClient = new MSQTestOverlordServiceClient( queryJsonMapper, injector, diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteNestedDataQueryMSQTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteNestedDataQueryMSQTest.java index 2d8067e900e98..5d4c0994ea066 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteNestedDataQueryMSQTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteNestedDataQueryMSQTest.java @@ -67,15 +67,7 @@ public SqlEngine createEngine( Injector injector ) { - final WorkerMemoryParameters workerMemoryParameters = - WorkerMemoryParameters.createInstance( - WorkerMemoryParameters.PROCESSING_MINIMUM_BYTES * 50, - 2, - 10, - 2, - 0, - 0 - ); + final WorkerMemoryParameters workerMemoryParameters = MSQTestBase.makeTestWorkerMemoryParameters(); final MSQTestOverlordServiceClient indexingServiceClient = new MSQTestOverlordServiceClient( queryJsonMapper, injector, diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteSelectJoinQueryMSQTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteSelectJoinQueryMSQTest.java index 317fe30a646da..6bbf9c6da5e40 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteSelectJoinQueryMSQTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteSelectJoinQueryMSQTest.java @@ -136,15 +136,7 @@ public SqlEngine createEngine( Injector injector ) { - final WorkerMemoryParameters workerMemoryParameters = - WorkerMemoryParameters.createInstance( - WorkerMemoryParameters.PROCESSING_MINIMUM_BYTES * 50, - 2, - 10, - 2, - 0, - 0 - ); + final WorkerMemoryParameters workerMemoryParameters = MSQTestBase.makeTestWorkerMemoryParameters(); final MSQTestOverlordServiceClient indexingServiceClient = new MSQTestOverlordServiceClient( queryJsonMapper, injector, diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteSelectQueryMSQTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteSelectQueryMSQTest.java index 3008f9d43b47c..2de9229b4adc7 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteSelectQueryMSQTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteSelectQueryMSQTest.java @@ -73,15 +73,7 @@ public SqlEngine createEngine( Injector injector ) { - final WorkerMemoryParameters workerMemoryParameters = - WorkerMemoryParameters.createInstance( - WorkerMemoryParameters.PROCESSING_MINIMUM_BYTES * 50, - 2, - 10, - 2, - 0, - 0 - ); + final WorkerMemoryParameters workerMemoryParameters = MSQTestBase.makeTestWorkerMemoryParameters(); final MSQTestOverlordServiceClient indexingServiceClient = new MSQTestOverlordServiceClient( queryJsonMapper, injector, diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteUnionQueryMSQTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteUnionQueryMSQTest.java index b5d8368b068f0..e4b678402a8b8 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteUnionQueryMSQTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/CalciteUnionQueryMSQTest.java @@ -79,15 +79,7 @@ public SqlEngine createEngine( Injector injector ) { - final WorkerMemoryParameters workerMemoryParameters = - WorkerMemoryParameters.createInstance( - WorkerMemoryParameters.PROCESSING_MINIMUM_BYTES * 50, - 2, - 10, - 2, - 0, - 0 - ); + final WorkerMemoryParameters workerMemoryParameters = MSQTestBase.makeTestWorkerMemoryParameters(); final MSQTestOverlordServiceClient indexingServiceClient = new MSQTestOverlordServiceClient( queryJsonMapper, injector, diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java index 5f0bd545b7c69..2136d96d6d11c 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java @@ -333,16 +333,7 @@ public class MSQTestBase extends BaseCalciteQueryTest private SegmentCacheManager segmentCacheManager; private TestGroupByBuffers groupByBuffers; - protected final WorkerMemoryParameters workerMemoryParameters = Mockito.spy( - WorkerMemoryParameters.createInstance( - WorkerMemoryParameters.PROCESSING_MINIMUM_BYTES * 50, - 2, - 10, - 2, - 1, - 0 - ) - ); + protected final WorkerMemoryParameters workerMemoryParameters = Mockito.spy(makeTestWorkerMemoryParameters()); protected static class MSQBaseComponentSupplier extends StandardComponentSupplier { @@ -753,6 +744,19 @@ public static ObjectMapper setupObjectMapper(Injector injector) return mapper; } + public static WorkerMemoryParameters makeTestWorkerMemoryParameters() + { + return WorkerMemoryParameters.createInstance( + WorkerMemoryParameters.PROCESSING_MINIMUM_BYTES * 50, + 2, + 10, + 1, + 2, + 1, + 0 + ); + } + private String runMultiStageQuery(String query, Map context) { final DirectStatement stmt = sqlStatementFactory.directStatement( diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerClient.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerClient.java index 96e26cba77e1f..4c7ca61be0232 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerClient.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerClient.java @@ -87,9 +87,9 @@ public void postWorkerWarning(List MSQErrorReports) } @Override - public List getTaskList() + public List getWorkerIds() { - return controller.getTaskIds(); + return controller.getWorkerIds(); } @Override diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerContext.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerContext.java index 20d31fbd4cfef..e651043020329 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerContext.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerContext.java @@ -156,32 +156,33 @@ public ListenableFuture runTask(String taskId, Object taskObject) Worker worker = new WorkerImpl( task, new MSQTestWorkerContext( + task.getId(), inMemoryWorkers, controller, mapper, injector, - workerMemoryParameters - ), - workerStorageParameters + workerMemoryParameters, + workerStorageParameters + ) ); inMemoryWorkers.put(task.getId(), worker); statusMap.put(task.getId(), TaskStatus.running(task.getId())); - ListenableFuture future = executor.submit(() -> { + ListenableFuture future = executor.submit(() -> { try { - return worker.run(); + worker.run(); } catch (Exception e) { throw new RuntimeException(e); } }); - Futures.addCallback(future, new FutureCallback() + Futures.addCallback(future, new FutureCallback() { @Override - public void onSuccess(@Nullable TaskStatus result) + public void onSuccess(@Nullable Object result) { - statusMap.put(task.getId(), result); + statusMap.put(task.getId(), TaskStatus.success(task.getId())); } @Override @@ -261,7 +262,7 @@ public ListenableFuture cancelTask(String workerId) { final Worker worker = inMemoryWorkers.remove(workerId); if (worker != null) { - worker.stopGracefully(); + worker.stop(); } return Futures.immediateFuture(null); } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerClient.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerClient.java index 72cb246a43e16..65145b5f5c01a 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerClient.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerClient.java @@ -80,11 +80,7 @@ public ListenableFuture postResultPartitionBoundaries( ) { try { - inMemoryWorkers.get(workerTaskId).postResultPartitionBoundaries( - partitionBoundaries, - stageId.getQueryId(), - stageId.getStageNumber() - ); + inMemoryWorkers.get(workerTaskId).postResultPartitionBoundaries(stageId, partitionBoundaries); return Futures.immediateFuture(null); } catch (Exception e) { @@ -122,8 +118,7 @@ public ListenableFuture fetchChannelData( ) { try (InputStream inputStream = - inMemoryWorkers.get(workerTaskId) - .readChannel(stageId.getQueryId(), stageId.getStageNumber(), partitionNumber, offset)) { + inMemoryWorkers.get(workerTaskId).readStageOutput(stageId, partitionNumber, offset).get()) { byte[] buffer = new byte[8 * 1024]; boolean didRead = false; int bytesRead; @@ -138,12 +133,11 @@ public ListenableFuture fetchChannelData( catch (Exception e) { throw new ISE(e, "Error reading frame file channel"); } - } @Override public void close() { - inMemoryWorkers.forEach((k, v) -> v.stopGracefully()); + inMemoryWorkers.forEach((k, v) -> v.stop()); } } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerContext.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerContext.java index 14f6f73b24ab0..082429a9d7b13 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerContext.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerContext.java @@ -22,59 +22,69 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.inject.Injector; import org.apache.druid.frame.processor.Bouncer; -import org.apache.druid.indexer.report.TaskReportFileWriter; -import org.apache.druid.indexing.common.TaskToolbox; -import org.apache.druid.indexing.common.task.NoopTestTaskReportFileWriter; import org.apache.druid.java.util.common.FileUtils; import org.apache.druid.java.util.common.io.Closer; import org.apache.druid.msq.exec.Controller; import org.apache.druid.msq.exec.ControllerClient; import org.apache.druid.msq.exec.DataServerQueryHandlerFactory; +import org.apache.druid.msq.exec.OutputChannelMode; import org.apache.druid.msq.exec.Worker; import org.apache.druid.msq.exec.WorkerClient; import org.apache.druid.msq.exec.WorkerContext; import org.apache.druid.msq.exec.WorkerMemoryParameters; -import org.apache.druid.msq.indexing.IndexerFrameContext; -import org.apache.druid.msq.indexing.IndexerWorkerContext; +import org.apache.druid.msq.exec.WorkerStorageParameters; import org.apache.druid.msq.kernel.FrameContext; import org.apache.druid.msq.kernel.QueryDefinition; import org.apache.druid.msq.querykit.DataSegmentProvider; +import org.apache.druid.query.groupby.GroupingEngine; import org.apache.druid.segment.IndexIO; import org.apache.druid.segment.IndexMergerV9; +import org.apache.druid.segment.SegmentWrangler; import org.apache.druid.segment.column.ColumnConfig; import org.apache.druid.segment.incremental.NoopRowIngestionMeters; +import org.apache.druid.segment.incremental.RowIngestionMeters; import org.apache.druid.segment.loading.DataSegmentPusher; -import org.apache.druid.segment.realtime.NoopChatHandlerProvider; import org.apache.druid.segment.writeout.OffHeapMemorySegmentWriteOutMediumFactory; import org.apache.druid.server.DruidNode; -import org.apache.druid.server.coordination.DataSegmentAnnouncer; -import org.apache.druid.server.security.AuthTestUtils; import java.io.File; import java.util.Map; public class MSQTestWorkerContext implements WorkerContext { + private final String workerId; private final Controller controller; private final ObjectMapper mapper; private final Injector injector; private final Map inMemoryWorkers; private final File file = FileUtils.createTempDir(); + private final Bouncer bouncer = new Bouncer(1); private final WorkerMemoryParameters workerMemoryParameters; + private final WorkerStorageParameters workerStorageParameters; public MSQTestWorkerContext( + String workerId, Map inMemoryWorkers, Controller controller, ObjectMapper mapper, Injector injector, - WorkerMemoryParameters workerMemoryParameters + WorkerMemoryParameters workerMemoryParameters, + WorkerStorageParameters workerStorageParameters ) { + this.workerId = workerId; this.inMemoryWorkers = inMemoryWorkers; this.controller = controller; this.mapper = mapper; this.injector = injector; this.workerMemoryParameters = workerMemoryParameters; + this.workerStorageParameters = workerStorageParameters; + } + + @Override + public String queryId() + { + return controller.queryId(); } @Override @@ -96,7 +106,13 @@ public void registerWorker(Worker worker, Closer closer) } @Override - public ControllerClient makeControllerClient(String controllerId) + public String workerId() + { + return workerId; + } + + @Override + public ControllerClient makeControllerClient() { return new MSQTestControllerClient(controller); } @@ -114,42 +130,9 @@ public File tempDir() } @Override - public FrameContext frameContext(QueryDefinition queryDef, int stageNumber) + public FrameContext frameContext(QueryDefinition queryDef, int stageNumber, OutputChannelMode outputChannelMode) { - IndexIO indexIO = new IndexIO(mapper, ColumnConfig.DEFAULT); - IndexMergerV9 indexMerger = new IndexMergerV9( - mapper, - indexIO, - OffHeapMemorySegmentWriteOutMediumFactory.instance(), - true - ); - final TaskReportFileWriter reportFileWriter = new NoopTestTaskReportFileWriter(); - - return new IndexerFrameContext( - new IndexerWorkerContext( - new TaskToolbox.Builder() - .segmentPusher(injector.getInstance(DataSegmentPusher.class)) - .segmentAnnouncer(injector.getInstance(DataSegmentAnnouncer.class)) - .jsonMapper(mapper) - .taskWorkDir(tempDir()) - .indexIO(indexIO) - .indexMergerV9(indexMerger) - .taskReportFileWriter(reportFileWriter) - .authorizerMapper(AuthTestUtils.TEST_AUTHORIZER_MAPPER) - .chatHandlerProvider(new NoopChatHandlerProvider()) - .rowIngestionMetersFactory(NoopRowIngestionMeters::new) - .build(), - injector, - indexIO, - null, - null, - null - ), - indexIO, - injector.getInstance(DataSegmentProvider.class), - injector.getInstance(DataServerQueryHandlerFactory.class), - workerMemoryParameters - ); + return new FrameContextImpl(new File(tempDir(), queryDef.getStageDefinition(stageNumber).getId().toString())); } @Override @@ -165,9 +148,9 @@ public DruidNode selfNode() } @Override - public Bouncer processorBouncer() + public int maxConcurrentStages() { - return injector.getInstance(Bouncer.class); + return 1; } @Override @@ -175,4 +158,109 @@ public DataServerQueryHandlerFactory dataServerQueryHandlerFactory() { return injector.getInstance(DataServerQueryHandlerFactory.class); } + + class FrameContextImpl implements FrameContext + { + private final File tempDir; + + public FrameContextImpl(File tempDir) + { + this.tempDir = tempDir; + } + + @Override + public SegmentWrangler segmentWrangler() + { + return injector.getInstance(SegmentWrangler.class); + } + + @Override + public GroupingEngine groupingEngine() + { + return injector.getInstance(GroupingEngine.class); + } + + @Override + public RowIngestionMeters rowIngestionMeters() + { + return new NoopRowIngestionMeters(); + } + + @Override + public DataSegmentProvider dataSegmentProvider() + { + return injector.getInstance(DataSegmentProvider.class); + } + + @Override + public DataServerQueryHandlerFactory dataServerQueryHandlerFactory() + { + return injector.getInstance(DataServerQueryHandlerFactory.class); + } + + @Override + public File tempDir() + { + return new File(tempDir, "tmp"); + } + + @Override + public ObjectMapper jsonMapper() + { + return mapper; + } + + @Override + public IndexIO indexIO() + { + return new IndexIO(mapper, ColumnConfig.DEFAULT); + } + + @Override + public File persistDir() + { + return new File(tempDir, "persist"); + } + + @Override + public DataSegmentPusher segmentPusher() + { + return injector.getInstance(DataSegmentPusher.class); + } + + @Override + public IndexMergerV9 indexMerger() + { + return new IndexMergerV9( + mapper, + indexIO(), + OffHeapMemorySegmentWriteOutMediumFactory.instance(), + true + ); + } + + @Override + public Bouncer processorBouncer() + { + return bouncer; + } + + @Override + public WorkerMemoryParameters memoryParameters() + { + return workerMemoryParameters; + } + + @Override + public WorkerStorageParameters storageParameters() + { + return workerStorageParameters; + } + + @Override + public void close() + { + + } + } } diff --git a/processing/src/main/java/org/apache/druid/frame/channel/ReadableFileFrameChannel.java b/processing/src/main/java/org/apache/druid/frame/channel/ReadableFileFrameChannel.java index 963a001ad6db3..7da6550ccca72 100644 --- a/processing/src/main/java/org/apache/druid/frame/channel/ReadableFileFrameChannel.java +++ b/processing/src/main/java/org/apache/druid/frame/channel/ReadableFileFrameChannel.java @@ -104,6 +104,14 @@ public void close() } } + /** + * Returns whether this channel represents the entire underlying {@link FrameFile}. + */ + public boolean isEntireFile() + { + return currentFrame == 0 && endFrame == frameFile.numFrames(); + } + /** * Returns a new reference to the {@link FrameFile} that this channel is reading from. Callers should close this * reference when done reading. diff --git a/processing/src/test/java/org/apache/druid/frame/processor/ReadableFileFrameChannelTest.java b/processing/src/test/java/org/apache/druid/frame/processor/ReadableFileFrameChannelTest.java new file mode 100644 index 0000000000000..9025d1820864a --- /dev/null +++ b/processing/src/test/java/org/apache/druid/frame/processor/ReadableFileFrameChannelTest.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.frame.processor; + +import org.apache.druid.frame.FrameType; +import org.apache.druid.frame.channel.ReadableFileFrameChannel; +import org.apache.druid.frame.file.FrameFile; +import org.apache.druid.frame.read.FrameReader; +import org.apache.druid.frame.testutil.FrameSequenceBuilder; +import org.apache.druid.frame.testutil.FrameTestUtil; +import org.apache.druid.java.util.common.guava.Sequences; +import org.apache.druid.segment.QueryableIndexStorageAdapter; +import org.apache.druid.segment.StorageAdapter; +import org.apache.druid.segment.TestIndex; +import org.apache.druid.testing.InitializedNullHandlingTest; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.io.File; +import java.io.IOException; +import java.util.List; + +public class ReadableFileFrameChannelTest extends InitializedNullHandlingTest +{ + private static final int ROWS_PER_FRAME = 20; + + private List> allRows; + private FrameReader frameReader; + private FrameFile frameFile; + + @Rule + public TemporaryFolder temporaryFolder = new TemporaryFolder(); + + @Before + public void setUp() throws IOException + { + final StorageAdapter adapter = new QueryableIndexStorageAdapter(TestIndex.getNoRollupMMappedTestIndex()); + final File file = FrameTestUtil.writeFrameFile( + FrameSequenceBuilder.fromAdapter(adapter) + .frameType(FrameType.ROW_BASED) + .maxRowsPerFrame(ROWS_PER_FRAME) + .frames(), + temporaryFolder.newFile() + ); + allRows = FrameTestUtil.readRowsFromAdapter(adapter, adapter.getRowSignature(), false).toList(); + frameReader = FrameReader.create(adapter.getRowSignature()); + frameFile = FrameFile.open(file, null, FrameFile.Flag.DELETE_ON_CLOSE); + } + + @After + public void tearDown() throws Exception + { + frameFile.close(); + } + + @Test + public void test_fullFile() + { + final ReadableFileFrameChannel channel = new ReadableFileFrameChannel(frameFile); + Assert.assertTrue(channel.isEntireFile()); + + FrameTestUtil.assertRowsEqual( + Sequences.simple(allRows), + FrameTestUtil.readRowsFromFrameChannel(channel, frameReader) + ); + + Assert.assertFalse(channel.isEntireFile()); + } + + @Test + public void test_partialFile() + { + final ReadableFileFrameChannel channel = new ReadableFileFrameChannel(frameFile, 1, 2); + Assert.assertFalse(channel.isEntireFile()); + + FrameTestUtil.assertRowsEqual( + Sequences.simple(allRows).skip(ROWS_PER_FRAME).limit(ROWS_PER_FRAME), + FrameTestUtil.readRowsFromFrameChannel(channel, frameReader) + ); + + Assert.assertFalse(channel.isEntireFile()); + } +}