diff --git a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobContainer.java b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobContainer.java index c6ae58371e15c..75f5990a1ccf2 100644 --- a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobContainer.java +++ b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobContainer.java @@ -254,24 +254,7 @@ public void readBlobAsync(String blobName, ActionListener listener) blobPartInputStreamFutures.add(getBlobPartInputStreamContainer(s3AsyncClient, bucketName, blobKey, partNumber)); } } - - CompletableFuture.allOf(blobPartInputStreamFutures.toArray(CompletableFuture[]::new)) - .whenComplete((unused, partThrowable) -> { - if (partThrowable == null) { - listener.onResponse( - new ReadContext( - blobSize, - blobPartInputStreamFutures.stream().map(CompletableFuture::join).collect(Collectors.toList()), - blobChecksum - ) - ); - } else { - Exception ex = partThrowable.getCause() instanceof Exception - ? (Exception) partThrowable.getCause() - : new Exception(partThrowable.getCause()); - listener.onFailure(ex); - } - }); + listener.onResponse(new ReadContext(blobSize, blobPartInputStreamFutures, blobChecksum)); }); } catch (Exception ex) { listener.onFailure(SdkException.create("Error occurred while fetching blob parts from the repository", ex)); diff --git a/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobStoreContainerTests.java b/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobStoreContainerTests.java index 9817d7cd520ef..055a882885065 100644 --- a/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobStoreContainerTests.java +++ b/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobStoreContainerTests.java @@ -969,7 +969,7 @@ public void testReadBlobAsyncMultiPart() throws Exception { assertEquals(objectSize, readContext.getBlobSize()); for (int partNumber = 1; partNumber < objectPartCount; partNumber++) { - InputStreamContainer inputStreamContainer = readContext.getPartStreams().get(partNumber); + InputStreamContainer inputStreamContainer = readContext.getPartStreams().get(partNumber).get(); final int offset = partNumber * partSize; assertEquals(partSize, inputStreamContainer.getContentLength()); assertEquals(offset, inputStreamContainer.getOffset()); @@ -1024,7 +1024,7 @@ public void testReadBlobAsyncSinglePart() throws Exception { assertEquals(checksum, readContext.getBlobChecksum()); assertEquals(objectSize, readContext.getBlobSize()); - InputStreamContainer inputStreamContainer = readContext.getPartStreams().stream().findFirst().get(); + InputStreamContainer inputStreamContainer = readContext.getPartStreams().stream().findFirst().get().get(); assertEquals(objectSize, inputStreamContainer.getContentLength()); assertEquals(0, inputStreamContainer.getOffset()); assertEquals(objectSize, inputStreamContainer.getInputStream().readAllBytes().length); diff --git a/server/src/internalClusterTest/java/org/opensearch/remotestore/multipart/mocks/MockFsAsyncBlobContainer.java b/server/src/internalClusterTest/java/org/opensearch/remotestore/multipart/mocks/MockFsAsyncBlobContainer.java index 079753de95680..934656cdf702b 100644 --- a/server/src/internalClusterTest/java/org/opensearch/remotestore/multipart/mocks/MockFsAsyncBlobContainer.java +++ b/server/src/internalClusterTest/java/org/opensearch/remotestore/multipart/mocks/MockFsAsyncBlobContainer.java @@ -27,6 +27,7 @@ import java.nio.file.StandardOpenOption; import java.util.ArrayList; import java.util.List; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; @@ -124,11 +125,11 @@ public void readBlobAsync(String blobName, ActionListener listener) long contentLength = listBlobs().get(blobName).length(); long partSize = contentLength / 10; int numberOfParts = (int) ((contentLength % partSize) == 0 ? contentLength / partSize : (contentLength / partSize) + 1); - List blobPartStreams = new ArrayList<>(); + List> blobPartStreams = new ArrayList<>(); for (int partNumber = 0; partNumber < numberOfParts; partNumber++) { long offset = partNumber * partSize; InputStreamContainer blobPartStream = new InputStreamContainer(readBlob(blobName, offset, partSize), partSize, offset); - blobPartStreams.add(blobPartStream); + blobPartStreams.add(CompletableFuture.completedFuture(blobPartStream)); } ReadContext blobReadContext = new ReadContext(contentLength, blobPartStreams, null); listener.onResponse(blobReadContext); diff --git a/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamBlobContainer.java b/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamBlobContainer.java index e73a9f5cd0bc9..220178d587ac4 100644 --- a/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamBlobContainer.java +++ b/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamBlobContainer.java @@ -13,7 +13,6 @@ import org.opensearch.common.blobstore.stream.read.listener.ReadContextListener; import org.opensearch.common.blobstore.stream.write.WriteContext; import org.opensearch.core.action.ActionListener; -import org.opensearch.threadpool.ThreadPool; import java.io.IOException; import java.nio.file.Path; @@ -49,12 +48,11 @@ public interface AsyncMultiStreamBlobContainer extends BlobContainer { * Asynchronously downloads the blob to the specified location using an executor from the thread pool. * @param blobName The name of the blob for which needs to be downloaded. * @param fileLocation The path on local disk where the blob needs to be downloaded. - * @param threadPool The threadpool instance which will provide the executor for performing a multipart download. * @param completionListener Listener which will be notified when the download is complete. */ @ExperimentalApi - default void asyncBlobDownload(String blobName, Path fileLocation, ThreadPool threadPool, ActionListener completionListener) { - ReadContextListener readContextListener = new ReadContextListener(blobName, fileLocation, threadPool, completionListener); + default void asyncBlobDownload(String blobName, Path fileLocation, ActionListener completionListener) { + ReadContextListener readContextListener = new ReadContextListener(blobName, fileLocation, completionListener); readBlobAsync(blobName, readContextListener); } diff --git a/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainer.java b/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainer.java index c64dc6b9e3ae4..5637326915746 100644 --- a/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainer.java +++ b/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainer.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.io.InputStream; import java.util.List; +import java.util.concurrent.CompletableFuture; import java.util.stream.Collectors; /** @@ -144,8 +145,10 @@ public long getBlobSize() { } @Override - public List getPartStreams() { - return super.getPartStreams().stream().map(this::decryptInputStreamContainer).collect(Collectors.toList()); + public List> getPartStreams() { + return super.getPartStreams().stream() + .map(cf -> cf.thenApply(this::decryptInputStreamContainer)) + .collect(Collectors.toUnmodifiableList()); } /** diff --git a/server/src/main/java/org/opensearch/common/blobstore/stream/read/ReadContext.java b/server/src/main/java/org/opensearch/common/blobstore/stream/read/ReadContext.java index 2c305fb03c475..dc3e2e931c7d3 100644 --- a/server/src/main/java/org/opensearch/common/blobstore/stream/read/ReadContext.java +++ b/server/src/main/java/org/opensearch/common/blobstore/stream/read/ReadContext.java @@ -12,6 +12,7 @@ import org.opensearch.common.io.InputStreamContainer; import java.util.List; +import java.util.concurrent.CompletableFuture; /** * ReadContext is used to encapsulate all data needed by BlobContainer#readBlobAsync @@ -19,18 +20,18 @@ @ExperimentalApi public class ReadContext { private final long blobSize; - private final List partStreams; + private final List> asyncPartStreams; private final String blobChecksum; - public ReadContext(long blobSize, List partStreams, String blobChecksum) { + public ReadContext(long blobSize, List> asyncPartStreams, String blobChecksum) { this.blobSize = blobSize; - this.partStreams = partStreams; + this.asyncPartStreams = asyncPartStreams; this.blobChecksum = blobChecksum; } public ReadContext(ReadContext readContext) { this.blobSize = readContext.blobSize; - this.partStreams = readContext.partStreams; + this.asyncPartStreams = readContext.asyncPartStreams; this.blobChecksum = readContext.blobChecksum; } @@ -39,14 +40,14 @@ public String getBlobChecksum() { } public int getNumberOfParts() { - return partStreams.size(); + return asyncPartStreams.size(); } public long getBlobSize() { return blobSize; } - public List getPartStreams() { - return partStreams; + public List> getPartStreams() { + return asyncPartStreams; } } diff --git a/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriter.java b/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriter.java index 84fd7ed9ffebf..0eae22220ea82 100644 --- a/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriter.java +++ b/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriter.java @@ -22,16 +22,16 @@ import java.nio.file.Path; import java.nio.file.StandardOpenOption; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.BiConsumer; /** * FilePartWriter transfers the provided stream into the specified file path using a {@link FileChannel} * instance. It performs offset based writes to the file and notifies the {@link FileCompletionListener} on completion. */ @InternalApi -class FilePartWriter implements Runnable { +class FilePartWriter implements BiConsumer { private final int partNumber; - private final InputStreamContainer blobPartStreamContainer; private final Path fileLocation; private final AtomicBoolean anyPartStreamFailed; private final ActionListener fileCompletionListener; @@ -42,20 +42,26 @@ class FilePartWriter implements Runnable { public FilePartWriter( int partNumber, - InputStreamContainer blobPartStreamContainer, Path fileLocation, AtomicBoolean anyPartStreamFailed, ActionListener fileCompletionListener ) { this.partNumber = partNumber; - this.blobPartStreamContainer = blobPartStreamContainer; this.fileLocation = fileLocation; this.anyPartStreamFailed = anyPartStreamFailed; this.fileCompletionListener = fileCompletionListener; } @Override - public void run() { + public void accept(InputStreamContainer blobPartStreamContainer, Throwable throwable) { + if (throwable != null) { + if (throwable instanceof Exception) { + processFailure((Exception) throwable); + } else { + processFailure(new Exception(throwable)); + } + return; + } // Ensures no writes to the file if any stream fails. if (anyPartStreamFailed.get() == false) { try (FileChannel outputFileChannel = FileChannel.open(fileLocation, StandardOpenOption.WRITE, StandardOpenOption.CREATE)) { diff --git a/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListener.java b/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListener.java index 4338bddb3fbe7..4aa028fd6e7cc 100644 --- a/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListener.java +++ b/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListener.java @@ -13,29 +13,25 @@ import org.opensearch.common.annotation.InternalApi; import org.opensearch.common.blobstore.stream.read.ReadContext; import org.opensearch.core.action.ActionListener; -import org.opensearch.threadpool.ThreadPool; import java.nio.file.Path; import java.util.concurrent.atomic.AtomicBoolean; /** * ReadContextListener orchestrates the async file fetch from the {@link org.opensearch.common.blobstore.BlobContainer} - * using a {@link ReadContext} callback. On response, it spawns off the download using multiple streams which are - * spread across a {@link ThreadPool} executor. + * using a {@link ReadContext} callback. On response, it spawns off the download using multiple streams. */ @InternalApi public class ReadContextListener implements ActionListener { private final String fileName; private final Path fileLocation; - private final ThreadPool threadPool; private final ActionListener completionListener; private static final Logger logger = LogManager.getLogger(ReadContextListener.class); - public ReadContextListener(String fileName, Path fileLocation, ThreadPool threadPool, ActionListener completionListener) { + public ReadContextListener(String fileName, Path fileLocation, ActionListener completionListener) { this.fileName = fileName; this.fileLocation = fileLocation; - this.threadPool = threadPool; this.completionListener = completionListener; } @@ -47,14 +43,9 @@ public void onResponse(ReadContext readContext) { FileCompletionListener fileCompletionListener = new FileCompletionListener(numParts, fileName, completionListener); for (int partNumber = 0; partNumber < numParts; partNumber++) { - FilePartWriter filePartWriter = new FilePartWriter( - partNumber, - readContext.getPartStreams().get(partNumber), - fileLocation, - anyPartStreamFailed, - fileCompletionListener - ); - threadPool.executor(ThreadPool.Names.GENERIC).submit(filePartWriter); + readContext.getPartStreams() + .get(partNumber) + .whenComplete(new FilePartWriter(partNumber, fileLocation, anyPartStreamFailed, fileCompletionListener)); } } diff --git a/server/src/main/java/org/opensearch/index/shard/IndexShard.java b/server/src/main/java/org/opensearch/index/shard/IndexShard.java index d476e8b7c9288..b3aca98407d9e 100644 --- a/server/src/main/java/org/opensearch/index/shard/IndexShard.java +++ b/server/src/main/java/org/opensearch/index/shard/IndexShard.java @@ -62,7 +62,6 @@ import org.opensearch.action.admin.indices.flush.FlushRequest; import org.opensearch.action.admin.indices.forcemerge.ForceMergeRequest; import org.opensearch.action.admin.indices.upgrade.post.UpgradeRequest; -import org.opensearch.action.support.GroupedActionListener; import org.opensearch.action.support.PlainActionFuture; import org.opensearch.action.support.replication.PendingReplicationActions; import org.opensearch.action.support.replication.ReplicationResponse; @@ -4914,24 +4913,17 @@ private void downloadSegments( RemoteSegmentStoreDirectory targetRemoteDirectory, Set toDownloadSegments, final Runnable onFileSync - ) { - final PlainActionFuture completionListener = PlainActionFuture.newFuture(); - final GroupedActionListener batchDownloadListener = new GroupedActionListener<>( - ActionListener.map(completionListener, v -> null), - toDownloadSegments.size() - ); - - final ActionListener segmentsDownloadListener = ActionListener.map(batchDownloadListener, fileName -> { + ) throws IOException { + final Path indexPath = store.shardPath() == null ? null : store.shardPath().resolveIndex(); + for (String segment : toDownloadSegments) { + final PlainActionFuture segmentListener = PlainActionFuture.newFuture(); + sourceRemoteDirectory.copyTo(segment, storeDirectory, indexPath, segmentListener); + segmentListener.actionGet(); onFileSync.run(); if (targetRemoteDirectory != null) { - targetRemoteDirectory.copyFrom(storeDirectory, fileName, fileName, IOContext.DEFAULT); + targetRemoteDirectory.copyFrom(storeDirectory, segment, segment, IOContext.DEFAULT); } - return null; - }); - - final Path indexPath = store.shardPath() == null ? null : store.shardPath().resolveIndex(); - toDownloadSegments.forEach(file -> { sourceRemoteDirectory.copyTo(file, storeDirectory, indexPath, segmentsDownloadListener); }); - completionListener.actionGet(); + } } private boolean localDirectoryContains(Directory localDirectory, String file, long checksum) { diff --git a/server/src/main/java/org/opensearch/index/store/RemoteSegmentStoreDirectory.java b/server/src/main/java/org/opensearch/index/store/RemoteSegmentStoreDirectory.java index b23d2d7d0a3f8..e74252e89894a 100644 --- a/server/src/main/java/org/opensearch/index/store/RemoteSegmentStoreDirectory.java +++ b/server/src/main/java/org/opensearch/index/store/RemoteSegmentStoreDirectory.java @@ -468,7 +468,7 @@ public void copyTo(String source, Directory destinationDirectory, Path destinati if (destinationPath != null && remoteDataDirectory.getBlobContainer() instanceof AsyncMultiStreamBlobContainer) { final AsyncMultiStreamBlobContainer blobContainer = (AsyncMultiStreamBlobContainer) remoteDataDirectory.getBlobContainer(); final Path destinationFilePath = destinationPath.resolve(source); - blobContainer.asyncBlobDownload(blobName, destinationFilePath, threadPool, fileCompletionListener); + blobContainer.asyncBlobDownload(blobName, destinationFilePath, fileCompletionListener); } else { // Fallback to older mechanism of downloading the file try { diff --git a/server/src/main/java/org/opensearch/indices/replication/RemoteStoreReplicationSource.java b/server/src/main/java/org/opensearch/indices/replication/RemoteStoreReplicationSource.java index aeb690465905f..e17c5293c38ac 100644 --- a/server/src/main/java/org/opensearch/indices/replication/RemoteStoreReplicationSource.java +++ b/server/src/main/java/org/opensearch/indices/replication/RemoteStoreReplicationSource.java @@ -14,7 +14,7 @@ import org.apache.lucene.store.Directory; import org.apache.lucene.store.FilterDirectory; import org.apache.lucene.util.Version; -import org.opensearch.action.support.GroupedActionListener; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.common.concurrent.GatedCloseable; import org.opensearch.core.action.ActionListener; import org.opensearch.index.shard.IndexShard; @@ -141,14 +141,12 @@ private void downloadSegments( ActionListener completionListener ) { final Path indexPath = shardPath == null ? null : shardPath.resolveIndex(); - final GroupedActionListener batchDownloadListener = new GroupedActionListener<>( - ActionListener.map(completionListener, v -> new GetSegmentFilesResponse(toDownloadSegments)), - toDownloadSegments.size() - ); - ActionListener segmentsDownloadListener = ActionListener.map(batchDownloadListener, result -> null); - toDownloadSegments.forEach( - fileMetadata -> remoteStoreDirectory.copyTo(fileMetadata.name(), storeDirectory, indexPath, segmentsDownloadListener) - ); + for (StoreFileMetadata storeFileMetadata : toDownloadSegments) { + final PlainActionFuture segmentListener = PlainActionFuture.newFuture(); + remoteStoreDirectory.copyTo(storeFileMetadata.name(), storeDirectory, indexPath, segmentListener); + segmentListener.actionGet(); + } + completionListener.onResponse(new GetSegmentFilesResponse(toDownloadSegments)); } @Override diff --git a/server/src/test/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainerTests.java b/server/src/test/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainerTests.java index 947a4f9b1c9ab..99b21fb26ea2c 100644 --- a/server/src/test/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainerTests.java +++ b/server/src/test/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainerTests.java @@ -20,6 +20,7 @@ import java.io.ByteArrayInputStream; import java.io.IOException; import java.util.List; +import java.util.concurrent.CompletableFuture; import java.util.function.UnaryOperator; import org.mockito.Mockito; @@ -51,10 +52,12 @@ public void testReadBlobAsync() throws Exception { // Objects needed for API call final byte[] data = new byte[size]; Randomness.get().nextBytes(data); + final InputStreamContainer inputStreamContainer = new InputStreamContainer(new ByteArrayInputStream(data), data.length, 0); final ListenerTestUtils.CountingCompletionListener completionListener = new ListenerTestUtils.CountingCompletionListener<>(); - final ReadContext readContext = new ReadContext(size, List.of(inputStreamContainer), null); + final CompletableFuture streamContainerFuture = CompletableFuture.completedFuture(inputStreamContainer); + final ReadContext readContext = new ReadContext(size, List.of(streamContainerFuture), null); Mockito.doAnswer(invocation -> { ActionListener readContextActionListener = invocation.getArgument(1); @@ -76,7 +79,7 @@ public void testReadBlobAsync() throws Exception { assertEquals(1, response.getNumberOfParts()); assertEquals(size, response.getBlobSize()); - InputStreamContainer responseContainer = response.getPartStreams().get(0); + InputStreamContainer responseContainer = response.getPartStreams().get(0).get(); assertEquals(0, responseContainer.getOffset()); assertEquals(size, responseContainer.getContentLength()); assertEquals(100, responseContainer.getInputStream().available()); @@ -99,7 +102,8 @@ public void testReadBlobAsyncException() throws Exception { final InputStreamContainer inputStreamContainer = new InputStreamContainer(new ByteArrayInputStream(data), data.length, 0); final ListenerTestUtils.CountingCompletionListener completionListener = new ListenerTestUtils.CountingCompletionListener<>(); - final ReadContext readContext = new ReadContext(size, List.of(inputStreamContainer), null); + final CompletableFuture streamContainerFuture = CompletableFuture.completedFuture(inputStreamContainer); + final ReadContext readContext = new ReadContext(size, List.of(streamContainerFuture), null); Mockito.doAnswer(invocation -> { ActionListener readContextActionListener = invocation.getArgument(1); diff --git a/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriterTests.java b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriterTests.java index 811566eb5767b..a2917feb00b10 100644 --- a/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriterTests.java +++ b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriterTests.java @@ -40,14 +40,8 @@ public void testFilePartWriter() throws Exception { AtomicBoolean anyStreamFailed = new AtomicBoolean(); CountingCompletionListener fileCompletionListener = new CountingCompletionListener<>(); - FilePartWriter filePartWriter = new FilePartWriter( - partNumber, - inputStreamContainer, - segmentFilePath, - anyStreamFailed, - fileCompletionListener - ); - filePartWriter.run(); + FilePartWriter filePartWriter = new FilePartWriter(partNumber, segmentFilePath, anyStreamFailed, fileCompletionListener); + filePartWriter.accept(inputStreamContainer, null); assertTrue(Files.exists(segmentFilePath)); assertEquals(contentLength, Files.size(segmentFilePath)); @@ -65,14 +59,8 @@ public void testFilePartWriterWithOffset() throws Exception { AtomicBoolean anyStreamFailed = new AtomicBoolean(); CountingCompletionListener fileCompletionListener = new CountingCompletionListener<>(); - FilePartWriter filePartWriter = new FilePartWriter( - partNumber, - inputStreamContainer, - segmentFilePath, - anyStreamFailed, - fileCompletionListener - ); - filePartWriter.run(); + FilePartWriter filePartWriter = new FilePartWriter(partNumber, segmentFilePath, anyStreamFailed, fileCompletionListener); + filePartWriter.accept(inputStreamContainer, null); assertTrue(Files.exists(segmentFilePath)); assertEquals(contentLength + offset, Files.size(segmentFilePath)); @@ -89,14 +77,8 @@ public void testFilePartWriterLargeInput() throws Exception { AtomicBoolean anyStreamFailed = new AtomicBoolean(); CountingCompletionListener fileCompletionListener = new CountingCompletionListener<>(); - FilePartWriter filePartWriter = new FilePartWriter( - partNumber, - inputStreamContainer, - segmentFilePath, - anyStreamFailed, - fileCompletionListener - ); - filePartWriter.run(); + FilePartWriter filePartWriter = new FilePartWriter(partNumber, segmentFilePath, anyStreamFailed, fileCompletionListener); + filePartWriter.accept(inputStreamContainer, null); assertTrue(Files.exists(segmentFilePath)); assertEquals(contentLength, Files.size(segmentFilePath)); @@ -107,21 +89,12 @@ public void testFilePartWriterLargeInput() throws Exception { public void testFilePartWriterException() throws Exception { Path segmentFilePath = path.resolve(UUID.randomUUID().toString()); - int contentLength = 100; int partNumber = 1; - InputStream inputStream = new ByteArrayInputStream(randomByteArrayOfLength(contentLength)); - InputStreamContainer inputStreamContainer = new InputStreamContainer(inputStream, contentLength, 0); AtomicBoolean anyStreamFailed = new AtomicBoolean(); CountingCompletionListener fileCompletionListener = new CountingCompletionListener<>(); IOException ioException = new IOException(); - FilePartWriter filePartWriter = new FilePartWriter( - partNumber, - inputStreamContainer, - segmentFilePath, - anyStreamFailed, - fileCompletionListener - ); + FilePartWriter filePartWriter = new FilePartWriter(partNumber, segmentFilePath, anyStreamFailed, fileCompletionListener); assertFalse(anyStreamFailed.get()); filePartWriter.processFailure(ioException); @@ -148,14 +121,8 @@ public void testFilePartWriterStreamFailed() throws Exception { AtomicBoolean anyStreamFailed = new AtomicBoolean(true); CountingCompletionListener fileCompletionListener = new CountingCompletionListener<>(); - FilePartWriter filePartWriter = new FilePartWriter( - partNumber, - inputStreamContainer, - segmentFilePath, - anyStreamFailed, - fileCompletionListener - ); - filePartWriter.run(); + FilePartWriter filePartWriter = new FilePartWriter(partNumber, segmentFilePath, anyStreamFailed, fileCompletionListener); + filePartWriter.accept(inputStreamContainer, null); assertFalse(Files.exists(segmentFilePath)); assertEquals(0, fileCompletionListener.getResponseCount()); diff --git a/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListenerTests.java b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListenerTests.java index 21b7b47390a9b..b2ae0d20e7486 100644 --- a/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListenerTests.java +++ b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListenerTests.java @@ -29,6 +29,7 @@ import java.util.ArrayList; import java.util.List; import java.util.UUID; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.CountDownLatch; import static org.opensearch.common.blobstore.stream.read.listener.ListenerTestUtils.CountingCompletionListener; @@ -64,10 +65,10 @@ public void init() throws Exception { public void testReadContextListener() throws InterruptedException, IOException { Path fileLocation = path.resolve(UUID.randomUUID().toString()); - List blobPartStreams = initializeBlobPartStreams(); + List> blobPartStreams = initializeBlobPartStreams(); CountDownLatch countDownLatch = new CountDownLatch(1); ActionListener completionListener = new LatchedActionListener<>(new PlainActionFuture<>(), countDownLatch); - ReadContextListener readContextListener = new ReadContextListener(TEST_SEGMENT_FILE, fileLocation, threadPool, completionListener); + ReadContextListener readContextListener = new ReadContextListener(TEST_SEGMENT_FILE, fileLocation, completionListener); ReadContext readContext = new ReadContext((long) PART_SIZE * NUMBER_OF_PARTS, blobPartStreams, null); readContextListener.onResponse(readContext); @@ -79,10 +80,10 @@ public void testReadContextListener() throws InterruptedException, IOException { public void testReadContextListenerFailure() throws Exception { Path fileLocation = path.resolve(UUID.randomUUID().toString()); - List blobPartStreams = initializeBlobPartStreams(); + List> blobPartStreams = initializeBlobPartStreams(); CountDownLatch countDownLatch = new CountDownLatch(1); ActionListener completionListener = new LatchedActionListener<>(new PlainActionFuture<>(), countDownLatch); - ReadContextListener readContextListener = new ReadContextListener(TEST_SEGMENT_FILE, fileLocation, threadPool, completionListener); + ReadContextListener readContextListener = new ReadContextListener(TEST_SEGMENT_FILE, fileLocation, completionListener); InputStream badInputStream = new InputStream() { @Override @@ -101,7 +102,13 @@ public int available() { } }; - blobPartStreams.add(NUMBER_OF_PARTS, new InputStreamContainer(badInputStream, PART_SIZE, PART_SIZE * NUMBER_OF_PARTS)); + blobPartStreams.add( + NUMBER_OF_PARTS, + CompletableFuture.supplyAsync( + () -> new InputStreamContainer(badInputStream, PART_SIZE, PART_SIZE * NUMBER_OF_PARTS), + threadPool.generic() + ) + ); ReadContext readContext = new ReadContext((long) (PART_SIZE + 1) * NUMBER_OF_PARTS, blobPartStreams, null); readContextListener.onResponse(readContext); @@ -112,18 +119,24 @@ public int available() { public void testReadContextListenerException() { Path fileLocation = path.resolve(UUID.randomUUID().toString()); CountingCompletionListener listener = new CountingCompletionListener(); - ReadContextListener readContextListener = new ReadContextListener(TEST_SEGMENT_FILE, fileLocation, threadPool, listener); + ReadContextListener readContextListener = new ReadContextListener(TEST_SEGMENT_FILE, fileLocation, listener); IOException exception = new IOException(); readContextListener.onFailure(exception); assertEquals(1, listener.getFailureCount()); assertEquals(exception, listener.getException()); } - private List initializeBlobPartStreams() { - List blobPartStreams = new ArrayList<>(); + private List> initializeBlobPartStreams() { + List> blobPartStreams = new ArrayList<>(); for (int partNumber = 0; partNumber < NUMBER_OF_PARTS; partNumber++) { InputStream testStream = new ByteArrayInputStream(randomByteArrayOfLength(PART_SIZE)); - blobPartStreams.add(new InputStreamContainer(testStream, PART_SIZE, (long) partNumber * PART_SIZE)); + int finalPartNumber = partNumber; + blobPartStreams.add( + CompletableFuture.supplyAsync( + () -> new InputStreamContainer(testStream, PART_SIZE, (long) finalPartNumber * PART_SIZE), + threadPool.generic() + ) + ); } return blobPartStreams; } diff --git a/server/src/test/java/org/opensearch/index/store/RemoteSegmentStoreDirectoryTests.java b/server/src/test/java/org/opensearch/index/store/RemoteSegmentStoreDirectoryTests.java index 8d99d98fbaaf4..7393c3c52d2a1 100644 --- a/server/src/test/java/org/opensearch/index/store/RemoteSegmentStoreDirectoryTests.java +++ b/server/src/test/java/org/opensearch/index/store/RemoteSegmentStoreDirectoryTests.java @@ -536,10 +536,10 @@ public void testCopyFilesToMultipart() throws Exception { when(remoteDataDirectory.getBlobContainer()).thenReturn(blobContainer); Mockito.doAnswer(invocation -> { - ActionListener completionListener = invocation.getArgument(3); + ActionListener completionListener = invocation.getArgument(2); completionListener.onResponse(invocation.getArgument(0)); return null; - }).when(blobContainer).asyncBlobDownload(any(), any(), any(), any()); + }).when(blobContainer).asyncBlobDownload(any(), any(), any()); CountDownLatch downloadLatch = new CountDownLatch(1); ActionListener completionListener = new ActionListener() { @@ -554,7 +554,7 @@ public void onFailure(Exception e) {} Path path = createTempDir(); remoteSegmentStoreDirectory.copyTo(filename, storeDirectory, path, completionListener); assertTrue(downloadLatch.await(5000, TimeUnit.SECONDS)); - verify(blobContainer, times(1)).asyncBlobDownload(contains(filename), eq(path.resolve(filename)), any(), any()); + verify(blobContainer, times(1)).asyncBlobDownload(contains(filename), eq(path.resolve(filename)), any()); verify(storeDirectory, times(0)).copyFrom(any(), any(), any(), any()); }