Skip to content

Commit

Permalink
Refactor multipart download to a more async model
Browse files Browse the repository at this point in the history
The previous approach of kicking off the stream requests for all parts
of a file did not work well for very large files. For example, a 20GiB
file uploaded in 16MiB parts will consist of 1200+ parts. When we
attempted to initiate streaming for all parts concurrently, some parts
would hit a client timeout after 2 minutes without being able to get a
connection due to the other parts not having been completed in that time
frame. This refactoring adds yet another layer of indirection in order
to allow the code that is actually writing the destination file to
control the rate at which streams are started. This should allow for
downloading files consisting of arbitrarily many parts at any connection
speed.

This commit also wires in the download rate limiter so that the
`indices.recovery.max_bytes_per_sec` is properly honored.

Signed-off-by: Andrew Ross <andrross@amazon.com>
  • Loading branch information
andrross committed Oct 4, 2023
1 parent 8f0635c commit a9b8b63
Show file tree
Hide file tree
Showing 32 changed files with 359 additions and 334 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -241,17 +241,20 @@ public void readBlobAsync(String blobName, ActionListener<ReadContext> listener)
return;
}

final List<CompletableFuture<InputStreamContainer>> blobPartInputStreamFutures = new ArrayList<>();
final List<ReadContext.StreamPartCreator> blobPartInputStreamFutures = new ArrayList<>();
final long blobSize = blobMetadata.objectSize();
final Integer numberOfParts = blobMetadata.objectParts() == null ? null : blobMetadata.objectParts().totalPartsCount();
final String blobChecksum = blobMetadata.checksum().checksumCRC32();

if (numberOfParts == null) {
blobPartInputStreamFutures.add(getBlobPartInputStreamContainer(s3AsyncClient, bucketName, blobKey, null));
blobPartInputStreamFutures.add(() -> getBlobPartInputStreamContainer(s3AsyncClient, bucketName, blobKey, null));
} else {
// S3 multipart files use 1 to n indexing
for (int partNumber = 1; partNumber <= numberOfParts; partNumber++) {
blobPartInputStreamFutures.add(getBlobPartInputStreamContainer(s3AsyncClient, bucketName, blobKey, partNumber));
final int innerPartNumber = partNumber;
blobPartInputStreamFutures.add(
() -> getBlobPartInputStreamContainer(s3AsyncClient, bucketName, blobKey, innerPartNumber)
);
}
}
listener.onResponse(new ReadContext(blobSize, blobPartInputStreamFutures, blobChecksum));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -969,7 +969,7 @@ public void testReadBlobAsyncMultiPart() throws Exception {
assertEquals(objectSize, readContext.getBlobSize());

for (int partNumber = 1; partNumber < objectPartCount; partNumber++) {
InputStreamContainer inputStreamContainer = readContext.getPartStreams().get(partNumber).get();
InputStreamContainer inputStreamContainer = readContext.getPartStreams().get(partNumber).get().join();
final int offset = partNumber * partSize;
assertEquals(partSize, inputStreamContainer.getContentLength());
assertEquals(offset, inputStreamContainer.getOffset());
Expand Down Expand Up @@ -1024,7 +1024,7 @@ public void testReadBlobAsyncSinglePart() throws Exception {
assertEquals(checksum, readContext.getBlobChecksum());
assertEquals(objectSize, readContext.getBlobSize());

InputStreamContainer inputStreamContainer = readContext.getPartStreams().stream().findFirst().get().get();
InputStreamContainer inputStreamContainer = readContext.getPartStreams().stream().findFirst().get().get().join();
assertEquals(objectSize, inputStreamContainer.getContentLength());
assertEquals(0, inputStreamContainer.getOffset());
assertEquals(objectSize, inputStreamContainer.getInputStream().readAllBytes().length);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,8 @@ public static final IndexShard newIndexShard(
null,
null,
() -> IndexSettings.DEFAULT_REMOTE_TRANSLOG_BUFFER_INTERVAL,
nodeId
nodeId,
null
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,11 @@ public void readBlobAsync(String blobName, ActionListener<ReadContext> listener)
long contentLength = listBlobs().get(blobName).length();
long partSize = contentLength / 10;
int numberOfParts = (int) ((contentLength % partSize) == 0 ? contentLength / partSize : (contentLength / partSize) + 1);
List<CompletableFuture<InputStreamContainer>> blobPartStreams = new ArrayList<>();
List<ReadContext.StreamPartCreator> blobPartStreams = new ArrayList<>();
for (int partNumber = 0; partNumber < numberOfParts; partNumber++) {
long offset = partNumber * partSize;
InputStreamContainer blobPartStream = new InputStreamContainer(readBlob(blobName, offset, partSize), partSize, offset);
blobPartStreams.add(CompletableFuture.completedFuture(blobPartStream));
blobPartStreams.add(() -> CompletableFuture.completedFuture(blobPartStream));
}
ReadContext blobReadContext = new ReadContext(contentLength, blobPartStreams, null);
listener.onResponse(blobReadContext);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,10 @@

import org.opensearch.common.annotation.ExperimentalApi;
import org.opensearch.common.blobstore.stream.read.ReadContext;
import org.opensearch.common.blobstore.stream.read.listener.ReadContextListener;
import org.opensearch.common.blobstore.stream.write.WriteContext;
import org.opensearch.core.action.ActionListener;

import java.io.IOException;
import java.nio.file.Path;

/**
* An extension of {@link BlobContainer} that adds {@link AsyncMultiStreamBlobContainer#asyncBlobUpload} to allow
Expand Down Expand Up @@ -44,18 +42,6 @@ public interface AsyncMultiStreamBlobContainer extends BlobContainer {
@ExperimentalApi
void readBlobAsync(String blobName, ActionListener<ReadContext> listener);

/**
* Asynchronously downloads the blob to the specified location using an executor from the thread pool.
* @param blobName The name of the blob for which needs to be downloaded.
* @param fileLocation The path on local disk where the blob needs to be downloaded.
* @param completionListener Listener which will be notified when the download is complete.
*/
@ExperimentalApi
default void asyncBlobDownload(String blobName, Path fileLocation, ActionListener<String> completionListener) {
ReadContextListener readContextListener = new ReadContextListener(blobName, fileLocation, completionListener);
readBlobAsync(blobName, readContextListener);
}

/*
* Wether underlying blobContainer can verify integrity of data after transfer. If true and if expected
* checksum is provided in WriteContext, then the checksum of transferred data is compared with expected checksum
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import java.io.IOException;
import java.io.InputStream;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;

/**
Expand Down Expand Up @@ -145,9 +144,9 @@ public long getBlobSize() {
}

@Override
public List<CompletableFuture<InputStreamContainer>> getPartStreams() {
public List<StreamPartCreator> getPartStreams() {
return super.getPartStreams().stream()
.map(cf -> cf.thenApply(this::decryptInputStreamContainer))
.map(supplier -> (StreamPartCreator) () -> supplier.get().thenApply(this::decryptInputStreamContainer))
.collect(Collectors.toUnmodifiableList());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,18 @@

import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.function.Supplier;

/**
* ReadContext is used to encapsulate all data needed by <code>BlobContainer#readBlobAsync</code>
*/
@ExperimentalApi
public class ReadContext {
private final long blobSize;
private final List<CompletableFuture<InputStreamContainer>> asyncPartStreams;
private final List<StreamPartCreator> asyncPartStreams;
private final String blobChecksum;

public ReadContext(long blobSize, List<CompletableFuture<InputStreamContainer>> asyncPartStreams, String blobChecksum) {
public ReadContext(long blobSize, List<StreamPartCreator> asyncPartStreams, String blobChecksum) {
this.blobSize = blobSize;
this.asyncPartStreams = asyncPartStreams;
this.blobChecksum = blobChecksum;
Expand All @@ -47,7 +48,23 @@ public long getBlobSize() {
return blobSize;
}

public List<CompletableFuture<InputStreamContainer>> getPartStreams() {
public List<StreamPartCreator> getPartStreams() {
return asyncPartStreams;
}

/**
* Functional interface defining an instance that can create an async action
* to create a part of an object represented as an InputStreamContainer.
*/
@FunctionalInterface
public interface StreamPartCreator extends Supplier<CompletableFuture<InputStreamContainer>> {
/**
* Kicks off a async process to start streaming.
*
* @return When the returned future is completed, streaming has
* just begun. Clients must fully consume the resulting stream.
*/
@Override
CompletableFuture<InputStreamContainer> get();
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -8,89 +8,37 @@

package org.opensearch.common.blobstore.stream.read.listener;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.common.annotation.InternalApi;
import org.opensearch.common.io.Channels;
import org.opensearch.common.io.InputStreamContainer;
import org.opensearch.core.action.ActionListener;

import java.io.IOException;
import java.io.InputStream;
import java.nio.channels.FileChannel;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiConsumer;
import java.util.function.UnaryOperator;

/**
* FilePartWriter transfers the provided stream into the specified file path using a {@link FileChannel}
* instance. It performs offset based writes to the file and notifies the {@link FileCompletionListener} on completion.
* instance.
*/
@InternalApi
class FilePartWriter implements BiConsumer<InputStreamContainer, Throwable> {

private final int partNumber;
private final Path fileLocation;
private final AtomicBoolean anyPartStreamFailed;
private final ActionListener<Integer> fileCompletionListener;
private static final Logger logger = LogManager.getLogger(FilePartWriter.class);

class FilePartWriter {

Check warning on line 27 in server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriter.java

View check run for this annotation

Codecov / codecov/patch

server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/FilePartWriter.java#L27

Added line #L27 was not covered by tests
// 8 MB buffer for transfer
private static final int BUFFER_SIZE = 8 * 1024 * 2024;

public FilePartWriter(
int partNumber,
Path fileLocation,
AtomicBoolean anyPartStreamFailed,
ActionListener<Integer> fileCompletionListener
) {
this.partNumber = partNumber;
this.fileLocation = fileLocation;
this.anyPartStreamFailed = anyPartStreamFailed;
this.fileCompletionListener = fileCompletionListener;
}

@Override
public void accept(InputStreamContainer blobPartStreamContainer, Throwable throwable) {
if (throwable != null) {
if (throwable instanceof Exception) {
processFailure((Exception) throwable);
} else {
processFailure(new Exception(throwable));
}
return;
}
// Ensures no writes to the file if any stream fails.
if (anyPartStreamFailed.get() == false) {
try (FileChannel outputFileChannel = FileChannel.open(fileLocation, StandardOpenOption.WRITE, StandardOpenOption.CREATE)) {
try (InputStream inputStream = blobPartStreamContainer.getInputStream()) {
long streamOffset = blobPartStreamContainer.getOffset();
final byte[] buffer = new byte[BUFFER_SIZE];
int bytesRead;
while ((bytesRead = inputStream.read(buffer)) != -1) {
Channels.writeToChannel(buffer, 0, bytesRead, outputFileChannel, streamOffset);
streamOffset += bytesRead;
}
public static void write(Path fileLocation, InputStreamContainer stream, UnaryOperator<InputStream> rateLimiter) throws IOException {
try (FileChannel outputFileChannel = FileChannel.open(fileLocation, StandardOpenOption.WRITE, StandardOpenOption.CREATE)) {
try (InputStream inputStream = rateLimiter.apply(stream.getInputStream())) {
long streamOffset = stream.getOffset();
final byte[] buffer = new byte[BUFFER_SIZE];
int bytesRead;
while ((bytesRead = inputStream.read(buffer)) != -1) {
Channels.writeToChannel(buffer, 0, bytesRead, outputFileChannel, streamOffset);
streamOffset += bytesRead;
}
} catch (IOException e) {
processFailure(e);
return;
}
fileCompletionListener.onResponse(partNumber);
}
}

void processFailure(Exception e) {
try {
Files.deleteIfExists(fileLocation);
} catch (IOException ex) {
// Die silently
logger.info("Failed to delete file {} on stream failure: {}", fileLocation, ex);
}
if (anyPartStreamFailed.getAndSet(true) == false) {
fileCompletionListener.onFailure(e);
}
}
}
Loading

0 comments on commit a9b8b63

Please sign in to comment.