Skip to content

Commit

Permalink
[KV Cache Injection] Pathway to handle GQA models (#1818)
Browse files Browse the repository at this point in the history
* initial commit

* beautification

* Update src/sparseml/exporters/transforms/kv_cache/configs.py

Co-authored-by: Michael Goin <michael@neuralmagic.com>

* complete the condition for GQA

* Update src/sparseml/exporters/transforms/kv_cache/configs.py

Co-authored-by: Michael Goin <michael@neuralmagic.com>

* Update src/sparseml/exporters/transforms/kv_cache/configs.py

Co-authored-by: Michael Goin <michael@neuralmagic.com>

---------

Co-authored-by: Michael Goin <michael@neuralmagic.com>
  • Loading branch information
2 people authored and bfineran committed Nov 16, 2023
1 parent e587c59 commit 5254cfb
Showing 1 changed file with 47 additions and 0 deletions.
47 changes: 47 additions & 0 deletions src/sparseml/exporters/transforms/kv_cache/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,57 @@ def get_kv_cache_config(
kv_cache_config.num_attention_heads = num_attention_heads
kv_cache_config.hidden_size_kv_cache = hidden_size_kv_cache

kv_cache_config = adapt_cache_structure_for_gqa(
kv_cache_config, transformers_config
)

_LOGGER.info("Properly configured arguments for KV Cache Transformation")
return kv_cache_config


def adapt_cache_structure_for_gqa(
kv_cache_config: KeyValueCacheConfig,
transformers_config: Dict[str, Any],
model_names: List[str] = ["llama"],
) -> KeyValueCacheConfig:
"""
Potentially adapts the kv_cache_config, so that it
properly works with Grouped Query Attention (GQA).
For now, this function only supports the llama model.
Llama uses:
Multi Head Attention (MHA) if `num_key_value_heads==num_attention_heads` (default),
Grouped Query Attention (GQA) if `num_key_value_heads<num_attention_heads`,
Multi Query Attention (MQA) if `num_key_value_heads==1`,
:param kv_cache_config: The kv cache config for the model.
:param transformers_config: The transformers config for
the model. If contains the key:`num_key_value_heads`,
the model may be potentially using GQA instead of
MHA and thus the kv_cache_config needs to be adapted.
:param model_names: The list of model names that may use
GQA instead of MQA.
:return: Potentially adapted kv cache config for the model.
If the model does not use GQA, the kv_cache_config is
returned unchanged.
"""
# For now, we only support GQA for LLAMA.
model_name = kv_cache_config.model_name
num_attention_heads = kv_cache_config.num_attention_heads
num_key_value_heads = transformers_config.get("num_key_value_heads")

if num_key_value_heads is not None and model_name in model_names:
if num_key_value_heads > 1 and num_key_value_heads != num_attention_heads:
# introduce the modification the config to support GQA for LLAMA.
kv_cache_config.transpose_value_input = None

_LOGGER.info(
f"Adapted the model: {transformers_config['model_type']} "
f"to work with GQA."
)
return kv_cache_config


def _get_transformers_config(model_path: Union[str, Path]) -> Dict[str, Any]:
# from the model path, get the config.json file and return it as a dict.
model_path = Path(model_path) if isinstance(model_path, str) else model_path
Expand Down

0 comments on commit 5254cfb

Please sign in to comment.