Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding semantic segmentation example #1764

Merged
merged 9 commits into from
Jul 6, 2022
Merged
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.modality.cv.translator;

import ai.djl.modality.cv.BufferedImageFactory;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.util.NDImageUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.ArgumentsUtil;
import ai.djl.translate.Transform;
import ai.djl.translate.TranslatorContext;

import java.nio.ByteBuffer;
import java.util.Map;

/**
* A {@link BaseImageTranslator} that post-process the {@link NDArray} into {@link DetectedObjects}
* with boundaries at the detailed pixel level.
*/
public class SemanticSegmentationTranslator extends BaseImageTranslator<Image> {
private final int shortEdge;
private final int maxEdge;

private static final int CHANNEL = 3;
private static final int CLASSNUM = 21;
private static final int BIKE = 2;
private static final int CAR = 7;
private static final int DOG = 8;
private static final int CAT = 12;
private static final int PERSON = 15;

// sheep is also identified with id 13 as well, this is taken into account when coloring pixels
private static final int SHEEP = 17; // 13

/**
* Creates the Semantic Segmentation translator from the given builder.
JohnDoll2023 marked this conversation as resolved.
Show resolved Hide resolved
*
* @param builder the builder for the translator
*/
public SemanticSegmentationTranslator(Builder builder) {
JohnDoll2023 marked this conversation as resolved.
Show resolved Hide resolved
super(builder);
this.shortEdge = builder.shortEdge;
this.maxEdge = builder.maxEdge;

pipeline.insert(0, null, new ResizeShort());
}

/** {@inheritDoc} */
@Override
public NDList processInput(TranslatorContext ctx, Image image) {
ctx.setAttachment("originalHeight", image.getHeight());
ctx.setAttachment("originalWidth", image.getWidth());
return super.processInput(ctx, image);
}

/** {@inheritDoc} */
@Override
public Image processOutput(TranslatorContext ctx, NDList list) {
// scores contains the probabilities of each pixel being a certain object
float[] scores = list.get(1).toFloatArray();

// get dimensions of image
int width = (int) ctx.getAttachment("originalWidth");
int height = (int) ctx.getAttachment("originalHeight");

// build image array
try (NDManager manager = NDManager.newBaseManager()) {
ByteBuffer bb = manager.allocateDirect(CHANNEL * height * width);
NDArray intRet = manager.create(bb, new Shape(CHANNEL, height, width), DataType.UINT8);

// change color of pixels in image array where objects have been detected
for (int j = 0; j < height; j++) {
for (int k = 0; k < width; k++) {
int maxi = 0;
double maxnum = -Double.MAX_VALUE;
for (int i = 0; i < CLASSNUM; i++) {

// get score for each i at the k,j pixel of the image
float score = scores[i * (width * height) + j * width + k];
if (score > maxnum) {
maxnum = score;
maxi = i;
}
}

// color pixel if object was found, otherwise leave as is (black)
if (maxi == PERSON || maxi == BIKE) {
NDIndex index = new NDIndex(0, j, k);
intRet.set(index, 0xFF00FF);
} else if (maxi == CAT || maxi == SHEEP || maxi == 13) {
NDIndex index = new NDIndex(1, j, k);
intRet.set(index, 0xFF00FF);
} else if (maxi == CAR || maxi == DOG) {
NDIndex index = new NDIndex(2, j, k);
intRet.set(index, 0xFF00FF);
}
}
}
return BufferedImageFactory.getInstance().fromNDArray(intRet);
}
}

/**
* Creates a builder to build a {@code SemanticSegmentationTranslator}.
*
* @return a new builder
*/
public static Builder builder() {
return new Builder();
}

/**
* Creates a builder to build a {@code SemanticSegmentationTranslator} with specified arguments.
*
* @param arguments arguments to specify builder options
* @return a new builder
*/
public static Builder builder(Map<String, ?> arguments) {
Builder builder = new Builder();

builder.configPreProcess(arguments);
builder.configPostProcess(arguments);

return builder;
}

/** Resizes the image based on the shorter edge or maximum edge length. */
private class ResizeShort implements Transform {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
private class ResizeShort implements Transform {
private static final class ResizeShort implements Transform {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ResizeShort class from the InstanceSegmentationTranslator is not static or final, and by making it static, the shortEdge, maxEdge, rescaledHeight, and rescaledWidth variables must be initialized at the beginning of the file. If I make it static, what values should they be assigned?

/** {@inheritDoc} */
@Override
public NDArray transform(NDArray array) {
Shape shape = array.getShape();
int width = (int) shape.get(1);
int height = (int) shape.get(0);
int min = Math.min(width, height);
int max = Math.max(width, height);
float scale = shortEdge / (float) min;
if (Math.round(scale * max) > maxEdge) {
scale = maxEdge / (float) max;
}
int rescaledHeight = Math.round(height * scale);
int rescaledWidth = Math.round(width * scale);

return NDImageUtils.resize(array, rescaledWidth, rescaledHeight);
}
}

/** The builder for Semantic Segmentation translator. */
public static class Builder extends ClassificationBuilder<Builder> {
int shortEdge = 600;
int maxEdge = 1000;

Builder() {}

/** {@inheritDoc} */
@Override
protected Builder self() {
return this;
}

/** {@inheritDoc} */
@Override
protected void configPostProcess(Map<String, ?> arguments) {
super.configPostProcess(arguments);
shortEdge = ArgumentsUtil.intValue(arguments, "shortEdge", 600);
maxEdge = ArgumentsUtil.intValue(arguments, "maxEdge", 1000);
}

/**
* Builds the translator.
*
* @return the new translator
*/
public SemanticSegmentationTranslator build() {
validate();
return new SemanticSegmentationTranslator(this);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.modality.cv.translator;

import ai.djl.Model;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.translator.wrapper.FileTranslator;
import ai.djl.modality.cv.translator.wrapper.InputStreamTranslator;
import ai.djl.modality.cv.translator.wrapper.UrlTranslator;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorFactory;

import java.io.InputStream;
import java.net.URL;
import java.nio.file.Path;
import java.util.Map;

/** A {@link TranslatorFactory} that creates a {@link SemanticSegmentationTranslator} instance. */
public class SemanticSegmentationTranslatorFactory extends ObjectDetectionTranslatorFactory {

/** {@inheritDoc} */
@Override
public Translator<?, ?> newInstance(
Class<?> input, Class<?> output, Model model, Map<String, ?> arguments) {
if (input == Image.class && output == Image.class) {
return SemanticSegmentationTranslator.builder(arguments).build();
} else if (input == Path.class && output == Image.class) {
return new FileTranslator<>(SemanticSegmentationTranslator.builder(arguments).build());
} else if (input == URL.class && output == Image.class) {
return new UrlTranslator<>(SemanticSegmentationTranslator.builder(arguments).build());
} else if (input == InputStream.class && output == Image.class) {
return new InputStreamTranslator<>(
SemanticSegmentationTranslator.builder(arguments).build());
} else if (input == Input.class && output == Output.class) {
return new ImageServingTranslator(
SemanticSegmentationTranslator.builder(arguments).build());
}
throw new IllegalArgumentException("Unsupported input/output types.");
}
}
2 changes: 1 addition & 1 deletion examples/docs/instance_segmentation.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,6 @@ This should produce the following output
```


With the previous command, an output image with bounding box around all objects will be saved at: build/output/instances.jpg:
With the previous command, an output image with bounding box around all objects will be saved at: build/output/instances.png:

![detected-instances](img/detected_instances.png)
72 changes: 72 additions & 0 deletions examples/docs/semantic_segmentation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Semantic segmentation example

Semantic segmentation refers to the task of detecting objects of various classes at pixel level. It colors the pixels based on the objects detected in that space.

In this example, you learn how to implement inference code with Deep Java Library (DJL) to segment classes at instance level in an image.

The following is the semantic segmentation example source code:

[SemanticSegmentation.java](https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/SemanticSegmentation.java).


## Setup guide

Follow [setup](../../docs/development/setup.md) to configure your development environment.

## Run semantic segmentation example

### Input image file
You can find the image used in this example in project test resource folder: `src/test/resources/segmentation.jpg`

![segmentation](../src/test/resources/segmentation.jpg)

### Build the project and run

```
cd examples
./gradlew run -Dmain=ai.djl.examples.inference.SemanticSegmentation
```

This should produce the following output

```text
[INFO ] - Segmentation result image has been saved in: build/output/semantic_instances.png
```


With the previous command, an output image with bounding box around all objects will be saved at: build/output/semantic_instances.png:

![detected-instances](https://resources.djl.ai/images/semantic_segmentation/semantic_instance_bikes.png)

## Run another semantic segmentation example

### Input image file
You can find the image used in this example in project test resource folder: `src/test/resources/dog_bike_car.jpg`

![segmentation](../src/test/resources/dog_bike_car.jpg)

### Edit project to run inference with new image

In the `SemanticSegmentation.java` file, find the `predict()` method. Change the `imageFile` path to look like this:

```
Path imageFile = Paths.get("src/test/resources/dog_bike_car.jpg");
```

### Build the project and run

```
cd examples
./gradlew run -Dmain=ai.djl.examples.inference.InstanceSegmentation
```

This should produce the following output

```text
[INFO ] - Segmentation result image has been saved in: build/output/semantic_instances.png
```


With the previous command, an output image with bounding box around all objects will be saved at: build/output/semantic_instances.png:

![detected-instances](https://resources.djl.ai/images/semantic_segmentation/semantic_instance_dog_bike_car.png)
Loading