Skip to content

Commit

Permalink
[KV Cache] BLOOM support(#1664)
Browse files Browse the repository at this point in the history
  • Loading branch information
dbogunowicz committed Jul 12, 2023
1 parent c4f211f commit 3593b1a
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

_LOGGER = logging.getLogger(__name__)

ALLOWED_NODES_BEFORE_SOFTMAX = ["Cast"]
ALLOWED_NODES_BEFORE_SOFTMAX = ["Cast", "Reshape"]
OUTPUT_CACHE_NAME = """present.{attention_layer_idx}.{cache_type}"""
INPUT_CACHE_NAME = """past_key_values.{attention_layer_idx}.{cache_type}"""

Expand Down
13 changes: 12 additions & 1 deletion src/sparseml/exporters/transforms/kv_cache/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,20 @@ class Config:
multiply_batch_by_num_att_heads=False,
)

BLOOM_CONFIG = KeyValueCacheConfig(
model_name="bloom",
positional_embedding_transform=None,
key_num_attention_heads="num_attention_heads",
key_num_embedding_hidden_size="n_embed",
transpose_value_input=None,
transpose_key_input=(0, 1, 3, 2),
multiply_batch_by_num_att_heads=True,
)


def get_kv_cache_config(
model_path: str, supported_configs: List[BaseModel] = [OPT_CONFIG, CODEGEN_CONFIG]
model_path: str,
supported_configs: List[BaseModel] = [OPT_CONFIG, CODEGEN_CONFIG, BLOOM_CONFIG],
) -> KeyValueCacheConfig:
"""
Get the kv cache config for the model at the given path.
Expand Down

0 comments on commit 3593b1a

Please sign in to comment.