Skip to content

Commit

Permalink
complete implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
bfineran authored and dbogunowicz committed Jul 12, 2023
1 parent 6e761ed commit 20d1944
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 160 deletions.
113 changes: 24 additions & 89 deletions src/sparseml/exporters/kv_cache_injector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,58 +12,39 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from typing import Union

import onnx

from sparseml.exporters.base_exporter import BaseExporter
from sparseml.exporters.transforms import OnnxTransform
from sparseml.exporters.transforms.kv_cache import (
CacheKeysAndValues,
KeyValueCacheConfig,
get_kv_cache_config,
)
from sparsezoo.utils import save_onnx


_LOGGER = logging.getLogger(__name__)


class KeyValueCacheInjector(BaseExporter):
def __init__(
self,
model_path: Optional[str] = None,
model_path: str,
inplace: bool = True,
**kwargs: Any,
):
"""
A transformation that injects Key Value cache support into the model.
This means that the
- autoregressive model that
* takes input_ids and attention_mask as INPUT
* returns logits as OUTPUT
- is transformed into a model that
* takes input_ids, attention_mask, and kv_cache as INPUT
* returns logits and updated kv_cache as OUTPUT
This means that the autoregressive model takes as an input / returns
as an output a cache of key value pairs that are used to speed up the
autoregressive generation process (reduce the compute of key/value pairs
by storing the results of previous computations in memory).
The goal of the KV cache injection is speed up the autoregressive
generation process (reduce the compute of key/value pairs by storing
the results of previous computations in memory).
The exporter will look for a `config.json` file in the `model_path` directory
to determine the static dimensions of the kv cache input/output.
The exporter will look for a `config.json` file in the `model_path`
directory to determine the parameters for KV cache injection.
If `model_path` is not provided, the requested parameters can be
provided in the `kwargs`.
This transformation not only solely injects the kv cache
inputs/outputs, but also adjusts the original ONNX graph to
account for the necessary changes. This involves e.g. adding
the 'position' input to the model, so that the positional
embeddings of the new model are compatible with the past kv
cache information.
This transformation not only injects the cache support, but also adjusts
the model to account for the cache support. This means altering the input
to the model, such as adding "position" input to the model.
Usage:
```python
Expand All @@ -72,15 +53,6 @@ def __init__(
exporter.export(onnx_model, "model.onnx")
```
Alternatively:
```python
onnx_model: onnx.ModelProto = ...
exporter = KeyValueCacheInjector(model_path="path/to/model")
exporter = KeyValueCacheInjector(num_attention_heads = 16,
hidden_size_dim = 64)
exporter.export(onnx_model, "model.onnx")
```
You can also just optimize the model directly without saving to disk:
```python
onnx_model: onnx.ModelProto = ...
Expand All @@ -91,25 +63,21 @@ def __init__(
:param model_path: The path to the directory containing the model.
:param inplace: If True, the model will be modified in place.
If False, a copy of the model will be made and modified.
:param kwargs: (Optionally) the parameters for the KV cache injection
if no `model_path` is provided.
"""
self.inplace = inplace

config = get_kv_cache_config(model_path)

if config is not None:
transforms = self._get_transforms_from_config(config)

elif kwargs:
transforms = self._get_transforms_from_kwargs(kwargs)

else:
raise ValueError(
"Either `model_path` or kwargs must be provided to "
"KeyValueCacheInjector"
)

self.config = get_kv_cache_config(model_path)
transforms = [
CacheKeysAndValues(
num_attention_heads=self.config.num_attention_heads,
hidden_size_kv_cache=self.config.hidden_size_kv_cache,
internally_multiply_batch_by_num_attention_heads=self.config.internally_multiply_batch_by_num_attention_heads, # noqa: E501
transpose_kv_value_input=self.config.transpose_kv_value_input,
transpose_kv_key_input=self.config.transpose_kv_key_input,
),
]
if self.config.positional_embedding_transform:
PositionEmbeddingAdjustment = self.config.positional_embedding_transform
transforms.append(PositionEmbeddingAdjustment())
super().__init__(transforms)

def pre_validate(self, model: Union[onnx.ModelProto, str, Path]) -> onnx.ModelProto:
Expand All @@ -128,36 +96,3 @@ def post_validate(self, model: onnx.ModelProto) -> onnx.ModelProto:
def export(self, pre_transforms_model: onnx.ModelProto, file_path: str):
post_transforms_model: onnx.ModelProto = self.apply(pre_transforms_model)
save_onnx(post_transforms_model, file_path)

@staticmethod
def _get_transforms_from_config(config: KeyValueCacheConfig) -> List[OnnxTransform]:
positions_adjustment = config.positions_adjustment_transform

transforms = [
CacheKeysAndValues(
num_attention_heads=config.num_attention_heads,
hidden_size_kv_cache=config.hidden_size_kv_cache,
multiply_batch_by_num_att_heads=config.multiply_batch_by_num_att_heads,
transpose_value_input=config.transpose_value_input,
transpose_key_input=config.transpose_key_input,
)
]
if positions_adjustment is not None:
transforms += [positions_adjustment()]

return transforms

@staticmethod
def _get_transforms_from_kwargs(kwargs: Dict[str, Any]) -> List[OnnxTransform]:
transforms = [
CacheKeysAndValues(
num_attention_heads=kwargs.get("num_attention_heads"),
hidden_size_kv_cache=kwargs.get("hidden_size_kv_cache"),
multiply_batch_by_num_att_heads=kwargs.get(
"multiply_batch_by_num_att_heads", False
),
transpose_value_input=kwargs.get("transpose_value_input", None),
transpose_key_input=kwargs.get("transpose_key_input", None),
)
]
return transforms
7 changes: 3 additions & 4 deletions src/sparseml/exporters/transforms/kv_cache/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
# isort:skip_file

from .cache_keys_and_values import *
from .positions_adjustment_base import *
from .positions_adjustment_opt import *
from .positions_adjustment_codegen import *
from .configs import *
from .position_embeddings_adjustment_base import *
from .position_embeddings_adjustment_opt import *
from .position_embeddings_adjustment_codegen import *
Loading

0 comments on commit 20d1944

Please sign in to comment.