Skip to content

Commit

Permalink
Use image from S3 bucket
Browse files Browse the repository at this point in the history
Change-Id: Iec3943455cb58dc0d8644841446cb3d91242bb20
  • Loading branch information
frankfliu committed Jul 5, 2022
1 parent df9291f commit c2a27a3
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,6 @@ public class SemanticSegmentationTranslator extends BaseImageTranslator<Image> {
private final int shortEdge;
private final int maxEdge;

private int rescaledWidth;
private int rescaledHeight;

private static final int CHANNEL = 3;
private static final int CLASSNUM = 21;
private static final int BIKE = 2;
Expand Down Expand Up @@ -79,43 +76,44 @@ public Image processOutput(TranslatorContext ctx, NDList list) {
float[] scores = list.get(1).toFloatArray();

// get dimensions of image
final int width = (int) ctx.getAttachment("originalWidth");
final int height = (int) ctx.getAttachment("originalHeight");
int width = (int) ctx.getAttachment("originalWidth");
int height = (int) ctx.getAttachment("originalHeight");

// build image array
NDManager manager = NDManager.newBaseManager();
ByteBuffer bb = manager.allocateDirect(CHANNEL * height * width);
NDArray intRet = manager.create(bb, new Shape(CHANNEL, height, width), DataType.UINT8);

// change color of pixels in image array where objects have been detected
for (int j = 0; j < height; j++) {
for (int k = 0; k < width; k++) {
int maxi = 0;
double maxnum = -Double.MAX_VALUE;
for (int i = 0; i < CLASSNUM; i++) {

// get score for each i at the k,j pixel of the image
float score = scores[i * (width * height) + j * width + k];
if (score > maxnum) {
maxnum = score;
maxi = i;
try (NDManager manager = NDManager.newBaseManager()) {
ByteBuffer bb = manager.allocateDirect(CHANNEL * height * width);
NDArray intRet = manager.create(bb, new Shape(CHANNEL, height, width), DataType.UINT8);

// change color of pixels in image array where objects have been detected
for (int j = 0; j < height; j++) {
for (int k = 0; k < width; k++) {
int maxi = 0;
double maxnum = -Double.MAX_VALUE;
for (int i = 0; i < CLASSNUM; i++) {

// get score for each i at the k,j pixel of the image
float score = scores[i * (width * height) + j * width + k];
if (score > maxnum) {
maxnum = score;
maxi = i;
}
}
}

// color pixel if object was found, otherwise leave as is (black)
if (maxi == PERSON || maxi == BIKE) {
NDIndex index = new NDIndex(0, j, k);
intRet.set(index, 0xFF00FF);
} else if (maxi == CAT || maxi == SHEEP || maxi == 13) {
NDIndex index = new NDIndex(1, j, k);
intRet.set(index, 0xFF00FF);
} else if (maxi == CAR || maxi == DOG) {
NDIndex index = new NDIndex(2, j, k);
intRet.set(index, 0xFF00FF);
// color pixel if object was found, otherwise leave as is (black)
if (maxi == PERSON || maxi == BIKE) {
NDIndex index = new NDIndex(0, j, k);
intRet.set(index, 0xFF00FF);
} else if (maxi == CAT || maxi == SHEEP || maxi == 13) {
NDIndex index = new NDIndex(1, j, k);
intRet.set(index, 0xFF00FF);
} else if (maxi == CAR || maxi == DOG) {
NDIndex index = new NDIndex(2, j, k);
intRet.set(index, 0xFF00FF);
}
}
}
return BufferedImageFactory.getInstance().fromNDArray(intRet);
}
return BufferedImageFactory.getInstance().fromNDArray(intRet);
}

/**
Expand Down Expand Up @@ -156,8 +154,8 @@ public NDArray transform(NDArray array) {
if (Math.round(scale * max) > maxEdge) {
scale = maxEdge / (float) max;
}
rescaledHeight = Math.round(height * scale);
rescaledWidth = Math.round(width * scale);
int rescaledHeight = Math.round(height * scale);
int rescaledWidth = Math.round(width * scale);

return NDImageUtils.resize(array, rescaledWidth, rescaledHeight);
}
Expand Down
Binary file removed examples/docs/img/semantic_instance_bikes.png
Binary file not shown.
Binary file not shown.
4 changes: 2 additions & 2 deletions examples/docs/semantic_segmentation.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ This should produce the following output

With the previous command, an output image with bounding box around all objects will be saved at: build/output/semantic_instances.png:

![detected-instances](img/semantic_instance_bikes.png)
![detected-instances](https://resources.djl.ai/images/semantic_segmentation/semantic_instance_bikes.png)

## Run another semantic segmentation example

Expand Down Expand Up @@ -69,4 +69,4 @@ This should produce the following output

With the previous command, an output image with bounding box around all objects will be saved at: build/output/semantic_instances.png:

![detected-instances](img/semantic_instance_dog_bike_car.png)
![detected-instances](https://resources.djl.ai/images/semantic_segmentation/semantic_instance_dog_bike_car.png)
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,25 @@
*/
package ai.djl.examples.inference;

import ai.djl.Application;
import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.transform.Normalize;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.translator.SemanticSegmentationTranslator;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.io.InputStream;
import java.net.URL;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/**
* An example of inference using a semantic segmentation model.
Expand All @@ -44,7 +41,8 @@
*/
public final class SemanticSegmentation {

private static final Logger logger = LoggerFactory.getLogger(SemanticSegmentationTranslator.class);
private static final Logger logger =
LoggerFactory.getLogger(SemanticSegmentationTranslator.class);

private SemanticSegmentation() {}

Expand All @@ -56,25 +54,19 @@ public static void predict() throws IOException, ModelException, TranslateExcept
Path imageFile = Paths.get("src/test/resources/dog_bike_car.jpg");
Image img = ImageFactory.getInstance().fromFile(imageFile);

URL url = new URL("https://djl-misc.s3.amazonaws.com/tmp/semantic_segmentation/ai/djl/pytorch/deeplab/deeplabv3_scripted.pt");
try (InputStream in = url.openStream()) {
Files.copy(in, Paths.get("src/test/resources/deeplabv3_scripted.pt"));
}

// get image dimensions
final int height = img.getHeight();
final int width = img.getWidth();

// rgb coloring means and standard deviations for improved detection performance
final float[] MEAN = {0.485f, 0.456f, 0.406f};
final float[] STD = {0.229f, 0.224f, 0.225f};
String url =
"https://mlrepo.djl.ai/model/cv/semantic_segmentation/ai/djl/pytorch/deeplabv3/0.0.1/deeplabv3.zip";
Map<String, String> arguments = new ConcurrentHashMap<>();
arguments.put("toTensor", "true");
arguments.put("normalize", "true");
SemanticSegmentationTranslator translator =
SemanticSegmentationTranslator.builder(arguments).build();

Criteria<Image, Image> criteria =
Criteria.builder()
.optApplication(Application.CV.SEMANTIC_SEGMENTATION)
.setTypes(Image.class, Image.class)
.optModelPath(Paths.get("src/test/resources/deeplabv3_scripted.pt"))
.optTranslator(SemanticSegmentationTranslator.builder().addTransform(new Resize(width, height)).addTransform(new ToTensor()).addTransform(new Normalize(MEAN, STD)).optSynsetUrl("https://mlrepo.djl.ai/model/cv/instance_segmentation/ai/djl/mxnet/mask_rcnn/classes.txt").build())
.optModelUrls(url)
.optTranslator(translator)
.optEngine("PyTorch")
.optProgress(new ProgressBar())
.build();
Expand Down

0 comments on commit c2a27a3

Please sign in to comment.