Skip to content

Commit

Permalink
remove tokenizer and transformations not used in the forward pass
Browse files Browse the repository at this point in the history
  • Loading branch information
dsikka committed Jun 21, 2023
1 parent 28bd2e5 commit 88cacfa
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 10 deletions.
5 changes: 1 addition & 4 deletions integrations/clip/clip_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@


class VisualModel(nn.Module):
def __init__(self, visual_model, transformations, output_tokens):
def __init__(self, visual_model, output_tokens):

super().__init__()

self.visual_model = visual_model
self.transformations = transformations
self.visual_model.output_tokens = output_tokens

def forward(self, x):
Expand All @@ -33,7 +32,6 @@ class TextModel(nn.Module):
def __init__(
self,
token_embedding,
tokenizer,
positional_embedding,
transformer,
ln_final,
Expand All @@ -44,7 +42,6 @@ def __init__(
super().__init__()

self.token_embedding = token_embedding
self.tokenizer = tokenizer
self.positional_embedding = positional_embedding
self.transformer = transformer
self.ln_final = ln_final
Expand Down
8 changes: 2 additions & 6 deletions integrations/clip/clip_onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from typing import Any, Union

import torch
from torchvision.transforms.transforms import Compose

import open_clip
from clip_models import TextModel, VisualModel
Expand All @@ -53,14 +52,12 @@ def _export_visual(
model: torch.nn.Module,
device: str,
export_path: Union[str, Path],
transformations: Compose,
is_coca: bool,
**export_kwargs,
):
module_name = "clip_visual.onnx"
visual_model = VisualModel(
visual_model=model.visual,
transformations=transformations,
output_tokens=is_coca,
)

Expand Down Expand Up @@ -92,7 +89,6 @@ def _export_text(
else:
text_model = TextModel(
token_embedding=model.token_embedding,
tokenizer=tokenizer,
positional_embedding=model.positional_embedding,
transformer=model.transformer,
ln_final=model.ln_final,
Expand Down Expand Up @@ -211,14 +207,14 @@ def main():
"do_constant_folding": True,
}

model, _, transform = open_clip.create_model_and_transforms(
model, _, _ = open_clip.create_model_and_transforms(
model_name=args.model, pretrained=args.pretrained
)

tokenizer = open_clip.get_tokenizer(args.model)
is_coca = "coca" in args.model

_export_visual(model, device, clip_onnx_path, transform, is_coca, **export_kwargs)
_export_visual(model, device, clip_onnx_path, is_coca, **export_kwargs)
_export_text(model, device, clip_onnx_path, tokenizer, is_coca, **export_kwargs)

if is_coca:
Expand Down

0 comments on commit 88cacfa

Please sign in to comment.