Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[api] Refactor PublisherBytesSupplier.java #2831

Merged
merged 1 commit into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,10 @@

import ai.djl.ndarray.BytesSupplier;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;

/**
Expand All @@ -29,16 +26,14 @@
*/
public class PublisherBytesSupplier implements BytesSupplier {

private final List<byte[]> allData;
private final AtomicBoolean completed;
private Consumer<byte[]> subscriber;
private final AtomicInteger dataPushed;
private CountDownLatch latch;
private CompletableFuture<Void> future;

/** Constructs a {@link PublisherBytesSupplier}. */
public PublisherBytesSupplier() {
allData = new ArrayList<>();
completed = new AtomicBoolean();
dataPushed = new AtomicInteger();
latch = new CountDownLatch(1);
future = new CompletableFuture<>();
}

/**
Expand All @@ -48,83 +43,46 @@ public PublisherBytesSupplier() {
* @param lastChunk true if this is the last chunk
*/
public void appendContent(byte[] data, boolean lastChunk) {
synchronized (allData) {
allData.add(data);
if (subscriber == null) {
try {
if (!latch.await(2, TimeUnit.MINUTES)) {
throw new IllegalStateException("Wait for subscriber timeout.");
}
if (subscriber == null) {
// workaround Spotbugs
throw new IllegalStateException("subscriber is not set.");
}
} catch (InterruptedException e) {
throw new IllegalStateException("Append content interrupted.", e);
}
}
subscriber.accept(data);
if (lastChunk) {
completed.set(true);
subscriber.accept(null);
future.complete(null);
}
pushData();
}

/**
* Adds the subscriber to the {@link BytesSupplier} to get notified about additional data.
*
* @param subscriber a consumer function that will receive bytes when new daata is added and
* null when completed
* @return a {@code CompletableFuture} object
*/
public void subscribe(Consumer<byte[]> subscriber) {
public CompletableFuture<Void> subscribe(Consumer<byte[]> subscriber) {
if (this.subscriber != null) {
throw new IllegalStateException(
"The PublisherBytesSupplier only allows a single Subscriber");
}
this.subscriber = subscriber;
pushData();
}

private void pushData() {
if (subscriber == null) {
return;
}

int dataAvailable;
synchronized (allData) {
dataAvailable = allData.size();
}

int sent = dataPushed.getAndSet(dataAvailable);
if (sent < dataAvailable) {
synchronized (this) {
for (; sent < dataAvailable; sent++) {
subscriber.accept(allData.get(sent));
}
if (completed.get()) {
subscriber.accept(null);
}
}
}
}

/** Waits until completed before passing thread (BLOCKS THREAD!). */
@SuppressWarnings("PMD.EmptyControlStatement")
public void waitToRead() {
// Block until complete!!!
while (!completed.get()) {
// Do nothing
}
}

/** {@inheritDoc} */
@Override
public byte[] getAsBytes() {
if (!completed.get()) {
throw new IllegalStateException(
"PublisherByteSupplier must be completely filled before reading.");
}

try (ByteArrayOutputStream bos = new ByteArrayOutputStream()) {
for (byte[] data : allData) {
bos.write(data);
}
return bos.toByteArray();
} catch (IOException e) {
throw new AssertionError("Failed to read BytesSupplier", e);
}
latch.countDown();
return future;
}

/** {@inheritDoc} */
@Override
public ByteBuffer toByteBuffer() {
return ByteBuffer.wrap(getAsBytes());
throw new UnsupportedOperationException("Not supported.");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,32 +15,38 @@
import org.testng.Assert;
import org.testng.annotations.Test;

import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicInteger;

public class PublisherBytesSupplierTest {

@Test
public void test() {
public void test() throws ExecutionException, InterruptedException {
AtomicInteger contentCount = new AtomicInteger();
PublisherBytesSupplier supplier = new PublisherBytesSupplier();

// Add to supplier without subscriber
supplier.appendContent(new byte[] {1}, false);
Assert.assertEquals(contentCount.get(), 0);
new Thread(
() -> {
// Add to supplier without subscriber
supplier.appendContent(new byte[] {1}, false);
// Add to supplier with subscriber
supplier.appendContent(new byte[] {1}, true);
})
.start();

// Subscribing with data should trigger subscriptions
supplier.subscribe(
d -> {
if (d == null) {
// Do nothing on completion
return;
}
contentCount.getAndIncrement();
});
Assert.assertEquals(contentCount.get(), 1);
CompletableFuture<Void> future =
supplier.subscribe(
d -> {
if (d == null) {
// Do nothing on completion
return;
}
contentCount.getAndIncrement();
});

// Add to supplier with subscriber
supplier.appendContent(new byte[] {1}, true);
future.get();
Assert.assertEquals(contentCount.get(), 2);
}
}
Loading