From 837905a56d97edf1c8830d5f17ba70c44f4d9d6a Mon Sep 17 00:00:00 2001 From: Damian Date: Thu, 14 Sep 2023 15:20:15 +0000 Subject: [PATCH 1/7] initial commit --- tests/test_data/pipeline_bench_config.json | 1 - tests/test_pipeline_benchmark.py | 1 - 2 files changed, 2 deletions(-) diff --git a/tests/test_data/pipeline_bench_config.json b/tests/test_data/pipeline_bench_config.json index 5886762cea..afd4db352d 100644 --- a/tests/test_data/pipeline_bench_config.json +++ b/tests/test_data/pipeline_bench_config.json @@ -2,7 +2,6 @@ "data_type": "dummy", "gen_sequence_length": 100, "input_image_shape": [500,500,3], - "data_folder": "/home/sadkins/imagenette2-320/", "recursive_search": true, "max_string_length": -1, "pipeline_kwargs": {}, diff --git a/tests/test_pipeline_benchmark.py b/tests/test_pipeline_benchmark.py index 485599d044..782a1f8016 100644 --- a/tests/test_pipeline_benchmark.py +++ b/tests/test_pipeline_benchmark.py @@ -95,7 +95,6 @@ def test_pipeline_benchmark( if res.stdout is not None: print(f"\n==== test_benchmark output ====\n{res.stdout}") assert res.returncode == 0 - assert "error" not in res.stdout.lower() assert "fail" not in res.stdout.lower() assert "total_inference" in res.stdout.lower() From a4f5d7ed99852fc4b567a4fea7ea18cb0ced25fc Mon Sep 17 00:00:00 2001 From: Damian Date: Fri, 15 Sep 2023 14:11:55 +0000 Subject: [PATCH 2/7] initial commit --- src/deepsparse/transformers/helpers.py | 116 +++--------------- .../transformers/pipelines/pipeline.py | 49 ++------ tests/deepsparse/transformers/test_helpers.py | 21 +--- 3 files changed, 32 insertions(+), 154 deletions(-) diff --git a/src/deepsparse/transformers/helpers.py b/src/deepsparse/transformers/helpers.py index 44f0cfc77f..a71c66ec43 100644 --- a/src/deepsparse/transformers/helpers.py +++ b/src/deepsparse/transformers/helpers.py @@ -19,7 +19,6 @@ import os import re -import warnings from pathlib import Path from tempfile import NamedTemporaryFile from typing import List, Optional, Tuple, Union @@ -35,8 +34,7 @@ __all__ = [ - "get_hugging_face_configs", - "get_onnx_path", + "get_deployment_path", "overwrite_transformer_onnx_model_inputs", "fix_numpy_types", "get_transformer_layer_init_names", @@ -46,17 +44,25 @@ _LOGGER = get_main_logger() _MODEL_DIR_ONNX_NAME = "model.onnx" -_MODEL_DIR_CONFIG_NAME = "config.json" -_MODEL_DIR_TOKENIZER_NAME = "tokenizer.json" -_MODEL_DIR_TOKENIZER_CONFIG_NAME = "tokenizer_config.json" -_OPT_TOKENIZER_FILES = ["special_tokens_map.json", "vocab.json", "merges.txt"] -def get_onnx_path(model_path: str) -> str: +def get_deployment_path(model_path: str) -> Tuple[str, str]: + """ + Returns the path to the deployment directory + for the given model path. The deployment directory + contains all the necessary files for running the model + in the deepsparse pipeline + + :param model_path: path to model directory, sparsezoo stub, or ONNX file + :return: path to the deployment directory and path to the ONNX file inside + the deployment directory + """ if os.path.isfile(model_path): - return model_path + # return the parent directory of the ONNX file + return os.path.dirname(model_path), model_path if os.path.isdir(model_path): + # model_files = os.listdir(model_path) if _MODEL_DIR_ONNX_NAME not in model_files: @@ -65,103 +71,17 @@ def get_onnx_path(model_path: str) -> str: f"{model_path}. Be sure that an export of the model is written to " f"{os.path.join(model_path, _MODEL_DIR_ONNX_NAME)}" ) - onnx_path = os.path.join(model_path, _MODEL_DIR_ONNX_NAME) + return model_path, os.path.join(model_path, _MODEL_DIR_ONNX_NAME) elif model_path.startswith("zoo:"): zoo_model = Model(model_path) - onnx_path = zoo_model.onnx_model.path + deployment_path = zoo_model.deployment.path + return deployment_path, os.path.join(deployment_path, _MODEL_DIR_ONNX_NAME) else: raise ValueError( f"model_path {model_path} is not a valid file, directory, or zoo stub" ) - return onnx_path - - -def get_hugging_face_configs(model_path: str) -> Tuple[str, str]: - """ - :param model_path: path to model directory, transformers sparsezoo stub, - or directory containing `config.json`, and `tokenizer.json` files. - If the json files are not found, an exception will be raised. - :return: tuple of ONNX file path, parent directory of config file - if it exists, and parent directory of tokenizer config file if it - exists. (Parent directories returned instead of absolute path - for compatibility with transformers .from_pretrained() method) - """ - config_path = None - tokenizer_path = None - - if os.path.isdir(model_path): - model_files = os.listdir(model_path) - # attempt to read config and tokenizer from sparsezoo-like framework directory - framework_dir = None - if "framework" in model_files: - framework_dir = os.path.join(model_path, "framework") - if "pytorch" in model_files: - framework_dir = os.path.join(model_path, "pytorch") - if framework_dir and os.path.isdir(framework_dir): - framework_files = os.listdir(framework_dir) - if _MODEL_DIR_CONFIG_NAME in framework_files: - config_path = framework_dir - if ( - _MODEL_DIR_TOKENIZER_NAME - or _MODEL_DIR_TOKENIZER_CONFIG_NAME in framework_files - ): - tokenizer_path = framework_dir - - # prefer config and tokenizer files in same directory as model.onnx - if _MODEL_DIR_CONFIG_NAME in model_files: - config_path = model_path - if ( - _MODEL_DIR_TOKENIZER_NAME in model_files - or _MODEL_DIR_TOKENIZER_CONFIG_NAME in model_files - ): - tokenizer_path = model_path - - elif model_path.startswith("zoo:"): - zoo_model = Model(model_path) - config_path = _get_file_parent( - zoo_model.deployment.default.get_file(_MODEL_DIR_CONFIG_NAME).path - ) - tokenizer_file = zoo_model.deployment.default.get_file( - _MODEL_DIR_TOKENIZER_NAME - ) - - tokenizer_config_file = zoo_model.deployment.default.get_file( - _MODEL_DIR_TOKENIZER_CONFIG_NAME - ) - - if tokenizer_config_file is not None: - tokenizer_config_path = _get_file_parent( - tokenizer_config_file.path - ) # trigger download of tokenizer_config - - if tokenizer_file is not None: - tokenizer_path = _get_file_parent(tokenizer_file.path) - else: - # if tokenizer_file is not present, we assume it's the OPT model - # this means that we use tokenizer_config_path instead of tokenizer_path - # and need to download the additional tokenizer files - tokenizer_path = tokenizer_config_path - for file in _OPT_TOKENIZER_FILES: - zoo_model.deployment.default.get_file(file).path - - else: - raise ValueError( - f"model_path {model_path} is not a valid directory or zoo stub" - ) - - if config_path is None or tokenizer_path is None: - warnings.warn( - f"Unable to find model or tokenizer configs for model_path {model_path}. " - f"model_path must be a directory containing config.json, and/or " - f"tokenizer.json files. Found config and tokenizer paths: {config_path}, " - f"{tokenizer_path}. If not given, set the `tokenizer` and `config` args " - "for the Pipeline." - ) - - return config_path, tokenizer_path - def overwrite_transformer_onnx_model_inputs( path: str, diff --git a/src/deepsparse/transformers/pipelines/pipeline.py b/src/deepsparse/transformers/pipelines/pipeline.py index 997c604c4a..3c8f08e9dc 100644 --- a/src/deepsparse/transformers/pipelines/pipeline.py +++ b/src/deepsparse/transformers/pipelines/pipeline.py @@ -16,8 +16,6 @@ Base Pipeline class for transformers inference pipeline """ -import json -import os import warnings from pathlib import Path from typing import Any, Dict, List, Mapping, Optional, Union @@ -28,8 +26,7 @@ from deepsparse import Bucketable, Pipeline from deepsparse.transformers.helpers import ( - get_hugging_face_configs, - get_onnx_path, + get_deployment_path, overwrite_transformer_onnx_model_inputs, ) @@ -130,41 +127,17 @@ def setup_onnx_file_path(self) -> str: :return: file path to the processed ONNX file for the engine to compile """ - onnx_path = get_onnx_path(self.model_path) - - if not self.config or not self.tokenizer: - config_found, tokenizer_found = get_hugging_face_configs(self.model_path) - if config_found: - self.config = config_found - if tokenizer_found: - self.tokenizer = tokenizer_found - - if isinstance(self.config, dict): - local_config_path = os.path.join(self.model_path, "config.json") - with open(local_config_path, "w") as f: - json.dump(self.config, f) - self.config = local_config_path - - if isinstance(self.config, (str, Path)): - if str(self.config).endswith(".json"): - self.config_path = self.config - else: - self.config_path = os.path.join(self.config, "config.json") - - self.config = transformers.PretrainedConfig.from_pretrained( - self.config, - finetuning_task=self.task if hasattr(self, "task") else None, - ) - - if isinstance(self.tokenizer, (str, Path)): - self.tokenizer_config_path = os.path.join(self.tokenizer, "tokenizer.json") - - self.tokenizer = AutoTokenizer.from_pretrained( - self.tokenizer, - trust_remote_code=self._trust_remote_code, - model_max_length=self.sequence_length, - ) + deployment_path, onnx_path = get_deployment_path(self.model_path) + self.config = transformers.PretrainedConfig.from_pretrained( + deployment_path, + finetuning_task=self.task if hasattr(self, "task") else None, + ) + self.tokenizer = AutoTokenizer.from_pretrained( + deployment_path, + trust_remote_code=self._trust_remote_code, + model_max_length=self.sequence_length, + ) if not self._delay_overwriting_inputs: # overwrite onnx graph to given required input shape ( diff --git a/tests/deepsparse/transformers/test_helpers.py b/tests/deepsparse/transformers/test_helpers.py index 30309ff1be..5cd1cf0dfa 100644 --- a/tests/deepsparse/transformers/test_helpers.py +++ b/tests/deepsparse/transformers/test_helpers.py @@ -12,14 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os - import onnx import pytest from deepsparse.transformers.helpers import ( - get_hugging_face_configs, - get_onnx_path, + get_deployment_path, get_transformer_layer_init_names, truncate_transformer_onnx_model, ) @@ -35,20 +32,8 @@ ), ], ) -def test_get_onnx_path_and_configs_from_stub(stub): - onnx_path = get_onnx_path(stub) - config_dir, tokenizer_dir = get_hugging_face_configs(stub) - - assert onnx_path.endswith("model.onnx") - assert os.path.exists(onnx_path) - - config_dir_files = os.listdir(config_dir) - assert "config.json" in config_dir_files - - tokenizer_dir_files = os.listdir(tokenizer_dir) - assert "tokenizer.json" in tokenizer_dir_files - # make assert optional if stubs added for models with no known tokenizer_config - assert "tokenizer_config.json" in tokenizer_dir_files +def test_get_deployment_path(stub): + assert get_deployment_path(stub) @pytest.fixture(scope="session") From 310ca6ecf8957fa37062686affd9384a66c73abf Mon Sep 17 00:00:00 2001 From: Damian Date: Mon, 18 Sep 2023 09:53:57 +0000 Subject: [PATCH 3/7] ready for review --- src/deepsparse/transformers/helpers.py | 11 +++++------ src/deepsparse/transformers/pipelines/pipeline.py | 13 ++++++------- src/deepsparse/utils/onnx.py | 4 +++- tests/test_pipeline_benchmark.py | 1 + 4 files changed, 15 insertions(+), 14 deletions(-) diff --git a/src/deepsparse/transformers/helpers.py b/src/deepsparse/transformers/helpers.py index a71c66ec43..f49e8271e6 100644 --- a/src/deepsparse/transformers/helpers.py +++ b/src/deepsparse/transformers/helpers.py @@ -28,7 +28,7 @@ from onnx import ModelProto from deepsparse.log import get_main_logger -from deepsparse.utils.onnx import truncate_onnx_model +from deepsparse.utils.onnx import _MODEL_DIR_ONNX_NAME, truncate_onnx_model from sparsezoo import Model from sparsezoo.utils import save_onnx @@ -43,15 +43,14 @@ _LOGGER = get_main_logger() -_MODEL_DIR_ONNX_NAME = "model.onnx" - def get_deployment_path(model_path: str) -> Tuple[str, str]: """ Returns the path to the deployment directory - for the given model path. The deployment directory - contains all the necessary files for running the model - in the deepsparse pipeline + for the given model path and the path to the mandatory + ONNX model that should reside in the deployment directory. + The deployment directory contains all the necessary files + for running the transformers model in the deepsparse pipeline :param model_path: path to model directory, sparsezoo stub, or ONNX file :return: path to the deployment directory and path to the ONNX file inside diff --git a/src/deepsparse/transformers/pipelines/pipeline.py b/src/deepsparse/transformers/pipelines/pipeline.py index 5e4a1d4f05..065a26ce71 100644 --- a/src/deepsparse/transformers/pipelines/pipeline.py +++ b/src/deepsparse/transformers/pipelines/pipeline.py @@ -118,18 +118,16 @@ def sequence_length(self) -> Union[int, List[int]]: def setup_onnx_file_path(self) -> str: """ - Parses ONNX model from the `model_path` provided. - For tokenizers and model configs, supports paths, dictionaries, - or transformers.PretrainedConfig/transformes.PreTrainedTokenizerBase types. Also - supports the default None, in which case the config and tokenizer are read from - the provided `model_path`. - - Supports sparsezoo stubs + Parses ONNX model from the `model_path` provided. It additionally + creates config and tokenizer objects from the `deployment path`, + derived from the `model_path` provided. :return: file path to the processed ONNX file for the engine to compile """ deployment_path, onnx_path = get_deployment_path(self.model_path) + # temporarily set transformers logger to ERROR to avoid + # printing misleading warnings hf_logger = logging.getLogger("transformers") hf_logger_level = hf_logger.level hf_logger.setLevel(logging.ERROR) @@ -138,6 +136,7 @@ def setup_onnx_file_path(self) -> str: finetuning_task=self.task if hasattr(self, "task") else None, ) hf_logger.setLevel(hf_logger_level) + self.tokenizer = AutoTokenizer.from_pretrained( deployment_path, trust_remote_code=self._trust_remote_code, diff --git a/src/deepsparse/utils/onnx.py b/src/deepsparse/utils/onnx.py index b6d358475e..a291bad58b 100644 --- a/src/deepsparse/utils/onnx.py +++ b/src/deepsparse/utils/onnx.py @@ -55,10 +55,12 @@ "has_model_kv_cache", "CACHE_INPUT_PREFIX", "CACHE_OUTPUT_PREFIX", + "_MODEL_DIR_ONNX_NAME", ] _LOGGER = logging.getLogger(__name__) +_MODEL_DIR_ONNX_NAME = "model.onnx" CACHE_INPUT_PREFIX = "past_key_values" CACHE_OUTPUT_PREFIX = "present" @@ -126,7 +128,7 @@ def model_to_path(model: Union[str, Model, File]) -> str: if Model is not object and isinstance(model, Model): # default to the main onnx file for the model - model = model.onnx_model.path + model = model.deployment.get_file(_MODEL_DIR_ONNX_NAME) elif File is not object and isinstance(model, File): # get the downloaded_path -- will auto download if not on local system model = model.path diff --git a/tests/test_pipeline_benchmark.py b/tests/test_pipeline_benchmark.py index 782a1f8016..485599d044 100644 --- a/tests/test_pipeline_benchmark.py +++ b/tests/test_pipeline_benchmark.py @@ -95,6 +95,7 @@ def test_pipeline_benchmark( if res.stdout is not None: print(f"\n==== test_benchmark output ====\n{res.stdout}") assert res.returncode == 0 + assert "error" not in res.stdout.lower() assert "fail" not in res.stdout.lower() assert "total_inference" in res.stdout.lower() From 5946ef86169c2f8651c6ee3dbf5cf05cc058684c Mon Sep 17 00:00:00 2001 From: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com> Date: Mon, 18 Sep 2023 11:56:55 +0200 Subject: [PATCH 4/7] Apply suggestions from code review --- src/deepsparse/transformers/helpers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/deepsparse/transformers/helpers.py b/src/deepsparse/transformers/helpers.py index f49e8271e6..2cc75792fe 100644 --- a/src/deepsparse/transformers/helpers.py +++ b/src/deepsparse/transformers/helpers.py @@ -61,7 +61,6 @@ def get_deployment_path(model_path: str) -> Tuple[str, str]: return os.path.dirname(model_path), model_path if os.path.isdir(model_path): - # model_files = os.listdir(model_path) if _MODEL_DIR_ONNX_NAME not in model_files: From 77a1bdefd2fd3eb2cab3c3c3a0ef5dddf865783d Mon Sep 17 00:00:00 2001 From: Damian Date: Mon, 18 Sep 2023 10:23:42 +0000 Subject: [PATCH 5/7] fix (failing tests) --- src/deepsparse/utils/onnx.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/deepsparse/utils/onnx.py b/src/deepsparse/utils/onnx.py index a291bad58b..83c38befc0 100644 --- a/src/deepsparse/utils/onnx.py +++ b/src/deepsparse/utils/onnx.py @@ -128,7 +128,8 @@ def model_to_path(model: Union[str, Model, File]) -> str: if Model is not object and isinstance(model, Model): # default to the main onnx file for the model - model = model.deployment.get_file(_MODEL_DIR_ONNX_NAME) + model = model.deployment.get_file(_MODEL_DIR_ONNX_NAME).path + elif File is not object and isinstance(model, File): # get the downloaded_path -- will auto download if not on local system model = model.path From 32f5ae6bc153b3164722383f58cb4f002e2f8c64 Mon Sep 17 00:00:00 2001 From: Damian Date: Mon, 18 Sep 2023 12:07:18 +0000 Subject: [PATCH 6/7] fix test --- tests/utils/test_engine_mocking.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils/test_engine_mocking.py b/tests/utils/test_engine_mocking.py index 8056de0686..cc4a000bf4 100644 --- a/tests/utils/test_engine_mocking.py +++ b/tests/utils/test_engine_mocking.py @@ -39,7 +39,7 @@ def test_mock_engine_calls(engine_mock: MagicMock): os.path.join( os.path.expanduser("~"), ".cache/sparsezoo/neuralmagic/", - "resnet_v1-50-imagenet-pruned85.4block_quantized/model.onnx", + "resnet_v1-50-imagenet-pruned85.4block_quantized/deployment/model.onnx", ), 3, 1, From d22bfce32d9cd8c675dba6e97601d7af986408c1 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Wed, 20 Sep 2023 15:47:25 -0400 Subject: [PATCH 7/7] Support loading HF repos with `hf:` stubs (#1260) --- src/deepsparse/transformers/helpers.py | 12 ++++++++++++ src/deepsparse/utils/onnx.py | 2 +- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/src/deepsparse/transformers/helpers.py b/src/deepsparse/transformers/helpers.py index 2cc75792fe..3f24749ccb 100644 --- a/src/deepsparse/transformers/helpers.py +++ b/src/deepsparse/transformers/helpers.py @@ -75,6 +75,18 @@ def get_deployment_path(model_path: str) -> Tuple[str, str]: zoo_model = Model(model_path) deployment_path = zoo_model.deployment.path return deployment_path, os.path.join(deployment_path, _MODEL_DIR_ONNX_NAME) + elif model_path.startswith("hf:"): + from huggingface_hub import snapshot_download + + deployment_path = snapshot_download(repo_id=model_path.replace("hf:", "", 1)) + onnx_path = os.path.join(deployment_path, _MODEL_DIR_ONNX_NAME) + if not os.path.isfile(onnx_path): + raise ValueError( + f"{_MODEL_DIR_ONNX_NAME} not found in transformers model directory " + f"{deployment_path}. Be sure that an export of the model is written to " + f"{onnx_path}" + ) + return deployment_path, onnx_path else: raise ValueError( f"model_path {model_path} is not a valid file, directory, or zoo stub" diff --git a/src/deepsparse/utils/onnx.py b/src/deepsparse/utils/onnx.py index 83c38befc0..4357c50f7a 100644 --- a/src/deepsparse/utils/onnx.py +++ b/src/deepsparse/utils/onnx.py @@ -142,7 +142,7 @@ def model_to_path(model: Union[str, Model, File]) -> str: model_path = Path(model) if model_path.is_dir(): - return str(model_path / "model.onnx") + return str(model_path / _MODEL_DIR_ONNX_NAME) return model