From c6aa08fde919046b4b6a5980cd4d80018849cf95 Mon Sep 17 00:00:00 2001 From: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com> Date: Wed, 12 Jul 2023 17:15:12 +0200 Subject: [PATCH] [Feature Branch] KV Cache Interface (#1083) * initial commit * Update src/deepsparse/license.py * limit to 150mb * ready to review * initial commit * [Codegen][ORT][Static Seq Length] TextGenerationPipeline (#946) * initial commit * coreys simplifications * finishing the second model static * ready, time for beautification * ready for review * moved the code to examples * fix eos logic * add argument num_tokens_to_generate * [CodeGen][Documentation] (#956) * initial commit * coreys simplifications * finishing the second model static * ready, time for beautification * ready for review * moved the code to examples * fix eos logic * add argument num_tokens_to_generate * initial commit * change order * Update examples/codegen/README.md Co-authored-by: corey-nm <109536191+corey-nm@users.noreply.github.com> --------- Co-authored-by: corey-nm <109536191+corey-nm@users.noreply.github.com> * reimplementation for generative pipelines * restore text generation from examples * [CodeGen] ONNX model loading to support >2Gb models / two engines (#991) * refactor sucessfull * Pipeline fully refactored, time to test engine support. Note: Sliding window not yet implemented! * First iteration with Sage * Apply suggestions from code review * ORT agrees with the Engine. But they both give not entirely correct result. Hey, this is good news still * dynamic ORT vs static DS * pipeline handles OPT multitoken pass * fixes to get static pipeline a little further along * adjust shapes and slicing to enable static autoregressive pass - ISSUE: tokens past the base seq len are repeated * migrate from cache_length to positions input * got if working for multitoken + single token scenario * cleanup the pipeline * further cleanup post merge * Pipeline working for single-token inference only * do not load the onnx model with external files twice * pipeline never redundantly saves the external data + more robust tokenizer * Stop saving tmp files, otherwise the engine looks for external files in the wrong place * Left pad support * cleanup * cleanup2 * Add in pipeline timing * add in force tokens logic * remove input validation for text generation pipelines * remove multitoken support for now * remove kv cache engine and other fixes * nest input shape override * comment out input shape override * add non batch override for ORT * clean up generation pipeline * initial commit * Update src/deepsparse/license.py * limit to 150mb * ready to review * fix the erronous Makefile * perhaps fixed GHA * take into consideration that GHA creates four files * initial commit * tested with actual model * remove val_inp argument * Update README.md * Apply suggestions from code review * Update README.md * [BugFix] Update deepsparse dockerfile (#1069) * Remove autoinstall triggering commands * Fix typo * initial implementation * working implementation for pipeline input * [Fix] Fix CLI benchmark errors (#1071) * initial commit * ready for review * Update src/deepsparse/utils/onnx.py * Clean a typo in the pipeline code * initial commit * [KV Cache Interface] DecoderKVCache (#1084) * initial implementation * initial implementation * Revert "initial implementation" This reverts commit 765a5f71fdb4b6beb6cff3b990de73ec32784fd6. * Merge DecoderKVCache with KVCacheORT (KVCacheORT will not exist, it is just an abstraction) * rebase * add tests * DecoderKVCache that manipulates cache state and additionally passes info to the engine via KVCache object * improvements after the sync with Mark * remove prefill * fix the computation of total cache capacity * address PR comments * [WiP] [KV Cache Interface] Text Generation & Decoder Engine Implementation (#1089) * initial commit * Update src/deepsparse/license.py * limit to 150mb * ready to review * initial commit * [Codegen][ORT][Static Seq Length] TextGenerationPipeline (#946) * initial commit * coreys simplifications * finishing the second model static * ready, time for beautification * ready for review * moved the code to examples * fix eos logic * add argument num_tokens_to_generate * [CodeGen][Documentation] (#956) * initial commit * coreys simplifications * finishing the second model static * ready, time for beautification * ready for review * moved the code to examples * fix eos logic * add argument num_tokens_to_generate * initial commit * change order * Update examples/codegen/README.md Co-authored-by: corey-nm <109536191+corey-nm@users.noreply.github.com> --------- Co-authored-by: corey-nm <109536191+corey-nm@users.noreply.github.com> * reimplementation for generative pipelines * restore text generation from examples * [CodeGen] ONNX model loading to support >2Gb models / two engines (#991) * refactor sucessfull * Pipeline fully refactored, time to test engine support. Note: Sliding window not yet implemented! * First iteration with Sage * Apply suggestions from code review * ORT agrees with the Engine. But they both give not entirely correct result. Hey, this is good news still * dynamic ORT vs static DS * pipeline handles OPT multitoken pass * fixes to get static pipeline a little further along * adjust shapes and slicing to enable static autoregressive pass - ISSUE: tokens past the base seq len are repeated * migrate from cache_length to positions input * got if working for multitoken + single token scenario * cleanup the pipeline * further cleanup post merge * Pipeline working for single-token inference only * do not load the onnx model with external files twice * pipeline never redundantly saves the external data + more robust tokenizer * Stop saving tmp files, otherwise the engine looks for external files in the wrong place * Left pad support * cleanup * cleanup2 * Add in pipeline timing * add in force tokens logic * remove input validation for text generation pipelines * remove multitoken support for now * remove kv cache engine and other fixes * nest input shape override * comment out input shape override * add non batch override for ORT * clean up generation pipeline * initial commit * Update src/deepsparse/license.py * limit to 150mb * ready to review * fix the erronous Makefile * perhaps fixed GHA * take into consideration that GHA creates four files * initial commit * tested with actual model * remove val_inp argument * Update README.md * Apply suggestions from code review * Update README.md * initial implementation * initial implementation * Revert "initial implementation" This reverts commit 765a5f71fdb4b6beb6cff3b990de73ec32784fd6. * rebase * add tests * strip down complexity out of text generation pipeline * initial implementation * In a good state for the review on 22.06 * remove files to make review easier * Revert "remove files to make review easier" This reverts commit ea82e99c858091b62513a0451889edd6d5e82898. * Merge DecoderKVCache with KVCacheORT (KVCacheORT will not exist, it is just an abstraction) * rebase * add tests * Delete decoder_kv_cache.py * Delete test_decoder_kv_cache.py * DecoderKVCache that manipulates cache state and additionally passes info to the engine via KVCache object * fix formatting of the transformers/utils/__init__.py * improvements after the sync with Mark * All changes applied, time for testing * Scaffolding to also run multitoken * add delay_overwriting_inputs * multitoken is working (although in limited capacity) * fix no kv cache inference * Do not create engine if not needed * remove the prefill option * fix docstring * remove prefill * fix the computation of total cache capacity * merge * addressed PR comments * quality --------- Co-authored-by: corey-nm <109536191+corey-nm@users.noreply.github.com> Co-authored-by: Mark Kurtz Co-authored-by: Benjamin * now kv cache decoder holds information about the num of tokens preprocessed. also encountered first bug when running with the engine * cleanup the old files * Update src/deepsparse/transformers/engines/nl_decoder_engine.py * ready for review * ready for testing * managed to get first logits right * Delete example * cleanup before sharing with Ben and Sage * Update src/deepsparse/transformers/engines/nl_decoder_engine.py * assert proper padding on pipeline init * now also supporting kv cache perplexity. time for cleanup * ready for review * correctly print engine info * work with left padding of the tokenizer * quality * fix the multitoken inference * Perplexity Eval for Text Generation Models (#1073) * initial commit * Update src/deepsparse/license.py * limit to 150mb * ready to review * initial commit * [Codegen][ORT][Static Seq Length] TextGenerationPipeline (#946) * initial commit * coreys simplifications * finishing the second model static * ready, time for beautification * ready for review * moved the code to examples * fix eos logic * add argument num_tokens_to_generate * [CodeGen][Documentation] (#956) * initial commit * coreys simplifications * finishing the second model static * ready, time for beautification * ready for review * moved the code to examples * fix eos logic * add argument num_tokens_to_generate * initial commit * change order * Update examples/codegen/README.md Co-authored-by: corey-nm <109536191+corey-nm@users.noreply.github.com> --------- Co-authored-by: corey-nm <109536191+corey-nm@users.noreply.github.com> * reimplementation for generative pipelines * restore text generation from examples * [CodeGen] ONNX model loading to support >2Gb models / two engines (#991) * refactor sucessfull * Pipeline fully refactored, time to test engine support. Note: Sliding window not yet implemented! * First iteration with Sage * Apply suggestions from code review * ORT agrees with the Engine. But they both give not entirely correct result. Hey, this is good news still * dynamic ORT vs static DS * pipeline handles OPT multitoken pass * fixes to get static pipeline a little further along * adjust shapes and slicing to enable static autoregressive pass - ISSUE: tokens past the base seq len are repeated * migrate from cache_length to positions input * got if working for multitoken + single token scenario * cleanup the pipeline * further cleanup post merge * Pipeline working for single-token inference only * do not load the onnx model with external files twice * pipeline never redundantly saves the external data + more robust tokenizer * Stop saving tmp files, otherwise the engine looks for external files in the wrong place * Left pad support * cleanup * cleanup2 * Add in pipeline timing * add in force tokens logic * remove input validation for text generation pipelines * remove multitoken support for now * remove kv cache engine and other fixes * nest input shape override * comment out input shape override * add non batch override for ORT * clean up generation pipeline * initial commit * Update src/deepsparse/license.py * limit to 150mb * ready to review * fix the erronous Makefile * perhaps fixed GHA * take into consideration that GHA creates four files * initial commit * tested with actual model * remove val_inp argument * Update README.md * Apply suggestions from code review * Update README.md * [BugFix] Update deepsparse dockerfile (#1069) * Remove autoinstall triggering commands * Fix typo * initial implementation * working implementation for pipeline input * [Fix] Fix CLI benchmark errors (#1071) * initial commit * ready for review * Update src/deepsparse/utils/onnx.py * Clean a typo in the pipeline code * cleanup the old files * Update src/deepsparse/transformers/engines/nl_decoder_engine.py * ready for review * ready for testing * assert proper padding on pipeline init * now also supporting kv cache perplexity. time for cleanup * ready for review * correctly print engine info * work with left padding of the tokenizer * quality * fix the multitoken inference --------- Co-authored-by: corey-nm <109536191+corey-nm@users.noreply.github.com> Co-authored-by: Mark Kurtz Co-authored-by: Benjamin Co-authored-by: Rahul Tuli * [Text Generation] Run deepsparse engine without the LIB.kv_cache object (#1108) * Update src/deepsparse/transformers/engines/nl_decoder_engine.py * fixed the logic to assert correct multibatch inference * fix integration tests * initial implementation * fix the integration test * better solution for fixing the issues caused by this PR in GHA * revert changes to yolo pipeline * Update src/deepsparse/transformers/engines/nl_decoder_engine.py Co-authored-by: Rahul Tuli * response to Rahuls comments --------- Co-authored-by: Mark Kurtz Co-authored-by: Benjamin Co-authored-by: Rahul Tuli --- src/deepsparse/engine.py | 48 -- src/deepsparse/pipeline.py | 62 ++- src/deepsparse/tasks.py | 21 + src/deepsparse/transformers/README.md | 48 +- .../transformers/engines/__init__.py | 15 + .../transformers/engines/nl_decoder_engine.py | 340 +++++++++++++ .../transformers/eval_downstream.py | 45 +- src/deepsparse/transformers/helpers.py | 7 +- src/deepsparse/transformers/metrics.py | 139 +++++- .../transformers/pipelines/pipeline.py | 36 +- .../transformers/pipelines/text_generation.py | 468 ++++++++++++++++++ src/deepsparse/transformers/utils/__init__.py | 18 + .../transformers/utils/decoder_kv_cache.py | 198 ++++++++ src/deepsparse/transformers/utils/helpers.py | 66 +++ .../utils/test_decoder_kv_cache.py | 70 +++ 15 files changed, 1488 insertions(+), 93 deletions(-) create mode 100644 src/deepsparse/transformers/engines/__init__.py create mode 100644 src/deepsparse/transformers/engines/nl_decoder_engine.py create mode 100644 src/deepsparse/transformers/pipelines/text_generation.py create mode 100644 src/deepsparse/transformers/utils/__init__.py create mode 100644 src/deepsparse/transformers/utils/decoder_kv_cache.py create mode 100644 src/deepsparse/transformers/utils/helpers.py create mode 100644 tests/deepsparse/transformers/utils/test_decoder_kv_cache.py diff --git a/src/deepsparse/engine.py b/src/deepsparse/engine.py index 8f4bb14c0b..0f5160299d 100644 --- a/src/deepsparse/engine.py +++ b/src/deepsparse/engine.py @@ -28,7 +28,6 @@ from deepsparse.benchmark import BenchmarkResults from deepsparse.utils import ( generate_random_inputs, - get_output_names, join_engine_outputs, model_to_path, override_onnx_input_shapes, @@ -56,7 +55,6 @@ "Scheduler", "Context", "MultiModelEngine", - "KVCacheEngine", "BaseEngine", ] @@ -867,52 +865,6 @@ def __init__( ) -class KVCacheEngine(Engine): - """ - Engine that can do kv caching. - """ - - def __init__( - self, - model: Union[str, "Model", "File"], - batch_size: int = 1, - num_cores: int = None, - num_streams: int = None, - scheduler: Scheduler = None, - input_shapes: List[List[int]] = None, - kv_cache_bools: List[bool] = None, - prev_cache_length: int = 0, - ): - BaseEngine.construct( - self, model, batch_size, num_cores, num_streams, scheduler, input_shapes - ) - - if kv_cache_bools is None: - # If no list was provided, then we assume all outputs except for the first are KV caches - # Note: In the future we can look at the names of outputs to be more sure - # - # Create a boolean list of every output of the model - output_names = get_output_names(self._model_path) - kv_cache_bools = [True for i in range(len(output_names))] - # Assume first input is logits and logits ought not to be cached - kv_cache_bools[0] = False - - num_streams = _validate_num_streams(num_streams, self._num_cores) - if self._input_shapes: - raise NotImplementedError("Don't do this yet :)") - else: - self._eng_net = LIB.deepsparse_engine( - self._model_path, - self._batch_size, - self._num_cores, - num_streams, - self._scheduler.value, - None, - kv_cache_bools, - prev_cache_length, - ) - - def compile_model( model: Union[str, "Model", "File"], batch_size: int = 1, diff --git a/src/deepsparse/pipeline.py b/src/deepsparse/pipeline.py index f84a190946..4f21ab54aa 100644 --- a/src/deepsparse/pipeline.py +++ b/src/deepsparse/pipeline.py @@ -59,6 +59,7 @@ "yolo_pipeline", "Bucketable", "BucketingPipeline", + "create_engine", ] DEEPSPARSE_ENGINE = "deepsparse" @@ -753,26 +754,10 @@ def log_inference_times(self, timer: StagedTimer): category=MetricCategories.SYSTEM, ) - def _initialize_engine(self) -> Union[Engine, ORTEngine]: - engine_type = self.engine_type.lower() - - if engine_type == DEEPSPARSE_ENGINE: - if self.context is not None and isinstance(self.context, Context): - self._engine_args.pop("num_cores", None) - self._engine_args.pop("scheduler", None) - self._engine_args["context"] = self.context - return MultiModelEngine( - model=self.onnx_file_path, - **self._engine_args, - ) - return Engine(self.onnx_file_path, **self._engine_args) - elif engine_type == ORT_ENGINE: - return ORTEngine(self.onnx_file_path, **self._engine_args) - else: - raise ValueError( - f"Unknown engine_type {self.engine_type}. Supported values include: " - f"{SUPPORTED_PIPELINE_ENGINES}" - ) + def _initialize_engine(self) -> Union[Engine, MultiModelEngine, ORTEngine]: + return create_engine( + self.onnx_file_path, self.engine_type, self._engine_args, self.context + ) def _identifier(self): # get pipeline identifier; used in the context of logging @@ -950,6 +935,43 @@ def route_input_to_bucket( pass +def create_engine( + onnx_file_path: str, + engine_type: str, + engine_args: Dict, + context: Optional[Context] = None, +) -> Union[Engine, MultiModelEngine, ORTEngine]: + """ + Create an inference engine for a given ONNX model + + :param onnx_file_path: path to ONNX model file + :param engine_type: type of engine to create. + :param engine_args: arguments to pass to engine constructor + :param context: context to use for engine + :return: inference engine + """ + engine_type = engine_type.lower() + + if engine_type == DEEPSPARSE_ENGINE: + if context is not None and isinstance(context, Context): + engine_args.pop("num_cores", None) + engine_args.pop("scheduler", None) + engine_args["context"] = context + return MultiModelEngine( + model=onnx_file_path, + **engine_args, + ) + return Engine(onnx_file_path, **engine_args) + + if engine_type == ORT_ENGINE: + return ORTEngine(onnx_file_path, **engine_args) + + raise ValueError( + f"Unknown engine_type {engine_type}. Supported values include: " + f"{SUPPORTED_PIPELINE_ENGINES}" + ) + + def _initialize_executor_and_workers( batch_size: Optional[int], workers_or_executor: Optional[Union[int, ThreadPoolExecutor]], diff --git a/src/deepsparse/tasks.py b/src/deepsparse/tasks.py index aa6c349eb6..c02ed44482 100644 --- a/src/deepsparse/tasks.py +++ b/src/deepsparse/tasks.py @@ -95,6 +95,12 @@ class SupportedTasks: ), ) + text_generation = namedtuple("text_generation", ["opt", "codegen", "bloom"])( + codegen=AliasedTask("codegen", []), + opt=AliasedTask("opt", []), + bloom=AliasedTask("bloom", []), + ) + image_classification = namedtuple("image_classification", ["image_classification"])( image_classification=AliasedTask( "image_classification", @@ -150,6 +156,9 @@ def check_register_task( # custom task, register the CustomPipeline import deepsparse.pipelines.custom_pipeline # noqa: F401 + elif cls.is_text_generation(task): + import deepsparse.transformers.pipelines.text_generation # noqa: F401 + elif cls.is_nlp(task): # trigger transformers pipelines to register with Pipeline.register import deepsparse.transformers.pipelines # noqa: F401 @@ -193,6 +202,18 @@ def check_register_task( f"{list(all_tasks)}" ) + @classmethod + def is_text_generation(cls, task: str) -> bool: + """ + :param task: the name of the task to check whether it is a text generation task + such as codegen + :return: True if it is a text generation task, False otherwise + """ + return any( + text_generation_task.matches(task) + for text_generation_task in cls.text_generation + ) + @classmethod def is_nlp(cls, task: str) -> bool: """ diff --git a/src/deepsparse/transformers/README.md b/src/deepsparse/transformers/README.md index 86a1adbffe..be8df1ebd5 100644 --- a/src/deepsparse/transformers/README.md +++ b/src/deepsparse/transformers/README.md @@ -10,6 +10,7 @@ methods such as [pruning](https://neuralmagic.com/blog/pruning-overview/) and [q These techniques result in significantly more performant and smaller models with limited to no effect on the baseline metrics. This integration currently supports several fundamental NLP tasks: +- **Text Generation** - given the input prompt, generate an output text sequence (e.g. to fill in incomplete text, summarize or paraphrase a text paragraph) - **Question Answering** - posing questions about a document - **Sentiment Analysis** - assigning a sentiment to a piece of text - **Text Classification** - assigning a label or class to a piece of text (e.g duplicate question pairing) @@ -32,9 +33,9 @@ This grants the engine the flexibility to serve any model in a framework-agnosti The DeepSparse pipelines require the following files within a folder on the local server to properly load a Transformers model: - `model.onnx`: The exported Transformers model in the [ONNX format](https://github.com/onnx/onnx). -- `tokenizer.json`: The [HuggingFace compatible tokenizer configuration](https://huggingface.co/docs/transformers/fast_tokenizers) used with the model. - `config.json`: The [HuggingFace compatible configuration file](https://huggingface.co/docs/transformers/main_classes/configuration) used with the model. - +- `tokenizer_config.json`: The [HuggingFace compatible tokenizer configuration](https://huggingface.co/docs/transformers/fast_tokenizers) used with the model. +- `tokenizer.json`, `special_tokens_map.json`, `vocab.json`, `merges.txt` (optional): Other files that may be required by a tokenizer Below we describe two possibilities to obtain the required structure. #### SparseML Export @@ -48,7 +49,7 @@ sparseml.transformers.export_onnx --task question-answering --model_path model_p ``` This creates `model.onnx` file, in the directory of your `model_path`(e.g. `/trained_model/model.onnx`). -The `tokenizer.json` and `config.json` are stored under the `model_path` folder as well, so a DeepSparse pipeline ca be directly instantiated by using that folder after export (e.g. `/trained_model/`). +Any additional, required files, such as e.g.`tokenizer.json` or `config.json`, are stored under the `model_path` folder as well, so a DeepSparse pipeline can be directly instantiated by using that folder after export (e.g. `/trained_model/`). #### SparseZoo Stub Alternatively, you can skip the process of the ONNX model export by using Neural Magic's [SparseZoo](https://sparsezoo.neuralmagic.com/). The SparseZoo contains pre-sparsified models and SparseZoo stubs enable you to reference any model on the SparseZoo in a convenient and predictable way. @@ -138,6 +139,47 @@ response.text >> '{"score":0.9534820914268494,"start":8,"end":14,"answer":"batman"}' ``` +### Text Generation +The text generation task generates a sequence of tokens given the prompt. Popular text generation LLMs (Large Language Models) are used +for the chatbots (the instruction models), code generation, text summarization, or filling out the missing text. The following example uses a sparsified text classification +OPT model to complete the prompt + +[List of available SparseZoo Text Generation Models]( +https://sparsezoo.neuralmagic.com/?useCase=text_generation) + +#### Python Pipeline +```python +from deepsparse import Pipeline + +opt_pipeline = Pipeline.create(task="opt") + +inference = opt_pipeline("Who is the president of the United States?") + +>> 'The president of the United States is the head of the executive branch of government...' +``` + +#### HTTP Server +Spinning up: +```bash +deepsparse.server \ + task text-generation \ + --model_path # TODO: Pending until text generation models get uploaded to SparseZoo +``` + +Making a request: +```python +import requests + +url = "http://localhost:5543/predict" # Server's port default to 5543 + +obj = {"sequence": "Who is the president of the United States?"} + +response = requests.post(url, json=obj) +response.text + +>> 'The president of the United States is the head of the executive branch of government...' +``` + ### Sentiment Analysis The sentiment analysis task takes in a sentence and classifies its sentiment. The following example uses a pruned and quantized text sentiment analysis BERT model trained on the `sst2` dataset downloaded diff --git a/src/deepsparse/transformers/engines/__init__.py b/src/deepsparse/transformers/engines/__init__.py new file mode 100644 index 0000000000..95107e43f8 --- /dev/null +++ b/src/deepsparse/transformers/engines/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# flake8: noqa +from .nl_decoder_engine import * diff --git a/src/deepsparse/transformers/engines/nl_decoder_engine.py b/src/deepsparse/transformers/engines/nl_decoder_engine.py new file mode 100644 index 0000000000..6ca2b81dc7 --- /dev/null +++ b/src/deepsparse/transformers/engines/nl_decoder_engine.py @@ -0,0 +1,340 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import logging +from typing import Any, Dict, List, Optional, Tuple + +import numpy +import onnx +from transformers import AutoTokenizer + +from deepsparse.engine import Context +from deepsparse.pipeline import DEEPSPARSE_ENGINE, create_engine +from deepsparse.transformers.utils.decoder_kv_cache import DecoderKVCache +from deepsparse.transformers.utils.helpers import generate_session_id, softmax +from sparsezoo.utils.onnx import save_onnx + + +_LOGGER = logging.getLogger(__name__) + +__all__ = ["NLDecoderEngine"] + +_CACHE_INPUT_NAME = "past_key_values" + + +class NLDecoderEngine: + """ + The NLDecoderEngine (NaturalLanguageDecoderEngine) handles the + logic around the inference for Natural Language pipeline, + including batching and kv cache logic. + + :param onnx_file_path: The path to the onnx model file + :param engine_type: The type of engine to use for the inference + :param engine_args: The arguments to pass to the engine + :param sequence_length: The maximum sequence length to run the engine for + :param input_ids_length: The maximum input ids length to run the engine for + :param engine_context: The context to run the engine in + :param sampling_temperature: The temperature to use for sampling + :param deterministic: Whether to use deterministic sampling + :param tokenizer: The tokenizer to used for engine inputs + :param engine_context: The context to run the engine in + :param use_deepsparse_cache: Whether to use the deepsparse + kv cache in the DecoderKVCache object or not + """ + + def __init__( + self, + onnx_file_path: str, + engine_type: str, + engine_args: Dict[str, Any], + sequence_length: int, + input_ids_length: int, + tokenizer: AutoTokenizer, + sampling_temperature: float = 1.0, + deterministic: bool = True, + engine_context: Optional[Context] = None, + use_deepsparse_cache=False, + ): + + onnx_file_path, output_indices_to_be_cached = self.overwrite_onnx_model_inputs( + onnx_file_path=onnx_file_path, + batch_size=engine_args.get("batch_size", 1), + sequence_length=sequence_length, + input_ids_length=input_ids_length, + ) + kv_cache_enabled = False + if sum(output_indices_to_be_cached): + kv_cache_enabled = True + if use_deepsparse_cache and engine_type == DEEPSPARSE_ENGINE: + # inform the engine, that are using the kv cache + engine_args["cache_output_bools"] = output_indices_to_be_cached + + self.engine = create_engine( + onnx_file_path=onnx_file_path, + engine_type=engine_type, + engine_args=engine_args, + context=engine_context, + ) + self.sequence_length = sequence_length + self.sampling_temperature = sampling_temperature + self.deterministic = deterministic + self.input_ids_length = input_ids_length + self.kv_cache_enabled = kv_cache_enabled + self.kv_cache = ( + DecoderKVCache(use_deepsparse_cache) if kv_cache_enabled else None + ) + self._freeze_first_position = self._should_freeze_first_position(tokenizer) + self._session_id = generate_session_id() + + @property + def session_id(self) -> str: + """ + :return: The session id for the kv_cache if enabled + """ + return self._session_id + + @session_id.setter + def session_id(self, session_id: str): + """ + :param session_id: The session id to set for the kv_cache + """ + self._session_id = session_id + + @property + def onnx_input_names_no_cache(self) -> List[str]: + """ + :return: The input names for the onnx model, excluding + the potential kv cache inputs + """ + return [ + name + for name in self.engine.input_names + if not name.startswith(_CACHE_INPUT_NAME) + ] + + def __call__( + self, + inp: List[numpy.ndarray], + val_inp: bool = True, + ) -> Tuple[numpy.ndarray, numpy.ndarray]: + """ + The main entry point for running the engine. + + :param inp: The input to run the engine with. We expect a + list of numpy arrays that contain the input ids, + attention mask, and position ids (optionally) + :param val_inp: Whether the input is for validation or not + :return: The generated token and corresponding logits + """ + if self.kv_cache: + # if kv cache is enabled, we need to add the kv cache state + # to the input + inp = self.add_kv_cache_to_input(inp) + + out = self.engine.run(inp, val_inp) + + if self.kv_cache: + logits, *kv_cache_state = out + self.update_kv_cache( + kv_cache_state=kv_cache_state, input_ids_len=self.input_ids_length + ) + else: + logits = out[0] + + token = self.generate_token(logits=logits[:, -1, :]) + + return token, logits + + def __str__(self): + return f"{self.__class__.__name__}: {self.engine}" + + def __repr__(self): + return str(self) + + def transfer_cache_state(self, cache: DecoderKVCache): + """ + Transfers the kv cache state and the number of tokens processed + information from another NLDecoderEngine. Call this method when + you want to transfer the kv cache state from one engine to another. + + :param cache: The `DecoderKVCache` object to transfer to the engine + from + """ + self.kv_cache = copy.deepcopy(cache) + + @staticmethod + def overwrite_onnx_model_inputs( + onnx_file_path: str, + sequence_length: int, + input_ids_length: int, + batch_size: int = 1, + ) -> Tuple[str, List[int]]: + """ + Enforces the appropriate input shapes for the onnx model, as well as + checks whether kv cache is enabled or not. + + :param onnx_file_path: The path to the onnx model file that will be + overwritten with the new input shapes + :param batch_size: The batch size to use for the input + :param sequence_length: The sequence length to use for the input + :param input_ids_length: The length of input_ids + :return: The path to the onnx model file that has been overwritten + with the new input shapes, as well as the indices of the inputs + that should be cached + """ + model = onnx.load(onnx_file_path, load_external_data=False) + initializer_input_names = set(node.name for node in model.graph.initializer) + external_inputs = [ + inp for inp in model.graph.input if inp.name not in initializer_input_names + ] + for external_input in external_inputs: + # overwrite the batch size for all the inputs + external_input.type.tensor_type.shape.dim[0].dim_value = batch_size + + if external_input.name in ["input_ids", "positions"]: + external_input.type.tensor_type.shape.dim[ + 1 + ].dim_value = input_ids_length + elif external_input.name == "attention_mask": + external_input.type.tensor_type.shape.dim[1].dim_value = sequence_length + elif external_input.name.startswith(_CACHE_INPUT_NAME): + external_input.type.tensor_type.shape.dim[2].dim_value = ( + sequence_length - input_ids_length + ) + else: + raise ValueError( + f"Unexpected external input name: {external_input.name}" + ) + + _LOGGER.info( + "Overwriting in-place the input shapes " + f"of the transformer model at {onnx_file_path}" + ) + save_onnx(model, onnx_file_path) + + output_indices_to_be_cached = [ + 1 if inp.name.startswith("present") else 0 for inp in model.graph.output + ] + + return onnx_file_path, output_indices_to_be_cached + + def generate_token(self, logits: numpy.ndarray) -> numpy.ndarray: + """ + Samples a token from the logits using the sampling temperature. + + :param logits: the logits from the model with shape (vocab_size,) + :return: the sampled token + """ + if self.deterministic: + return numpy.argmax(logits) + + logits /= self.sampling_temperature + + probs = softmax(logits) + + return numpy.random.choice(len(probs), p=probs) + + def reset_kv_cache(self): + """ + Resets the kv cache state. + """ + kv_cache_state = self._initialize_kv_cache_state( + self.sequence_length - self.input_ids_length + ) + self.kv_cache.setup_session( + session_id=self._session_id, + state=kv_cache_state, + num_processed_tokens=0, + freeze_first_position=self._freeze_first_position, + ) + + def add_kv_cache_to_input(self, inp: List[numpy.ndarray]) -> List[numpy.ndarray]: + """ + Takes the input and adds the past kv cache state to it. + + :param inp: The input to the model + :return The input with the kv cache state added to it + """ + kv_cache_state = self.kv_cache.cached_inputs + if kv_cache_state is None: + self.reset_kv_cache() + kv_cache_state = self.kv_cache.cached_inputs + + kv_cache_state["input_ids"] = inp[0] + kv_cache_state["attention_mask"] = inp[1] + if len(inp) == 3: + kv_cache_state["positions"] = inp[2] + + new_inp = [kv_cache_state[name] for name in self.engine.input_names] + return new_inp + + def update_kv_cache( + self, + kv_cache_state: List[numpy.ndarray], + input_ids_len: int, + ): + """ + Updates the state of the kv cache + + :param kv_cache_state: The state of the kv cache storage + :param input_ids_len: The length of input_ids + """ + cache_onnx_names = [ + name + for name in self.engine.input_names + if name.startswith(_CACHE_INPUT_NAME) + ] + kv_cache_state = { + name: array for name, array in zip(cache_onnx_names, kv_cache_state) + } + + self.kv_cache.update_session( + state=kv_cache_state, + input_ids_len=input_ids_len, + ) + + def _initialize_kv_cache_state(self, length: int) -> Dict[str, numpy.ndarray]: + # initialize empty kv cache of size + # (batch_size, num_attention_heads, length, hidden_dims) + + cache_engine_input_index = next( + i + for i, name in enumerate(self.engine.input_names) + if _CACHE_INPUT_NAME in name + ) + batch_size, num_attention_heads, _, hidden_dims = self.engine.input_shapes[ + cache_engine_input_index + ] + + empty_kv_cache_tensor = numpy.zeros( + (batch_size, num_attention_heads, length, hidden_dims), + dtype=numpy.float32, + ) + + cache_keys = [ + output_name.replace("present", _CACHE_INPUT_NAME) + for output_name in self.engine.output_names + if output_name.startswith("present") + ] + return {key: empty_kv_cache_tensor for key in cache_keys} + + @staticmethod + def _should_freeze_first_position(tokenizer) -> bool: + # use tokenizer to find out whether we should freeze the first position + # (True if tokenizer has a prefix for a BOS token) + if tokenizer is None: + return False + if hasattr(tokenizer, "bos_token"): + return True + return False diff --git a/src/deepsparse/transformers/eval_downstream.py b/src/deepsparse/transformers/eval_downstream.py index 01d0861580..ffe83aa5d0 100644 --- a/src/deepsparse/transformers/eval_downstream.py +++ b/src/deepsparse/transformers/eval_downstream.py @@ -68,14 +68,45 @@ import numpy from tqdm.auto import tqdm -from deepsparse import Pipeline -from deepsparse.transformers.metrics import PrecisionRecallF1 +from deepsparse import DEEPSPARSE_ENGINE, ORT_ENGINE, Pipeline +from deepsparse.transformers.metrics import Perplexity, PrecisionRecallF1 from datasets import load_dataset, load_metric # isort: skip -DEEPSPARSE_ENGINE = "deepsparse" -ORT_ENGINE = "onnxruntime" + +def perplexity_eval(args, batch_size=16, dataset_name="openai_humaneval"): + if args.max_samples: + batch_size = min(batch_size, args.max_samples) + + dataset = load_dataset(dataset_name)["test"] + + text_generation = Pipeline.create( + task="text-generation", + model_path=args.model_path, + engine_type=args.engine, + num_cores=args.num_cores, + sequence_length=args.max_sequence_length, + prompt_processing_sequence_length=args.max_sequence_length, + max_generated_tokens=1, + ) + perplexity_metrics = Perplexity(pipeline=text_generation, batch_size=batch_size) + active_engines = [ + engine + for engine in [text_generation.engine, text_generation.multitoken_engine] + if engine + ] + print("Engine info: ") + [print(f"{engine}\n") for engine in active_engines] + predictions = [] + for idx, sample in _enumerate_progress(dataset, args.max_samples): + predictions.append(sample["prompt"] + sample["canonical_solution"]) + if len(predictions) == batch_size: + perplexity_metrics.add_batch(predictions) + predictions = [] + if args.max_samples and idx >= args.max_samples: + break + return perplexity_metrics def qa_eval(args, dataset_name="squad"): @@ -443,12 +474,14 @@ def _split_train_val(train_dataset, val_ratio, seed=42): "imdb": imdb_eval, "conll2003": conll2003_eval, "go_emotions": go_emotions_eval, + "openai_humaneval": perplexity_eval, } def parse_args(): parser = argparse.ArgumentParser( - description="Evaluate a BERT ONNX model on a downstream dataset" + description="Evaluate a Hugging Face Transformers " + "ONNX model on a downstream dataset" ) parser.add_argument( "model_path", @@ -461,9 +494,9 @@ def parse_args(): parser.add_argument( "-d", "--dataset", - type=str, choices=list(SUPPORTED_DATASETS.keys()), required=True, + type=str, ) parser.add_argument( "-v", diff --git a/src/deepsparse/transformers/helpers.py b/src/deepsparse/transformers/helpers.py index 83b519baa5..c951d232c8 100644 --- a/src/deepsparse/transformers/helpers.py +++ b/src/deepsparse/transformers/helpers.py @@ -93,13 +93,16 @@ def get_onnx_path_and_configs( framework_files = os.listdir(framework_dir) if _MODEL_DIR_CONFIG_NAME in framework_files: config_path = framework_dir - if _MODEL_DIR_TOKENIZER_NAME in framework_files: + if ( + _MODEL_DIR_TOKENIZER_NAME + or _MODEL_DIR_TOKENIZER_CONFIG_NAME in framework_files + ): tokenizer_path = framework_dir # prefer config and tokenizer files in same directory as model.onnx if _MODEL_DIR_CONFIG_NAME in model_files: config_path = model_path - if _MODEL_DIR_TOKENIZER_NAME in model_files: + if _MODEL_DIR_TOKENIZER_NAME or _MODEL_DIR_TOKENIZER_CONFIG_NAME in model_files: tokenizer_path = model_path elif model_path.startswith("zoo:"): diff --git a/src/deepsparse/transformers/metrics.py b/src/deepsparse/transformers/metrics.py index 407e9b9d6b..87292ecfd7 100644 --- a/src/deepsparse/transformers/metrics.py +++ b/src/deepsparse/transformers/metrics.py @@ -17,18 +17,155 @@ """ -from typing import Dict, Optional +from typing import Any, Dict, List, Optional import numpy +from tqdm import tqdm +import torch +from deepsparse import Pipeline +from deepsparse.transformers.pipelines.text_generation import TextGenerationPipeline +from deepsparse.transformers.utils.helpers import pad_to_fixed_length from sklearn.metrics import precision_recall_fscore_support __all__ = [ "PrecisionRecallF1", + "Perplexity", ] +class Perplexity: + def __init__(self, pipeline: Pipeline, batch_size: int = 16): + """ + Given the pipeline, compute the perplexity of the model + on the given text input. + + Code adapted from: + https://huggingface.co/spaces/evaluate-metric/perplexity/blob/main/perplexity.py # noqa: E501 + + :param pipeline: The pipeline to use for text generation + :param batch_size: The batch size to split the input text into + non-overlapping batches + """ + if not isinstance(pipeline, TextGenerationPipeline): + raise ValueError( + "Perplexity can only be computed for text generation pipelines" + ) + self._pipeline = pipeline + self._batch_size = batch_size + self._sequence_length = pipeline.sequence_length + self._loss_fct = torch.nn.CrossEntropyLoss(reduction="none") + + self.perplexities = [] + + def add_batch(self, predictions: List[str]): + """ + Run the model on the given input sequences and compute the perplexity. + The resulting perplexity is appended to the list of perplexities. + + :param predictions: The predictions to compute perplexity on + """ + # tokenize the input text + encodings = self._pipeline.tokenizer( + predictions, + return_attention_mask=True, + max_length=self._sequence_length, + truncation=True, + padding="max_length", + ) + + encoded_texts = encodings["input_ids"] + attention_masks = encodings["attention_mask"] + + for start_index in tqdm(range(0, len(encoded_texts), self._batch_size)): + end_index = min(start_index + self._batch_size, len(encoded_texts)) + encoded_batch = encoded_texts[start_index:end_index] + attention_mask = attention_masks[start_index:end_index] + + # Computing the ground truth labels + + # `encoded_batch` contains sequences of tokens padded + # with tokens from the left side. We need to remove + # them and zero-pad from the right side up to the length + # of the longest sequence in the batch + encoded_batch = numpy.array(encoded_batch) * numpy.array(attention_mask) + encoded_batch = [ + list(filter(lambda num: num != 0, sequence)) + for sequence in encoded_batch + ] + max_sequence_len = max([len(sequence) for sequence in encoded_batch]) + + encoded_batch = [ + pad_to_fixed_length(numpy.array(sequence), max_sequence_len) + for sequence in encoded_batch + ] + encoded_batch = numpy.stack(encoded_batch) + + # We need to apply the analogous transformation to the attention mask + attention_mask = numpy.array(attention_mask) + attention_mask = [ + list(filter(lambda num: num != 0, mask)) for mask in attention_mask + ] + attention_mask = [ + pad_to_fixed_length(numpy.array(mask), max_sequence_len) + for mask in attention_mask + ] + attention_mask = numpy.stack(attention_mask) + + labels = encoded_batch + + out = self._pipeline( + sequences=predictions, return_logits=True, fixed_sequences_length=True + ) + + logits = out.logits + + if not self._pipeline.has_cache: + # when running inference without cache, we need to apply + # analogous transformations to the logits as we did to the labels + # and attention mask + + # remove "nonsensical" logits for tokens + logits = [ + logit[-attn_mask.sum() :, :] + for (logit, attn_mask) in zip(logits, attention_mask) + ] + # pad logits to max length + logits = [ + pad_to_fixed_length(logit, max_sequence_len) for logit in logits + ] + logits = numpy.stack(logits) + + # shift logits and labels create the input and target for the loss function + shift_logits = logits[:, :-1, :] + shift_labels = labels[:, 1:] + shift_attention_mask_batch = attention_mask[:, 1:] + + # compute perplexity for this batch + perplexity_batch = torch.exp( + ( + self._loss_fct( + torch.tensor(shift_logits.transpose(0, 2, 1)), + torch.tensor(shift_labels), + ) + * torch.tensor(shift_attention_mask_batch) + ).sum(1) + / torch.tensor(shift_attention_mask_batch).sum(1) + ) + self.perplexities.extend(perplexity_batch.numpy().tolist()) + + def compute(self) -> Dict[str, Any]: + """ + :return: A dictionary containing the mean perplexity + and the list of perplexities + """ + return { + "mean_perplexity": numpy.mean(self.perplexities), + "perplexities": self.perplexities, + } + + class PrecisionRecallF1: def __init__(self, id_to_label: Optional[Dict[int, str]] = None): self._id_to_label = id_to_label diff --git a/src/deepsparse/transformers/pipelines/pipeline.py b/src/deepsparse/transformers/pipelines/pipeline.py index 38073e260f..843391768c 100644 --- a/src/deepsparse/transformers/pipelines/pipeline.py +++ b/src/deepsparse/transformers/pipelines/pipeline.py @@ -82,7 +82,9 @@ def __init__( self.config_path = None self.tokenizer_config_path = None # path to 'tokenizer.json' self.onnx_input_names = None - + self._delay_overwriting_inputs = ( + kwargs.pop("_delay_overwriting_inputs", None) or False + ) self._temp_model_directory = None super().__init__(**kwargs) @@ -99,6 +101,9 @@ def setup_onnx_file_path(self) -> str: Parses ONNX, tokenizer, and config file paths from the given `model_path`. Supports sparsezoo stubs + :param delay_overwriting_inputs: if True, do not overwrite the ONNX model + inputs to the given sequence length. Default is False + :return: file path to the processed ONNX file for the engine to compile """ onnx_path, config_path, tokenizer_path = get_onnx_path_and_configs( @@ -114,31 +119,36 @@ def setup_onnx_file_path(self) -> str: self.config_path = os.path.join(config_path, "config.json") self.tokenizer_config_path = os.path.join(tokenizer_path, "tokenizer.json") - # overwrite onnx graph to given required input shape - ( - onnx_path, - self.onnx_input_names, - self._temp_model_directory, - ) = overwrite_transformer_onnx_model_inputs( - onnx_path, max_length=self.sequence_length - ) + if not self._delay_overwriting_inputs: + # overwrite onnx graph to given required input shape + ( + onnx_path, + self.onnx_input_names, + self._temp_model_directory, + ) = overwrite_transformer_onnx_model_inputs( + onnx_path, max_length=self.sequence_length + ) return onnx_path def tokens_to_engine_input( - self, tokens: Mapping[Any, numpy.ndarray] + self, + tokens: Mapping[Any, numpy.ndarray], + onnx_input_names: Optional[List[str]] = None, ) -> List[numpy.ndarray]: """ :param tokens: outputs of the pipeline tokenizer :return: list of numpy arrays in expected order for model input """ - if not all(name in tokens for name in self.onnx_input_names): + if onnx_input_names is None: + onnx_input_names = self.onnx_input_names + if not all(name in tokens for name in onnx_input_names): raise ValueError( - f"pipeline expected arrays with names {self.onnx_input_names}, " + f"pipeline expected arrays with names {onnx_input_names}, " f"received inputs: {list(tokens.keys())}" ) - return [tokens[name] for name in self.onnx_input_names] + return [tokens[name] for name in onnx_input_names] @staticmethod def should_bucket(*args, **kwargs) -> bool: diff --git a/src/deepsparse/transformers/pipelines/text_generation.py b/src/deepsparse/transformers/pipelines/text_generation.py new file mode 100644 index 0000000000..f74696d37a --- /dev/null +++ b/src/deepsparse/transformers/pipelines/text_generation.py @@ -0,0 +1,468 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import List, Optional, Tuple, Type, Union + +import numpy +from pydantic import BaseModel, Field + +from deepsparse import Pipeline +from deepsparse.pipeline import DEEPSPARSE_ENGINE +from deepsparse.transformers.engines import NLDecoderEngine +from deepsparse.transformers.pipelines import TransformersPipeline +from deepsparse.transformers.utils.helpers import pad_to_fixed_length + + +_LOGGER = logging.getLogger(__name__) + +__all__ = ["TextGenerationPipeline"] + + +class TextGenerationInput(BaseModel): + sequences: Union[str, List[str]] = Field( + description="The input sequences to generate the text from.", + ) + return_logits: bool = Field( + default=False, + description="A flag that indicates whether to return " + "the logits for the input text sequence and the " + "generated text sequence. ", + ) + session_id: Optional[str] = Field( + default=None, + description="A user may set a string identifier " + "for the kv cache session. If None, " + "and the model is using kv cache, it " + "will be set to a random uuid.", + ) + fixed_sequences_length: bool = Field( + default=False, + description="A flag that indicates whether to modify " + "(pad or truncate) each input text sequence, so that " + "its tokenized length is equal to `sequence_length` " + "of tokens. Useful, when a batch of predictions needs " + "to have consistent length so one " + "can compute metric in a batched fashion. ", + ) + + +class TextGenerationOutput(BaseModel): + sequences: Union[str, List[str]] = Field( + description="The generated text sequences.", + ) + logits: Optional[numpy.ndarray] = Field( + default=None, + description="The logits for the generated text sequence." + "The logits have dimensions " + "[batch_size, sequence_length, vocab_size]", + ) + session_id: Optional[str] = Field( + default=None, description="A string identifier for the kv cache session." + ) + + class Config: + arbitrary_types_allowed = True + + +@Pipeline.register( + task="text_generation", + task_aliases=["codegen", "opt", "bloom"], +) +class TextGenerationPipeline(TransformersPipeline): + """ + Pipeline for text generation tasks. + + :param deterministic: if True, the pipeline will sample from + the probability distribution computed from the logits. + If False, the pipeline will get the next token by applying + an argmax function to the logits. + :param sampling_temperature: the temperature to use when sampling + from the probability distribution computed from the logits. + Higher values will result in more random samples. Should + be greater than 0.0. + :param max_generated_tokens: the maximum number of tokens to generate + given the input sequence. If None, the model will generate + tokens until the end of the sequence is reached. + Otherwise, it will generate up to the maximum number of tokens or end of + sequence is reached. + :param prompt_processing_sequence_length: For large prompts, the prompt is + processed in chunks of this length. This is to maximize the inference + speed. By default, this is set to 128. + :param force_max_tokens: if True, the pipeline will generate the maximum number + of tokens supplied even if the stop token is reached. + :param use_deepsparse_cache: if True, the pipeline will use the deepsparse kv cache + for caching the model outputs. + :param kwargs: kwargs to pass to the TransformersPipeline + """ + + def __init__( + self, + deterministic: bool = True, + sampling_temperature: float = 1.0, + max_generated_tokens: Optional[int] = 1024, + # TODO: Set this to 64 once we modify the OPT injection logic + prompt_processing_sequence_length: int = 128, + force_max_tokens: bool = False, + use_deepsparse_cache: bool = False, + **kwargs, + ): + if use_deepsparse_cache: + if kwargs["engine_type"] != DEEPSPARSE_ENGINE: + raise ValueError( + "`use_deepsparse_cache` is set to True " + "but the chosen `engine_type` " + f"is {kwargs['engine_type']}. " + f"Make sure to set `engine_type` to {DEEPSPARSE_ENGINE}" + ) + raise NotImplementedError( + "The deepsparse kv cache is not yet " + "supported for text generation pipelines" + ) + + super().__init__( + **kwargs, _delay_engine_initialize=True, _delay_overwriting_inputs=True + ) + + if self.engine_type == DEEPSPARSE_ENGINE: + _LOGGER.warning( + "The support for deepsparse engine is limited " + f"for {self.__class__.__name__}. " + "The multi-token engine will not be " + "used for prompt processing." + ) + + self.deterministic = deterministic + self.sampling_temperature = sampling_temperature + self.max_generated_tokens = max_generated_tokens + self.prompt_processing_sequence_length = prompt_processing_sequence_length + self.force_max_tokens = force_max_tokens + + # override tokenizer to pad to left + self.tokenizer.padding_side = "left" + if not self.tokenizer.pad_token: + self.tokenizer.pad_token = self.tokenizer.eos_token + + self.engine = None + + self.multitoken_engine = NLDecoderEngine( + onnx_file_path=self.onnx_file_path, + engine_type=self.engine_type, + engine_args=self.engine_args, + engine_context=self.context, + sampling_temperature=self.sampling_temperature, + deterministic=self.deterministic, + sequence_length=self.sequence_length, + input_ids_length=prompt_processing_sequence_length, + tokenizer=self.tokenizer, + use_deepsparse_cache=use_deepsparse_cache, + ) + + if self.multitoken_engine.kv_cache_enabled: + # unless kv cache is enabled, we don't + # need to initialize the single token engine + self.engine = NLDecoderEngine( + onnx_file_path=self.onnx_file_path, + engine_type=self.engine_type, + engine_args=self.engine_args, + engine_context=self.context, + sampling_temperature=self.sampling_temperature, + deterministic=self.deterministic, + sequence_length=self.sequence_length, + input_ids_length=1, + tokenizer=self.tokenizer, + use_deepsparse_cache=use_deepsparse_cache, + ) + if ( + not self.multitoken_engine.kv_cache_enabled + and self.max_generated_tokens > 1 + ): + raise ValueError( + "The model used for inference does not support kv cache. It is " + "assumed that it maps from the token sequence to predicted logits." + "Set `max_generated_tokens` to 1 to support that scenario." + ) + + @staticmethod + def route_input_to_bucket( + *args, input_schema: BaseModel, pipelines: List[Pipeline], **kwargs + ) -> Pipeline: + """ + This method is used to route the input to the correct pipeline. + + :param args: args to pass to the pipeline + :param input_schema: the input schema for the pipeline + :param pipelines: the list of pipelines to route the input to + :param kwargs: kwargs to pass to the pipeline + :return: the pipeline to route the input to + """ + raise ValueError("Bucketing is not supported for generation pipelines") + + @property + def input_schema(self) -> Type[BaseModel]: + """ + Property to return the input schema for the pipeline. + + :return: the input schema for the pipeline + """ + return TextGenerationInput + + @property + def output_schema(self) -> Type[BaseModel]: + """ + Property to return the output schema for the pipeline. + + :return: the output schema for the pipeline + """ + return TextGenerationOutput + + def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]: + """ + Convert the input schema for the pipeline to the inputs for the engine. + + :param inputs: the input schema for the pipeline + :return: the inputs for the engine + """ + + if inputs.fixed_sequences_length: + # to enforce a fixed sequence length, we need to + # truncate the input to the maximum sequence length + # or/and pad it to the maximum sequence length + truncate, padding = True, "max_length" + else: + # otherwise, we do not need to truncate the input + # and we shall can pad it to the longest sequence + # in the batch (so that the engine can process multiple inputs + # at once) + truncate, padding = False, "longest" + + input_tokens = self.tokenizer( + inputs.sequences, + return_tensors="np", + max_length=self.sequence_length, + padding=padding, + truncation=truncate, + ) + + attention_mask = input_tokens["attention_mask"] + + # TODO: Positions input is not required by BLOOM + # let's make it optional in the future + positions = attention_mask.cumsum(1) * attention_mask + positions -= 1 # assert that positions start at 0 + positions_input = dict(positions=positions) + + input_tokens = {**input_tokens, **positions_input} + onnx_input_names = self.multitoken_engine.onnx_input_names_no_cache + engine_input = self.tokens_to_engine_input(input_tokens, onnx_input_names) + + if inputs.session_id is not None: + # if session_id is provided, we need to set it in engines + self.engine.session_id = inputs.session_id + self.multitoken_engine.session_id = inputs.session_id + + postprocessing_kwargs = dict(return_logits=inputs.return_logits) + return engine_input, postprocessing_kwargs + + def process_engine_outputs( + self, engine_outputs: List[numpy.ndarray], **kwargs + ) -> TextGenerationOutput: + """ + Convert the engine outputs to the output schema for the pipeline. + + :param engine_outputs: the outputs from the engine + :return: the output schema for the pipeline + """ + generated_tokens, generated_logits = engine_outputs + sequences = self.tokenizer.batch_decode( + generated_tokens, skip_special_tokens=True + ) + logits = generated_logits if kwargs.get("return_logits") else None + + return TextGenerationOutput(sequences=sequences, logits=logits) + + def engine_forward( + self, engine_inputs: List[numpy.ndarray], **kwargs + ) -> Tuple[numpy.ndarray, numpy.ndarray]: + """ + Run the forward pass on the engine. + + :param engine_inputs: list of numpy inputs to + Pipeline engine forward pass + :return: A tuple of numpy array that contains the + sequence of generated tokens and a sequence + of logits for each generated token + """ + if not self.multitoken_engine.kv_cache_enabled: + tokens, prompt_logits = self.multitoken_engine(engine_inputs) + return numpy.array([tokens]), prompt_logits + + else: + # run the prompt through + tokens, prompt_logits = self.prompt_inference(engine_inputs) + + # create the generated output + max_tokens = ( + self.max_generated_tokens + if self.max_generated_tokens and self.max_generated_tokens > 0 + else 100 * self.sequence_length + ) # set safety for absolute max generation + + generated_tokens = [tokens[-1]] + generated_logits = prompt_logits + + while len(generated_tokens) < max_tokens: + ( + token, + logits, + ) = self.autoregressive_inference(tokens) + tokens.append(token) + generated_tokens.append(token) + generated_logits.append(logits) + + if token == self.tokenizer.eos_token_id and not self.force_max_tokens: + break + + return numpy.array([generated_tokens]), numpy.concatenate( + generated_logits, axis=1 + ) + + def prompt_inference( + self, engine_inputs: List[numpy.ndarray] + ) -> Tuple[List[int], List[numpy.ndarray]]: + """ + An inference run that processes the prompt through the + model to generate the new token and logits + + :param engine_inputs: the prompt (context) represented by a + list of numpy inputs to the engine + :return: A tuple of: + - The list of prompt tokens plus the new, generated token + - The logits generated from the prompt (with dimensions + ['batch_size', 'num_tokens', 'vocab_size']) + """ + # get tokens by attention mask + tokens = engine_inputs[0][engine_inputs[1].nonzero()].tolist() + + prompt_logits = [] + new_token = None + num_tokens_processed = 0 + + # clean the state of engines' cache + # in the future, this will be paired with the session ids + # to refrain from resetting if session id is being passed + self._reset_engines_cache() + + # TODO: Multiple passes through the multitoken + # engine once the OPT injection is fixed + if ( + len(tokens) > self.prompt_processing_sequence_length + and self.engine_type != DEEPSPARSE_ENGINE + ): + # trim the input to the prompt size + engine_inputs = [ + input[:, : self.prompt_processing_sequence_length] + for input in engine_inputs + ] + new_token, new_logits = self.multitoken_engine(engine_inputs) + num_tokens_processed = self.prompt_processing_sequence_length + prompt_logits.append(new_logits) + + if num_tokens_processed: + # transfer the cache state from the multi-token engine to the main engine + self.engine.transfer_cache_state(cache=self.multitoken_engine.kv_cache) + + # prompt size is small, run autoregressive inference to populate kv cache + run_tokens = [] if num_tokens_processed == 0 else tokens[:num_tokens_processed] + + for token in tokens[num_tokens_processed:]: + run_tokens.append(token) + new_token, new_logits = self.autoregressive_inference( + run_tokens, shift_positions_by_one=not bool(num_tokens_processed) + ) + prompt_logits.append(new_logits) + + tokens.append(new_token) + + return tokens, prompt_logits + + def autoregressive_inference( + self, + tokens: List[int], + shift_positions_by_one: bool = False, + ) -> Tuple[int, numpy.ndarray]: + """ + An inference run that processes the last token to generate + a new token and new logits. + + :param tokens: The current context (prompt + generated tokens so far) + :param shift_positions_by_one: Whether to shift the positions + by one. Used if we are processing the prompt from the scratch + (i.e. not using the multitoken engine) + :return: The new, generated token and the logits for the new token + (with dimensions ['batch_size', 'num_tokens', 'vocab_size']) + """ + new_token = tokens[-1] + # padding is added to left, so attention mask is 1s from the + # right up to the number of total tokens (prompt + generated) + attention_mask = numpy.zeros((1, self.sequence_length), dtype=numpy.int64) + num_tokens_processed = min(len(tokens), self.sequence_length) # cap by seq len + attention_mask[:, -num_tokens_processed:] = 1 + positions = numpy.array([[len(tokens)]], dtype=numpy.int64) + if shift_positions_by_one: + positions -= 1 + input_ids = numpy.array([[new_token]]) + engine_inputs = [input_ids, attention_mask, positions] + + generated_token, generated_logits = self.engine(engine_inputs) + + return generated_token, generated_logits + + @property + def has_cache(self) -> bool: + """ + Returns whether the ran model has kv cache or not + + :return: True if the model has kv cache, False otherwise + """ + return self.multitoken_engine.kv_cache_enabled + + @staticmethod + def join_engine_outputs( + batch_outputs: List[List[numpy.ndarray]], + ) -> List[numpy.ndarray]: + """ + Takes a list of outputs (batches) from the engine + and joins them into a single output. Asserts that + the dimensions of the outputs are the same, so that + they can be concatenated. + + :param batch_outputs: A list of outputs from the engine + :return: A list of joined outputs + """ + tokens, logits = zip(*batch_outputs) + tokens = numpy.concatenate(tokens, axis=0) + # find the longest sequence in the batch of logits + max_len = max([logits.shape[1] for logits in logits]) + # pad all logits to the same length + logits = [ + pad_to_fixed_length(array=single_logits, max_len=max_len, axis=1) + for single_logits in logits + ] + logits = numpy.concatenate(logits, axis=0) + return [tokens, logits] + + def _reset_engines_cache(self): + self.engine.reset_kv_cache() + self.multitoken_engine.reset_kv_cache() diff --git a/src/deepsparse/transformers/utils/__init__.py b/src/deepsparse/transformers/utils/__init__.py new file mode 100644 index 0000000000..eb7e731ef4 --- /dev/null +++ b/src/deepsparse/transformers/utils/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# flake8: noqa +from .decoder_kv_cache import * +from .helpers import * diff --git a/src/deepsparse/transformers/utils/decoder_kv_cache.py b/src/deepsparse/transformers/utils/decoder_kv_cache.py new file mode 100644 index 0000000000..a11a65734e --- /dev/null +++ b/src/deepsparse/transformers/utils/decoder_kv_cache.py @@ -0,0 +1,198 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List + +import numpy + + +__all__ = ["DecoderKVCache", "SEQUENCE_LENGTH_AXIS"] + + +SEQUENCE_LENGTH_AXIS = 2 + + +class DecoderKVCache: + def __init__(self, use_deepsparse_cache: bool = False): + """ + The goal this object is to handle the manipulation + of the key value cache. + + :param use_deepsparse_cache: If set to True, the `kv_cache` object + from the deepsparse.LIB will be loaded as an attribute. + This object is used to handle the manipulation of the + key/value buffers on the DeepSparse engine side. + """ + # assuming that kv cache arrays are of shape + # [batch_size, num_heads, sequence_length, hidden_size] + self._sequence_len_axis = SEQUENCE_LENGTH_AXIS + self._use_deepsparse_cache = use_deepsparse_cache + self._session_id = None + self._freeze_first_position = None + self._state = None + self._total_num_processed_tokens = None + self._kv_cache = None + + def setup_session( + self, + session_id: str, + state: Dict[str, Any], + num_processed_tokens: int = 0, + freeze_first_position: bool = False, + ): + """ + Setup the session - a level of abstraction that allocates + the resources to store and manipulate the kv cache. + + :param session_id: The session id to use for the current + session. Used to identify the kv cache state + :param state: The state of the cache. This is a dictionary + that maps the name of the cache array to the cache array. + The cache tensor is a numpy array of shape + [batch_size, num_heads, sequence_length - num_input_ids, hidden_size] + :param num_processed_tokens: The number of tokens processed so far. + :param freeze_first_position: If set to True, once the kv cache + gets filled, the position along the sequence length axis + that corresponds to the first token will be frozen. + This assures that, once the KV cache is full (there are no + "blank" entries), and we are removing the "oldest" entry + from the cache, we will nevertheless keep the cache entry + that corresponds to the BOS token in the sequence. + By default, is set to False. + """ + self._session_id = session_id + self._state = state + self._freeze_first_position = freeze_first_position + self._total_num_processed_tokens = num_processed_tokens + + if self._use_deepsparse_cache: + raise NotImplementedError("DeepSparse cache is not supported yet.") + + def update_session( + self, + state: Dict[str, Any], + input_ids_len: int, + ): + """ + Updating the session is identical with taking the kv cache + output of from the forward pass and restructuring it, so it + can be directly used as input for the next forward pass. + + :param state: The state of the cache. This is a dictionary + that maps the name of the cache array to the cache array. + The cache tensor is a numpy array of shape + [batch_size, num_heads, sequence_length, hidden_size] + :param input_ids_len: The number of input ids in the current + input batch: (batch_size, length). + Corresponds to `input_ids.shape[1]` + """ + self._total_num_processed_tokens += input_ids_len + total_cache_capacity = state[list(state.keys())[0]].shape[ + self._sequence_len_axis + ] + # total_capacity = num_tokens (num of non-blank tokens) + + # + num_padded_entries (num of blank tokens) + num_padded_entries = max( + 0, total_cache_capacity - self._total_num_processed_tokens + ) + # we want to remove input_ids_len entries from the cache + # because len_input_ids + inp_cache_len = out_cache_len + # TODO: Make it more general once + # multitoken regression is supported + num_entries_to_delete = 1 # input_ids_len + + if num_padded_entries: + """ + Transforms input KV cache that contains blank entries. + It removes the rightmost blank entries from the cache. + Example 1: + (entries in the cache denote the order in which they were + added to the cache, zero is to denote a blank entry) + ``` + state["state_name"]: (1, 1, 5, 1) = array([[[[0], [0], [1], [2], [3]]]]) + -> num_padded_entries = 2 + -> num_entries_to_delete = 1 + -> num_padded_entries > num_entries_to_delete + # there are more blank entries than entries to delete + results in: + state["state_name"]: (1, 1, 4, 1) = array([[[[0], [1], [2], [3]]]]) + ``` + Example 2: + ``` + state["state_name"]: (1, 1, 6, 1) = array([[[[0], [0], [0], [1], [2], [3]]]]) # noqa: E501 + -> num_padded_entries = 3 + -> num_entries_to_delete = 5 + -> num_padded_entries < num_entries_to_delete + # there are less blank entries than entries to delete + results in: + state["state_name"]: (1, 1, 3, 1) = array([[[[1], [2], [3]]]]) + ``` + """ + num_padded_entries_to_delete = min( + num_padded_entries, num_entries_to_delete + ) + idxs_to_remove = [ + num_padded_entries - i - 1 for i in range(num_padded_entries_to_delete) + ] + # if we had fewer blank entries than entries to delete, + # we updated the number of entries to delete to a non-zero value + num_entries_to_delete = max(0, num_entries_to_delete - num_padded_entries) + # update the state of the cache + state = self._delete_entries(state, idxs_to_remove) + + if num_entries_to_delete: + """ + Transforms the input KV cache that has been totally + filled with non-blank entries. + Example: + ``` + state["state_name"]: (1, 1, 5, 1) = array([[[[1], [2], [3], [4], [5]]]]) + num_entries_to_delete = 2 + if self.freeze_first_position == False: + state["state_name"]: (1, 1, 3, 1) = array([[[[3], [4], [5]]]]) + else: + + state["state_name"]: (1, 1, 3, 1) = array([[[[1], [4], [5]]]]) + ``` + """ + idxs_to_remove = [ + i + int(self._freeze_first_position) + for i in range(num_entries_to_delete) + ] + + state = self._delete_entries(state, idxs_to_remove) + + self._state = state + + def _delete_entries( + self, state: Dict[str, Any], indices: List[int] + ) -> Dict[str, Any]: + for key, value in state.items(): + state[key] = numpy.delete(value, indices, axis=self._sequence_len_axis) + state[key] = numpy.ascontiguousarray(state[key]) + return state + + @property + def session_id(self): + if self._session_id is None: + raise ValueError("Attempted to access session_id before setting up session") + return self._session_id + + @session_id.setter + def session_id(self, session_id: str): + self._session_id = session_id + + @property + def cached_inputs(self): + return self._state diff --git a/src/deepsparse/transformers/utils/helpers.py b/src/deepsparse/transformers/utils/helpers.py new file mode 100644 index 0000000000..f4e72ca665 --- /dev/null +++ b/src/deepsparse/transformers/utils/helpers.py @@ -0,0 +1,66 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import uuid + +import numpy + + +__all__ = ["softmax", "generate_session_id", "pad_to_fixed_length"] + + +def softmax(x: numpy.ndarray) -> numpy.ndarray: + """ + Compute softmax values for x. This function is + against overflow/underflow by using the + trick of shifting the input vector by subtracting + the maximum element in it from all elements + + :param x: input array + :return: softmax values + """ + z = x - max(x) + numerator = numpy.exp(z) + denominator = numpy.sum(numerator) + return numerator / denominator + + +def generate_session_id() -> str: + """ + Generate uuid for session id. This is used to + identify the kv cache session for the user + """ + session_id = str(uuid.uuid4()) + return session_id + + +def pad_to_fixed_length( + array: numpy.ndarray, max_len: int, axis: int = 0, value: int = 0 +) -> numpy.ndarray: + """ + Pads the array to a fixed length along the given axis. + The padding is done on the right side of the array. + + :param array: array to pad + :param max_len: maximum length to pad to + :param axis: axis to pad along + :param value: value to pad with + :return: padded array + """ + # per dimension padding is (before, after) + padding = [(0, 0)] * len(array.shape) + # for the specified axis, pad to the max length + # (from the right side of the array) + padding[axis] = (0, max_len - array.shape[axis]) + return numpy.pad(array, padding, mode="constant", constant_values=value) diff --git a/tests/deepsparse/transformers/utils/test_decoder_kv_cache.py b/tests/deepsparse/transformers/utils/test_decoder_kv_cache.py new file mode 100644 index 0000000000..82b4d9dd7d --- /dev/null +++ b/tests/deepsparse/transformers/utils/test_decoder_kv_cache.py @@ -0,0 +1,70 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy + +import numpy as np + +import pytest +from deepsparse.transformers.utils import DecoderKVCache + + +@pytest.mark.parametrize( + "state, input_ids_len, freeze_first_position, state_updated", + [ + ( + {"dummy_cache_name": np.array([[[[0], [0], [1], [2], [3]]]])}, + 1, + False, + {"dummy_cache_name": np.array([[[[0], [1], [2], [3]]]])}, + ), + ( + {"dummy_cache_name": np.array([[[[1], [2], [3], [4]]]])}, + 1, + False, + {"dummy_cache_name": np.array([[[[2], [3], [4]]]])}, + ), + ( + {"dummy_cache_name": np.array([[[[1], [2], [3], [4]]]])}, + 1, + True, + {"dummy_cache_name": np.array([[[[1], [3], [4]]]])}, + ), + ], +) +class TestDecoderKVCache: + @pytest.fixture + def setup( + self, + state, + input_ids_len, + freeze_first_position, + state_updated, + ): + decoder = DecoderKVCache() + state_flattened = state["dummy_cache_name"].flatten() + num_processed_tokens = state_flattened[state_flattened != 0].shape[0] + decoder.setup_session( + session_id="None", + state=state, + num_processed_tokens=num_processed_tokens, + freeze_first_position=freeze_first_position, + ) + yield decoder, state, input_ids_len, state_updated + + def test_update_session(self, setup): + decoder, state, input_ids_len, exp_state_updated = setup + decoder.update_session(copy.deepcopy(state), input_ids_len) + state_updated = decoder.cached_inputs + for key in state_updated.keys(): + assert np.array_equal(state_updated[key], exp_state_updated[key])