Skip to content

Commit

Permalink
update codegen alias to use the new registry and text generation pipe…
Browse files Browse the repository at this point in the history
…line
  • Loading branch information
dsikka committed Dec 8, 2023
1 parent 6c42956 commit 8c71397
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 10 deletions.
27 changes: 25 additions & 2 deletions src/deepsparse/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,22 @@ class SupportedTasks:
bloom=AliasedTask("bloom", []),
)

code_generation = namedtuple(
"code_generation", ["code_generation", "code_gen", "codegen"]
)(
code_generation=AliasedTask("code_generation", []),
code_gen=AliasedTask("code_gen", []),
codegen=AliasedTask("codegen", []),
)

image_classification = namedtuple("image_classification", ["image_classification"])(
image_classification=AliasedTask(
"image_classification",
["image_classification"],
),
)

all_task_categories = [text_generation]
all_task_categories = [text_generation, code_generation, image_classification]

@classmethod
def check_register_task(
Expand All @@ -107,6 +115,9 @@ def check_register_task(
if cls.is_text_generation(task):
import deepsparse.transformers.pipelines.text_generation # noqa: F401

elif cls.is_code_generation(task):
import deepsparse.transformers.pipelines.code_generation # noqa: F401

elif cls.is_image_classification(task):
# trigger image classification pipelines to
# register with Pipeline.register
Expand Down Expand Up @@ -142,7 +153,7 @@ def is_image_classification(cls, task: str) -> bool:

@classmethod
def task_names(cls):
task_names = ["custom"]
task_names = []
for task_category in cls.all_task_categories:
for task in task_category:
unique_aliases = (
Expand All @@ -151,6 +162,18 @@ def task_names(cls):
task_names += (task._name, *unique_aliases)
return task_names

@classmethod
def is_code_generation(cls, task: str) -> bool:
"""
:param task: the name of the task to check whether it is a text generation task
such as codegen
:return: True if it is a text generation task, False otherwise
"""
return any(
code_generation_task.matches(task)
for code_generation_task in cls.code_generation
)


def dynamic_import_task(module_or_path: str) -> str:
"""
Expand Down
11 changes: 3 additions & 8 deletions src/deepsparse/transformers/pipelines/code_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,14 @@
# limitations under the License.


from deepsparse.legacy import Pipeline
from deepsparse.legacy.transformers.pipelines.text_generation import (
TextGenerationPipeline,
)
from deepsparse.operators import OperatorRegistry
from deepsparse.transformers.pipelines.text_generation import TextGenerationPipeline


__all__ = ["CodeGenerationPipeline"]


@Pipeline.register(
task="code_generation",
task_aliases=["codegen"],
)
@OperatorRegistry.register(name=["code_generation", "code_gen", "codegen"])
class CodeGenerationPipeline(TextGenerationPipeline):
"""
Subclass of text generation pipeline to support any defaults or
Expand Down

0 comments on commit 8c71397

Please sign in to comment.