Skip to content

Commit

Permalink
clean-up pipelines, updatetyping and descriptions
Browse files Browse the repository at this point in the history
  • Loading branch information
dsikka committed Aug 1, 2023
1 parent c1003be commit b0695f6
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 24 deletions.
2 changes: 1 addition & 1 deletion src/deepsparse/clip/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ kwargs = {
pipeline = BasePipeline.create(task="clip_caption", **kwargs)

pipeline_input = CLIPCaptionInput(image=CLIPVisualInput(images="thailand.jpg"))
output = pipeline(pipeline_input)
output = pipeline(pipeline_input).caption
print(output[0])
```
Running the code above, we get the following caption:
Expand Down
37 changes: 26 additions & 11 deletions src/deepsparse/clip/captioning_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,25 @@ class CLIPCaptionOutput(BaseModel):

@BasePipeline.register(task="clip_caption", default_model_path=None)
class CLIPCaptionPipeline(BasePipeline):
"""
Pipelines designed to generate a caption for a given image. The CLIPCaptionPipeline
relies on 3 other pipelines: CLIPVisualPipeline, CLIPTextPipeline, and the
CLIPDecoder Pipeline. The pipeline takes in a single image and then uses the
pipelines along with Beam Search to generate a caption.
:param visual_model_path: either a local path or sparsezoo stub for the CLIP visual
branch onnx model
:param text_model_path: either a local path or sparsezoo stub for the CLIP text
branch onnx model
:param decoder_model_path: either a local path or sparsezoo stub for the CLIP
decoder branch onnx model
:param num_beams: number of beams to use in Beam Search
:param num_beam_groups: number of beam groups to use in Beam Search
:param min_seq_len: the minimum length of the caption sequence
:param max_seq_len: the maxmium length of the caption sequence
"""

def __init__(
self,
visual_model_path: str,
Expand All @@ -61,17 +80,13 @@ def __init__(
num_beams: int = 10,
num_beam_groups: int = 5,
min_seq_len: int = 5,
seq_len: int = 20,
fixed_output_length: bool = False,
max_seq_len: int = 20,
**kwargs,
):
self.num_beams = num_beams
self.num_beam_groups = num_beam_groups
self.seq_len = seq_len
self.max_seq_len = max_seq_len
self.min_seq_len = min_seq_len
self.fixed_output_length = fixed_output_length

super().__init__(**kwargs)

self.visual = Pipeline.create(
task="clip_visual",
Expand All @@ -86,8 +101,9 @@ def __init__(
**{"model_path": decoder_model_path},
)

# TODO: have to verify all input types
def _encode_and_decode(self, text, image_embs):
super().__init__(**kwargs)

def _encode_and_decode(self, text: torch.Tensor, image_embs: torch.Tensor):
original_size = text.shape[-1]
padded_tokens = F.pad(text, (15 - original_size, 0))
text_embeddings = self.text(
Expand All @@ -104,16 +120,15 @@ def _encode_and_decode(self, text, image_embs):
}

# Adapted from open_clip
def _generate(self, pipeline_inputs):
# Make these input values?
def _generate(self, pipeline_inputs: CLIPCaptionInput):
sot_token_id = 49406
eos_token_id = 49407
pad_token_id = 0
batch_size = 1
repetition_penalty = 1.0
device = "cpu"

stopping_criteria = [MaxLengthCriteria(max_length=self.seq_len)]
stopping_criteria = [MaxLengthCriteria(max_length=self.max_seq_len)]
stopping_criteria = StoppingCriteriaList(stopping_criteria)

logits_processor = LogitsProcessorList(
Expand Down
10 changes: 7 additions & 3 deletions src/deepsparse/clip/decoder_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,12 @@ class CLIPDecoderInput(BaseModel):
Input for the CLIP Decoder Branch
"""

text_embeddings: Any = Field(description="Text emebddings from the text branch")
image_embeddings: Any = Field(description="Image embeddings from the visual branch")
text_embeddings: Any = Field(
description="np.array of text emebddings from the " "text branch"
)
image_embeddings: Any = Field(
description="np.array of image embeddings from the " "visual branch"
)


class CLIPDecoderOutput(BaseModel):
Expand All @@ -39,7 +43,7 @@ class CLIPDecoderOutput(BaseModel):
"""

logits: List[Any] = Field(
description="Logits produced from the text and image emebeddings."
description="np.array of logits produced from the decoder."
)


Expand Down
8 changes: 5 additions & 3 deletions src/deepsparse/clip/text_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class CLIPTextInput(BaseModel):
"""

text: Union[str, List[str], Any, List[Any]] = Field(
description="Either raw text or text embeddings"
description="Either raw strings or an np.array with tokenized text"
)


Expand All @@ -41,8 +41,10 @@ class CLIPTextOutput(BaseModel):
"""

text_embeddings: List[Any] = Field(
description="Text embeddings for the single text or list of embeddings for "
"multiple."
description="np.array of text embeddings. For the caption "
"pipeline, a list of two embeddings is produced. For zero-shot "
"classifcation, one array is produced with the embeddings stacked along "
"batch axis."
)


Expand Down
6 changes: 4 additions & 2 deletions src/deepsparse/clip/visual_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@ class CLIPVisualOutput(BaseModel):
"""

image_embeddings: List[Any] = Field(
description="Image embeddings for the single image or list of embeddings for "
"multiple images"
description="np.arrays consisting of image embeddings. For the caption "
"pipeline, a list of two image embeddings is produced. For zero-shot "
"classifcation, one array is produced with the embeddings stacked along "
"batch axis."
)


Expand Down
13 changes: 9 additions & 4 deletions src/deepsparse/clip/zeroshot_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,23 @@ class CLIPZeroShotInput(BaseModel):
"""

image: CLIPVisualInput = Field(
description="Path to image to run zero-shot prediction on."
description="Image(s) to run zero-shot prediction. See CLIPVisualPipeline "
"for details."
)
text: CLIPTextInput = Field(
description="Text/classes to run zero-shot prediction "
"see CLIPTextPipeline for details."
)
text: CLIPTextInput = Field(description="List of text to process")


class CLIPZeroShotOutput(BaseModel):
"""
Output for the CLIP Zero Shot Model
"""

# TODO: Maybe change this to a dictionary where keys are text inputs
text_scores: List[Any] = Field(description="Probability of each text class")
text_scores: List[Any] = Field(
description="np.array consisting of probabilities " " each class provided."
)


@BasePipeline.register(task="clip_zeroshot", default_model_path=None)
Expand Down

0 comments on commit b0695f6

Please sign in to comment.