Skip to content

Commit

Permalink
[Text Gen UX] top level constructor aliases + code gen subclass (#1274)
Browse files Browse the repository at this point in the history
  • Loading branch information
bfineran authored Sep 25, 2023
1 parent d13cc2d commit 20a6157
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 3 deletions.
52 changes: 52 additions & 0 deletions src/deepsparse/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@
"Bucketable",
"BucketingPipeline",
"create_engine",
"TextGeneration",
"CodeGeneration",
"Chat",
]

DEEPSPARSE_ENGINE = "deepsparse"
Expand Down Expand Up @@ -774,6 +777,55 @@ def _initialize_executor_and_workers(
return executor, num_async_workers


def text_generation_pipeline(
*args, model: Optional[str] = None, **kwargs
) -> "Pipeline":
"""
:return: text generation pipeline with the given args and
kwargs passed to Pipeline.create
"""
kwargs = _parse_model_arg(model, **kwargs)
return Pipeline.create("text_generation", *args, **kwargs)


def code_generation_pipeline(
*args, model: Optional[str] = None, **kwargs
) -> "Pipeline":
"""
:return: text generation pipeline with the given args and
kwargs passed to Pipeline.create
"""
kwargs = _parse_model_arg(model, **kwargs)
return Pipeline.create("code_generation", *args, **kwargs)


def chat_pipeline(*args, model: Optional[str] = None, **kwargs) -> "Pipeline":
"""
:return: text generation pipeline with the given args and
kwargs passed to Pipeline.create
"""
kwargs = _parse_model_arg(model, **kwargs)
return Pipeline.create("chat", *args, **kwargs)


def _parse_model_arg(model: Optional[str], **kwargs) -> dict:
if model is not None:
model_path = kwargs.get("model_path")
if model_path is not None:
raise ValueError(
f"Only one of model and model_path may be supplied, found {model} "
f"and {model_path} respectively"
)
kwargs["model_path"] = model
return kwargs


# aliases for top level import
TextGeneration = text_generation_pipeline
CodeGeneration = code_generation_pipeline
Chat = chat_pipeline


def question_answering_pipeline(*args, **kwargs) -> "Pipeline":
"""
transformers question_answering pipeline
Expand Down
23 changes: 21 additions & 2 deletions src/deepsparse/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,16 @@ class SupportedTasks:
chatbot=AliasedTask("chatbot", []), chat=AliasedTask("chat", [])
)
text_generation = namedtuple(
"text_generation", ["text_generation", "opt", "codegen", "bloom"]
"text_generation", ["text_generation", "opt", "bloom"]
)(
text_generation=AliasedTask("text_generation", []),
codegen=AliasedTask("codegen", []),
opt=AliasedTask("opt", []),
bloom=AliasedTask("bloom", []),
)
code_generation = namedtuple("code_generation", ["code_generation", "codegen"])(
code_generation=AliasedTask("code_generation", []),
codegen=AliasedTask("codegen", []),
)

image_classification = namedtuple("image_classification", ["image_classification"])(
image_classification=AliasedTask(
Expand Down Expand Up @@ -153,6 +156,7 @@ class SupportedTasks:
open_pif_paf,
text_generation,
chat,
code_generation,
]

@classmethod
Expand All @@ -174,6 +178,9 @@ def check_register_task(
elif cls.is_chat(task):
import deepsparse.transformers.pipelines.chat # noqa: F401

elif cls.is_code_generation(task):
import deepsparse.transformers.pipelines.code_generation # noqa: F401

elif cls.is_nlp(task):
# trigger transformers pipelines to register with Pipeline.register
import deepsparse.transformers.pipelines # noqa: F401
Expand Down Expand Up @@ -237,6 +244,18 @@ def is_text_generation(cls, task: str) -> bool:
for text_generation_task in cls.text_generation
)

@classmethod
def is_code_generation(cls, task: str) -> bool:
"""
:param task: the name of the task to check whether it is a text generation task
such as codegen
:return: True if it is a text generation task, False otherwise
"""
return any(
code_generation_task.matches(task)
for code_generation_task in cls.code_generation
)

@classmethod
def is_nlp(cls, task: str) -> bool:
"""
Expand Down
1 change: 1 addition & 0 deletions src/deepsparse/transformers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@
from .zero_shot_text_classification import *
from .embedding_extraction import *
from .chat import *
from .code_generation import *
33 changes: 33 additions & 0 deletions src/deepsparse/transformers/pipelines/code_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from deepsparse import Pipeline
from deepsparse.transformers.pipelines.text_generation import TextGenerationPipeline


__all__ = ["CodeGenerationPipeline"]


@Pipeline.register(
task="code_generation",
task_aliases=["codegen"],
)
class CodeGenerationPipeline(TextGenerationPipeline):
"""
Subclass of text generation pipeline to support any defaults or
overrides needed for code generation
"""

pass
2 changes: 1 addition & 1 deletion src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ class Config:

@Pipeline.register(
task="text_generation",
task_aliases=["codegen", "opt", "bloom"],
task_aliases=["opt", "bloom"],
)
class TextGenerationPipeline(TransformersPipeline):
"""
Expand Down

0 comments on commit 20a6157

Please sign in to comment.