From a68410e73a42975a0da476d34c216b3c43f63041 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Wed, 22 Jun 2022 14:01:57 -0400 Subject: [PATCH 1/2] [ML] adds new auto task type that attempts to automatically determine NLP task type from model config --- bin/eland_import_hub_model | 20 +++++-- eland/ml/pytorch/__init__.py | 2 + eland/ml/pytorch/transformers.py | 55 +++++++++++++++++- .../test_transformer_pytorch_model_pytest.py | 56 +++++++++++++++++++ 4 files changed, 127 insertions(+), 6 deletions(-) diff --git a/bin/eland_import_hub_model b/bin/eland_import_hub_model index 332c982d..f65e7327 100755 --- a/bin/eland_import_hub_model +++ b/bin/eland_import_hub_model @@ -84,9 +84,11 @@ def get_arg_parser(): ) parser.add_argument( "--task-type", - required=True, + required=False, choices=SUPPORTED_TASK_TYPES, - help="The task type for the model usage.", + help="The task type for the model usage. Will attempt to auto-detect task type for the model if not provided. " + "Default: auto", + default="auto" ) parser.add_argument( "--quantize", @@ -165,7 +167,11 @@ if __name__ == "__main__": try: from eland.ml.pytorch import PyTorchModel - from eland.ml.pytorch.transformers import SUPPORTED_TASK_TYPES, TransformerModel + from eland.ml.pytorch.transformers import ( + SUPPORTED_TASK_TYPES, + TaskTypeError, + TransformerModel, + ) except ModuleNotFoundError as e: logger.error(textwrap.dedent(f"""\ \033[31mFailed to run because module '{e.name}' is not available.\033[0m @@ -187,8 +193,12 @@ if __name__ == "__main__": with tempfile.TemporaryDirectory() as tmp_dir: logger.info(f"Loading HuggingFace transformer tokenizer and model '{args.hub_model_id}'") - tm = TransformerModel(args.hub_model_id, args.task_type, args.quantize) - model_path, config, vocab_path = tm.save(tmp_dir) + try: + tm = TransformerModel(args.hub_model_id, args.task_type, args.quantize) + model_path, config, vocab_path = tm.save(tmp_dir) + except TransformerModel as err: + logger.error(f"Failed to get model for task type, please provide valid task type via '--task-type' parameter. Caused by {err}") + exit(1) ptm = PyTorchModel(es, args.es_model_id if args.es_model_id else tm.elasticsearch_model_id()) model_exists = es.options(ignore_status=404).ml.get_trained_models(model_id=ptm.model_id).meta.status == 200 diff --git a/eland/ml/pytorch/__init__.py b/eland/ml/pytorch/__init__.py index f0fc74d6..5cd49b51 100644 --- a/eland/ml/pytorch/__init__.py +++ b/eland/ml/pytorch/__init__.py @@ -23,6 +23,7 @@ NlpTrainedModelConfig, ) from eland.ml.pytorch.traceable_model import TraceableModel # noqa: F401 +from eland.ml.pytorch.transformers import task_type_from_model_config __all__ = [ "PyTorchModel", @@ -31,4 +32,5 @@ "NlpBertTokenizationConfig", "NlpRobertaTokenizationConfig", "NlpMPNetTokenizationConfig", + "task_type_from_model_config", ] diff --git a/eland/ml/pytorch/transformers.py b/eland/ml/pytorch/transformers.py index 7520e765..ad900766 100644 --- a/eland/ml/pytorch/transformers.py +++ b/eland/ml/pytorch/transformers.py @@ -23,7 +23,7 @@ import json import os.path from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Union import torch # type: ignore import transformers # type: ignore @@ -33,6 +33,7 @@ AutoConfig, AutoModel, AutoModelForQuestionAnswering, + PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast, @@ -64,6 +65,15 @@ "zero_shot_classification", "question_answering", } +ARCHITECTURE_TO_TASK_TYPE = { + "MaskedLM": ["fill_mask", "text_embedding"], + "TokenClassification": ["ner"], + "SequenceClassification": ["text_classification", "zero_shot_classification"], + "QuestionAnswering": ["question_answering"], + "DPRQuestionEncoder": ["text_embedding"], + "DPRContextEncoder": ["text_embedding"], +} +ZERO_SHOT_LABELS = {"contradiction", "neutral", "entailment"} TASK_TYPE_TO_INFERENCE_CONFIG = { "fill_mask": FillMaskInferenceOptions, "ner": NerInferenceOptions, @@ -97,6 +107,37 @@ ] +class TaskTypeError(Exception): + pass + + +def task_type_from_model_config(model_config: PretrainedConfig) -> Optional[str]: + if model_config.architectures is None: + if model_config.name_or_path.startswith("sentence-transformers/"): + return "text_embedding" + return None + potential_task_types: Set[str] = set() + for architecture in model_config.architectures: + for (substr, task_type) in ARCHITECTURE_TO_TASK_TYPE.items(): + if substr in architecture: + for t in task_type: + potential_task_types.add(t) + if len(potential_task_types) == 0: + return None + if len(potential_task_types) > 1: + if "zero_shot_classification" in potential_task_types: + if model_config.label2id: + labels = set([x.lower() for x in model_config.label2id.keys()]) + if len(labels.difference(ZERO_SHOT_LABELS)) == 0: + return "zero_shot_classification" + return "text_classification" + if "text_embedding" in potential_task_types: + if model_config.name_or_path.startswith("sentence-transformers/"): + return "text_embedding" + return "fill_mask" + return potential_task_types.pop() + + class _QuestionAnsweringWrapperModule(nn.Module): # type: ignore """ A wrapper around a question answering model. @@ -581,6 +622,18 @@ def _create_config(self) -> NlpTrainedModelConfig: ) def _create_traceable_model(self) -> TraceableModel: + if self._task_type == "auto": + model = transformers.AutoModel.from_pretrained( + self._model_id, torchscript=True + ) + maybe_task_type = task_type_from_model_config(model.config) + if maybe_task_type is None: + raise TaskTypeError( + f"Unable to automatically determine task type for model {self._model_id}, please supply task type: {SUPPORTED_TASK_TYPES_NAMES}" + ) + else: + self._task_type = maybe_task_type + if self._task_type == "fill_mask": model = transformers.AutoModelForMaskedLM.from_pretrained( self._model_id, torchscript=True diff --git a/tests/ml/pytorch/test_transformer_pytorch_model_pytest.py b/tests/ml/pytorch/test_transformer_pytorch_model_pytest.py index a09caa29..ab73f412 100644 --- a/tests/ml/pytorch/test_transformer_pytorch_model_pytest.py +++ b/tests/ml/pytorch/test_transformer_pytorch_model_pytest.py @@ -34,12 +34,14 @@ try: import torch # noqa: F401 from torch import Tensor, nn # noqa: F401 + from transformers import PretrainedConfig # noqa: F401 from eland.ml.pytorch import ( # noqa: F401 NlpBertTokenizationConfig, NlpTrainedModelConfig, PyTorchModel, TraceableModel, + task_type_from_model_config, ) from eland.ml.pytorch.nlp_ml_model import ( NerInferenceOptions, @@ -222,6 +224,41 @@ def forward( ), ] +AUTO_TASK_RESULTS = [ + ("any_bert", "BERTMaskedLM", None, "fill_mask"), + ("any_roberta", "RoBERTaMaskedLM", None, "fill_mask"), + ("sentence-transformers/any_bert", "BERTMaskedLM", None, "text_embedding"), + ("sentence-transformers/any_roberta", "RoBERTaMaskedLM", None, "text_embedding"), + ("sentence-transformers/mpnet", "MPNetMaskedLM", None, "text_embedding"), + ("anynermodel", "BERTForTokenClassification", None, "ner"), + ("anynermodel", "MPNetForTokenClassification", None, "ner"), + ("anynermodel", "RoBERTaForTokenClassification", None, "ner"), + ("anynermodel", "BERTForQuestionAnswering", None, "question_answering"), + ("anynermodel", "MPNetForQuestionAnswering", None, "question_answering"), + ("anynermodel", "RoBERTaForQuestionAnswering", None, "question_answering"), + ("aqaModel", "DPRQuestionEncoder", None, "text_embedding"), + ("aqaModel", "DPRContextEncoder", None, "text_embedding"), + ( + "any_bert", + "BERTForSequenceClassification", + ["foo", "bar", "baz"], + "text_classification", + ), + ( + "any_bert", + "BERTForSequenceClassification", + ["contradiction", "neutral", "entailment"], + "zero_shot_classification", + ), + ( + "any_bert", + "BERTForSequenceClassification", + ["CONTRADICTION", "NEUTRAL", "ENTAILMENT"], + "zero_shot_classification", + ), + ("any_bert", "SomeUnknownType", None, None), +] + @pytest.fixture(scope="function", autouse=True) def setup_and_tear_down(): @@ -274,3 +311,22 @@ def test_model_upload(self, model_id, model, config, input, prediction): result = ptm.infer(docs=[{"text_field": input}]) assert result.get("predicted_value") is not None assert result["predicted_value"] == prediction + + @pytest.mark.parametrize( + "model_id,architecture,labels,expected_task", AUTO_TASK_RESULTS + ) + def test_auto_task_type(self, model_id, architecture, labels, expected_task): + config = ( + PretrainedConfig( + name_or_path=model_id, + architectures=[architecture], + label2id=dict(zip(labels, range(len(labels)))), + id2label=dict(zip(range(len(labels)), labels)), + ) + if labels + else PretrainedConfig( + name_or_path=model_id, + architectures=[architecture], + ) + ) + assert task_type_from_model_config(model_config=config) == expected_task From 2fbd10d9a954c67025d4135cd1ef12be8d53a8a1 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Thu, 23 Jun 2022 07:41:31 -0400 Subject: [PATCH 2/2] Update bin/eland_import_hub_model Co-authored-by: David Kyle --- bin/eland_import_hub_model | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bin/eland_import_hub_model b/bin/eland_import_hub_model index f65e7327..bef0770a 100755 --- a/bin/eland_import_hub_model +++ b/bin/eland_import_hub_model @@ -196,7 +196,7 @@ if __name__ == "__main__": try: tm = TransformerModel(args.hub_model_id, args.task_type, args.quantize) model_path, config, vocab_path = tm.save(tmp_dir) - except TransformerModel as err: + except TaskTypeError as err: logger.error(f"Failed to get model for task type, please provide valid task type via '--task-type' parameter. Caused by {err}") exit(1)