Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] adds new auto task type that attempts to automatically determine NLP task type from model config #475

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions bin/eland_import_hub_model
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,11 @@ def get_arg_parser():
)
parser.add_argument(
"--task-type",
required=True,
required=False,
choices=SUPPORTED_TASK_TYPES,
help="The task type for the model usage.",
help="The task type for the model usage. Will attempt to auto-detect task type for the model if not provided. "
"Default: auto",
default="auto"
)
parser.add_argument(
"--quantize",
Expand Down Expand Up @@ -165,7 +167,11 @@ if __name__ == "__main__":

try:
from eland.ml.pytorch import PyTorchModel
from eland.ml.pytorch.transformers import SUPPORTED_TASK_TYPES, TransformerModel
from eland.ml.pytorch.transformers import (
SUPPORTED_TASK_TYPES,
TaskTypeError,
TransformerModel,
)
except ModuleNotFoundError as e:
logger.error(textwrap.dedent(f"""\
\033[31mFailed to run because module '{e.name}' is not available.\033[0m
Expand All @@ -187,8 +193,12 @@ if __name__ == "__main__":
with tempfile.TemporaryDirectory() as tmp_dir:
logger.info(f"Loading HuggingFace transformer tokenizer and model '{args.hub_model_id}'")

tm = TransformerModel(args.hub_model_id, args.task_type, args.quantize)
model_path, config, vocab_path = tm.save(tmp_dir)
try:
tm = TransformerModel(args.hub_model_id, args.task_type, args.quantize)
model_path, config, vocab_path = tm.save(tmp_dir)
except TaskTypeError as err:
logger.error(f"Failed to get model for task type, please provide valid task type via '--task-type' parameter. Caused by {err}")
exit(1)

ptm = PyTorchModel(es, args.es_model_id if args.es_model_id else tm.elasticsearch_model_id())
model_exists = es.options(ignore_status=404).ml.get_trained_models(model_id=ptm.model_id).meta.status == 200
Expand Down
2 changes: 2 additions & 0 deletions eland/ml/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
NlpTrainedModelConfig,
)
from eland.ml.pytorch.traceable_model import TraceableModel # noqa: F401
from eland.ml.pytorch.transformers import task_type_from_model_config

__all__ = [
"PyTorchModel",
Expand All @@ -31,4 +32,5 @@
"NlpBertTokenizationConfig",
"NlpRobertaTokenizationConfig",
"NlpMPNetTokenizationConfig",
"task_type_from_model_config",
]
55 changes: 54 additions & 1 deletion eland/ml/pytorch/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import json
import os.path
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Set, Tuple, Union

import torch # type: ignore
import transformers # type: ignore
Expand All @@ -33,6 +33,7 @@
AutoConfig,
AutoModel,
AutoModelForQuestionAnswering,
PretrainedConfig,
PreTrainedModel,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
Expand Down Expand Up @@ -64,6 +65,15 @@
"zero_shot_classification",
"question_answering",
}
ARCHITECTURE_TO_TASK_TYPE = {
"MaskedLM": ["fill_mask", "text_embedding"],
"TokenClassification": ["ner"],
"SequenceClassification": ["text_classification", "zero_shot_classification"],
"QuestionAnswering": ["question_answering"],
"DPRQuestionEncoder": ["text_embedding"],
"DPRContextEncoder": ["text_embedding"],
}
ZERO_SHOT_LABELS = {"contradiction", "neutral", "entailment"}
TASK_TYPE_TO_INFERENCE_CONFIG = {
"fill_mask": FillMaskInferenceOptions,
"ner": NerInferenceOptions,
Expand Down Expand Up @@ -97,6 +107,37 @@
]


class TaskTypeError(Exception):
pass


def task_type_from_model_config(model_config: PretrainedConfig) -> Optional[str]:
if model_config.architectures is None:
if model_config.name_or_path.startswith("sentence-transformers/"):
return "text_embedding"
return None
potential_task_types: Set[str] = set()
for architecture in model_config.architectures:
for (substr, task_type) in ARCHITECTURE_TO_TASK_TYPE.items():
if substr in architecture:
for t in task_type:
potential_task_types.add(t)
if len(potential_task_types) == 0:
return None
if len(potential_task_types) > 1:
if "zero_shot_classification" in potential_task_types:
if model_config.label2id:
labels = set([x.lower() for x in model_config.label2id.keys()])
if len(labels.difference(ZERO_SHOT_LABELS)) == 0:
return "zero_shot_classification"
return "text_classification"
if "text_embedding" in potential_task_types:
if model_config.name_or_path.startswith("sentence-transformers/"):
return "text_embedding"
return "fill_mask"
return potential_task_types.pop()


class _QuestionAnsweringWrapperModule(nn.Module): # type: ignore
"""
A wrapper around a question answering model.
Expand Down Expand Up @@ -581,6 +622,18 @@ def _create_config(self) -> NlpTrainedModelConfig:
)

def _create_traceable_model(self) -> TraceableModel:
if self._task_type == "auto":
model = transformers.AutoModel.from_pretrained(
self._model_id, torchscript=True
)
maybe_task_type = task_type_from_model_config(model.config)
if maybe_task_type is None:
raise TaskTypeError(
f"Unable to automatically determine task type for model {self._model_id}, please supply task type: {SUPPORTED_TASK_TYPES_NAMES}"
)
else:
self._task_type = maybe_task_type

if self._task_type == "fill_mask":
model = transformers.AutoModelForMaskedLM.from_pretrained(
self._model_id, torchscript=True
Expand Down
56 changes: 56 additions & 0 deletions tests/ml/pytorch/test_transformer_pytorch_model_pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,14 @@
try:
import torch # noqa: F401
from torch import Tensor, nn # noqa: F401
from transformers import PretrainedConfig # noqa: F401

from eland.ml.pytorch import ( # noqa: F401
NlpBertTokenizationConfig,
NlpTrainedModelConfig,
PyTorchModel,
TraceableModel,
task_type_from_model_config,
)
from eland.ml.pytorch.nlp_ml_model import (
NerInferenceOptions,
Expand Down Expand Up @@ -222,6 +224,41 @@ def forward(
),
]

AUTO_TASK_RESULTS = [
("any_bert", "BERTMaskedLM", None, "fill_mask"),
("any_roberta", "RoBERTaMaskedLM", None, "fill_mask"),
("sentence-transformers/any_bert", "BERTMaskedLM", None, "text_embedding"),
("sentence-transformers/any_roberta", "RoBERTaMaskedLM", None, "text_embedding"),
("sentence-transformers/mpnet", "MPNetMaskedLM", None, "text_embedding"),
("anynermodel", "BERTForTokenClassification", None, "ner"),
("anynermodel", "MPNetForTokenClassification", None, "ner"),
("anynermodel", "RoBERTaForTokenClassification", None, "ner"),
("anynermodel", "BERTForQuestionAnswering", None, "question_answering"),
("anynermodel", "MPNetForQuestionAnswering", None, "question_answering"),
("anynermodel", "RoBERTaForQuestionAnswering", None, "question_answering"),
("aqaModel", "DPRQuestionEncoder", None, "text_embedding"),
("aqaModel", "DPRContextEncoder", None, "text_embedding"),
(
"any_bert",
"BERTForSequenceClassification",
["foo", "bar", "baz"],
"text_classification",
),
(
"any_bert",
"BERTForSequenceClassification",
["contradiction", "neutral", "entailment"],
"zero_shot_classification",
),
(
"any_bert",
"BERTForSequenceClassification",
["CONTRADICTION", "NEUTRAL", "ENTAILMENT"],
"zero_shot_classification",
),
("any_bert", "SomeUnknownType", None, None),
]


@pytest.fixture(scope="function", autouse=True)
def setup_and_tear_down():
Expand Down Expand Up @@ -274,3 +311,22 @@ def test_model_upload(self, model_id, model, config, input, prediction):
result = ptm.infer(docs=[{"text_field": input}])
assert result.get("predicted_value") is not None
assert result["predicted_value"] == prediction

@pytest.mark.parametrize(
"model_id,architecture,labels,expected_task", AUTO_TASK_RESULTS
)
def test_auto_task_type(self, model_id, architecture, labels, expected_task):
config = (
PretrainedConfig(
name_or_path=model_id,
architectures=[architecture],
label2id=dict(zip(labels, range(len(labels)))),
id2label=dict(zip(range(len(labels)), labels)),
)
if labels
else PretrainedConfig(
name_or_path=model_id,
architectures=[architecture],
)
)
assert task_type_from_model_config(model_config=config) == expected_task