diff --git a/java/compression/src/test/java/org/apache/arrow/compression/TestArrowReaderWriterWithCompression.java b/java/compression/src/test/java/org/apache/arrow/compression/TestArrowReaderWriterWithCompression.java index 878e1917a2e4a..3196f093bbc8f 100644 --- a/java/compression/src/test/java/org/apache/arrow/compression/TestArrowReaderWriterWithCompression.java +++ b/java/compression/src/test/java/org/apache/arrow/compression/TestArrowReaderWriterWithCompression.java @@ -18,7 +18,11 @@ package org.apache.arrow.compression; import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.FileOutputStream; import java.nio.channels.Channels; +import java.nio.channels.FileChannel; +import java.nio.channels.SeekableByteChannel; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.HashMap; @@ -62,7 +66,7 @@ public void init() { allocator = new RootAllocator(Long.MAX_VALUE); dictionaryVector1 = (VarCharVector) - FieldType.nullable(new ArrowType.Utf8()).createNewSingleVector("D1", allocator, null); + FieldType.nullable(new ArrowType.Utf8()).createNewSingleVector("f1", allocator, null); setVector(dictionaryVector1, "foo".getBytes(StandardCharsets.UTF_8), @@ -78,9 +82,7 @@ public void terminate() throws Exception { dictionaryVector1.close(); allocator.close(); } - - - + @Test public void testArrowFileZstdRoundTrip() throws Exception { // Prepare sample data @@ -117,7 +119,6 @@ public void testArrowFileZstdRoundTrip() throws Exception { new ArrowFileReader(new ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator, NoCompressionCodec.Factory.INSTANCE)) { Assert.assertEquals(1, reader.getRecordBlocks().size()); - Exception exception = Assert.assertThrows(IllegalArgumentException.class, () -> reader.loadNextBatch()); String expectedMessage = "Please add arrow-compression module to use CommonsCompressionFactory for ZSTD"; Assert.assertEquals(expectedMessage, exception.getMessage()); @@ -170,9 +171,8 @@ public void testArrowFileZstdRoundTripWithDictionary() throws Exception { try (ArrowFileReader reader = new ArrowFileReader(new ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator, NoCompressionCodec.Factory.INSTANCE)) { - Assert.assertEquals(1, reader.getRecordBlocks().size()); - Exception exception = Assert.assertThrows(IllegalArgumentException.class, () -> reader.loadNextBatch()); String expectedMessage = "Please add arrow-compression module to use CommonsCompressionFactory for ZSTD"; + Exception exception = Assert.assertThrows(IllegalArgumentException.class, () -> reader.loadNextBatch()); Assert.assertEquals(expectedMessage, exception.getMessage()); } } @@ -196,34 +196,38 @@ public void testArrowStreamZstdRoundTrip() throws Exception { fields.add(encodedVector1.getField()); VectorSchemaRoot root = VectorSchemaRoot.create(new Schema(fields), allocator); - final int rowCount = 10; + final int rowCount = 3; GenerateSampleData.generateTestData(root.getVector(0), rowCount); root.setRowCount(rowCount); // Write an in-memory compressed arrow file ByteArrayOutputStream out = new ByteArrayOutputStream(); + File tempFile = File.createTempFile("dictionary_compression", ".arrow"); + FileOutputStream fileOut = new FileOutputStream(tempFile); try (final ArrowStreamWriter writer = - new ArrowStreamWriter(root, provider, Channels.newChannel(out), IpcOption.DEFAULT, - CommonsCompressionFactory.INSTANCE, CompressionUtil.CodecType.ZSTD, Optional.of(7))) { + new ArrowStreamWriter(root, provider, Channels.newChannel(fileOut), IpcOption.DEFAULT, + CommonsCompressionFactory.INSTANCE, CompressionUtil.CodecType.ZSTD, + Optional.of(7))) { writer.start(); writer.writeBatch(); writer.end(); } - // Read the in-memory compressed arrow file with CommonsCompressionFactory provided - try (ArrowStreamReader reader = - new ArrowStreamReader(new ByteArrayReadableSeekableByteChannel(out.toByteArray()), - allocator, CommonsCompressionFactory.INSTANCE)) { + // Read the on-disk compressed arrow file with CommonsCompressionFactory provided + try (SeekableByteChannel channel = FileChannel.open(tempFile.toPath()); + ArrowStreamReader reader = + new ArrowStreamReader(channel, allocator, CommonsCompressionFactory.INSTANCE)) { + org.apache.arrow.vector.types.pojo.Schema schema = reader.getVectorSchemaRoot().getSchema(); Assert.assertTrue(reader.loadNextBatch()); Assert.assertTrue(root.equals(reader.getVectorSchemaRoot())); Assert.assertFalse(reader.loadNextBatch()); - } - // Read the in-memory compressed arrow file without CompressionFactory provided - try (ArrowStreamReader reader = - new ArrowStreamReader(new ByteArrayReadableSeekableByteChannel(out.toByteArray()), - allocator, NoCompressionCodec.Factory.INSTANCE)) { + } + // Read the on-disk compressed arrow file without CompressionFactory provided + try (SeekableByteChannel channel = FileChannel.open(tempFile.toPath()); + ArrowStreamReader reader = + new ArrowStreamReader(channel, allocator, NoCompressionCodec.Factory.INSTANCE)) { Exception exception = Assert.assertThrows(IllegalArgumentException.class, () -> reader.loadNextBatch()); String expectedMessage = "Please add arrow-compression module to use CommonsCompressionFactory for ZSTD"; Assert.assertEquals(expectedMessage, exception.getMessage()); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowReader.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowReader.java index 04c57d7e82fef..01f4e925c69b3 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowReader.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowReader.java @@ -251,7 +251,7 @@ private void load(ArrowDictionaryBatch dictionaryBatch, FieldVector vector) { VectorSchemaRoot root = new VectorSchemaRoot( Collections.singletonList(vector.getField()), Collections.singletonList(vector), 0); - VectorLoader loader = new VectorLoader(root); + VectorLoader loader = new VectorLoader(root, this.compressionFactory); try { loader.load(dictionaryBatch.getDictionary()); } finally { diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java index a33c55de53f23..9953428031372 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java @@ -66,6 +66,12 @@ public abstract class ArrowWriter implements AutoCloseable { protected IpcOption option; + private CompressionCodec.Factory compressionFactory; + + private CompressionUtil.CodecType codecType; + + private Optional compressionLevel; + protected ArrowWriter(VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out) { this(root, provider, out, IpcOption.DEFAULT); } @@ -99,6 +105,10 @@ protected ArrowWriter(VectorSchemaRoot root, DictionaryProvider provider, Writab this.option = option; this.dictionaryProvider = provider; + this.compressionFactory = compressionFactory; + this.codecType = codecType; + this.compressionLevel = compressionLevel; + List fields = new ArrayList<>(root.getSchema().getFields().size()); MetadataV4UnionChecker.checkForUnion(root.getSchema().getFields().iterator(), option.metadataVersion); @@ -133,7 +143,11 @@ protected void writeDictionaryBatch(Dictionary dictionary) throws IOException { Collections.singletonList(vector.getField()), Collections.singletonList(vector), count); - VectorUnloader unloader = new VectorUnloader(dictRoot); + VectorUnloader unloader = new VectorUnloader(dictRoot, /*includeNullCount*/ true, + this.compressionLevel.isPresent() ? + this.compressionFactory.createCodec(this.codecType, this.compressionLevel.get()) : + this.compressionFactory.createCodec(this.codecType), + /*alignBuffers*/ true); ArrowRecordBatch batch = unloader.getRecordBatch(); ArrowDictionaryBatch dictionaryBatch = new ArrowDictionaryBatch(id, batch, false); try {