-
Notifications
You must be signed in to change notification settings - Fork 140
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[CLIP]: Sparseml onnx export pathways for open_clip models (#1626)
* Add example script on how to export open_clip models using sparseml onnx export; add additional models to encapsulate operations needed during inference time * update setup.py to add clip dependency, add copyright * format fixes * add docstring and example cli command * Add decoder export; update docstring * format fix * remove tokenizer and transformations not used in the forward pass * update readme to bumpup pytorch version; add torch nightly caveat for clip * Update README, use env variable for max torch, add typing * throw an error if opset < 14 is used
- Loading branch information
Showing
6 changed files
with
326 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
<!-- | ||
Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, | ||
software distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
--> | ||
|
||
# CLIP Export Examples | ||
|
||
The examples in `clip_onnx_export.py` provide the steps needed to export a CLIP model using sparseml's onnx exporting functionality. The models and pretrained weights are pulled in from [OpenClip](https://github.com/mlfoundations/open_clip/tree/main) and the command line tools provided allow exporting of a given model's Text and Visual branches. See the OpenClip repository for a full list of available models. For the CoCa models available in OpenClip, an additional text-decoder is also exported. | ||
|
||
## Installation | ||
|
||
The examples provided require torch nighly and `open_clip_torch==2.20.0` to be installed. To work within the `sparseml` environment, be sure to set the environment variable `MAX_TORCH` to your installed version when | ||
installing torch nightly. | ||
|
||
Example: `MAX_TORCH="2.1.0.dev20230613+cpu"` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
class VisualModel(nn.Module): | ||
def __init__(self, visual_model: torch.nn.Module, output_tokens: bool): | ||
|
||
super().__init__() | ||
|
||
self.visual_model = visual_model | ||
self.visual_model.output_tokens = output_tokens | ||
|
||
def forward(self, x): | ||
return self.visual_model(x) | ||
|
||
|
||
class TextModel(nn.Module): | ||
def __init__( | ||
self, | ||
token_embedding: torch.nn.Embedding, | ||
positional_embedding: torch.nn.parameter.Parameter, | ||
transformer: torch.nn.Module, | ||
ln_final: torch.nn.LayerNorm, | ||
text_projection: torch.nn.parameter.Parameter, | ||
attn_mask: torch.Tensor, | ||
): | ||
|
||
super().__init__() | ||
|
||
self.token_embedding = token_embedding | ||
self.positional_embedding = positional_embedding | ||
self.transformer = transformer | ||
self.ln_final = ln_final | ||
self.text_projection = text_projection | ||
self.attn_mask = attn_mask | ||
self.cast_dtype = self.transformer.get_cast_dtype() | ||
|
||
def forward(self, input_ids): | ||
x = self.token_embedding(input_ids).to(self.cast_dtype) | ||
x = x + self.positional_embedding.to(self.cast_dtype) | ||
x = x.permute(1, 0, 2) # NLD -> LND | ||
x = self.transformer(x, attn_mask=self.attn_mask) | ||
x = x.permute(1, 0, 2) # LND -> NLD | ||
x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] | ||
# take features from the eot embedding (eot_token = highest in each sequence) | ||
x = x[torch.arange(x.shape[0]), input_ids.argmax(dim=-1)] @ self.text_projection | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,234 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
""" | ||
Examples on how to use sparsemls's onnx export functionality to export CLIP visual | ||
and text models using the OpenCLIP API. | ||
Note: This requires torch nightly and openclip to be installed: | ||
https://github.com/mlfoundations/open_clip | ||
""" | ||
import argparse | ||
from collections import OrderedDict | ||
from pathlib import Path | ||
from typing import Any, Union | ||
|
||
import torch | ||
|
||
import open_clip | ||
from clip_models import TextModel, VisualModel | ||
from sparseml.pytorch.utils import export_onnx | ||
|
||
|
||
def _export_onnx( | ||
module: torch.nn.Module, | ||
sample_batch: Any, | ||
file_path: Union[Path, str], | ||
opset: int = 14, | ||
**export_kwargs, | ||
): | ||
# _export_onnx by default uses opset = 14 as required by CLIP and will fail | ||
# for opset < 14 as certain operators are not supported. | ||
if opset < 14: | ||
raise ValueError("CLIP onnx export requires a minimum opset of 14") | ||
|
||
export_onnx( | ||
module=module, | ||
sample_batch=sample_batch, | ||
opset=opset, | ||
file_path=file_path, | ||
**export_kwargs, | ||
) | ||
|
||
|
||
def _export_visual( | ||
model: torch.nn.Module, | ||
device: str, | ||
export_path: Union[str, Path], | ||
is_coca: bool, | ||
**export_kwargs, | ||
): | ||
module_name = "clip_visual.onnx" | ||
visual_model = VisualModel( | ||
visual_model=model.visual, | ||
output_tokens=is_coca, | ||
) | ||
|
||
image_shape = visual_model.visual_model.image_size[0] | ||
sample_input = torch.randn(1, 3, image_shape, image_shape, requires_grad=True) | ||
|
||
visual_model = visual_model.to(device) | ||
visual_model.eval() | ||
|
||
_export_onnx( | ||
module=visual_model, | ||
sample_batch=sample_input, | ||
file_path=export_path / module_name, | ||
**export_kwargs, | ||
) | ||
|
||
|
||
def _export_text( | ||
model: torch.nn.Module, | ||
device: str, | ||
export_path: Union[str, Path], | ||
tokenizer, | ||
is_coca: bool, | ||
**export_kwargs, | ||
): | ||
module_name = "clip_text.onnx" | ||
# If the model is a CLIP CoCa model, store the text model as is. For non-CoCa | ||
# models, OpenCLIP does not provide access to the text model, only the transformer | ||
# therefore in that case, create a new TextModel object to wrap the transformer | ||
# and all relevant properties needed for the forward pass. | ||
if is_coca: | ||
text_model = model.text | ||
else: | ||
text_model = TextModel( | ||
token_embedding=model.token_embedding, | ||
positional_embedding=model.positional_embedding, | ||
transformer=model.transformer, | ||
ln_final=model.ln_final, | ||
text_projection=model.text_projection, | ||
attn_mask=model.attn_mask, | ||
) | ||
|
||
text_model = text_model.to(device) | ||
text_model.eval() | ||
sample_batch = tokenizer(["a dog"]) | ||
|
||
if is_coca: | ||
sample_batch = sample_batch[:, :-1] | ||
|
||
_export_onnx( | ||
module=text_model, | ||
sample_batch=sample_batch, | ||
file_path=export_path / module_name, | ||
**export_kwargs, | ||
) | ||
|
||
|
||
def _export_text_decoder( | ||
model: torch.nn.Module, device: str, export_path: Union[str, Path], **export_kwargs | ||
): | ||
|
||
module_name = "clip_text_decoder.onnx" | ||
decoder = model.text_decoder.to(device) | ||
decoder.eval() | ||
|
||
sample_batch = OrderedDict() | ||
sample_batch["image_embs"] = torch.randn(1, 255, model.text.output_dim) | ||
sample_batch["text_embs"] = torch.randn( | ||
1, model.text.context_length, model.text.output_dim | ||
) | ||
|
||
_export_onnx( | ||
module=decoder, | ||
sample_batch=sample_batch, | ||
file_path=export_path / module_name, | ||
**export_kwargs, | ||
) | ||
|
||
|
||
def main(): | ||
""" | ||
Given a model name and pretrained weights (see OpenClip for available options), | ||
the text and visual branches for CLIP are exported to onnx using sparseml's | ||
exporting functionality. Commandline tools are provided to export a specific model/ | ||
pretraining however, by default, the visual and text branches of the ViT-B-32 model | ||
will be exported and saved to a directory called `clip_onnx`. A custom path can | ||
also be provided using the `export-path` argument. Custom names for the input and | ||
output nodes of the graph can also be assigned, using the `input_name` and | ||
`output_name` arguments. | ||
Specifically for CoCa models, an additional text-decoder is also exported and saved | ||
in the same folder. Currently, only coca_ViT-B-32 and coca_ViT-L-14 are supported. | ||
Example: | ||
python clip_onnx_export.py --model convnext_base_w_320 \ | ||
--pretrained laion_aesthetic_s13b_b82k --export-path convnext_onnx | ||
======== Diagnostic Run torch.onnx.export version 2.1.0.dev20230613+cpu ======== | ||
verbose: False, log level: 40 | ||
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ======================== | ||
======== Diagnostic Run torch.onnx.export version 2.1.0.dev20230613+cpu ======== | ||
verbose: False, log level: 40 | ||
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ======================== | ||
""" | ||
parser = argparse.ArgumentParser( | ||
description="Fetch CLIP models and export to onnx using sparseml" | ||
) | ||
parser.add_argument( | ||
"--model", | ||
type=str, | ||
default="ViT-B-32", | ||
help="Name of CLIP model. See OpenClip docs for a list of available models", | ||
) | ||
parser.add_argument( | ||
"--pretrained", | ||
type=str, | ||
default="laion2b_s34b_b79k", | ||
help="Name of the pretraining to use. See OpenClip docs for a list of options.", | ||
) | ||
parser.add_argument( | ||
"--export-path", | ||
type=str, | ||
default="clip_onnx", | ||
help="Path of the directory to which the onnx outputs will be saved.", | ||
) | ||
parser.add_argument( | ||
"--input_name", | ||
type=str, | ||
default="inputs", | ||
help="names to assign to the input nodes", | ||
) | ||
parser.add_argument( | ||
"--output_name", | ||
type=str, | ||
default="outputs", | ||
help="names to assign to the output nodes", | ||
) | ||
|
||
args = parser.parse_args() | ||
|
||
device = "cpu" | ||
clip_onnx_path = Path(args.export_path) | ||
|
||
input_names = [args.input_name] | ||
output_names = [args.output_name] | ||
export_kwargs = { | ||
"input_names": input_names, | ||
"output_names": output_names, | ||
"do_constant_folding": True, | ||
} | ||
|
||
model, _, _ = open_clip.create_model_and_transforms( | ||
model_name=args.model, pretrained=args.pretrained | ||
) | ||
|
||
tokenizer = open_clip.get_tokenizer(args.model) | ||
is_coca = "coca" in args.model | ||
|
||
_export_visual(model, device, clip_onnx_path, is_coca, **export_kwargs) | ||
_export_text(model, device, clip_onnx_path, tokenizer, is_coca, **export_kwargs) | ||
|
||
if is_coca: | ||
_export_text_decoder(model, device, clip_onnx_path, **export_kwargs) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters