Skip to content

Commit

Permalink
move docstring; add params
Browse files Browse the repository at this point in the history
  • Loading branch information
dsikka committed Jul 31, 2023
1 parent 4674c95 commit 06e958e
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions src/deepsparse/clip/zeroshot_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,21 @@ class CLIPZeroShotOutput(BaseModel):

@BasePipeline.register(task="clip_zeroshot", default_model_path=None)
class CLIPZeroShotPipeline(BasePipeline):
"""
Pipeline designed to run zero-shot classification given a list of images and
possible classes. The CLIPZeroShotPipeline relies on two pipelines, the
CLIPTextPipeline which handles CLIP's text branch adn the CLIPVisualPipeline
which handles CLIP's visual branch. The final score calculations are handled and
returned by the CLIPZeroShotPipeline. See README.md for a detailed example.
:param visual_model_path: either a local path or sparsezoo stub for the CLIP visual
branch onnx model
:param text_model_path: either a local path or sparsezoo stub for the CLIP text
branch onnx model
"""

def __init__(self, visual_model_path: str, text_model_path: str, **kwargs):
"""
Pipeline designed to run zero-shot classification given a list of images and
possible classes. The CLIPZeroShotPipeline relies on two pipelines, the
CLIPTextPipeline which handles CLIP's text branch adn the CLIPVisualPipeline
which handles CLIP's visual branch. The final score calculations are handled and
returned by the CLIPZeroShotPipeline. See README.md for a detailed example.
"""
self.visual = Pipeline.create(
task="clip_visual", **{"model_path": visual_model_path}
)
Expand Down

0 comments on commit 06e958e

Please sign in to comment.