diff --git a/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListener.java b/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListener.java index 2914fd0c440fa..c77f2384ace0d 100644 --- a/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListener.java +++ b/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListener.java @@ -10,7 +10,10 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.apache.lucene.util.IOUtils; import org.opensearch.action.support.GroupedActionListener; +import org.opensearch.common.SuppressForbidden; +import org.opensearch.common.UUIDs; import org.opensearch.common.annotation.InternalApi; import org.opensearch.common.blobstore.stream.read.ReadContext; import org.opensearch.core.action.ActionListener; @@ -20,6 +23,8 @@ import java.io.InputStream; import java.nio.file.Files; import java.nio.file.Path; +import java.nio.file.StandardCopyOption; +import java.util.Collection; import java.util.Queue; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.Executor; @@ -33,9 +38,11 @@ @InternalApi public class ReadContextListener implements ActionListener { private static final Logger logger = LogManager.getLogger(ReadContextListener.class); - + private static final String DOWNLOAD_PREFIX = "download."; private final String blobName; private final Path fileLocation; + private final String tmpFileName; + private final Path tmpFileLocation; private final ActionListener completionListener; private final ThreadPool threadPool; private final UnaryOperator rateLimiter; @@ -55,6 +62,8 @@ public ReadContextListener( this.threadPool = threadPool; this.rateLimiter = rateLimiter; this.maxConcurrentStreams = maxConcurrentStreams; + this.tmpFileName = DOWNLOAD_PREFIX + UUIDs.randomBase64UUID() + "." + blobName; + this.tmpFileLocation = fileLocation.getParent().resolve(tmpFileName); } @Override @@ -62,15 +71,12 @@ public void onResponse(ReadContext readContext) { logger.debug("Received {} parts for blob {}", readContext.getNumberOfParts(), blobName); final int numParts = readContext.getNumberOfParts(); final AtomicBoolean anyPartStreamFailed = new AtomicBoolean(false); - final GroupedActionListener groupedListener = new GroupedActionListener<>( - ActionListener.wrap(r -> completionListener.onResponse(blobName), completionListener::onFailure), - numParts - ); + final GroupedActionListener groupedListener = new GroupedActionListener<>(getFileCompletionListener(), numParts); final Queue queue = new ConcurrentLinkedQueue<>(readContext.getPartStreams()); final StreamPartProcessor processor = new StreamPartProcessor( queue, anyPartStreamFailed, - fileLocation, + tmpFileLocation, groupedListener, threadPool.executor(ThreadPool.Names.REMOTE_RECOVERY), rateLimiter @@ -80,6 +86,37 @@ public void onResponse(ReadContext readContext) { } } + @SuppressForbidden(reason = "need to fsync once all parts received") + private ActionListener> getFileCompletionListener() { + return ActionListener.wrap(response -> { + logger.trace("renaming temp file [{}] to [{}]", tmpFileLocation, fileLocation); + try { + IOUtils.fsync(tmpFileLocation, false); + Files.move(tmpFileLocation, fileLocation, StandardCopyOption.ATOMIC_MOVE); + // sync parent dir metadata + IOUtils.fsync(fileLocation.getParent(), true); + completionListener.onResponse(blobName); + } catch (IOException e) { + logger.error("Unable to rename temp file + " + tmpFileLocation, e); + completionListener.onFailure(e); + } + }, e -> { + try { + Files.deleteIfExists(tmpFileLocation); + } catch (IOException ex) { + logger.warn("Unable to clean temp file {}", tmpFileLocation); + } + completionListener.onFailure(e); + }); + } + + /* + * For Tests + */ + Path getTmpFileLocation() { + return tmpFileLocation; + } + @Override public void onFailure(Exception e) { completionListener.onFailure(e); diff --git a/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListenerTests.java b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListenerTests.java index 7e4c96cbadcda..0163c2275e7f4 100644 --- a/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListenerTests.java +++ b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListenerTests.java @@ -130,6 +130,7 @@ public int available() { countDownLatch.await(); assertFalse(Files.exists(fileLocation)); + assertFalse(Files.exists(readContextListener.getTmpFileLocation())); } public void testReadContextListenerException() { @@ -149,6 +150,68 @@ public void testReadContextListenerException() { assertEquals(exception, listener.getException()); } + public void testWriteToTempFile() throws Exception { + final String fileName = UUID.randomUUID().toString(); + Path fileLocation = path.resolve(fileName); + List blobPartStreams = initializeBlobPartStreams(); + CountDownLatch countDownLatch = new CountDownLatch(1); + ActionListener completionListener = new LatchedActionListener<>(new PlainActionFuture<>(), countDownLatch); + ReadContextListener readContextListener = new ReadContextListener( + TEST_SEGMENT_FILE, + fileLocation, + completionListener, + threadPool, + UnaryOperator.identity(), + MAX_CONCURRENT_STREAMS + ); + ByteArrayInputStream assertingStream = new ByteArrayInputStream(randomByteArrayOfLength(PART_SIZE)) { + @Override + public int read(byte[] b) throws IOException { + assertTrue("parts written to temp file location", Files.exists(readContextListener.getTmpFileLocation())); + return super.read(b); + } + }; + blobPartStreams.add( + NUMBER_OF_PARTS, + () -> CompletableFuture.supplyAsync( + () -> new InputStreamContainer(assertingStream, PART_SIZE, PART_SIZE * NUMBER_OF_PARTS), + threadPool.generic() + ) + ); + ReadContext readContext = new ReadContext((long) (PART_SIZE + 1) * NUMBER_OF_PARTS + 1, blobPartStreams, null); + readContextListener.onResponse(readContext); + + countDownLatch.await(); + assertTrue(Files.exists(fileLocation)); + assertFalse(Files.exists(readContextListener.getTmpFileLocation())); + } + + public void testWriteToTempFile_alreadyExists_replacesFile() throws Exception { + final String fileName = UUID.randomUUID().toString(); + Path fileLocation = path.resolve(fileName); + // create an empty file at location. + Files.createFile(fileLocation); + assertEquals(0, Files.readAllBytes(fileLocation).length); + List blobPartStreams = initializeBlobPartStreams(); + CountDownLatch countDownLatch = new CountDownLatch(1); + ActionListener completionListener = new LatchedActionListener<>(new PlainActionFuture<>(), countDownLatch); + ReadContextListener readContextListener = new ReadContextListener( + TEST_SEGMENT_FILE, + fileLocation, + completionListener, + threadPool, + UnaryOperator.identity(), + MAX_CONCURRENT_STREAMS + ); + ReadContext readContext = new ReadContext((long) (PART_SIZE + 1) * NUMBER_OF_PARTS, blobPartStreams, null); + readContextListener.onResponse(readContext); + + countDownLatch.await(); + assertTrue(Files.exists(fileLocation)); + assertEquals(50, Files.readAllBytes(fileLocation).length); + assertFalse(Files.exists(readContextListener.getTmpFileLocation())); + } + private List initializeBlobPartStreams() { List blobPartStreams = new ArrayList<>(); for (int partNumber = 0; partNumber < NUMBER_OF_PARTS; partNumber++) {