Skip to content

Commit

Permalink
Merge branch 'infer-ux-fixes' of github.com:neuralmagic/deepsparse in…
Browse files Browse the repository at this point in the history
…to infer-ux-fixes
  • Loading branch information
horheynm committed Oct 2, 2023
2 parents 2a4b972 + 8243740 commit ad9e96d
Show file tree
Hide file tree
Showing 10 changed files with 113 additions and 180 deletions.
11 changes: 11 additions & 0 deletions CITATION.cff
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
cff-version: 1.2.0
title: DeepSparse
message: "Please use this information to cite DeepSparse in research or other publications."
authors:
- affiliation: Neural Magic
given-names: Neural Magic
date-released: 2021-2-4
url: "neuralmagic.com"
repository-code: "https://github.com/neuralmagic/deepsparse"
keywords:
- sparsity-aware inference
132 changes: 31 additions & 101 deletions src/deepsparse/transformers/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import os
import re
import warnings
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import List, Optional, Tuple, Union
Expand All @@ -29,14 +28,13 @@
from onnx import ModelProto

from deepsparse.log import get_main_logger
from deepsparse.utils.onnx import truncate_onnx_model
from deepsparse.utils.onnx import _MODEL_DIR_ONNX_NAME, truncate_onnx_model
from sparsezoo import Model
from sparsezoo.utils import save_onnx


__all__ = [
"get_hugging_face_configs",
"get_onnx_path",
"get_deployment_path",
"overwrite_transformer_onnx_model_inputs",
"fix_numpy_types",
"get_transformer_layer_init_names",
Expand All @@ -45,16 +43,22 @@

_LOGGER = get_main_logger()

_MODEL_DIR_ONNX_NAME = "model.onnx"
_MODEL_DIR_CONFIG_NAME = "config.json"
_MODEL_DIR_TOKENIZER_NAME = "tokenizer.json"
_MODEL_DIR_TOKENIZER_CONFIG_NAME = "tokenizer_config.json"
_OPT_TOKENIZER_FILES = ["special_tokens_map.json", "vocab.json", "merges.txt"]


def get_onnx_path(model_path: str) -> str:
def get_deployment_path(model_path: str) -> Tuple[str, str]:
"""
Returns the path to the deployment directory
for the given model path and the path to the mandatory
ONNX model that should reside in the deployment directory.
The deployment directory contains all the necessary files
for running the transformers model in the deepsparse pipeline
:param model_path: path to model directory, sparsezoo stub, or ONNX file
:return: path to the deployment directory and path to the ONNX file inside
the deployment directory
"""
if os.path.isfile(model_path):
return model_path
# return the parent directory of the ONNX file
return os.path.dirname(model_path), model_path

if os.path.isdir(model_path):
model_files = os.listdir(model_path)
Expand All @@ -65,103 +69,29 @@ def get_onnx_path(model_path: str) -> str:
f"{model_path}. Be sure that an export of the model is written to "
f"{os.path.join(model_path, _MODEL_DIR_ONNX_NAME)}"
)
onnx_path = os.path.join(model_path, _MODEL_DIR_ONNX_NAME)
return model_path, os.path.join(model_path, _MODEL_DIR_ONNX_NAME)

elif model_path.startswith("zoo:"):
zoo_model = Model(model_path)
onnx_path = zoo_model.onnx_model.path
deployment_path = zoo_model.deployment.path
return deployment_path, os.path.join(deployment_path, _MODEL_DIR_ONNX_NAME)
elif model_path.startswith("hf:"):
from huggingface_hub import snapshot_download

deployment_path = snapshot_download(repo_id=model_path.replace("hf:", "", 1))
onnx_path = os.path.join(deployment_path, _MODEL_DIR_ONNX_NAME)
if not os.path.isfile(onnx_path):
raise ValueError(
f"{_MODEL_DIR_ONNX_NAME} not found in transformers model directory "
f"{deployment_path}. Be sure that an export of the model is written to "
f"{onnx_path}"
)
return deployment_path, onnx_path
else:
raise ValueError(
f"model_path {model_path} is not a valid file, directory, or zoo stub"
)

return onnx_path


def get_hugging_face_configs(model_path: str) -> Tuple[str, str]:
"""
:param model_path: path to model directory, transformers sparsezoo stub,
or directory containing `config.json`, and `tokenizer.json` files.
If the json files are not found, an exception will be raised.
:return: tuple of ONNX file path, parent directory of config file
if it exists, and parent directory of tokenizer config file if it
exists. (Parent directories returned instead of absolute path
for compatibility with transformers .from_pretrained() method)
"""
config_path = None
tokenizer_path = None

if os.path.isdir(model_path):
model_files = os.listdir(model_path)
# attempt to read config and tokenizer from sparsezoo-like framework directory
framework_dir = None
if "framework" in model_files:
framework_dir = os.path.join(model_path, "framework")
if "pytorch" in model_files:
framework_dir = os.path.join(model_path, "pytorch")
if framework_dir and os.path.isdir(framework_dir):
framework_files = os.listdir(framework_dir)
if _MODEL_DIR_CONFIG_NAME in framework_files:
config_path = framework_dir
if (
_MODEL_DIR_TOKENIZER_NAME
or _MODEL_DIR_TOKENIZER_CONFIG_NAME in framework_files
):
tokenizer_path = framework_dir

# prefer config and tokenizer files in same directory as model.onnx
if _MODEL_DIR_CONFIG_NAME in model_files:
config_path = model_path
if (
_MODEL_DIR_TOKENIZER_NAME in model_files
or _MODEL_DIR_TOKENIZER_CONFIG_NAME in model_files
):
tokenizer_path = model_path

elif model_path.startswith("zoo:"):
zoo_model = Model(model_path)
config_path = _get_file_parent(
zoo_model.deployment.default.get_file(_MODEL_DIR_CONFIG_NAME).path
)
tokenizer_file = zoo_model.deployment.default.get_file(
_MODEL_DIR_TOKENIZER_NAME
)

tokenizer_config_file = zoo_model.deployment.default.get_file(
_MODEL_DIR_TOKENIZER_CONFIG_NAME
)

if tokenizer_config_file is not None:
tokenizer_config_path = _get_file_parent(
tokenizer_config_file.path
) # trigger download of tokenizer_config

if tokenizer_file is not None:
tokenizer_path = _get_file_parent(tokenizer_file.path)
else:
# if tokenizer_file is not present, we assume it's the OPT model
# this means that we use tokenizer_config_path instead of tokenizer_path
# and need to download the additional tokenizer files
tokenizer_path = tokenizer_config_path
for file in _OPT_TOKENIZER_FILES:
zoo_model.deployment.default.get_file(file).path

else:
raise ValueError(
f"model_path {model_path} is not a valid directory or zoo stub"
)

if config_path is None or tokenizer_path is None:
warnings.warn(
f"Unable to find model or tokenizer configs for model_path {model_path}. "
f"model_path must be a directory containing config.json, and/or "
f"tokenizer.json files. Found config and tokenizer paths: {config_path}, "
f"{tokenizer_path}. If not given, set the `tokenizer` and `config` args "
"for the Pipeline."
)

return config_path, tokenizer_path


def overwrite_transformer_onnx_model_inputs(
path: str,
Expand Down
8 changes: 3 additions & 5 deletions src/deepsparse/transformers/pipelines/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,12 +179,10 @@ def process_engine_outputs(
"""

engine_outputs, session_ids = list(*engine_outputs)

kwargs["session_ids"] = session_ids
# process the engine outputs within the context of TextGenerationPipeline
text_generation_output = super().process_engine_outputs(
engine_outputs, **kwargs
)
# create the ChatOutput from the data provided
return ChatOutput(**text_generation_output.dict(), session_ids=session_ids)
return super().process_engine_outputs(engine_outputs, **kwargs)

def engine_forward(
self, engine_inputs: List[numpy.ndarray], context: Dict
Expand Down
69 changes: 21 additions & 48 deletions src/deepsparse/transformers/pipelines/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@
Base Pipeline class for transformers inference pipeline
"""

import json
import logging
import os
import warnings
from pathlib import Path
from typing import Any, Dict, List, Mapping, Optional, Union
Expand All @@ -29,8 +27,7 @@

from deepsparse import Bucketable, Pipeline
from deepsparse.transformers.helpers import (
get_hugging_face_configs,
get_onnx_path,
get_deployment_path,
overwrite_transformer_onnx_model_inputs,
)

Expand Down Expand Up @@ -121,54 +118,30 @@ def sequence_length(self) -> Union[int, List[int]]:

def setup_onnx_file_path(self) -> str:
"""
Parses ONNX model from the `model_path` provided.
For tokenizers and model configs, supports paths, dictionaries,
or transformers.PretrainedConfig/transformes.PreTrainedTokenizerBase types. Also
supports the default None, in which case the config and tokenizer are read from
the provided `model_path`.
Supports sparsezoo stubs
Parses ONNX model from the `model_path` provided. It additionally
creates config and tokenizer objects from the `deployment path`,
derived from the `model_path` provided.
:return: file path to the processed ONNX file for the engine to compile
"""
onnx_path = get_onnx_path(self.model_path)

if not self.config or not self.tokenizer:
config_found, tokenizer_found = get_hugging_face_configs(self.model_path)
if config_found:
self.config = config_found
if tokenizer_found:
self.tokenizer = tokenizer_found

if isinstance(self.config, dict):
local_config_path = os.path.join(self.model_path, "config.json")
with open(local_config_path, "w") as f:
json.dump(self.config, f)
self.config = local_config_path

if isinstance(self.config, (str, Path)):
if str(self.config).endswith(".json"):
self.config_path = self.config
else:
self.config_path = os.path.join(self.config, "config.json")

hf_logger = logging.getLogger("transformers")
hf_logger_level = hf_logger.level
hf_logger.setLevel(logging.ERROR)
self.config = transformers.PretrainedConfig.from_pretrained(
self.config,
finetuning_task=self.task if hasattr(self, "task") else None,
)
hf_logger.setLevel(hf_logger_level)

if isinstance(self.tokenizer, (str, Path)):
self.tokenizer_config_path = os.path.join(self.tokenizer, "tokenizer.json")
deployment_path, onnx_path = get_deployment_path(self.model_path)

# temporarily set transformers logger to ERROR to avoid
# printing misleading warnings
hf_logger = logging.getLogger("transformers")
hf_logger_level = hf_logger.level
hf_logger.setLevel(logging.ERROR)
self.config = transformers.PretrainedConfig.from_pretrained(
deployment_path,
finetuning_task=self.task if hasattr(self, "task") else None,
)
hf_logger.setLevel(hf_logger_level)

self.tokenizer = AutoTokenizer.from_pretrained(
self.tokenizer,
trust_remote_code=self._trust_remote_code,
model_max_length=self.sequence_length,
)
self.tokenizer = AutoTokenizer.from_pretrained(
deployment_path,
trust_remote_code=self._trust_remote_code,
model_max_length=self.sequence_length,
)

if not self._delay_overwriting_inputs:
# overwrite onnx graph to given required input shape
Expand Down
16 changes: 13 additions & 3 deletions src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,9 @@ def _create_generated_text_output(
finished=False,
)

def _stream_engine_outputs(self, engine_outputs, prompts, generation_config):
def _stream_engine_outputs(
self, engine_outputs, prompts, generation_config, **kwargs
):
for output in engine_outputs:
generated_tokens, generated_logits, finished_reason = output
logits = generated_logits if generation_config.output_scores else None
Expand All @@ -552,10 +554,18 @@ def _stream_engine_outputs(self, engine_outputs, prompts, generation_config):
finished_reason[0],
logits,
)
yield TextGenerationOutput(
# Add session_id to schema if it exists
# more relevant for `ChatPipeline`
schema_kwargs = (
{"session_ids": session_ids}
if (session_ids := kwargs.get("session_ids"))
else {}
)
yield self.output_schema(
created=datetime.datetime.now(),
prompts=prompts,
generations=[generation],
**schema_kwargs,
)

def process_engine_outputs(
Expand Down Expand Up @@ -634,7 +644,7 @@ def process_engine_outputs(
)
outputs.update(debug_params)

return TextGenerationOutput(**outputs)
return self.output_schema(**outputs)

def engine_forward(
self, engine_inputs: List[numpy.ndarray], context: Dict
Expand Down
7 changes: 5 additions & 2 deletions src/deepsparse/utils/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,12 @@
"has_model_kv_cache",
"CACHE_INPUT_PREFIX",
"CACHE_OUTPUT_PREFIX",
"_MODEL_DIR_ONNX_NAME",
]

_LOGGER = logging.getLogger(__name__)

_MODEL_DIR_ONNX_NAME = "model.onnx"
CACHE_INPUT_PREFIX = "past_key_values"
CACHE_OUTPUT_PREFIX = "present"

Expand Down Expand Up @@ -126,7 +128,8 @@ def model_to_path(model: Union[str, Model, File]) -> str:

if Model is not object and isinstance(model, Model):
# default to the main onnx file for the model
model = model.onnx_model.path
model = model.deployment.get_file(_MODEL_DIR_ONNX_NAME).path

elif File is not object and isinstance(model, File):
# get the downloaded_path -- will auto download if not on local system
model = model.path
Expand All @@ -139,7 +142,7 @@ def model_to_path(model: Union[str, Model, File]) -> str:

model_path = Path(model)
if model_path.is_dir():
return str(model_path / "model.onnx")
return str(model_path / _MODEL_DIR_ONNX_NAME)

return model

Expand Down
Loading

0 comments on commit ad9e96d

Please sign in to comment.