Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[KV Cache Injection] Causal Mask implementation for OPT and CodeGen #1677

Merged
merged 12 commits into from
Jul 27, 2023
Merged
12 changes: 5 additions & 7 deletions src/sparseml/exporters/kv_cache_injector.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,8 @@ def __init__(

This transformation not only solely injects the kv cache
inputs/outputs, but also adjusts the original ONNX graph to
account for the necessary changes. This involves e.g. adding
the 'position' input to the model, so that the positional
embeddings of the new model are compatible with the past kv
cache information.
account for the necessary changes. This is done by the
optional `additional_transforms` variable.

Usage:
```python
Expand Down Expand Up @@ -133,7 +131,7 @@ def export(self, pre_transforms_model: onnx.ModelProto, file_path: str):

@staticmethod
def _get_transforms_from_config(config: KeyValueCacheConfig) -> List[OnnxTransform]:
positions_adjustment = config.positions_adjustment_transform
additional_transforms = config.additional_transforms

transforms = [
CacheKeysAndValues(
Expand All @@ -144,8 +142,8 @@ def _get_transforms_from_config(config: KeyValueCacheConfig) -> List[OnnxTransfo
transpose_key_input=config.transpose_key_input,
)
]
if positions_adjustment is not None:
transforms += [positions_adjustment()]
if additional_transforms is not None:
transforms += [additional_transforms()]

return transforms

Expand Down
6 changes: 3 additions & 3 deletions src/sparseml/exporters/transforms/kv_cache/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# isort:skip_file

from .cache_keys_and_values import *
from .positions_adjustment_base import *
from .positions_adjustment_opt import *
from .positions_adjustment_codegen import *
from .transforms_base import *
from .transforms_opt import *
from .transforms_codegen import *
from .configs import *
23 changes: 10 additions & 13 deletions src/sparseml/exporters/transforms/kv_cache/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@

from pydantic import BaseModel, Field

from sparseml.exporters.transforms.kv_cache.positions_adjustment_codegen import (
PositionsAdjustmentCodeGen,
from sparseml.exporters.transforms.kv_cache.transforms_codegen import (
AdditionalTransformsCodeGen,
)
from sparseml.exporters.transforms.kv_cache.positions_adjustment_opt import (
PositionsAdjustmentOPT,
from sparseml.exporters.transforms.kv_cache.transforms_opt import (
AdditionalTransformsOPT,
)


Expand All @@ -37,12 +37,9 @@ class KeyValueCacheConfig(BaseModel):
description="The name of the model type. This should correspond to "
"the `model_type` field in the transformer's `config.json` file."
)
positions_adjustment_transform: Any = Field(
description="The class to use to transform the positional embeddings. "
"This should be a subclass of `PositionsAdjustmentBase`. Note: In the "
"future, when we encounter models that are more complex than just "
"editing the positions in the model, we can make this transformation more "
"general."
additional_transforms: Any = Field(
description="A transform class to use for additional transforms "
"to the model required for finalizing the kv cache injection."
)
key_num_attention_heads: str = Field(
description="The key to use to get the number of attention heads from the "
Expand Down Expand Up @@ -84,7 +81,7 @@ class Config:

OPT_CONFIG = KeyValueCacheConfig(
model_name="opt",
positions_adjustment_transform=PositionsAdjustmentOPT,
additional_transforms=AdditionalTransformsOPT,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be great to support a list here eventually instead in case we don't want to squash many steps into a single transform or enable more code sharing between model specific additional transforms outside of inheritance

key_num_attention_heads="num_attention_heads",
key_num_embedding_hidden_size="hidden_size",
transpose_value_input=None,
Expand All @@ -94,7 +91,7 @@ class Config:

CODEGEN_CONFIG = KeyValueCacheConfig(
model_name="codegen",
positions_adjustment_transform=PositionsAdjustmentCodeGen,
additional_transforms=AdditionalTransformsCodeGen,
key_num_attention_heads="n_head",
key_num_embedding_hidden_size="n_embd",
transpose_value_input=(0, 2, 1, 3),
Expand All @@ -104,7 +101,7 @@ class Config:

BLOOM_CONFIG = KeyValueCacheConfig(
model_name="bloom",
positional_embedding_transform=None,
additional_transforms=None,
key_num_attention_heads="num_attention_heads",
key_num_embedding_hidden_size="n_embed",
transpose_value_input=None,
Expand Down

This file was deleted.

This file was deleted.

Loading
Loading