Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CLIP]: Sparseml onnx export pathways for open_clip models #1626

Merged
merged 11 commits into from
Jun 23, 2023
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ SparseML enables you to create a sparse model trained on your dataset in two way
This repository is tested on Python 3.8-3.10, and Linux/Debian systems.

It is recommended to install in a [virtual environment](https://docs.python.org/3/library/venv.html) to keep your system in order.
Currently supported ML Frameworks are the following: `torch>=1.1.0,<1.14`, `tensorflow>=1.8.0,<2.0.0`, `tensorflow.keras >= 2.2.0`.
Currently supported ML Frameworks are the following: `torch>=1.1.0,<=2.0`, `tensorflow>=1.8.0,<2.0.0`, `tensorflow.keras >= 2.2.0`.

Install with pip using:

Expand Down
26 changes: 26 additions & 0 deletions integrations/clip/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
<!--
Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->

# CLIP Export Examples

The examples in `clip_onnx_export.py` provide the steps needed to export a CLIP model using sparseml's onnx exporting functionality. The models and pretrained weights are pulled in from [OpenClip](https://github.com/mlfoundations/open_clip/tree/main) and the command line tools provided allow exporting of a given model's Text and Visual branches. See the OpenClip repository for a full list of available models. For the CoCa models available in OpenClip, an additional text-decoder is also exported.

## Installation

The examples provided require torch nighly and `open_clip_torch==2.20.0` to be installed. To work within the `sparseml` environment, be sure to set the environment variable `MAX_TORCH` to your installed version when
installing torch nightly.

Example: `MAX_TORCH="2.1.0.dev20230613+cpu"`
61 changes: 61 additions & 0 deletions integrations/clip/clip_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn as nn


class VisualModel(nn.Module):
def __init__(self, visual_model: torch.nn.Module, output_tokens: bool):

super().__init__()

self.visual_model = visual_model
self.visual_model.output_tokens = output_tokens

def forward(self, x):
return self.visual_model(x)


class TextModel(nn.Module):
def __init__(
self,
token_embedding: torch.nn.Embedding,
positional_embedding: torch.nn.parameter.Parameter,
transformer: torch.nn.Module,
ln_final: torch.nn.LayerNorm,
text_projection: torch.nn.parameter.Parameter,
attn_mask: torch.Tensor,
):

super().__init__()

self.token_embedding = token_embedding
self.positional_embedding = positional_embedding
self.transformer = transformer
self.ln_final = ln_final
self.text_projection = text_projection
self.attn_mask = attn_mask
self.cast_dtype = self.transformer.get_cast_dtype()

def forward(self, input_ids):
x = self.token_embedding(input_ids).to(self.cast_dtype)
x = x + self.positional_embedding.to(self.cast_dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x, attn_mask=self.attn_mask)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token = highest in each sequence)
x = x[torch.arange(x.shape[0]), input_ids.argmax(dim=-1)] @ self.text_projection
dsikka marked this conversation as resolved.
Show resolved Hide resolved
return x
229 changes: 229 additions & 0 deletions integrations/clip/clip_onnx_export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Examples on how to use sparsemls's onnx export functionality to export CLIP visual
and text models using the OpenCLIP API.

Note: This requires torch nightly and openclip to be installed:
https://github.com/mlfoundations/open_clip

"""
import argparse
rahul-tuli marked this conversation as resolved.
Show resolved Hide resolved
from collections import OrderedDict
from pathlib import Path
from typing import Any, Union

import torch

import open_clip
from clip_models import TextModel, VisualModel
from sparseml.pytorch.utils import export_onnx


def _export_onnx(
module: torch.nn.Module,
sample_batch: Any,
file_path: Union[Path, str],
opset: int = 14,
**export_kwargs,
):
export_onnx(
module=module,
sample_batch=sample_batch,
opset=opset,
file_path=file_path,
**export_kwargs,
)
dsikka marked this conversation as resolved.
Show resolved Hide resolved


def _export_visual(
model: torch.nn.Module,
device: str,
export_path: Union[str, Path],
is_coca: bool,
**export_kwargs,
):
module_name = "clip_visual.onnx"
visual_model = VisualModel(
visual_model=model.visual,
output_tokens=is_coca,
)

image_shape = visual_model.visual_model.image_size[0]
sample_input = torch.randn(1, 3, image_shape, image_shape, requires_grad=True)

visual_model = visual_model.to(device)
visual_model.eval()

_export_onnx(
module=visual_model,
sample_batch=sample_input,
file_path=export_path / module_name,
**export_kwargs,
)


def _export_text(
model: torch.nn.Module,
device: str,
export_path: Union[str, Path],
tokenizer,
is_coca: bool,
**export_kwargs,
):
module_name = "clip_text.onnx"
dsikka marked this conversation as resolved.
Show resolved Hide resolved
# If the model is a CLIP CoCa model, store the text model as is. For non-CoCa
# models, OpenCLIP does not provide access to the text model, only the transformer
# therefore in that case, create a new TextModel object to wrap the transformer
# and all relevant properties needed for the forward pass.
if is_coca:
dsikka marked this conversation as resolved.
Show resolved Hide resolved
text_model = model.text
else:
text_model = TextModel(
token_embedding=model.token_embedding,
positional_embedding=model.positional_embedding,
transformer=model.transformer,
ln_final=model.ln_final,
text_projection=model.text_projection,
attn_mask=model.attn_mask,
)

text_model = text_model.to(device)
text_model.eval()
sample_batch = tokenizer(["a dog"])

if is_coca:
sample_batch = sample_batch[:, :-1]

_export_onnx(
module=text_model,
sample_batch=sample_batch,
file_path=export_path / module_name,
**export_kwargs,
)


def _export_text_decoder(
model: torch.nn.Module, device: str, export_path: Union[str, Path], **export_kwargs
):

module_name = "clip_text_decoder.onnx"
decoder = model.text_decoder.to(device)
decoder.eval()

sample_batch = OrderedDict()
sample_batch["image_embs"] = torch.randn(1, 255, model.text.output_dim)
sample_batch["text_embs"] = torch.randn(
1, model.text.context_length, model.text.output_dim
)

_export_onnx(
module=decoder,
sample_batch=sample_batch,
file_path=export_path / module_name,
**export_kwargs,
)


def main():
"""
Given a model name and pretrained weights (see OpenClip for available options),
the text and visual branches for CLIP are exported to onnx using sparseml's
exporting functionality. Commandline tools are provided to export a specific model/
pretraining however, by default, the visual and text branches of the ViT-B-32 model
will be exported and saved to a directory called `clip_onnx`. A custom path can
also be provided using the `export-path` argument. Custom names for the input and
output nodes of the graph can also be assigned, using the `input_name` and
`output_name` arguments.

Specifically for CoCa models, an additional text-decoder is also exported and saved
in the same folder. Currently, only coca_ViT-B-32 and coca_ViT-L-14 are supported.

Example:
python clip_onnx_export.py --model convnext_base_w_320 \
--pretrained laion_aesthetic_s13b_b82k --export-path convnext_onnx

======== Diagnostic Run torch.onnx.export version 2.1.0.dev20230613+cpu ========
verbose: False, log level: 40
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================

======== Diagnostic Run torch.onnx.export version 2.1.0.dev20230613+cpu ========
verbose: False, log level: 40
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================

"""
parser = argparse.ArgumentParser(
description="Fetch CLIP models and export to onnx using sparseml"
)
parser.add_argument(
"--model",
type=str,
default="ViT-B-32",
help="Name of CLIP model. See OpenClip docs for a list of available models",
)
parser.add_argument(
"--pretrained",
type=str,
default="laion2b_s34b_b79k",
help="Name of the pretraining to use. See OpenClip docs for a list of options.",
)
parser.add_argument(
"--export-path",
type=str,
default="clip_onnx",
help="Path of the directory to which the onnx outputs will be saved.",
)
parser.add_argument(
"--input_name",
type=str,
default="inputs",
help="names to assign to the input nodes",
)
parser.add_argument(
"--output_name",
type=str,
default="outputs",
help="names to assign to the output nodes",
)

args = parser.parse_args()

device = "cpu"
clip_onnx_path = Path(args.export_path)

input_names = [args.input_name]
output_names = [args.output_name]
export_kwargs = {
"input_names": input_names,
"output_names": output_names,
"do_constant_folding": True,
}

model, _, _ = open_clip.create_model_and_transforms(
model_name=args.model, pretrained=args.pretrained
)

tokenizer = open_clip.get_tokenizer(args.model)
is_coca = "coca" in args.model

_export_visual(model, device, clip_onnx_path, is_coca, **export_kwargs)
_export_text(model, device, clip_onnx_path, tokenizer, is_coca, **export_kwargs)

if is_coca:
_export_text_decoder(model, device, clip_onnx_path, **export_kwargs)


if __name__ == "__main__":
main()
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
_deepsparse_ent_deps = [f"deepsparse-ent~={version_nm_deps}"]

_onnxruntime_deps = ["onnxruntime>=1.0.0"]
_clip_deps = ["open_clip_torch==2.20.0"]
supported_torch_version = "torch>=1.7.0,<=2.0"
_pytorch_deps = [
supported_torch_version,
Expand Down Expand Up @@ -144,6 +145,7 @@ def _setup_install_requires() -> List:

def _setup_extras() -> Dict:
return {
"clip": _clip_deps,
"dev": _dev_deps,
"deepsparse": _deepsparse_deps,
"deepsparse-ent": _deepsparse_ent_deps,
Expand Down
4 changes: 2 additions & 2 deletions src/sparseml/pytorch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import functools
import os
from typing import Optional

from sparseml.base import check_version
Expand Down Expand Up @@ -49,7 +49,7 @@


_TORCH_MIN_VERSION = "1.0.0"
_TORCH_MAX_VERSION = "2.0.100"
_TORCH_MAX_VERSION = os.environ.get("MAX_TORCH", "2.0.100")


def check_torch_install(
Expand Down
Loading