Skip to content

Commit

Permalink
allow transformations to be also a list
Browse files Browse the repository at this point in the history
  • Loading branch information
bogunowicz@arrival.com committed Jul 27, 2023
1 parent 224412f commit 9b92f77
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
4 changes: 3 additions & 1 deletion src/sparseml/exporters/kv_cache_injector.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,9 @@ def _get_transforms_from_config(config: KeyValueCacheConfig) -> List[OnnxTransfo
)
]
if additional_transforms is not None:
transforms += [additional_transforms()]
if not isinstance(additional_transforms, list):
additional_transforms = [additional_transforms]
transforms += [transform() for transform in additional_transforms]

return transforms

Expand Down
11 changes: 7 additions & 4 deletions src/sparseml/exporters/transforms/kv_cache/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
import json
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Type, Union

from pydantic import BaseModel, Field

from sparseml.exporters.transforms import OnnxTransform
from sparseml.exporters.transforms.kv_cache.transforms_codegen import (
AdditionalTransformsCodeGen,
)
Expand All @@ -37,9 +38,11 @@ class KeyValueCacheConfig(BaseModel):
description="The name of the model type. This should correspond to "
"the `model_type` field in the transformer's `config.json` file."
)
additional_transforms: Any = Field(
description="A transform class to use for additional transforms "
"to the model required for finalizing the kv cache injection."
additional_transforms: Union[
List[Type[OnnxTransform]], Type[OnnxTransform], None
] = Field(
description="A transform class (or list thereof) to use for additional "
"transforms to the model required for finalizing the kv cache injection."
)
key_num_attention_heads: str = Field(
description="The key to use to get the number of attention heads from the "
Expand Down

0 comments on commit 9b92f77

Please sign in to comment.