Skip to content

Commit

Permalink
update README instructions; fix linalg import
Browse files Browse the repository at this point in the history
  • Loading branch information
dsikka committed Jul 31, 2023
1 parent c637dbd commit 402bfc6
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 9 deletions.
6 changes: 5 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,11 @@ def _parse_requirements_file(file_path):
"haystack_reqs.txt",
)
_haystack_integration_deps = _parse_requirements_file(_haystack_requirements_file_path)
_clip_deps = ["open_clip_torch==2.20.0", "scipy==1.10.1"]
_clip_deps = [
"open_clip_torch==2.20.0",
"scipy==1.10.1",
f"{'nm-transformers' if is_release else 'nm-transformers-nightly'}",
]


def _check_supported_system():
Expand Down
57 changes: 52 additions & 5 deletions src/deepsparse/clip/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ DeepSparse allows inference on [CLIP](https://github.com/mlfoundations/open_clip

The CLIP integration currently supports the following task:
- **Zero-shot Image Classification** - Classifying images given possible classes
- **Caption Generation** - Generate a caption given an image

## Getting Started

Expand All @@ -13,24 +14,38 @@ Before you start your adventure with the DeepSparse Engine, make sure that your
```pip install deepsparse[clip]```

### Model Format
By default, to deploy CLIP models using the DeepSparse Engine, it is required to supply the model in the ONNX format. This grants the engine the flexibility to serve any model in a framework-agnostic environment. To see examples of pulling CLIP models and exporting them to ONNX, please see the [sparseml documentation](https://github.com/neuralmagic/sparseml/tree/main/integrations/clip). For the Zero-shot image classification workflow, two ONNX models are required, a visual model for CLIP's visual branch, and a text model for CLIP's text branch. Both of these model should be produced through the sparseml integration linked above.
By default, to deploy CLIP models using the DeepSparse Engine, it is required to supply the model in the ONNX format. This grants the engine the flexibility to serve any model in a framework-agnostic environment. To see examples of pulling CLIP models and exporting them to ONNX, please see the [sparseml documentation](https://github.com/neuralmagic/sparseml/tree/main/integrations/clip).

For the Zero-shot image classification workflow, two ONNX models are required, a visual model for CLIP's visual branch, and a text model for CLIP's text branch. Both of these models can be produced through the sparseml integration linked above. For caption generation, specific models called CoCa models are required and instructions on how to export CoCa models are also provided in the sparseml documentation above. The CoCa exporting pathway will generate one additional decoder model, along with the text and visual models.

### Deployment examples:
The following example uses pipelines to run the CLIP models for inference. As input, the pipeline ingests a list of images and a list of possible classes. A class is returned for each of the provided images.
The following example uses pipelines to run the CLIP models for inference. For Zero-shot prediction, the pipeline ingests a list of images and a list of possible classes. A class is returned for each of the provided images. For caption generation, only an image file is required.

If you don't have images ready, pull down the sample images using the following commands:

```bash
wget -O basilica.jpg https://github.com/raw/neuralmagic/deepsparse/main/src/deepsparse/yolo/sample_images/basilica.jpg
```

```bash
wget -O buddy.jpeg https://github.com/raw/neuralmagic/deepsparse/main/tests/deepsparse/pipelines/sample_images/buddy.jpeg
```

This will pull down two images, one with a happy dog and one with St.Peter's basilica.
```bash
wget -O thailand.jpg https://github.com/raw/neuralmagic/deepsparse/main/src/deepsparse/yolact/sample_images/thailand.jpg
```

<p float="left">
<img src="https://github.com/raw/neuralmagic/deepsparse/main/src/deepsparse/yolo/sample_images/basilica.jpg" width="300" />
<img src="https://github.com/raw/neuralmagic/deepsparse/main/tests/deepsparse/pipelines/sample_images/buddy.jpeg" width="300" />
<img src="https://github.com/raw/neuralmagic/deepsparse/main/src/deepsparse/yolact/sample_images/thailand.jpg" width="300" />
</p>

This will pull down 3 images, a happy dog, St.Peter's basilica, and two elephants.

#### Zero-shot Prediction

Let's run an example to clasify the images. We'll provide the images in a list with their file names as well as a list of possible classes. We'll also provide paths to the exported ONNX models.
Let's run an example to clasify the images. We'll provide the images in a list with their file names as well as a list of possible classes. We'll also provide paths to the exported ONNX models under the `zeroshot_research` root folder.

```python
import numpy as np
Expand All @@ -43,7 +58,7 @@ from deepsparse.clip import (
)

possible_classes = ["ice cream", "an elephant", "a dog", "a building", "a church"]
images = ["basilica.jpg", "buddy.jpeg"]
images = ["basilica.jpg", "buddy.jpeg", "thailand.jpg"]

model_path_text = "zeroshot_research/text/model.onnx"
model_path_visual = "zeroshot_research/visual/model.onnx"
Expand Down Expand Up @@ -72,4 +87,36 @@ DeepSparse, Copyright 2021-present / Neuralmagic, Inc. version: 1.6.0.20230727 C
Image basilica.jpg is a picture of a church
Image buddy.jpeg is a picture of a dog
Image thailand.jpg is a picture of an elephant
```

#### Caption Generation
Let's try a caption generation example. We'll leverage the `thailand.jpg` file that was pulled down earlier. We'll also provide the 3 exported CoCa ONNX models under the `caption_models` folder.

```python
from deepsparse import BasePipeline
from deepsparse.clip import CLIPCaptionInput, CLIPVisualInput

root = "caption_models"
model_path_visual = f"{root}/clip_visual.onnx"
model_path_text = f"{root}/clip_text.onnx"
model_path_decoder = f"{root}/clip_text_decoder.onnx"

kwargs = {
"visual_model_path": model_path_visual,
"text_model_path": model_path_text,
"decoder_model_path": model_path_decoder,
}
pipeline = BasePipeline.create(task="clip_caption", **kwargs)

pipeline_input = CLIPCaptionInput(image=CLIPVisualInput(images="thailand.jpg"))
output = pipeline(pipeline_input)
print(output[0])
```
Running the code above, we get the following caption:

```
DeepSparse, Copyright 2021-present / Neuralmagic, Inc. version: 1.6.0.20230727 COMMUNITY | (3cb4a3e5) (optimized) (system=avx2, binary=avx2)
an adult elephant and a baby elephant .
```
6 changes: 3 additions & 3 deletions src/deepsparse/clip/zeroshot_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from typing import Any, List, Type

import numpy as np
from numpy import linalg as la
from numpy import linalg as lingalg
from pydantic import BaseModel, Field

from deepsparse.clip import CLIPTextInput, CLIPVisualInput
Expand Down Expand Up @@ -82,8 +82,8 @@ def __call__(self, *args, **kwargs):
visual_output = self.visual(pipeline_inputs.image).image_embeddings[0]
text_output = self.text(pipeline_inputs.text).text_embeddings[0]

visual_output /= la.norm(visual_output, axis=-1, keepdims=True)
text_output /= la.norm(text_output, axis=-1, keepdims=True)
visual_output /= lingalg.norm(visual_output, axis=-1, keepdims=True)
text_output /= lingalg.norm(text_output, axis=-1, keepdims=True)

output_product = 100.0 * visual_output @ text_output.T
text_probs = softmax(output_product, axis=-1)
Expand Down

0 comments on commit 402bfc6

Please sign in to comment.