From 88cacfaf435a07b950657ef16105ad7ecfc90708 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Wed, 21 Jun 2023 03:20:11 +0000 Subject: [PATCH] remove tokenizer and transformations not used in the forward pass --- integrations/clip/clip_models.py | 5 +---- integrations/clip/clip_onnx_export.py | 8 ++------ 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/integrations/clip/clip_models.py b/integrations/clip/clip_models.py index ad194040407..21aa536293f 100644 --- a/integrations/clip/clip_models.py +++ b/integrations/clip/clip_models.py @@ -17,12 +17,11 @@ class VisualModel(nn.Module): - def __init__(self, visual_model, transformations, output_tokens): + def __init__(self, visual_model, output_tokens): super().__init__() self.visual_model = visual_model - self.transformations = transformations self.visual_model.output_tokens = output_tokens def forward(self, x): @@ -33,7 +32,6 @@ class TextModel(nn.Module): def __init__( self, token_embedding, - tokenizer, positional_embedding, transformer, ln_final, @@ -44,7 +42,6 @@ def __init__( super().__init__() self.token_embedding = token_embedding - self.tokenizer = tokenizer self.positional_embedding = positional_embedding self.transformer = transformer self.ln_final = ln_final diff --git a/integrations/clip/clip_onnx_export.py b/integrations/clip/clip_onnx_export.py index eaaf65f24ea..8162c5e7ce5 100644 --- a/integrations/clip/clip_onnx_export.py +++ b/integrations/clip/clip_onnx_export.py @@ -26,7 +26,6 @@ from typing import Any, Union import torch -from torchvision.transforms.transforms import Compose import open_clip from clip_models import TextModel, VisualModel @@ -53,14 +52,12 @@ def _export_visual( model: torch.nn.Module, device: str, export_path: Union[str, Path], - transformations: Compose, is_coca: bool, **export_kwargs, ): module_name = "clip_visual.onnx" visual_model = VisualModel( visual_model=model.visual, - transformations=transformations, output_tokens=is_coca, ) @@ -92,7 +89,6 @@ def _export_text( else: text_model = TextModel( token_embedding=model.token_embedding, - tokenizer=tokenizer, positional_embedding=model.positional_embedding, transformer=model.transformer, ln_final=model.ln_final, @@ -211,14 +207,14 @@ def main(): "do_constant_folding": True, } - model, _, transform = open_clip.create_model_and_transforms( + model, _, _ = open_clip.create_model_and_transforms( model_name=args.model, pretrained=args.pretrained ) tokenizer = open_clip.get_tokenizer(args.model) is_coca = "coca" in args.model - _export_visual(model, device, clip_onnx_path, transform, is_coca, **export_kwargs) + _export_visual(model, device, clip_onnx_path, is_coca, **export_kwargs) _export_text(model, device, clip_onnx_path, tokenizer, is_coca, **export_kwargs) if is_coca: