diff --git a/api/src/main/java/ai/djl/modality/cv/translator/SemanticSegmentationTranslator.java b/api/src/main/java/ai/djl/modality/cv/translator/SemanticSegmentationTranslator.java new file mode 100644 index 00000000000..5baaec52f0c --- /dev/null +++ b/api/src/main/java/ai/djl/modality/cv/translator/SemanticSegmentationTranslator.java @@ -0,0 +1,195 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.modality.cv.translator; + +import ai.djl.modality.cv.BufferedImageFactory; +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.output.DetectedObjects; +import ai.djl.modality.cv.util.NDImageUtils; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.index.NDIndex; +import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; +import ai.djl.translate.ArgumentsUtil; +import ai.djl.translate.Transform; +import ai.djl.translate.TranslatorContext; + +import java.nio.ByteBuffer; +import java.util.Map; + +/** + * A {@link BaseImageTranslator} that post-process the {@link NDArray} into {@link DetectedObjects} + * with boundaries at the detailed pixel level. + */ +public class SemanticSegmentationTranslator extends BaseImageTranslator { + private final int shortEdge; + private final int maxEdge; + + private static final int CHANNEL = 3; + private static final int CLASSNUM = 21; + private static final int BIKE = 2; + private static final int CAR = 7; + private static final int DOG = 8; + private static final int CAT = 12; + private static final int PERSON = 15; + + // sheep is also identified with id 13 as well, this is taken into account when coloring pixels + private static final int SHEEP = 17; // 13 + + /** + * Creates the Semantic Segmentation translator from the given builder. + * + * @param builder the builder for the translator + */ + public SemanticSegmentationTranslator(Builder builder) { + super(builder); + this.shortEdge = builder.shortEdge; + this.maxEdge = builder.maxEdge; + + pipeline.insert(0, null, new ResizeShort()); + } + + /** {@inheritDoc} */ + @Override + public NDList processInput(TranslatorContext ctx, Image image) { + ctx.setAttachment("originalHeight", image.getHeight()); + ctx.setAttachment("originalWidth", image.getWidth()); + return super.processInput(ctx, image); + } + + /** {@inheritDoc} */ + @Override + public Image processOutput(TranslatorContext ctx, NDList list) { + // scores contains the probabilities of each pixel being a certain object + float[] scores = list.get(1).toFloatArray(); + + // get dimensions of image + int width = (int) ctx.getAttachment("originalWidth"); + int height = (int) ctx.getAttachment("originalHeight"); + + // build image array + try (NDManager manager = NDManager.newBaseManager()) { + ByteBuffer bb = manager.allocateDirect(CHANNEL * height * width); + NDArray intRet = manager.create(bb, new Shape(CHANNEL, height, width), DataType.UINT8); + + // change color of pixels in image array where objects have been detected + for (int j = 0; j < height; j++) { + for (int k = 0; k < width; k++) { + int maxi = 0; + double maxnum = -Double.MAX_VALUE; + for (int i = 0; i < CLASSNUM; i++) { + + // get score for each i at the k,j pixel of the image + float score = scores[i * (width * height) + j * width + k]; + if (score > maxnum) { + maxnum = score; + maxi = i; + } + } + + // color pixel if object was found, otherwise leave as is (black) + if (maxi == PERSON || maxi == BIKE) { + NDIndex index = new NDIndex(0, j, k); + intRet.set(index, 0xFF00FF); + } else if (maxi == CAT || maxi == SHEEP || maxi == 13) { + NDIndex index = new NDIndex(1, j, k); + intRet.set(index, 0xFF00FF); + } else if (maxi == CAR || maxi == DOG) { + NDIndex index = new NDIndex(2, j, k); + intRet.set(index, 0xFF00FF); + } + } + } + return BufferedImageFactory.getInstance().fromNDArray(intRet); + } + } + + /** + * Creates a builder to build a {@code SemanticSegmentationTranslator}. + * + * @return a new builder + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Creates a builder to build a {@code SemanticSegmentationTranslator} with specified arguments. + * + * @param arguments arguments to specify builder options + * @return a new builder + */ + public static Builder builder(Map arguments) { + Builder builder = new Builder(); + + builder.configPreProcess(arguments); + builder.configPostProcess(arguments); + + return builder; + } + + /** Resizes the image based on the shorter edge or maximum edge length. */ + private class ResizeShort implements Transform { + /** {@inheritDoc} */ + @Override + public NDArray transform(NDArray array) { + Shape shape = array.getShape(); + int width = (int) shape.get(1); + int height = (int) shape.get(0); + int min = Math.min(width, height); + int max = Math.max(width, height); + float scale = shortEdge / (float) min; + if (Math.round(scale * max) > maxEdge) { + scale = maxEdge / (float) max; + } + int rescaledHeight = Math.round(height * scale); + int rescaledWidth = Math.round(width * scale); + + return NDImageUtils.resize(array, rescaledWidth, rescaledHeight); + } + } + + /** The builder for Semantic Segmentation translator. */ + public static class Builder extends ClassificationBuilder { + int shortEdge = 600; + int maxEdge = 1000; + + Builder() {} + + /** {@inheritDoc} */ + @Override + protected Builder self() { + return this; + } + + /** {@inheritDoc} */ + @Override + protected void configPostProcess(Map arguments) { + super.configPostProcess(arguments); + shortEdge = ArgumentsUtil.intValue(arguments, "shortEdge", 600); + maxEdge = ArgumentsUtil.intValue(arguments, "maxEdge", 1000); + } + + /** + * Builds the translator. + * + * @return the new translator + */ + public SemanticSegmentationTranslator build() { + validate(); + return new SemanticSegmentationTranslator(this); + } + } +} diff --git a/api/src/main/java/ai/djl/modality/cv/translator/SemanticSegmentationTranslatorFactory.java b/api/src/main/java/ai/djl/modality/cv/translator/SemanticSegmentationTranslatorFactory.java new file mode 100644 index 00000000000..ddb61b81c23 --- /dev/null +++ b/api/src/main/java/ai/djl/modality/cv/translator/SemanticSegmentationTranslatorFactory.java @@ -0,0 +1,52 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.modality.cv.translator; + +import ai.djl.Model; +import ai.djl.modality.Input; +import ai.djl.modality.Output; +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.translator.wrapper.FileTranslator; +import ai.djl.modality.cv.translator.wrapper.InputStreamTranslator; +import ai.djl.modality.cv.translator.wrapper.UrlTranslator; +import ai.djl.translate.Translator; +import ai.djl.translate.TranslatorFactory; + +import java.io.InputStream; +import java.net.URL; +import java.nio.file.Path; +import java.util.Map; + +/** A {@link TranslatorFactory} that creates a {@link SemanticSegmentationTranslator} instance. */ +public class SemanticSegmentationTranslatorFactory extends ObjectDetectionTranslatorFactory { + + /** {@inheritDoc} */ + @Override + public Translator newInstance( + Class input, Class output, Model model, Map arguments) { + if (input == Image.class && output == Image.class) { + return SemanticSegmentationTranslator.builder(arguments).build(); + } else if (input == Path.class && output == Image.class) { + return new FileTranslator<>(SemanticSegmentationTranslator.builder(arguments).build()); + } else if (input == URL.class && output == Image.class) { + return new UrlTranslator<>(SemanticSegmentationTranslator.builder(arguments).build()); + } else if (input == InputStream.class && output == Image.class) { + return new InputStreamTranslator<>( + SemanticSegmentationTranslator.builder(arguments).build()); + } else if (input == Input.class && output == Output.class) { + return new ImageServingTranslator( + SemanticSegmentationTranslator.builder(arguments).build()); + } + throw new IllegalArgumentException("Unsupported input/output types."); + } +} diff --git a/examples/docs/instance_segmentation.md b/examples/docs/instance_segmentation.md index 24e534a4818..ce8fde0069c 100644 --- a/examples/docs/instance_segmentation.md +++ b/examples/docs/instance_segmentation.md @@ -38,6 +38,6 @@ This should produce the following output ``` -With the previous command, an output image with bounding box around all objects will be saved at: build/output/instances.jpg: +With the previous command, an output image with bounding box around all objects will be saved at: build/output/instances.png: ![detected-instances](img/detected_instances.png) diff --git a/examples/docs/semantic_segmentation.md b/examples/docs/semantic_segmentation.md new file mode 100644 index 00000000000..c1d242df5be --- /dev/null +++ b/examples/docs/semantic_segmentation.md @@ -0,0 +1,72 @@ +# Semantic segmentation example + +Semantic segmentation refers to the task of detecting objects of various classes at pixel level. It colors the pixels based on the objects detected in that space. + +In this example, you learn how to implement inference code with Deep Java Library (DJL) to segment classes at instance level in an image. + +The following is the semantic segmentation example source code: + +[SemanticSegmentation.java](https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/SemanticSegmentation.java). + + +## Setup guide + +Follow [setup](../../docs/development/setup.md) to configure your development environment. + +## Run semantic segmentation example + +### Input image file +You can find the image used in this example in project test resource folder: `src/test/resources/segmentation.jpg` + +![segmentation](../src/test/resources/segmentation.jpg) + +### Build the project and run + +``` +cd examples +./gradlew run -Dmain=ai.djl.examples.inference.SemanticSegmentation +``` + +This should produce the following output + +```text +[INFO ] - Segmentation result image has been saved in: build/output/semantic_instances.png +``` + + +With the previous command, an output image with bounding box around all objects will be saved at: build/output/semantic_instances.png: + +![detected-instances](https://resources.djl.ai/images/semantic_segmentation/semantic_instance_bikes.png) + +## Run another semantic segmentation example + +### Input image file +You can find the image used in this example in project test resource folder: `src/test/resources/dog_bike_car.jpg` + +![segmentation](../src/test/resources/dog_bike_car.jpg) + +### Edit project to run inference with new image + +In the `SemanticSegmentation.java` file, find the `predict()` method. Change the `imageFile` path to look like this: + +``` +Path imageFile = Paths.get("src/test/resources/dog_bike_car.jpg"); +``` + +### Build the project and run + +``` +cd examples +./gradlew run -Dmain=ai.djl.examples.inference.InstanceSegmentation +``` + +This should produce the following output + +```text +[INFO ] - Segmentation result image has been saved in: build/output/semantic_instances.png +``` + + +With the previous command, an output image with bounding box around all objects will be saved at: build/output/semantic_instances.png: + +![detected-instances](https://resources.djl.ai/images/semantic_segmentation/semantic_instance_dog_bike_car.png) diff --git a/examples/src/main/java/ai/djl/examples/inference/SemanticSegmentation.java b/examples/src/main/java/ai/djl/examples/inference/SemanticSegmentation.java new file mode 100644 index 00000000000..0ee791de59e --- /dev/null +++ b/examples/src/main/java/ai/djl/examples/inference/SemanticSegmentation.java @@ -0,0 +1,95 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.examples.inference; + +import ai.djl.ModelException; +import ai.djl.inference.Predictor; +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.ImageFactory; +import ai.djl.modality.cv.translator.SemanticSegmentationTranslator; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.training.util.ProgressBar; +import ai.djl.translate.TranslateException; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** + * An example of inference using a semantic segmentation model. + * + *

See this doc + * for information about this example. + */ +public final class SemanticSegmentation { + + private static final Logger logger = + LoggerFactory.getLogger(SemanticSegmentationTranslator.class); + + private SemanticSegmentation() {} + + public static void main(String[] args) throws IOException, ModelException, TranslateException { + SemanticSegmentation.predict(); + } + + public static void predict() throws IOException, ModelException, TranslateException { + Path imageFile = Paths.get("src/test/resources/dog_bike_car.jpg"); + Image img = ImageFactory.getInstance().fromFile(imageFile); + + int height = img.getHeight(); + int width = img.getWidth(); + + String url = + "https://mlrepo.djl.ai/model/cv/semantic_segmentation/ai/djl/pytorch/deeplabv3/0.0.1/deeplabv3.zip"; + Map arguments = new ConcurrentHashMap<>(); + arguments.put("toTensor", "true"); + arguments.put("normalize", "true"); + arguments.put("resize", "true"); + arguments.put("width", String.valueOf(width)); + arguments.put("height", String.valueOf(height)); + SemanticSegmentationTranslator translator = + SemanticSegmentationTranslator.builder(arguments).build(); + + Criteria criteria = + Criteria.builder() + .setTypes(Image.class, Image.class) + .optModelUrls(url) + .optTranslator(translator) + .optEngine("PyTorch") + .optProgress(new ProgressBar()) + .build(); + try (ZooModel model = criteria.loadModel()) { + try (Predictor predictor = model.newPredictor()) { + Image semanticImage = predictor.predict(img); + saveSemanticImage(semanticImage); + } + } + } + + private static void saveSemanticImage(Image img) throws IOException { + Path outputDir = Paths.get("build/output"); + Files.createDirectories(outputDir); + + Path imagePath = outputDir.resolve("semantic_instances.png"); + img.save(Files.newOutputStream(imagePath), "png"); + logger.info("Segmentation result image has been saved in: {}", imagePath); + } +}