Skip to content

Commit

Permalink
[WiP] [KV Cache Interface] Text Generation & Decoder Engine Implement…
Browse files Browse the repository at this point in the history
…ation (#1089)

* initial commit

* Update src/deepsparse/license.py

* limit to 150mb

* ready to review

* initial commit

* [Codegen][ORT][Static Seq Length] TextGenerationPipeline (#946)

* initial commit

* coreys simplifications

* finishing the second model static

* ready, time for beautification

* ready for review

* moved the code to examples

* fix eos logic

* add argument num_tokens_to_generate

* [CodeGen][Documentation] (#956)

* initial commit

* coreys simplifications

* finishing the second model static

* ready, time for beautification

* ready for review

* moved the code to examples

* fix eos logic

* add argument num_tokens_to_generate

* initial commit

* change order

* Update examples/codegen/README.md

Co-authored-by: corey-nm <109536191+corey-nm@users.noreply.github.com>

---------

Co-authored-by: corey-nm <109536191+corey-nm@users.noreply.github.com>

* reimplementation for generative pipelines

* restore text generation from examples

* [CodeGen] ONNX model loading to support >2Gb models / two engines (#991)

* refactor sucessfull

* Pipeline fully refactored, time to test engine support. Note: Sliding window not yet implemented!

* First iteration with Sage

* Apply suggestions from code review

* ORT agrees with the Engine. But they both give not entirely correct result. Hey, this is good news still

* dynamic ORT vs static DS

* pipeline handles OPT multitoken pass

* fixes to get static pipeline a little further along

* adjust shapes and slicing to enable static autoregressive pass - ISSUE: tokens past the base seq len are repeated

* migrate from cache_length to positions input

* got if working for multitoken + single token scenario

* cleanup the pipeline

* further cleanup post merge

* Pipeline working for single-token inference only

* do not load the onnx model with external files twice

* pipeline never redundantly saves the external data + more robust tokenizer

* Stop saving tmp files, otherwise the engine looks for external files in the wrong place

* Left pad support

* cleanup

* cleanup2

* Add in pipeline timing

* add in force tokens logic

* remove input validation for text generation pipelines

* remove multitoken support for now

* remove kv cache engine and other fixes

* nest input shape override

* comment out input shape override

* add non batch override for ORT

* clean up generation pipeline

* initial commit

* Update src/deepsparse/license.py

* limit to 150mb

* ready to review

* fix the erronous Makefile

* perhaps fixed GHA

* take into consideration that GHA creates four files

* initial commit

* tested with actual model

* remove val_inp argument

* Update README.md

* Apply suggestions from code review

* Update README.md

* initial implementation

* initial implementation

* Revert "initial implementation"

This reverts commit 765a5f7.

* rebase

* add tests

* strip down complexity out of text generation pipeline

* initial implementation

* In a good state for the review on 22.06

* remove files to make review easier

* Revert "remove files to make review easier"

This reverts commit ea82e99.

* Merge DecoderKVCache with KVCacheORT (KVCacheORT will not exist, it is just an abstraction)

* rebase

* add tests

* Delete decoder_kv_cache.py

* Delete test_decoder_kv_cache.py

* DecoderKVCache that manipulates cache state and additionally passes info to the engine via KVCache object

* fix formatting of the transformers/utils/__init__.py

* improvements after the sync with Mark

* All changes applied, time for testing

* Scaffolding to also run multitoken

* add delay_overwriting_inputs

* multitoken is working (although in limited capacity)

* fix no kv cache inference

* Do not create engine if not needed

* remove the prefill option

* fix docstring

* remove prefill

* fix the computation of total cache capacity

* merge

* addressed PR comments

* quality

---------

Co-authored-by: corey-nm <109536191+corey-nm@users.noreply.github.com>
Co-authored-by: Mark Kurtz <mark.kurtz@neuralmagic.com>
Co-authored-by: Benjamin <ben@neuralmagic.com>
  • Loading branch information
4 people committed Jun 28, 2023
1 parent 0d6a423 commit 0809aea
Show file tree
Hide file tree
Showing 12 changed files with 914 additions and 88 deletions.
62 changes: 13 additions & 49 deletions src/deepsparse/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from deepsparse.benchmark import BenchmarkResults
from deepsparse.utils import (
generate_random_inputs,
get_output_names,
model_to_path,
override_onnx_input_shapes,
)
Expand All @@ -54,7 +53,6 @@
"Scheduler",
"Context",
"MultiModelEngine",
"KVCacheEngine",
"BaseEngine",
]

Expand Down Expand Up @@ -214,6 +212,7 @@ def construct(
num_streams: int = None,
scheduler: Scheduler = None,
input_shapes: List[List[int]] = None,
cache_input_bools: Optional[List[bool]] = None,
):
_analytics.send_event("python__engine__init")
self._model_path = model_to_path(model)
Expand All @@ -224,6 +223,7 @@ def construct(
self._input_shapes = input_shapes
self._cpu_avx_type = AVX_TYPE
self._cpu_vnni = VNNI
self._cache_input_bools = cache_input_bools

def construct_with_context(
self,
Expand Down Expand Up @@ -276,9 +276,17 @@ def __init__(
num_streams: int = None,
scheduler: Scheduler = None,
input_shapes: List[List[int]] = None,
cache_input_bools: Optional[List[bool]] = None,
):
BaseEngine.construct(
self, model, batch_size, num_cores, num_streams, scheduler, input_shapes
self,
model,
batch_size,
num_cores,
num_streams,
scheduler,
input_shapes,
cache_input_bools,
)

if self._input_shapes:
Expand All @@ -292,6 +300,7 @@ def __init__(
self._num_streams,
self._scheduler.value,
None,
self._cache_input_bools,
)
else:
self._eng_net = LIB.deepsparse_engine(
Expand All @@ -301,6 +310,7 @@ def __init__(
self._num_streams,
self._scheduler.value,
None,
self._cache_input_bools,
)

def __call__(
Expand Down Expand Up @@ -845,52 +855,6 @@ def __init__(
)


class KVCacheEngine(Engine):
"""
Engine that can do kv caching.
"""

def __init__(
self,
model: Union[str, "Model", "File"],
batch_size: int = 1,
num_cores: int = None,
num_streams: int = None,
scheduler: Scheduler = None,
input_shapes: List[List[int]] = None,
kv_cache_bools: List[bool] = None,
prev_cache_length: int = 0,
):
BaseEngine.construct(
self, model, batch_size, num_cores, num_streams, scheduler, input_shapes
)

if kv_cache_bools is None:
# If no list was provided, then we assume all outputs except for the first are KV caches
# Note: In the future we can look at the names of outputs to be more sure
#
# Create a boolean list of every output of the model
output_names = get_output_names(self._model_path)
kv_cache_bools = [True for i in range(len(output_names))]
# Assume first input is logits and logits ought not to be cached
kv_cache_bools[0] = False

num_streams = _validate_num_streams(num_streams, self._num_cores)
if self._input_shapes:
raise NotImplementedError("Don't do this yet :)")
else:
self._eng_net = LIB.deepsparse_engine(
self._model_path,
self._batch_size,
self._num_cores,
num_streams,
self._scheduler.value,
None,
kv_cache_bools,
prev_cache_length,
)


def compile_model(
model: Union[str, "Model", "File"],
batch_size: int = 1,
Expand Down
57 changes: 36 additions & 21 deletions src/deepsparse/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
"yolo_pipeline",
"Bucketable",
"BucketingPipeline",
"create_engine",
]

DEEPSPARSE_ENGINE = "deepsparse"
Expand Down Expand Up @@ -157,6 +158,7 @@ def __init__(
logger: Optional[Union[BaseLogger, str]] = None,
benchmark: bool = False,
_delay_engine_initialize: bool = False, # internal use only
_delay_overwriting_inputs: bool = False, # internal use only
):
self._benchmark = benchmark
self._model_path_orig = model_path
Expand Down Expand Up @@ -200,7 +202,7 @@ def __init__(
if engine_type.lower() == DEEPSPARSE_ENGINE:
self._engine_args["scheduler"] = scheduler

self.onnx_file_path = self.setup_onnx_file_path()
self.onnx_file_path = self.setup_onnx_file_path(_delay_overwriting_inputs)

if _delay_engine_initialize:
self.engine = None
Expand Down Expand Up @@ -810,26 +812,10 @@ def log_inference_times(self, timer: StagedTimer):
category=MetricCategories.SYSTEM,
)

def _initialize_engine(self) -> Union[Engine, ORTEngine]:
engine_type = self.engine_type.lower()

if engine_type == DEEPSPARSE_ENGINE:
if self.context is not None and isinstance(self.context, Context):
self._engine_args.pop("num_cores", None)
self._engine_args.pop("scheduler", None)
self._engine_args["context"] = self.context
return MultiModelEngine(
model=self.onnx_file_path,
**self._engine_args,
)
return Engine(self.onnx_file_path, **self._engine_args)
elif engine_type == ORT_ENGINE:
return ORTEngine(self.onnx_file_path, **self._engine_args)
else:
raise ValueError(
f"Unknown engine_type {self.engine_type}. Supported values include: "
f"{SUPPORTED_PIPELINE_ENGINES}"
)
def _initialize_engine(self) -> Union[Engine, MultiModelEngine, ORTEngine]:
return create_engine(
self.onnx_file_path, self.engine_type, self._engine_args, self.context
)

def _identifier(self):
# get pipeline identifier; used in the context of logging
Expand Down Expand Up @@ -1007,6 +993,35 @@ def route_input_to_bucket(
pass


def create_engine(
onnx_file_path: str,
engine_type: str,
engine_args: Dict,
context: Optional[Context] = None,
) -> Union[Engine, MultiModelEngine, ORTEngine]:
engine_type = engine_type.lower()

if engine_type == DEEPSPARSE_ENGINE:
if context is not None and isinstance(context, Context):
engine_args.pop("num_cores", None)
engine_args.pop("scheduler", None)
engine_args["context"] = context
return MultiModelEngine(
model=onnx_file_path,
**engine_args,
)
return Engine(onnx_file_path, **engine_args)

if engine_type == ORT_ENGINE:
engine_args.pop("cache_input_bools", None)
return ORTEngine(onnx_file_path, **engine_args)

raise ValueError(
f"Unknown engine_type {engine_type}. Supported values include: "
f"{SUPPORTED_PIPELINE_ENGINES}"
)


def _initialize_executor_and_workers(
batch_size: Optional[int],
workers_or_executor: Optional[Union[int, ThreadPoolExecutor]],
Expand Down
23 changes: 23 additions & 0 deletions src/deepsparse/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ class SupportedTasks:
),
)

text_generation = namedtuple("text_generation", ["opt", "codegen", "bloom"])(
codegen=AliasedTask("codegen", []),
opt=AliasedTask("opt", []),
bloom=AliasedTask("bloom", []),
)

image_classification = namedtuple("image_classification", ["image_classification"])(
image_classification=AliasedTask(
"image_classification",
Expand Down Expand Up @@ -150,6 +156,9 @@ def check_register_task(
# custom task, register the CustomPipeline
import deepsparse.pipelines.custom_pipeline # noqa: F401

elif cls.is_text_generation(task):
import deepsparse.transformers.pipelines.text_generation # noqa: F401

elif cls.is_nlp(task):
# trigger transformers pipelines to register with Pipeline.register
import deepsparse.transformers.pipelines # noqa: F401
Expand Down Expand Up @@ -193,6 +202,20 @@ def check_register_task(
f"{list(all_tasks)}"
)

@classmethod
def is_text_generation(cls, task: str) -> bool:
"""
:param task: the name of the task to check whether it is a text generation task
such as codegen
:return: True if it is a text generation task, False otherwise
"""
return any(
[
text_generation_task.matches(task)
for text_generation_task in cls.text_generation
]
)

@classmethod
def is_nlp(cls, task: str) -> bool:
"""
Expand Down
50 changes: 47 additions & 3 deletions src/deepsparse/transformers/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ methods such as [pruning](https://neuralmagic.com/blog/pruning-overview/) and [q
These techniques result in significantly more performant and smaller models with limited to no effect on the baseline metrics.

This integration currently supports several fundamental NLP tasks:
- **Text Generation** - given the input prompt, generate an output text sequence (e.g. to fill in incomplete text or paraphrase part of the prompt)
- **Question Answering** - posing questions about a document
- **Sentiment Analysis** - assigning a sentiment to a piece of text
- **Text Classification** - assigning a label or class to a piece of text (e.g duplicate question pairing)
Expand All @@ -30,10 +31,12 @@ compatible with our [hardware requirements](https://docs.neuralmagic.com/deepspa
By default, to deploy the transformer using DeepSparse Engine it is required to supply the model in the ONNX format along with the HuggingFace supporting files.
This grants the engine the flexibility to serve any model in a framework-agnostic environment.

The DeepSparse pipelines require the following files within a folder on the local server to properly load a Transformers model:
In general, the DeepSparse pipelines require the following files within a folder on the local server to properly load a Transformers model:
- `model.onnx`: The exported Transformers model in the [ONNX format](https://github.com/onnx/onnx).
- `tokenizer.json`: The [HuggingFace compatible tokenizer configuration](https://huggingface.co/docs/transformers/fast_tokenizers) used with the model.
- `model_kvcache.onnx` (optional): the ONNX model with the KV Cache support (akin to the Transformers model with `use_cache = True`. Specific for the `text-generation` integration.
- `config.json`: The [HuggingFace compatible configuration file](https://huggingface.co/docs/transformers/main_classes/configuration) used with the model.
- `tokenizer_config.json`: The [HuggingFace compatible tokenizer configuration](https://huggingface.co/docs/transformers/fast_tokenizers) used with the model.
- `tokenizer.json`, `special_tokens_map.json`, `vocab.json`, `merges.txt` (optional): Other files that may be required by a tokenizer

Below we describe two possibilities to obtain the required structure.

Expand All @@ -48,7 +51,7 @@ sparseml.transformers.export_onnx --task question-answering --model_path model_p
```

This creates `model.onnx` file, in the directory of your `model_path`(e.g. `/trained_model/model.onnx`).
The `tokenizer.json` and `config.json` are stored under the `model_path` folder as well, so a DeepSparse pipeline ca be directly instantiated by using that folder after export (e.g. `/trained_model/`).
Any additional, required files, such as e.g.`tokenizer.json` or `config.json`, are stored under the `model_path` folder as well, so a DeepSparse pipeline ca be directly instantiated by using that folder after export (e.g. `/trained_model/`).

#### SparseZoo Stub
Alternatively, you can skip the process of the ONNX model export by using Neural Magic's [SparseZoo](https://sparsezoo.neuralmagic.com/). The SparseZoo contains pre-sparsified models and SparseZoo stubs enable you to reference any model on the SparseZoo in a convenient and predictable way.
Expand Down Expand Up @@ -137,6 +140,47 @@ response.text

>> '{"score":0.9534820914268494,"start":8,"end":14,"answer":"batman"}'
```
### Text Generation
The text generation task generates a sequence of words given the prompt. Popular text generation LLMs (Large Language Models) are used
for the chats (the instruction models), code generation, text summarization, or filling out the missing text.
are used for chats or following instructions are also covered in this task. The following example uses a sparsified text classification
OPT model to complete the prompt

[List of available SparseZoo Text Generation Models](
https://sparsezoo.neuralmagic.com/?useCase=text_generation)

#### Python Pipeline
```python
from deepsparse import Pipeline

opt_pipeline = Pipeline.create(task="opt")

inference = opt_pipeline("Who is the president of the United States?")

>> 'The president of the United States is the head of the executive branch of government...'
```

#### HTTP Server
Spinning up:
```bash
deepsparse.server \
task text-generation \
--model_path # TODO: Pending until text generation models get uploaded to SparseZoo
```

Making a request:
```python
import requests

url = "http://localhost:5543/predict" # Server's port default to 5543

obj = {"sequence": "Who is the president of the United States?"}

response = requests.post(url, json=obj)
response.text

>> 'The president of the United States is the head of the executive branch of government...'
```

### Sentiment Analysis
The sentiment analysis task takes in a sentence and classifies its sentiment. The following example
Expand Down
15 changes: 15 additions & 0 deletions src/deepsparse/transformers/engines/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa
from .nl_decoder_engine import *
Loading

0 comments on commit 0809aea

Please sign in to comment.