Skip to content

Commit

Permalink
working implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
bogunowicz@arrival.com committed Jun 19, 2023
1 parent 6c33599 commit 87b40fb
Show file tree
Hide file tree
Showing 17 changed files with 1,533 additions and 63 deletions.
8 changes: 4 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,17 +62,17 @@
_deepsparse_ent_deps = [f"deepsparse-ent~={version_nm_deps}"]

_onnxruntime_deps = ["onnxruntime>=1.0.0"]
supported_torch_version = "torch>=1.7.0,<1.14"
supported_torch_version = "torch>=1.7.0,<=2.0"
_pytorch_deps = [
supported_torch_version,
"gputils",
]
_pytorch_all_deps = _pytorch_deps + [
"torchvision>=0.3.0,<0.15",
"torchaudio<=0.13",
"torchvision>=0.3.0,<=0.15.1",
"torchaudio<=2.0.1",
]
_pytorch_vision_deps = _pytorch_deps + [
"torchvision>=0.3.0,<0.15",
"torchvision>=0.3.0,<=0.15.1",
"opencv-python<=4.6.0.66",
]
_transformers_deps = _pytorch_deps + [
Expand Down
156 changes: 156 additions & 0 deletions src/sparseml/exporters/kv_cache_injector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from copy import deepcopy
from pathlib import Path
from typing import Any, Optional, Union

import onnx

from sparseml.exporters.base_exporter import BaseExporter
from sparseml.exporters.transforms.kv_cache import (
CacheKeysAndValues,
get_kv_cache_config,
)
from sparsezoo.utils import save_onnx


_LOGGER = logging.getLogger(__name__)


class KeyValueCacheInjector(BaseExporter):
def __init__(
self,
model_path: Optional[str] = None,
inplace: bool = True,
**kwargs: Any,
):
"""
A transformation that injects Key Value cache support into the model.
This means that the
- autoregressive model that
* takes input_ids and attention_mask as INPUT
* returns logits as OUTPUT
- is transformed into a model that
* takes input_ids, attention_mask, and kv_cache as INPUT
* returns logits and updated kv_cache as OUTPUT
The goal of the KV cache injection is speed up the autoregressive
generation process (reduce the compute of key/value pairs by storing
the results of previous computations in memory).
The exporter will look for a `config.json` file in the `model_path`
directory to determine the parameters for KV cache injection.
If `model_path` is not provided, the requested parameters can be
provided in the `kwargs`.
This transformation not only solely injects the kv cache
inputs/outputs, but also adjusts the original ONNX graph to
account for the necessary changes. This involves e.g. adding
the 'position' input to the model, so that the positional
embeddings of the new model are compatible with the past kv
cache information.
Usage:
```python
onnx_model: onnx.ModelProto = ...
exporter = KeyValueCacheInjector(model_path="path/to/model")
exporter.export(onnx_model, "model.onnx")
```
Alternatively:
```python
onnx_model: onnx.ModelProto = ...
exporter = KeyValueCacheInjector(model_path="path/to/model")
exporter = KeyValueCacheInjector(num_attention_heads = 16,
hidden_size_dim = 64)
exporter.export(onnx_model, "model.onnx")
```
You can also just optimize the model directly without saving to disk:
```python
onnx_model: onnx.ModelProto = ...
exporter = KeyValueCacheInjector(model_path="path/to/model")
optimized_model = exporter.apply(onnx_model)
```
:param model_path: The path to the directory containing the model.
:param inplace: If True, the model will be modified in place.
If False, a copy of the model will be made and modified.
:param kwargs: (Optionally) the parameters for the KV cache injection
if no `model_path` is provided.
"""
self.inplace = inplace
self.config = get_kv_cache_config(model_path)

if model_path is not None:
# get the parameters from the config
self.config = get_kv_cache_config(model_path)

num_attention_heads = self.config.num_attention_heads
hidden_size_kv_cache_dim = self.config.hidden_size_kv_cache
multiply_batch_by_num_att_heads = (
self.config.multiply_batch_by_num_att_heads
)
transpose_value_input = self.config.transpose_value_input
transpose_key_input = self.config.transpose_key_input
positions_adjustment = self.config.positions_adjustment_transform

elif kwargs:
# get the parameters from the kwargs
num_attention_heads = kwargs.get("num_attention_heads")
hidden_size_kv_cache_dim = kwargs.get("hidden_size_kv_cache_dim")
multiply_batch_by_num_att_heads = kwargs.get(
"multiply_batch_by_num_att_heads", False
)
transpose_value_input = kwargs.get("transpose_value_input")
transpose_key_input = kwargs.get("transpose_key_input")
positions_adjustment = None

else:
raise ValueError(
"Either `model_path` or kwargs must be provided to "
"KeyValueCacheInjector"
)

transforms = [
CacheKeysAndValues(
num_attention_heads=num_attention_heads,
hidden_size_kv_cache=hidden_size_kv_cache_dim,
multiply_batch_by_num_att_heads=multiply_batch_by_num_att_heads,
transpose_value_input=transpose_value_input,
transpose_key_input=transpose_key_input,
)
]
if positions_adjustment is not None:
transforms += [positions_adjustment()]
super().__init__(transforms)

def pre_validate(self, model: Union[onnx.ModelProto, str, Path]) -> onnx.ModelProto:
if isinstance(model, (str, Path)):
model = onnx.load(str(model))

if not isinstance(model, onnx.ModelProto):
raise TypeError(f"Expected onnx.ModelProto, found {type(model)}")
return model if self.inplace else deepcopy(model)

def post_validate(self, model: onnx.ModelProto) -> onnx.ModelProto:
if not isinstance(model, onnx.ModelProto):
raise TypeError(f"Expected onnx.ModelProto, found {type(model)}")
return model

def export(self, pre_transforms_model: onnx.ModelProto, file_path: str):
post_transforms_model: onnx.ModelProto = self.apply(pre_transforms_model)
save_onnx(post_transforms_model, file_path)
1 change: 1 addition & 0 deletions src/sparseml/exporters/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,4 @@
from .remove_duplicate_qconv_weights import RemoveDuplicateQConvWeights
from .remove_duplicate_quantize_ops import RemoveDuplicateQuantizeOps
from .skip_input_quantize import SkipInputQuantize
from .kv_cache import *
25 changes: 25 additions & 0 deletions src/sparseml/exporters/transforms/kv_cache/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Transforms for adding KV caching mechanism into language models
"""

# flake8: noqa
# isort:skip_file

from .cache_keys_and_values import *
from .positions_adjustment_base import *
from .positions_adjustment_opt import *
from .positions_adjustment_codegen import *
from .configs import *
Loading

0 comments on commit 87b40fb

Please sign in to comment.