Skip to content

Commit

Permalink
Refactor multipart download to a more async model (opensearch-project…
Browse files Browse the repository at this point in the history
…#10349)

* Refactor read context streams to async streams

Signed-off-by: Kunal Kotwani <kkotwani@amazon.com>

* Refactor multipart download to a more async model

The previous approach of kicking off the stream requests for all parts
of a file did not work well for very large files. For example, a 20GiB
file uploaded in 16MiB parts will consist of 1200+ parts. When we
attempted to initiate streaming for all parts concurrently, some parts
would hit a client timeout after 2 minutes without being able to get a
connection due to the other parts not having been completed in that time
frame. This refactoring adds yet another layer of indirection in order
to allow the code that is actually writing the destination file to
control the rate at which streams are started. This should allow for
downloading files consisting of arbitrarily many parts at any connection
speed.

This commit also wires in the download rate limiter so that the
`indices.recovery.max_bytes_per_sec` is properly honored.

Signed-off-by: Andrew Ross <andrross@amazon.com>

---------

Signed-off-by: Kunal Kotwani <kkotwani@amazon.com>
Signed-off-by: Andrew Ross <andrross@amazon.com>
Co-authored-by: Kunal Kotwani <kkotwani@amazon.com>
Signed-off-by: Shivansh Arora <hishiv@amazon.com>
  • Loading branch information
2 people authored and shiv0408 committed Apr 25, 2024
1 parent ea92189 commit 1304987
Show file tree
Hide file tree
Showing 33 changed files with 400 additions and 418 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -241,37 +241,23 @@ public void readBlobAsync(String blobName, ActionListener<ReadContext> listener)
return;
}

final List<CompletableFuture<InputStreamContainer>> blobPartInputStreamFutures = new ArrayList<>();
final List<ReadContext.StreamPartCreator> blobPartInputStreamFutures = new ArrayList<>();
final long blobSize = blobMetadata.objectSize();
final Integer numberOfParts = blobMetadata.objectParts() == null ? null : blobMetadata.objectParts().totalPartsCount();
final String blobChecksum = blobMetadata.checksum().checksumCRC32();

if (numberOfParts == null) {
blobPartInputStreamFutures.add(getBlobPartInputStreamContainer(s3AsyncClient, bucketName, blobKey, null));
blobPartInputStreamFutures.add(() -> getBlobPartInputStreamContainer(s3AsyncClient, bucketName, blobKey, null));
} else {
// S3 multipart files use 1 to n indexing
for (int partNumber = 1; partNumber <= numberOfParts; partNumber++) {
blobPartInputStreamFutures.add(getBlobPartInputStreamContainer(s3AsyncClient, bucketName, blobKey, partNumber));
final int innerPartNumber = partNumber;
blobPartInputStreamFutures.add(
() -> getBlobPartInputStreamContainer(s3AsyncClient, bucketName, blobKey, innerPartNumber)
);
}
}

CompletableFuture.allOf(blobPartInputStreamFutures.toArray(CompletableFuture[]::new))
.whenComplete((unused, partThrowable) -> {
if (partThrowable == null) {
listener.onResponse(
new ReadContext(
blobSize,
blobPartInputStreamFutures.stream().map(CompletableFuture::join).collect(Collectors.toList()),
blobChecksum
)
);
} else {
Exception ex = partThrowable.getCause() instanceof Exception
? (Exception) partThrowable.getCause()
: new Exception(partThrowable.getCause());
listener.onFailure(ex);
}
});
listener.onResponse(new ReadContext(blobSize, blobPartInputStreamFutures, blobChecksum));
});
} catch (Exception ex) {
listener.onFailure(SdkException.create("Error occurred while fetching blob parts from the repository", ex));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -969,7 +969,7 @@ public void testReadBlobAsyncMultiPart() throws Exception {
assertEquals(objectSize, readContext.getBlobSize());

for (int partNumber = 1; partNumber < objectPartCount; partNumber++) {
InputStreamContainer inputStreamContainer = readContext.getPartStreams().get(partNumber);
InputStreamContainer inputStreamContainer = readContext.getPartStreams().get(partNumber).get().join();
final int offset = partNumber * partSize;
assertEquals(partSize, inputStreamContainer.getContentLength());
assertEquals(offset, inputStreamContainer.getOffset());
Expand Down Expand Up @@ -1024,7 +1024,7 @@ public void testReadBlobAsyncSinglePart() throws Exception {
assertEquals(checksum, readContext.getBlobChecksum());
assertEquals(objectSize, readContext.getBlobSize());

InputStreamContainer inputStreamContainer = readContext.getPartStreams().stream().findFirst().get();
InputStreamContainer inputStreamContainer = readContext.getPartStreams().stream().findFirst().get().get().join();
assertEquals(objectSize, inputStreamContainer.getContentLength());
assertEquals(0, inputStreamContainer.getOffset());
assertEquals(objectSize, inputStreamContainer.getInputStream().readAllBytes().length);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,8 @@ public static final IndexShard newIndexShard(
null,
null,
() -> IndexSettings.DEFAULT_REMOTE_TRANSLOG_BUFFER_INTERVAL,
nodeId
nodeId,
null
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.nio.file.StandardOpenOption;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
Expand Down Expand Up @@ -124,11 +125,11 @@ public void readBlobAsync(String blobName, ActionListener<ReadContext> listener)
long contentLength = listBlobs().get(blobName).length();
long partSize = contentLength / 10;
int numberOfParts = (int) ((contentLength % partSize) == 0 ? contentLength / partSize : (contentLength / partSize) + 1);
List<InputStreamContainer> blobPartStreams = new ArrayList<>();
List<ReadContext.StreamPartCreator> blobPartStreams = new ArrayList<>();
for (int partNumber = 0; partNumber < numberOfParts; partNumber++) {
long offset = partNumber * partSize;
InputStreamContainer blobPartStream = new InputStreamContainer(readBlob(blobName, offset, partSize), partSize, offset);
blobPartStreams.add(blobPartStream);
blobPartStreams.add(() -> CompletableFuture.completedFuture(blobPartStream));
}
ReadContext blobReadContext = new ReadContext(contentLength, blobPartStreams, null);
listener.onResponse(blobReadContext);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,10 @@

import org.opensearch.common.annotation.ExperimentalApi;
import org.opensearch.common.blobstore.stream.read.ReadContext;
import org.opensearch.common.blobstore.stream.read.listener.ReadContextListener;
import org.opensearch.common.blobstore.stream.write.WriteContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.threadpool.ThreadPool;

import java.io.IOException;
import java.nio.file.Path;

/**
* An extension of {@link BlobContainer} that adds {@link AsyncMultiStreamBlobContainer#asyncBlobUpload} to allow
Expand Down Expand Up @@ -45,19 +42,6 @@ public interface AsyncMultiStreamBlobContainer extends BlobContainer {
@ExperimentalApi
void readBlobAsync(String blobName, ActionListener<ReadContext> listener);

/**
* Asynchronously downloads the blob to the specified location using an executor from the thread pool.
* @param blobName The name of the blob for which needs to be downloaded.
* @param fileLocation The path on local disk where the blob needs to be downloaded.
* @param threadPool The threadpool instance which will provide the executor for performing a multipart download.
* @param completionListener Listener which will be notified when the download is complete.
*/
@ExperimentalApi
default void asyncBlobDownload(String blobName, Path fileLocation, ThreadPool threadPool, ActionListener<String> completionListener) {
ReadContextListener readContextListener = new ReadContextListener(blobName, fileLocation, threadPool, completionListener);
readBlobAsync(blobName, readContextListener);
}

/*
* Wether underlying blobContainer can verify integrity of data after transfer. If true and if expected
* checksum is provided in WriteContext, then the checksum of transferred data is compared with expected checksum
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,10 @@ public long getBlobSize() {
}

@Override
public List<InputStreamContainer> getPartStreams() {
return super.getPartStreams().stream().map(this::decryptInputStreamContainer).collect(Collectors.toList());
public List<StreamPartCreator> getPartStreams() {
return super.getPartStreams().stream()
.map(supplier -> (StreamPartCreator) () -> supplier.get().thenApply(this::decryptInputStreamContainer))
.collect(Collectors.toUnmodifiableList());
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,27 @@
import org.opensearch.common.io.InputStreamContainer;

import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.function.Supplier;

/**
* ReadContext is used to encapsulate all data needed by <code>BlobContainer#readBlobAsync</code>
*/
@ExperimentalApi
public class ReadContext {
private final long blobSize;
private final List<InputStreamContainer> partStreams;
private final List<StreamPartCreator> asyncPartStreams;
private final String blobChecksum;

public ReadContext(long blobSize, List<InputStreamContainer> partStreams, String blobChecksum) {
public ReadContext(long blobSize, List<StreamPartCreator> asyncPartStreams, String blobChecksum) {
this.blobSize = blobSize;
this.partStreams = partStreams;
this.asyncPartStreams = asyncPartStreams;
this.blobChecksum = blobChecksum;
}

public ReadContext(ReadContext readContext) {
this.blobSize = readContext.blobSize;
this.partStreams = readContext.partStreams;
this.asyncPartStreams = readContext.asyncPartStreams;
this.blobChecksum = readContext.blobChecksum;
}

Expand All @@ -39,14 +41,30 @@ public String getBlobChecksum() {
}

public int getNumberOfParts() {
return partStreams.size();
return asyncPartStreams.size();
}

public long getBlobSize() {
return blobSize;
}

public List<InputStreamContainer> getPartStreams() {
return partStreams;
public List<StreamPartCreator> getPartStreams() {
return asyncPartStreams;
}

/**
* Functional interface defining an instance that can create an async action
* to create a part of an object represented as an InputStreamContainer.
*/
@FunctionalInterface
public interface StreamPartCreator extends Supplier<CompletableFuture<InputStreamContainer>> {
/**
* Kicks off a async process to start streaming.
*
* @return When the returned future is completed, streaming has
* just begun. Clients must fully consume the resulting stream.
*/
@Override
CompletableFuture<InputStreamContainer> get();
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -8,83 +8,37 @@

package org.opensearch.common.blobstore.stream.read.listener;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.common.annotation.InternalApi;
import org.opensearch.common.io.Channels;
import org.opensearch.common.io.InputStreamContainer;
import org.opensearch.core.action.ActionListener;

import java.io.IOException;
import java.io.InputStream;
import java.nio.channels.FileChannel;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.UnaryOperator;

/**
* FilePartWriter transfers the provided stream into the specified file path using a {@link FileChannel}
* instance. It performs offset based writes to the file and notifies the {@link FileCompletionListener} on completion.
* instance.
*/
@InternalApi
class FilePartWriter implements Runnable {

private final int partNumber;
private final InputStreamContainer blobPartStreamContainer;
private final Path fileLocation;
private final AtomicBoolean anyPartStreamFailed;
private final ActionListener<Integer> fileCompletionListener;
private static final Logger logger = LogManager.getLogger(FilePartWriter.class);

class FilePartWriter {
// 8 MB buffer for transfer
private static final int BUFFER_SIZE = 8 * 1024 * 2024;

public FilePartWriter(
int partNumber,
InputStreamContainer blobPartStreamContainer,
Path fileLocation,
AtomicBoolean anyPartStreamFailed,
ActionListener<Integer> fileCompletionListener
) {
this.partNumber = partNumber;
this.blobPartStreamContainer = blobPartStreamContainer;
this.fileLocation = fileLocation;
this.anyPartStreamFailed = anyPartStreamFailed;
this.fileCompletionListener = fileCompletionListener;
}

@Override
public void run() {
// Ensures no writes to the file if any stream fails.
if (anyPartStreamFailed.get() == false) {
try (FileChannel outputFileChannel = FileChannel.open(fileLocation, StandardOpenOption.WRITE, StandardOpenOption.CREATE)) {
try (InputStream inputStream = blobPartStreamContainer.getInputStream()) {
long streamOffset = blobPartStreamContainer.getOffset();
final byte[] buffer = new byte[BUFFER_SIZE];
int bytesRead;
while ((bytesRead = inputStream.read(buffer)) != -1) {
Channels.writeToChannel(buffer, 0, bytesRead, outputFileChannel, streamOffset);
streamOffset += bytesRead;
}
public static void write(Path fileLocation, InputStreamContainer stream, UnaryOperator<InputStream> rateLimiter) throws IOException {
try (FileChannel outputFileChannel = FileChannel.open(fileLocation, StandardOpenOption.WRITE, StandardOpenOption.CREATE)) {
try (InputStream inputStream = rateLimiter.apply(stream.getInputStream())) {
long streamOffset = stream.getOffset();
final byte[] buffer = new byte[BUFFER_SIZE];
int bytesRead;
while ((bytesRead = inputStream.read(buffer)) != -1) {
Channels.writeToChannel(buffer, 0, bytesRead, outputFileChannel, streamOffset);
streamOffset += bytesRead;
}
} catch (IOException e) {
processFailure(e);
return;
}
fileCompletionListener.onResponse(partNumber);
}
}

void processFailure(Exception e) {
try {
Files.deleteIfExists(fileLocation);
} catch (IOException ex) {
// Die silently
logger.info("Failed to delete file {} on stream failure: {}", fileLocation, ex);
}
if (anyPartStreamFailed.getAndSet(true) == false) {
fileCompletionListener.onFailure(e);
}
}
}
Loading

0 comments on commit 1304987

Please sign in to comment.