diff --git a/src/sparseml/exporters/kv_cache_injector.py b/src/sparseml/exporters/kv_cache_injector.py index 89e1cf01c4f..9c69dbdfc19 100644 --- a/src/sparseml/exporters/kv_cache_injector.py +++ b/src/sparseml/exporters/kv_cache_injector.py @@ -143,7 +143,9 @@ def _get_transforms_from_config(config: KeyValueCacheConfig) -> List[OnnxTransfo ) ] if additional_transforms is not None: - transforms += [additional_transforms()] + if not isinstance(additional_transforms, list): + additional_transforms = [additional_transforms] + transforms += [transform() for transform in additional_transforms] return transforms diff --git a/src/sparseml/exporters/transforms/kv_cache/configs.py b/src/sparseml/exporters/transforms/kv_cache/configs.py index 707337cfdc3..0bfbbfb1850 100644 --- a/src/sparseml/exporters/transforms/kv_cache/configs.py +++ b/src/sparseml/exporters/transforms/kv_cache/configs.py @@ -15,10 +15,11 @@ import json import logging from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Type, Union from pydantic import BaseModel, Field +from sparseml.exporters.transforms import OnnxTransform from sparseml.exporters.transforms.kv_cache.transforms_codegen import ( AdditionalTransformsCodeGen, ) @@ -37,9 +38,11 @@ class KeyValueCacheConfig(BaseModel): description="The name of the model type. This should correspond to " "the `model_type` field in the transformer's `config.json` file." ) - additional_transforms: Any = Field( - description="A transform class to use for additional transforms " - "to the model required for finalizing the kv cache injection." + additional_transforms: Union[ + List[Type[OnnxTransform]], Type[OnnxTransform], None + ] = Field( + description="A transform class (or list thereof) to use for additional " + "transforms to the model required for finalizing the kv cache injection." ) key_num_attention_heads: str = Field( description="The key to use to get the number of attention heads from the "