Skip to content

Commit

Permalink
update task registry + generalize matching (#201)
Browse files Browse the repository at this point in the history
  • Loading branch information
bfineran committed May 4, 2023
1 parent afc5ece commit 918dc92
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 11 deletions.
41 changes: 34 additions & 7 deletions src/sparsify/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,40 +52,67 @@
TASK_REGISTRY: Dict[str, TaskName] = {
"image_classification": TaskName(
name="image_classification",
aliases=["ic", "classification"],
aliases=["ic", "classification", "cv_classification"],
domain="cv",
sub_domain="classification",
),
"object_detection": TaskName(
name="object_detection",
aliases=["od", "detection"],
aliases=["od", "detection", "cv_detection", "yolo"],
domain="cv",
sub_domain="detection",
),
"segmentation": TaskName(
name="segmentation", domain="cv", sub_domain="segmentation"
name="segmentation",
aliases=["image_segmentation", "cv_segmentation"],
domain="cv",
sub_domain="segmentation",
),
"document_classification": TaskName(
name="document_classification",
aliases=["nlp_document_classification"],
domain="nlp",
sub_domain="document_classification",
),
"information_retrieval": TaskName(
name="information_retrieval",
aliases=["ir, nlp_information_retrieval"],
domain="nlp",
sub_domain="information_retrieval",
),
"masked_language_modeling": TaskName(
name="masked_language_modeling",
aliases=["mlm, masked_language_modeling"],
domain="nlp",
sub_domain="masked_language_modeling",
),
"multilabel_text_classification": TaskName(
name="multilabel_text_classification",
aliases=["nlp_multilabel_text_classification"],
domain="nlp",
sub_domain="multilabel_text_classification",
),
"question_answering": TaskName(
name="question_answering",
aliases=["qa"],
aliases=["qa, nlp_question_answering"],
domain="nlp",
sub_domain="question_answering",
),
"text_classification": TaskName(
name="text_classification",
aliases=["glue"],
aliases=["glue", "nlp_text_classification"],
domain="nlp",
sub_domain="text_classification",
),
"sentiment_analysis": TaskName(
name="sentiment_analysis",
aliases=["sentiment"],
aliases=["sentiment", "nlp_sentiment_analysis"],
domain="nlp",
sub_domain="sentiment_analysis",
),
"token_classification": TaskName(
name="token_classification",
aliases=["ner", "named_entity_recognition"],
aliases=["ner", "named_entity_recognition", "nlp_token_classification"],
domain="nlp",
sub_domain="token_classification",
),
Expand Down
21 changes: 17 additions & 4 deletions src/sparsify/utils/task_name.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Union
from typing import List, Optional, Union


class TaskName:
Expand All @@ -34,8 +34,13 @@ class TaskName:
__slots__ = ("name", "aliases", "domain", "sub_domain")

def __init__(
self, name: str, domain: str, sub_domain: str, aliases: List[str] = []
self,
name: str,
domain: str,
sub_domain: str,
aliases: Optional[List[str]] = None,
):
aliases = aliases or []
for field in [name, domain, sub_domain]:
if not isinstance(field, str):
raise ValueError(f"'{field}' must be a string")
Expand Down Expand Up @@ -70,9 +75,9 @@ def pretty_print(self):

def __eq__(self, other: Union[str, "TaskName"]):
if isinstance(other, TaskName):
return other.name == self.name
return _task_name_eq(self.name, other.name)
elif isinstance(other, str):
return other.lower() in self.aliases
return any(_task_name_eq(other, alias) for alias in self.aliases)
else:
return False

Expand All @@ -88,3 +93,11 @@ def _get_supported_aliases(self, task: str):
if "_" in task:
aliases.append(task.replace("_", "-"))
return aliases


def _task_name_eq(str_1: str, str_2: str) -> bool:
# ignore case, ignore separators (-, _, ' ')
def _strip_separators(string):
return string.lower().replace(" ", "").replace("-", "").replace("_", "")

return _strip_separators(str_1) == _strip_separators(str_2)

0 comments on commit 918dc92

Please sign in to comment.