diff --git a/src/sparseml/transformers/integration_helper_functions_generative.py b/src/sparseml/transformers/integration_helper_functions_generative.py new file mode 100644 index 00000000000..2c17f47395a --- /dev/null +++ b/src/sparseml/transformers/integration_helper_functions_generative.py @@ -0,0 +1,44 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Callable, Dict, List + +from pydantic import Field + +from sparseml.transformers.integration_helper_functions import Transformers +from sparseml.transformers.utils.helpers import ( + MANDATORY_DEPLOYMENT_FILES, + NLG_TOKENIZER_FILES, +) +from sparseml.transformers.utils.optimizations import apply_kv_cache_injection +from src.sparseml.integration_helper_functions import ( + IntegrationHelperFunctions, + Integrations, +) + + +generative_transformers_graph_optimizations = { + "kv_cache_injection": apply_kv_cache_injection +} + + +@IntegrationHelperFunctions.register(name=Integrations.transformers_generative.value) +class GenerativeTransformers(Transformers): + graph_optimizations: Dict[str, Callable] = Field( + default=generative_transformers_graph_optimizations + ) + deployment_directory_files_mandatory: List[str] = Field( + default=list(MANDATORY_DEPLOYMENT_FILES.union(NLG_TOKENIZER_FILES)) + ) diff --git a/src/sparseml/transformers/utils/optimizations.py b/src/sparseml/transformers/utils/optimizations.py new file mode 100644 index 00000000000..20b58775b10 --- /dev/null +++ b/src/sparseml/transformers/utils/optimizations.py @@ -0,0 +1,41 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from pathlib import Path +from typing import Union + +import onnx + +from sparseml.exporters.kv_cache_injector import KeyValueCacheInjector + + +__all__ = ["apply_kv_cache_injection"] + +_LOGGER = logging.getLogger(__name__) + + +def apply_kv_cache_injection(onnx_model_path: Union[str, Path]) -> bool: + """ + Apply key value cache injection to an ONNX model + + :param onnx_model_path: path to the ONNX model to inject + :return: True if successful, False otherwise + """ + onnx_model = onnx.load(onnx_model_path, load_external_data=False) + model_path = os.path.dirname(onnx_model_path) + exporter = KeyValueCacheInjector(model_path=model_path) + exporter.export(onnx_model, onnx_model_path) + return True diff --git a/tests/sparseml/export/transformers/test_generative_transformers.py b/tests/sparseml/export/transformers/test_generative_transformers.py new file mode 100644 index 00000000000..8a014360fee --- /dev/null +++ b/tests/sparseml/export/transformers/test_generative_transformers.py @@ -0,0 +1,79 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import shutil + +import pytest +import torch + +from huggingface_hub import snapshot_download +from sparseml.export.export import export + + +@pytest.mark.parametrize( + "stub, task", + [("roneneldan/TinyStories-1M", "text-generation")], +) +class TestEndToEndExport: + @pytest.fixture() + def setup(self, tmp_path, stub, task): + model_path = tmp_path / "model" + target_path = tmp_path / "target" + + source_path = snapshot_download(stub, local_dir=model_path) + + yield source_path, target_path, task + + shutil.rmtree(tmp_path) + + def test_export_happy_path(self, setup): + source_path, target_path, task = setup + export( + source_path=source_path, + target_path=target_path, + task=task, + ) + assert (target_path / "deployment" / "model.onnx").exists() + + def test_export_with_sample_data(self, setup): + source_path, target_path, task = setup + + sequence_length = 32 + sample_data = dict( + input_ids=torch.ones((10, sequence_length), dtype=torch.long), + attention_mask=torch.ones((10, sequence_length), dtype=torch.long), + ) + export( + source_path=source_path, + target_path=target_path, + task=task, + sample_data=sample_data, + ) + assert (target_path / "deployment" / "model.onnx").exists() + + @pytest.mark.skipif( + reason="skipping since this functionality needs some more attention" + ) + def test_export_validate_correctness(self, setup): + source_path, target_path, task = setup + + num_samples = 4 + + export( + source_path=source_path, + target_path=target_path, + task=task, + num_export_samples=num_samples, + validate_correctness=True, + ) diff --git a/tests/sparseml/transformers/test_integration_helper_functions_generative.py b/tests/sparseml/transformers/test_integration_helper_functions_generative.py new file mode 100644 index 00000000000..7a1c77f7066 --- /dev/null +++ b/tests/sparseml/transformers/test_integration_helper_functions_generative.py @@ -0,0 +1,43 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from src.sparseml.integration_helper_functions import ( + IntegrationHelperFunctions, + Integrations, +) + + +def test_integration_helper_functions(): + # import needed to register the object on the fly + import sparseml.transformers.integration_helper_functions_generative # noqa F401 + + transformers_gen = IntegrationHelperFunctions.load_from_registry( + Integrations.transformers_generative.value + ) + assert transformers_gen.create_model + assert transformers_gen.create_dummy_input + assert transformers_gen.export + assert transformers_gen.graph_optimizations.values() == ["apply_kv_cache_injection"] + assert transformers_gen.create_data_samples + assert set(transformers_gen.deployment_directory_files_mandatory) == { + "model.onnx", + "tokenizer_config.json", + "config.json", + "special_tokens_map.json", + "vocab.json", + "merges.txt", + } + assert set(transformers_gen.deployment_directory_files_optional) == { + "tokenizer.json", + "tokenizer.model", + }