Skip to content

Commit

Permalink
Merge branch 'sparsify.alpha' into remove_optim
Browse files Browse the repository at this point in the history
  • Loading branch information
KSGulin authored May 5, 2023
2 parents ad53356 + 0c5ed2c commit 41dabe6
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 19 deletions.
8 changes: 7 additions & 1 deletion src/sparsify/cli/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
from pathlib import Path

import click
Expand Down Expand Up @@ -42,6 +43,10 @@ def one_shot(**kwargs):
# raises exception if sparsifyml not installed
from sparsify.one_shot import one_shot

recipe_args = kwargs.get("recipe_args")
if isinstance(recipe_args, str):
recipe_args = json.loads(recipe_args)

one_shot.one_shot(
task=kwargs["use_case"],
model_file=Path(kwargs["model"]),
Expand All @@ -50,6 +55,7 @@ def one_shot(**kwargs):
deploy_dir=Path(kwargs["working_dir"]),
eval_metric=kwargs["eval_metric"],
recipe_file=Path(kwargs["recipe"]) if kwargs["recipe"] is not None else None,
recipe_args=recipe_args,
)


Expand Down Expand Up @@ -95,7 +101,7 @@ def _parse_run_args_to_auto(sparse_transfer: bool, **kwargs):
task=kwargs["use_case"],
dataset=kwargs["data"],
save_directory=kwargs["working_dir"],
performance=kwargs["optim_level"],
optim_level=kwargs["optim_level"],
base_model=kwargs["model"],
recipe=kwargs["recipe"],
recipe_args=kwargs["recipe_args"] or {},
Expand Down
12 changes: 5 additions & 7 deletions src/sparsify/schemas/auto_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,13 @@ class APIArgs(BaseModel):
description="Absolute path to save directory",
default=DEFAULT_OUTPUT_DIRECTORY,
)
opt_level: Union[str, float] = Field(
title="performance",
optim_level: float = Field(
title="optim_level",
description=(
"Preferred tradeoff between accuracy and performance. Can be a string or a "
"float value in the range [0, 1]. Currently supported strings (and their "
"respective float values are `accuracy` (0), `balanced` (0.5), and "
"`performant` (1.0)"
"Preferred tradeoff between accuracy and performance. "
"Float value in the range [0, 1]. Default 0.5"
),
default="balanced",
default=0.5,
)
base_model: Optional[str] = Field(
title="base_model",
Expand Down
41 changes: 34 additions & 7 deletions src/sparsify/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,40 +52,67 @@
TASK_REGISTRY: Dict[str, TaskName] = {
"image_classification": TaskName(
name="image_classification",
aliases=["ic", "classification"],
aliases=["ic", "classification", "cv_classification"],
domain="cv",
sub_domain="classification",
),
"object_detection": TaskName(
name="object_detection",
aliases=["od", "detection"],
aliases=["od", "detection", "cv_detection", "yolo"],
domain="cv",
sub_domain="detection",
),
"segmentation": TaskName(
name="segmentation", domain="cv", sub_domain="segmentation"
name="segmentation",
aliases=["image_segmentation", "cv_segmentation"],
domain="cv",
sub_domain="segmentation",
),
"document_classification": TaskName(
name="document_classification",
aliases=["nlp_document_classification"],
domain="nlp",
sub_domain="document_classification",
),
"information_retrieval": TaskName(
name="information_retrieval",
aliases=["ir, nlp_information_retrieval"],
domain="nlp",
sub_domain="information_retrieval",
),
"masked_language_modeling": TaskName(
name="masked_language_modeling",
aliases=["mlm, masked_language_modeling"],
domain="nlp",
sub_domain="masked_language_modeling",
),
"multilabel_text_classification": TaskName(
name="multilabel_text_classification",
aliases=["nlp_multilabel_text_classification"],
domain="nlp",
sub_domain="multilabel_text_classification",
),
"question_answering": TaskName(
name="question_answering",
aliases=["qa"],
aliases=["qa, nlp_question_answering"],
domain="nlp",
sub_domain="question_answering",
),
"text_classification": TaskName(
name="text_classification",
aliases=["glue"],
aliases=["glue", "nlp_text_classification"],
domain="nlp",
sub_domain="text_classification",
),
"sentiment_analysis": TaskName(
name="sentiment_analysis",
aliases=["sentiment"],
aliases=["sentiment", "nlp_sentiment_analysis"],
domain="nlp",
sub_domain="sentiment_analysis",
),
"token_classification": TaskName(
name="token_classification",
aliases=["ner", "named_entity_recognition"],
aliases=["ner", "named_entity_recognition", "nlp_token_classification"],
domain="nlp",
sub_domain="token_classification",
),
Expand Down
21 changes: 17 additions & 4 deletions src/sparsify/utils/task_name.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Union
from typing import List, Optional, Union


class TaskName:
Expand All @@ -34,8 +34,13 @@ class TaskName:
__slots__ = ("name", "aliases", "domain", "sub_domain")

def __init__(
self, name: str, domain: str, sub_domain: str, aliases: List[str] = []
self,
name: str,
domain: str,
sub_domain: str,
aliases: Optional[List[str]] = None,
):
aliases = aliases or []
for field in [name, domain, sub_domain]:
if not isinstance(field, str):
raise ValueError(f"'{field}' must be a string")
Expand Down Expand Up @@ -70,9 +75,9 @@ def pretty_print(self):

def __eq__(self, other: Union[str, "TaskName"]):
if isinstance(other, TaskName):
return other.name == self.name
return _task_name_eq(self.name, other.name)
elif isinstance(other, str):
return other.lower() in self.aliases
return any(_task_name_eq(other, alias) for alias in self.aliases)
else:
return False

Expand All @@ -88,3 +93,11 @@ def _get_supported_aliases(self, task: str):
if "_" in task:
aliases.append(task.replace("_", "-"))
return aliases


def _task_name_eq(str_1: str, str_2: str) -> bool:
# ignore case, ignore separators (-, _, ' ')
def _strip_separators(string):
return string.lower().replace(" ", "").replace("-", "").replace("_", "")

return _strip_separators(str_1) == _strip_separators(str_2)

0 comments on commit 41dabe6

Please sign in to comment.