Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[examples] Re-organize CV examaples #3135

Merged
merged 1 commit into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/development/development_guideline.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ Before you run any examples in IntelliJ, configure your Application template as
3. Change the "Working directory:" value to: "$MODULE_WORKING_DIR$".
4. Select "OK" to save the template.

Navigate to the 'examples' module. Open the class that you want to execute (e.g. ai.djl.examples.inference.ObjectDetection).
Navigate to the 'examples' module. Open the class that you want to execute (e.g. ai.djl.examples.inference.cv.ObjectDetection).
Select the triangle at the class declaration line. A popup menu appears with 3 items:

- Run 'ObjectDetection.main()'
Expand Down
2 changes: 1 addition & 1 deletion examples/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ dependencies {
}

application {
mainClass = System.getProperty("main", "ai.djl.examples.inference.ObjectDetection")
mainClass = System.getProperty("main", "ai.djl.examples.inference.cv.ObjectDetection")
}

run {
Expand Down
4 changes: 2 additions & 2 deletions examples/docs/action_recognition.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Action recognition is a computer vision technique to infer human actions (presen

In this example, you learn how to implement inference code with a [ModelZoo model](../../docs/model-zoo.md) to detect human actions in an image.

The source code can be found at [ActionRecognition.java](https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/ActionRecognition.java).
The source code can be found at [ActionRecognition.java](https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/cv/ActionRecognition.java).

## Setup Guide

Expand All @@ -22,7 +22,7 @@ Use the following command to run the project:

```sh
cd examples
./gradlew run -Dmain=ai.djl.examples.inference.ActionRecognition
./gradlew run -Dmain=ai.djl.examples.inference.cv.ActionRecognition
```

Your output should look like the following:
Expand Down
4 changes: 2 additions & 2 deletions examples/docs/biggan.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ They consist of 2 neural networks that act as adversaries, the Generator and the

In this example, you will learn how to use a [BigGAN](https://deepmind.com/research/open-source/biggan) generator to create images, using the generator directly from the [ModelZoo](../../docs/model-zoo.md).

The source code for this example can be found at [BigGAN.java](https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/BigGAN.java).
The source code for this example can be found at [BigGAN.java](https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/cv/BigGAN.java).

## Setup guide

Expand All @@ -30,7 +30,7 @@ Use the following commands to run the project:

```sh
cd examples
./gradlew run -Dmain=ai.djl.examples.inference.BigGAN
./gradlew run -Dmain=ai.djl.examples.inference.cv.BigGAN
```

### Output
Expand Down
4 changes: 2 additions & 2 deletions examples/docs/image_classification.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Image classification refers to the task of extracting information classes from a

In this example, you learn how to implement inference code with Deep Java Library (DJL) to recognize handwritten digits from an image.

The image classification example code can be found at [ImageClassification.java](https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/ImageClassification.java).
The image classification example code can be found at [ImageClassification.java](https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/cv/ImageClassification.java).

You can also use the [Jupyter notebook tutorial](http://docs.djl.ai/docs/demos/jupyter/tutorial/03_image_classification_with_your_model.html).
The Jupyter notebook explains the key concepts in detail.
Expand Down Expand Up @@ -32,7 +32,7 @@ Run the project by using the following command:

```sh
cd examples
./gradlew run -Dmain=ai.djl.examples.inference.ImageClassification
./gradlew run -Dmain=ai.djl.examples.inference.cv.ImageClassification
```

Your output should look like the following:
Expand Down
4 changes: 2 additions & 2 deletions examples/docs/instance_segmentation.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ In this example, you learn how to implement inference code with Deep Java Librar

The following is the instance segmentation example source code:

[InstanceSegmentation.java](https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/InstanceSegmentation.java).
[InstanceSegmentation.java](https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/cv/InstanceSegmentation.java).


## Setup guide
Expand All @@ -24,7 +24,7 @@ You can find the image used in this example in project test resource folder: `sr

```sh
cd examples
./gradlew run -Dmain=ai.djl.examples.inference.InstanceSegmentation
./gradlew run -Dmain=ai.djl.examples.inference.cv.InstanceSegmentation
```

This should produce the following output
Expand Down
6 changes: 3 additions & 3 deletions examples/docs/mask_detection.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@ YOLOv5 is a powerful model for object detection tasks. With the transfer learnin
In this example, we apply it on the [Face Mask Detection dataset](https://www.kaggle.com/datasets/andrewmvd/face-mask-detection?select=images). We first train the YOLOv5s model in Python, with the help of [ATLearn](https://github.com/awslabs/atlearn/blob/main/examples/docs/face_mask_detection.md), a python transfer learning toolkit.
Then, the model is saved as an ONNX model, which is then imported into DJL for inference. We apply it on the mask wearing detection task.

The source code can be found at [MaskDetectionOnnx.java](../src/main/java/ai/djl/examples/inference/MaskDetection.java)
The source code can be found at [MaskDetectionOnnx.java](../src/main/java/ai/djl/examples/inference/cv/MaskDetection.java)

## The training part in ATLearn

We initially attempted to import a pretrained YOLOv5 into DJL, and fine-tune it with the [Face Mask Detection dataset](https://www.kaggle.com/datasets/andrewmvd/face-mask-detection?select=images), similar to [Train ResNet for Fruit Freshness Classficiation](./train_transfer_fresh_fruit.md). However, YOLOv5 can not be converted to a PyTorch traced model, due to its data-dependent execution flow (see this [discussion](https://discuss.pytorch.org/t/yolov5-convert-to-torchscript/150180)), which blocks the idea of retraining a Yolov5 model in DJL. So the training part is entirely in python.

The retraining of YOLOv5 can be found in an example in ATLearn: `examples/docs/face_mask_detection.md`. In this example, the YOLOv5 layers near the input are frozen while those near the output are fine-tuned with the customized data. This follows the transfer learning idea.

In this example, the trained model is first exported to ONNX file, eg. `mask.onnx` and then imported in [MaskDetectionOnnx.java](../src/main/java/ai/djl/examples/inference/MaskDetection.java).
In this example, the trained model is first exported to ONNX file, eg. `mask.onnx` and then imported in [MaskDetectionOnnx.java](../src/main/java/ai/djl/examples/inference/cv/MaskDetection.java).
This ONNX model file can also be used for inference in python, which will serve as a benchmark. See the [tutorial doc](https://github.com/awslabs/atlearn/blob/main/examples/docs/face_mask_detection.md).

## Setup guide
Expand All @@ -32,7 +32,7 @@ Use the following command to run the project:

```sh
cd examples
./gradlew run -Dmain=ai.djl.examples.inference.MaskDetection
./gradlew run -Dmain=ai.djl.examples.inference.cv.MaskDetection
```

Your output should look like the following:
Expand Down
4 changes: 2 additions & 2 deletions examples/docs/object_detection.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ for locating instances of objects in images or videos.

In this example, you learn how to implement inference code with a [ModelZoo model](../../docs/model-zoo.md) to detect dogs in an image.

The source code can be found at [ObjectDetection.java](https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/ObjectDetection.java).
The source code can be found at [ObjectDetection.java](https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/cv/ObjectDetection.java).

You can also use the [Jupyter notebook tutorial](http://docs.djl.ai/docs/demos/jupyter/object_detection_with_model_zoo.html).
The Jupyter notebook explains the key concepts in detail.
Expand All @@ -26,7 +26,7 @@ Use the following command to run the project:

```sh
cd examples
./gradlew run -Dmain=ai.djl.examples.inference.ObjectDetection
./gradlew run -Dmain=ai.djl.examples.inference.cv.ObjectDetection
```

Your output should look like the following:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ In this example we will use pre-trained model from [tensorflow model zoo](https:
The following code has been tested with EfficientDet, SSD MobileNet V2, Faster RCNN Inception Resnet V2,
but should work with most of tensorflow object detection models.

The source code can be found at [ObjectDetectionWithTensorflowSavedModel.java](https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/ObjectDetectionWithTensorflowSavedModel.java).
The source code can be found at [ObjectDetectionWithTensorflowSavedModel.java](https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/cv/ObjectDetectionWithTensorflowSavedModel.java).

## Setup guide

Expand Down Expand Up @@ -45,7 +45,7 @@ Use the following command to run the project:

```sh
cd examples
./gradlew run -Dmain=ai.djl.examples.inference.ObjectDetectionWithTensorflowSavedModel
./gradlew run -Dmain=ai.djl.examples.inference.cv.ObjectDetectionWithTensorflowSavedModel
```

Your output should look like the following:
Expand Down
4 changes: 2 additions & 2 deletions examples/docs/pose_estimation.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Pose estimation is a computer vision technique for determining the pose of an ob

In this example, you learn how to implement inference code with a [ModelZoo model](../../docs/model-zoo.md) to detect people and their joints in an image.

The source code can be found at [PoseEstimation.java](https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/PoseEstimation.java).
The source code can be found at [PoseEstimation.java](https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/cv/PoseEstimation.java).

## Setup guide

Expand All @@ -22,7 +22,7 @@ Use the following command to run the project:

```sh
cd examples
./gradlew run -Dmain=ai.djl.examples.inference.PoseEstimation
./gradlew run -Dmain=ai.djl.examples.inference.cv.PoseEstimation
```

Your output should look like the following:
Expand Down
6 changes: 3 additions & 3 deletions examples/docs/semantic_segmentation.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ In this example, you learn how to implement inference code with Deep Java Librar

The following is the semantic segmentation example source code:

[SemanticSegmentation.java](https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/SemanticSegmentation.java).
[SemanticSegmentation.java](https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/cv/SemanticSegmentation.java).


## Setup guide
Expand All @@ -24,7 +24,7 @@ You can find the image used in this example in project test resource folder: `sr

```sh
cd examples
./gradlew run -Dmain=ai.djl.examples.inference.SemanticSegmentation
./gradlew run -Dmain=ai.djl.examples.inference.cv.SemanticSegmentation
```

This should produce the following output
Expand Down Expand Up @@ -57,7 +57,7 @@ Path imageFile = Paths.get("src/test/resources/dog_bike_car.jpg");

```sh
cd examples
./gradlew run -Dmain=ai.djl.examples.inference.InstanceSegmentation
./gradlew run -Dmain=ai.djl.examples.inference.cv.SemanticSegmentation
```

This should produce the following output
Expand Down
2 changes: 1 addition & 1 deletion examples/docs/super_resolution.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
[Generative Adversarial Networks](https://en.wikipedia.org/wiki/Generative_adversarial_network) (GANs) are a branch of deep learning used for generative modeling.
They consist of 2 neural networks that act as adversaries, the Generator and the Discriminator. The Generator is assigned to generate fake images that look real, and the Discriminator needs to correctly identify the fake ones.

In this example, you will learn how to use a [ESRGAN](https://deepmind.com/research/open-source/biggan) generator to upscale images x4 better than any upscaling algorithm like bicubic or bilinear, using the [generator](https://tfhub.dev/captain-pool/esrgan-tf2/1) available on TensorFlow Hub. ESRGAN is short for Enhanced Super-Resolution Generative Adversarial Networks.
In this example, you will learn how to use a [ESRGAN](https://esrgan.readthedocs.io/en/latest/) generator to upscale images x4 better than any upscaling algorithm like bicubic or bilinear, using the [generator](https://tfhub.dev/captain-pool/esrgan-tf2/1) available on TensorFlow Hub. ESRGAN is short for Enhanced Super-Resolution Generative Adversarial Networks.

The source code for this example can be found in the [examples/sr](https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/sr/) package.

Expand Down
2 changes: 1 addition & 1 deletion examples/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
<maven.compiler.source>11</maven.compiler.source>
<maven.compiler.target>11</maven.compiler.target>
<djl.version>0.28.0-SNAPSHOT</djl.version>
<exec.mainClass>ai.djl.examples.inference.ObjectDetection</exec.mainClass>
<exec.mainClass>ai.djl.examples.inference.cv.ObjectDetection</exec.mainClass>
</properties>

<repositories>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.examples.inference;
package ai.djl.examples.inference.cv;

import ai.djl.Application;
import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
Expand Down Expand Up @@ -44,28 +43,27 @@ public final class ActionRecognition {
private ActionRecognition() {}

public static void main(String[] args) throws IOException, ModelException, TranslateException {
Classifications classification = ActionRecognition.predict();
Classifications classification = predict();
logger.info("{}", classification);
}

public static Classifications predict() throws IOException, ModelException, TranslateException {
Path imageFile = Paths.get("src/test/resources/action_discus_throw.png");
Image img = ImageFactory.getInstance().fromFile(imageFile);

// Use DJL MXNet model zoo model
Criteria<Image, Classifications> criteria =
Criteria.builder()
.optApplication(Application.CV.ACTION_RECOGNITION)
.setTypes(Image.class, Classifications.class)
.optFilter("backbone", "inceptionv3")
.optFilter("dataset", "ucf101")
.optModelUrls(
"djl://ai.djl.mxnet/action_recognition/0.0.1/inceptionv3_ucf101")
.optEngine("MXNet")
.optProgress(new ProgressBar())
.build();

try (ZooModel<Image, Classifications> inception = criteria.loadModel()) {
try (Predictor<Image, Classifications> action = inception.newPredictor()) {
return action.predict(img);
}
try (ZooModel<Image, Classifications> inception = criteria.loadModel();
Predictor<Image, Classifications> action = inception.newPredictor()) {
return action.predict(img);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.examples.inference;
package ai.djl.examples.inference.cv;

import ai.djl.Application;
import ai.djl.ModelException;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.examples.inference;
package ai.djl.examples.inference.cv;

import ai.djl.Model;
import ai.djl.ModelException;
Expand Down Expand Up @@ -48,7 +48,7 @@ public final class ImageClassification {
private ImageClassification() {}

public static void main(String[] args) throws IOException, ModelException, TranslateException {
Classifications classifications = ImageClassification.predict();
Classifications classifications = predict();
logger.info("{}", classifications);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.examples.inference;
package ai.djl.examples.inference.cv;

import ai.djl.Application;
import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
Expand Down Expand Up @@ -45,7 +44,7 @@ public final class InstanceSegmentation {
private InstanceSegmentation() {}

public static void main(String[] args) throws IOException, ModelException, TranslateException {
DetectedObjects detection = InstanceSegmentation.predict();
DetectedObjects detection = predict();
logger.info("{}", detection);
}

Expand All @@ -55,20 +54,18 @@ public static DetectedObjects predict() throws IOException, ModelException, Tran

Criteria<Image, DetectedObjects> criteria =
Criteria.builder()
.optApplication(Application.CV.INSTANCE_SEGMENTATION)
.setTypes(Image.class, DetectedObjects.class)
.optFilter("backbone", "resnet18")
.optFilter("flavor", "v1b")
.optFilter("dataset", "coco")
.optModelUrls(
"djl://ai.djl.mxnet/mask_rcnn/0.0.1/mask_rcnn_resnet18_v1b_coco")
.optEngine("MXNet")
.optProgress(new ProgressBar())
.build();

try (ZooModel<Image, DetectedObjects> model = criteria.loadModel()) {
try (Predictor<Image, DetectedObjects> predictor = model.newPredictor()) {
DetectedObjects detection = predictor.predict(img);
saveBoundingBoxImage(img, detection);
return detection;
}
try (ZooModel<Image, DetectedObjects> model = criteria.loadModel();
Predictor<Image, DetectedObjects> predictor = model.newPredictor()) {
DetectedObjects detection = predictor.predict(img);
saveBoundingBoxImage(img, detection);
return detection;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.examples.inference;
package ai.djl.examples.inference.cv;

import ai.djl.ModelException;
import ai.djl.inference.Predictor;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.examples.inference;
package ai.djl.examples.inference.cv;

import ai.djl.Application;
import ai.djl.ModelException;
Expand Down Expand Up @@ -46,7 +46,7 @@ public final class ObjectDetection {
private ObjectDetection() {}

public static void main(String[] args) throws IOException, ModelException, TranslateException {
DetectedObjects detection = ObjectDetection.predict();
DetectedObjects detection = predict();
logger.info("{}", detection);
}

Expand All @@ -70,12 +70,11 @@ public static DetectedObjects predict() throws IOException, ModelException, Tran
.optProgress(new ProgressBar())
.build();

try (ZooModel<Image, DetectedObjects> model = criteria.loadModel()) {
try (Predictor<Image, DetectedObjects> predictor = model.newPredictor()) {
DetectedObjects detection = predictor.predict(img);
saveBoundingBoxImage(img, detection);
return detection;
}
try (ZooModel<Image, DetectedObjects> model = criteria.loadModel();
Predictor<Image, DetectedObjects> predictor = model.newPredictor()) {
DetectedObjects detection = predictor.predict(img);
saveBoundingBoxImage(img, detection);
return detection;
}
}

Expand Down
Loading
Loading