Skip to content

Commit

Permalink
[CLIP] Update clip dependencies, README.md, export pathways (#1711)
Browse files Browse the repository at this point in the history
* fix instructions; export pathways

* update variable; quality
  • Loading branch information
dsikka committed Aug 25, 2023
1 parent f4a9a2f commit 9a5db83
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 28 deletions.
16 changes: 13 additions & 3 deletions integrations/clip/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,21 @@ limitations under the License.

# CLIP Export Examples

The examples in `clip_onnx_export.py` provide the steps needed to export a CLIP model using sparseml's onnx exporting functionality. The models and pretrained weights are pulled in from [OpenClip](https://github.com/mlfoundations/open_clip/tree/main) and the command line tools provided allow exporting of a given model's Text and Visual branches. See the OpenClip repository for a full list of available models. For the CoCa models available in OpenClip, an additional text-decoder is also exported.
The examples in `clip_onnx_export.py` provide the steps needed to export a CLIP model using sparseml's onnx exporting functionality. The models and pretrained weights are pulled in from [OpenClip](https://github.com/mlfoundations/open_clip/tree/main) and the command line tools provided allow exporting of a given model's Text and Visual branches. See the OpenClip repository for a full list of available models. For the CoCa/Caption models available in OpenClip, an additional text-decoder is also exported.

## Installation

The examples provided require torch nighly and `open_clip_torch==2.20.0` to be installed. To work within the `sparseml` environment, be sure to set the environment variable `MAX_TORCH` to your installed version when
The examples provided require `open_clip_torch==2.20.0` to be installed along with **torch nighly**. To work within the `sparseml` environment, be sure to set the environment variable `MAX_TORCH` to your installed version when
installing torch nightly.

Example: `MAX_TORCH="2.1.0.dev20230613+cpu"`
Steps:
- Install `sparseml[clip]`. This will ensure open_clip_torch is installed
- Uninstall torch by running:
```
pip uninstall torch
```
- Install torch nightly:
```
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/
```
- Set your environment variable to the correct torch version: Example: `export MAX_TORCH="2.1.0.dev20230613+cpu"`
12 changes: 0 additions & 12 deletions integrations/clip/clip_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,6 @@
import torch.nn as nn


class VisualModel(nn.Module):
def __init__(self, visual_model: torch.nn.Module, output_tokens: bool):

super().__init__()

self.visual_model = visual_model
self.visual_model.output_tokens = output_tokens

def forward(self, x):
return self.visual_model(x)


class TextModel(nn.Module):
def __init__(
self,
Expand Down
22 changes: 9 additions & 13 deletions integrations/clip/clip_onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import torch

import open_clip
from clip_models import TextModel, VisualModel
from clip_models import TextModel
from sparseml.pytorch.utils import export_onnx


Expand Down Expand Up @@ -61,12 +61,9 @@ def _export_visual(
**export_kwargs,
):
module_name = "clip_visual.onnx"
visual_model = VisualModel(
visual_model=model.visual,
output_tokens=is_coca,
)
visual_model = model.visual

image_shape = visual_model.visual_model.image_size[0]
image_shape = visual_model.image_size[0]
sample_input = torch.randn(1, 3, image_shape, image_shape, requires_grad=True)

visual_model = visual_model.to(device)
Expand Down Expand Up @@ -107,10 +104,11 @@ def _export_text(

text_model = text_model.to(device)
text_model.eval()
sample_batch = tokenizer(["a dog"])

if is_coca:
sample_batch = sample_batch[:, :-1]
sample_batch = torch.ones(6, 15, dtype=torch.long)
else:
sample_batch = tokenizer(["a dog"]).to(torch.int32)

_export_onnx(
module=text_model,
Expand All @@ -130,9 +128,7 @@ def _export_text_decoder(

sample_batch = OrderedDict()
sample_batch["image_embs"] = torch.randn(1, 255, model.text.output_dim)
sample_batch["text_embs"] = torch.randn(
1, model.text.context_length, model.text.output_dim
)
sample_batch["text_embs"] = torch.randn(1, 15, model.text.output_dim)

_export_onnx(
module=decoder,
Expand Down Expand Up @@ -175,13 +171,13 @@ def main():
parser.add_argument(
"--model",
type=str,
default="ViT-B-32",
default="coca_ViT-B-32",
help="Name of CLIP model. See OpenClip docs for a list of available models",
)
parser.add_argument(
"--pretrained",
type=str,
default="laion2b_s34b_b79k",
default="mscoco_finetuned_laion2b_s13b_b90k",
help="Name of the pretraining to use. See OpenClip docs for a list of options.",
)
parser.add_argument(
Expand Down

0 comments on commit 9a5db83

Please sign in to comment.