Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[KV Cache Injection] Causal Mask implementation for OPT and CodeGen #1677

Merged
merged 12 commits into from
Jul 27, 2023
Merged
12 changes: 5 additions & 7 deletions src/sparseml/exporters/kv_cache_injector.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,8 @@ def __init__(

This transformation not only solely injects the kv cache
inputs/outputs, but also adjusts the original ONNX graph to
account for the necessary changes. This involves e.g. adding
the 'position' input to the model, so that the positional
embeddings of the new model are compatible with the past kv
cache information.
account for the necessary changes. This is done by the
optional `additional_transforms` variable.

Usage:
```python
Expand Down Expand Up @@ -131,7 +129,7 @@ def export(self, pre_transforms_model: onnx.ModelProto, file_path: str):

@staticmethod
def _get_transforms_from_config(config: KeyValueCacheConfig) -> List[OnnxTransform]:
positions_adjustment = config.positions_adjustment_transform
additional_transforms = config.additional_transforms

transforms = [
CacheKeysAndValues(
Expand All @@ -142,8 +140,8 @@ def _get_transforms_from_config(config: KeyValueCacheConfig) -> List[OnnxTransfo
transpose_key_input=config.transpose_key_input,
)
]
if positions_adjustment is not None:
transforms += [positions_adjustment()]
if additional_transforms is not None:
transforms += [additional_transforms()]

return transforms

Expand Down
23 changes: 10 additions & 13 deletions src/sparseml/exporters/transforms/kv_cache/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@

from pydantic import BaseModel, Field

from sparseml.exporters.transforms.kv_cache.positions_adjustment_codegen import (
PositionsAdjustmentCodeGen,
from sparseml.exporters.transforms.kv_cache.transforms_codegen import (
AdditionalTransformsCodeGen,
)
from sparseml.exporters.transforms.kv_cache.positions_adjustment_opt import (
PositionsAdjustmentOPT,
from sparseml.exporters.transforms.kv_cache.transforms_opt import (
AdditionalTransformsOPT,
)


Expand All @@ -37,12 +37,9 @@ class KeyValueCacheConfig(BaseModel):
description="The name of the model type. This should correspond to "
"the `model_type` field in the transformer's `config.json` file."
)
positions_adjustment_transform: Any = Field(
description="The class to use to transform the positional embeddings. "
"This should be a subclass of `PositionsAdjustmentBase`. Note: In the "
"future, when we encounter models that are more complex than just "
"editing the positions in the model, we can make this transformation more "
"general."
additional_transforms: Any = Field(
description="A transform class to use for additional transforms "
"to the model required for finalizing the kv cache injection."
)
key_num_attention_heads: str = Field(
description="The key to use to get the number of attention heads from the "
Expand Down Expand Up @@ -84,7 +81,7 @@ class Config:

OPT_CONFIG = KeyValueCacheConfig(
model_name="opt",
positions_adjustment_transform=PositionsAdjustmentOPT,
additional_transforms=AdditionalTransformsOPT,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be great to support a list here eventually instead in case we don't want to squash many steps into a single transform or enable more code sharing between model specific additional transforms outside of inheritance

key_num_attention_heads="num_attention_heads",
key_num_embedding_hidden_size="hidden_size",
transpose_value_input=None,
Expand All @@ -94,7 +91,7 @@ class Config:

CODEGEN_CONFIG = KeyValueCacheConfig(
model_name="codegen",
positions_adjustment_transform=PositionsAdjustmentCodeGen,
additional_transforms=AdditionalTransformsCodeGen,
key_num_attention_heads="n_head",
key_num_embedding_hidden_size="n_embd",
transpose_value_input=(0, 2, 1, 3),
Expand All @@ -104,7 +101,7 @@ class Config:

BLOOM_CONFIG = KeyValueCacheConfig(
model_name="bloom",
positional_embedding_transform=None,
additional_transforms=None,
key_num_attention_heads="num_attention_heads",
key_num_embedding_hidden_size="n_embed",
transpose_value_input=None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
from sparseml.exporters.transforms.onnx_transform import OnnxTransform


class PositionsAdjustmentBase(OnnxTransform):
__all__ = ["AdditionalTransformsBase"]


class AdditionalTransformsBase(OnnxTransform):

POSITIONS_NAME = "positions" # matches intermediate var name in torch

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,17 @@

from onnx import ModelProto, NodeProto

from sparseml.exporters.transforms.kv_cache.positions_adjustment_base import (
PositionsAdjustmentBase,
from sparseml.exporters.transforms.kv_cache.transforms_base import (
AdditionalTransformsBase,
)
from sparseml.exporters.transforms.utils.matching import get_structural_matches
from sparseml.onnx.utils import ONNXGraph


__all__ = ["PositionsAdjustmentCodeGen"]
__all__ = ["AdditionalTransformsCodeGen"]


class PositionsAdjustmentCodeGen(PositionsAdjustmentBase):
class AdditionalTransformsCodeGen(AdditionalTransformsBase):

# The pattern that matches the node that creates
# the `position_ids` tensor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,20 @@

from onnx import ModelProto, NodeProto

from sparseml.exporters.transforms.kv_cache.positions_adjustment_base import (
PositionsAdjustmentBase,
from sparseml.exporters.transforms.kv_cache.transforms_base import (
AdditionalTransformsBase,
)
from sparseml.onnx.utils import ONNXGraph


__all__ = ["PositionsAdjustmentOPT"]
__all__ = ["AdditionalTransformsOPT"]


# name position embeddings weights
_EMBED_POSITIONS_ID = "model.decoder.embed_positions.weight"


class PositionsAdjustmentOPT(PositionsAdjustmentBase):
class AdditionalTransformsOPT(AdditionalTransformsBase):
"""
Base class for model architecture specific transforms to adjust graph
to take input_id positions as an argument rather than computing them
Expand Down
Loading