Skip to content

Commit

Permalink
[KV Cache Injection] Causal Mask implementation for OPT and CodeGen (#…
Browse files Browse the repository at this point in the history
…1677)

* initial commit

* [KV Cache Injection] Causal Mask for CodeGen (#1676)

* initial implementation; testing now

* fix a small blunder

* cleanup

---------

Co-authored-by: bogunowicz@arrival.com <bogunowicz@arrival.com>

* [KV Cache Injection] Causal Mask for OPT (#1688)

* initial implementation; testing now

* fix a small blunder

* cleanup

* initial implementation

* on to testing with deepsparse

---------

Co-authored-by: bogunowicz@arrival.com <bogunowicz@arrival.com>

* replace boolean causal mask for int64 causal mask

* better logging info

* allow transformations to be also a list

---------

Co-authored-by: bogunowicz@arrival.com <bogunowicz@arrival.com>
  • Loading branch information
dbogunowicz and bogunowicz@arrival.com committed Jul 27, 2023
1 parent ebc4ac6 commit b1d5ea2
Show file tree
Hide file tree
Showing 9 changed files with 550 additions and 294 deletions.
14 changes: 7 additions & 7 deletions src/sparseml/exporters/kv_cache_injector.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,8 @@ def __init__(
This transformation not only solely injects the kv cache
inputs/outputs, but also adjusts the original ONNX graph to
account for the necessary changes. This involves e.g. adding
the 'position' input to the model, so that the positional
embeddings of the new model are compatible with the past kv
cache information.
account for the necessary changes. This is done by the
optional `additional_transforms` variable.
Usage:
```python
Expand Down Expand Up @@ -133,7 +131,7 @@ def export(self, pre_transforms_model: onnx.ModelProto, file_path: str):

@staticmethod
def _get_transforms_from_config(config: KeyValueCacheConfig) -> List[OnnxTransform]:
positions_adjustment = config.positions_adjustment_transform
additional_transforms = config.additional_transforms

transforms = [
CacheKeysAndValues(
Expand All @@ -144,8 +142,10 @@ def _get_transforms_from_config(config: KeyValueCacheConfig) -> List[OnnxTransfo
transpose_key_input=config.transpose_key_input,
)
]
if positions_adjustment is not None:
transforms += [positions_adjustment()]
if additional_transforms is not None:
if not isinstance(additional_transforms, list):
additional_transforms = [additional_transforms]
transforms += [transform() for transform in additional_transforms]

return transforms

Expand Down
6 changes: 3 additions & 3 deletions src/sparseml/exporters/transforms/kv_cache/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# isort:skip_file

from .cache_keys_and_values import *
from .positions_adjustment_base import *
from .positions_adjustment_opt import *
from .positions_adjustment_codegen import *
from .transforms_base import *
from .transforms_opt import *
from .transforms_codegen import *
from .configs import *
28 changes: 14 additions & 14 deletions src/sparseml/exporters/transforms/kv_cache/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,16 @@
import json
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Type, Union

from pydantic import BaseModel, Field

from sparseml.exporters.transforms.kv_cache.positions_adjustment_codegen import (
PositionsAdjustmentCodeGen,
from sparseml.exporters.transforms import OnnxTransform
from sparseml.exporters.transforms.kv_cache.transforms_codegen import (
AdditionalTransformsCodeGen,
)
from sparseml.exporters.transforms.kv_cache.positions_adjustment_opt import (
PositionsAdjustmentOPT,
from sparseml.exporters.transforms.kv_cache.transforms_opt import (
AdditionalTransformsOPT,
)


Expand All @@ -37,12 +38,11 @@ class KeyValueCacheConfig(BaseModel):
description="The name of the model type. This should correspond to "
"the `model_type` field in the transformer's `config.json` file."
)
positions_adjustment_transform: Any = Field(
description="The class to use to transform the positional embeddings. "
"This should be a subclass of `PositionsAdjustmentBase`. Note: In the "
"future, when we encounter models that are more complex than just "
"editing the positions in the model, we can make this transformation more "
"general."
additional_transforms: Union[
List[Type[OnnxTransform]], Type[OnnxTransform], None
] = Field(
description="A transform class (or list thereof) to use for additional "
"transforms to the model required for finalizing the kv cache injection."
)
key_num_attention_heads: str = Field(
description="The key to use to get the number of attention heads from the "
Expand Down Expand Up @@ -84,7 +84,7 @@ class Config:

OPT_CONFIG = KeyValueCacheConfig(
model_name="opt",
positions_adjustment_transform=PositionsAdjustmentOPT,
additional_transforms=AdditionalTransformsOPT,
key_num_attention_heads="num_attention_heads",
key_num_embedding_hidden_size="hidden_size",
transpose_value_input=None,
Expand All @@ -94,7 +94,7 @@ class Config:

CODEGEN_CONFIG = KeyValueCacheConfig(
model_name="codegen",
positions_adjustment_transform=PositionsAdjustmentCodeGen,
additional_transforms=AdditionalTransformsCodeGen,
key_num_attention_heads="n_head",
key_num_embedding_hidden_size="n_embd",
transpose_value_input=(0, 2, 1, 3),
Expand All @@ -104,7 +104,7 @@ class Config:

BLOOM_CONFIG = KeyValueCacheConfig(
model_name="bloom",
positional_embedding_transform=None,
additional_transforms=None,
key_num_attention_heads="num_attention_heads",
key_num_embedding_hidden_size="n_embed",
transpose_value_input=None,
Expand Down

This file was deleted.

This file was deleted.

133 changes: 0 additions & 133 deletions src/sparseml/exporters/transforms/kv_cache/positions_adjustment_opt.py

This file was deleted.

Loading

0 comments on commit b1d5ea2

Please sign in to comment.