Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[KV Cache] Support for CodeGen #1590

Merged
merged 12 commits into from
Jun 27, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 69 additions & 69 deletions src/sparseml/exporters/kv_cache_injector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import logging
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, Optional, Union
from typing import Any, Optional, Union

import onnx

from sparseml.exporters.base_exporter import BaseExporter
from sparseml.exporters.transforms.kv_cache import (
CacheKeysAndValues,
PositionEmbeddingsAdjustment,
get_kv_cache_config,
)
from sparsezoo.utils import save_onnx


_LOGGER = logging.getLogger(__name__)

_SUPPORTED_ARCHITECTURES = ["opt"]


class KeyValueCacheInjector(BaseExporter):
def __init__(
Expand All @@ -42,31 +39,43 @@ def __init__(
):
"""
A transformation that injects Key Value cache support into the model.
This means that the autoregressive model takes as an input / returns
as an output a cache of key value pairs that are used to speed up the
autoregressive generation process (reduce the compute of key/value pairs
by storing the results of previous computations in memory).

The exporter will look for a `config.json` file in the `model_path` directory
to determine the values:
- num_attention_heads
- hidden_size_kv_cache
required to enforce static dimensions of the kv cache input/output.
If `model_path` is not provided, the two aforementioned values must
be provided in the `kwargs`.

This transformation not only injects the cache support, but also adjusts
the model to account for the cache support. This means altering the input
to the model, such as altering the positions to account for the injected
key/value pairs.
This means that the
- autoregressive model that
* takes input_ids and attention_mask as INPUT
* returns logits as OUTPUT
- is transformed into a model that
* takes input_ids, attention_mask, and kv_cache as INPUT
* returns logits and updated kv_cache as OUTPUT

The goal of the KV cache injection is speed up the autoregressive
generation process (reduce the compute of key/value pairs by storing
the results of previous computations in memory).

The exporter will look for a `config.json` file in the `model_path`
directory to determine the parameters for KV cache injection.
If `model_path` is not provided, the requested parameters can be
provided in the `kwargs`.

This transformation not only solely injects the kv cache
inputs/outputs, but also adjusts the original ONNX graph to
account for the necessary changes. This involves e.g. adding
the 'position' input to the model, so that the positional
embeddings of the new model are compatible with the past kv
cache information.

Usage:
```python
onnx_model: onnx.ModelProto = ...
exporter = KeyValueCacheInjector(model_path="path/to/model")
# alternatively
# exporter = KeyValueCacheInjector(num_attention_heads = 16,
# hidden_size_dim = 64)
exporter.export(onnx_model, "model.onnx")
```

Alternatively:
```python
onnx_model: onnx.ModelProto = ...
exporter = KeyValueCacheInjector(model_path="path/to/model")
exporter = KeyValueCacheInjector(num_attention_heads = 16,
hidden_size_dim = 64)
exporter.export(onnx_model, "model.onnx")
```

Expand All @@ -80,63 +89,54 @@ def __init__(
:param model_path: The path to the directory containing the model.
:param inplace: If True, the model will be modified in place.
If False, a copy of the model will be made and modified.
:param kwargs: (Optionally) the parameters for the KV cache injection
if no `model_path` is provided.
"""
self.inplace = inplace
if model_path:
self.config = self.get_config(model_path)
if not self.config["model_type"] in _SUPPORTED_ARCHITECTURES:
raise ValueError(
f"Model type {self.config.model_type} is currently not supported. "
f"Supported model types: {_SUPPORTED_ARCHITECTURES}."
f"Proceeding with transformation, but may require additional "
f"customization..."
)

num_attention_heads = self.config["num_attention_heads"]
hidden_size_kv_cache = self.config["hidden_size"] // num_attention_heads
self.config = get_kv_cache_config(model_path)

if model_path is not None:
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
# get the parameters from the config
self.config = get_kv_cache_config(model_path)
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved

num_attention_heads = self.config.num_attention_heads
hidden_size_kv_cache_dim = self.config.hidden_size_kv_cache
multiply_batch_by_num_att_heads = (
self.config.multiply_batch_by_num_att_heads
)
transpose_value_input = self.config.transpose_value_input
transpose_key_input = self.config.transpose_key_input
positions_adjustment = self.config.positions_adjustment_transform

elif kwargs:
# get the parameters from the kwargs
num_attention_heads = kwargs.get("num_attention_heads")
hidden_size_kv_cache = kwargs.get("hidden_size_kv_cache")
hidden_size_kv_cache_dim = kwargs.get("hidden_size_kv_cache_dim")
multiply_batch_by_num_att_heads = kwargs.get(
"multiply_batch_by_num_att_heads", False
)
transpose_value_input = kwargs.get("transpose_value_input")
transpose_key_input = kwargs.get("transpose_key_input")
positions_adjustment = None

else:
raise ValueError(
"Either `model_path` or `kwargs` must be provided to the "
"KeyValueCacheInjector."
"Either `model_path` or kwargs must be provided to "
"KeyValueCacheInjector"
)

transforms = [
CacheKeysAndValues(
num_attention_heads=num_attention_heads,
hidden_size_kv_cache=hidden_size_kv_cache,
),
PositionEmbeddingsAdjustment(),
]

super().__init__(transforms)

def get_config(self, model_path: Union[str, Path]) -> Dict[str, Any]:
"""
From the model path, get the config.json file and return it as a dict.

:param model_path: The path to the directory containing the model.
:return: The config.json file as a dict.
"""
model_path = Path(model_path) if isinstance(model_path, str) else model_path

if not model_path.is_dir():
raise ValueError(
f"`model_path` is expected to be a directory, found {model_path}"
hidden_size_kv_cache=hidden_size_kv_cache_dim,
multiply_batch_by_num_att_heads=multiply_batch_by_num_att_heads,
transpose_value_input=transpose_value_input,
transpose_key_input=transpose_key_input,
)
config_file = [
file for file in model_path.iterdir() if file.name == "config.json"
]
config_file = config_file[0]

_LOGGER.info(f"Found config file {config_file}")

with open(config_file) as f:
config = json.load(f)

return config
if positions_adjustment is not None:
transforms += [positions_adjustment()]
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
super().__init__(transforms)

def pre_validate(self, model: Union[onnx.ModelProto, str, Path]) -> onnx.ModelProto:
if isinstance(model, (str, Path)):
Expand Down
5 changes: 4 additions & 1 deletion src/sparseml/exporters/transforms/kv_cache/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,7 @@
# isort:skip_file

from .cache_keys_and_values import *
from .position_embeddings_adjustment import *
from .positions_adjustment_base import *
from .positions_adjustment_opt import *
from .positions_adjustment_codegen import *
from .configs import *
Loading
Loading