Skip to content

Commit

Permalink
[Export Refactor] Export generative transformers(#1910)
Browse files Browse the repository at this point in the history
  • Loading branch information
dbogunowicz committed Dec 18, 2023
1 parent fd581ea commit d6e0894
Show file tree
Hide file tree
Showing 4 changed files with 207 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import Callable, Dict, List

from pydantic import Field

from sparseml.transformers.integration_helper_functions import Transformers
from sparseml.transformers.utils.helpers import (
MANDATORY_DEPLOYMENT_FILES,
NLG_TOKENIZER_FILES,
)
from sparseml.transformers.utils.optimizations import apply_kv_cache_injection
from src.sparseml.integration_helper_functions import (
IntegrationHelperFunctions,
Integrations,
)


generative_transformers_graph_optimizations = {
"kv_cache_injection": apply_kv_cache_injection
}


@IntegrationHelperFunctions.register(name=Integrations.transformers_generative.value)
class GenerativeTransformers(Transformers):
graph_optimizations: Dict[str, Callable] = Field(
default=generative_transformers_graph_optimizations
)
deployment_directory_files_mandatory: List[str] = Field(
default=list(MANDATORY_DEPLOYMENT_FILES.union(NLG_TOKENIZER_FILES))
)
41 changes: 41 additions & 0 deletions src/sparseml/transformers/utils/optimizations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
from pathlib import Path
from typing import Union

import onnx

from sparseml.exporters.kv_cache_injector import KeyValueCacheInjector


__all__ = ["apply_kv_cache_injection"]

_LOGGER = logging.getLogger(__name__)


def apply_kv_cache_injection(onnx_model_path: Union[str, Path]) -> bool:
"""
Apply key value cache injection to an ONNX model
:param onnx_model_path: path to the ONNX model to inject
:return: True if successful, False otherwise
"""
onnx_model = onnx.load(onnx_model_path, load_external_data=False)
model_path = os.path.dirname(onnx_model_path)
exporter = KeyValueCacheInjector(model_path=model_path)
exporter.export(onnx_model, onnx_model_path)
return True
79 changes: 79 additions & 0 deletions tests/sparseml/export/transformers/test_generative_transformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import shutil

import pytest
import torch

from huggingface_hub import snapshot_download
from sparseml.export.export import export


@pytest.mark.parametrize(
"stub, task",
[("roneneldan/TinyStories-1M", "text-generation")],
)
class TestEndToEndExport:
@pytest.fixture()
def setup(self, tmp_path, stub, task):
model_path = tmp_path / "model"
target_path = tmp_path / "target"

source_path = snapshot_download(stub, local_dir=model_path)

yield source_path, target_path, task

shutil.rmtree(tmp_path)

def test_export_happy_path(self, setup):
source_path, target_path, task = setup
export(
source_path=source_path,
target_path=target_path,
task=task,
)
assert (target_path / "deployment" / "model.onnx").exists()

def test_export_with_sample_data(self, setup):
source_path, target_path, task = setup

sequence_length = 32
sample_data = dict(
input_ids=torch.ones((10, sequence_length), dtype=torch.long),
attention_mask=torch.ones((10, sequence_length), dtype=torch.long),
)
export(
source_path=source_path,
target_path=target_path,
task=task,
sample_data=sample_data,
)
assert (target_path / "deployment" / "model.onnx").exists()

@pytest.mark.skipif(
reason="skipping since this functionality needs some more attention"
)
def test_export_validate_correctness(self, setup):
source_path, target_path, task = setup

num_samples = 4

export(
source_path=source_path,
target_path=target_path,
task=task,
num_export_samples=num_samples,
validate_correctness=True,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from src.sparseml.integration_helper_functions import (
IntegrationHelperFunctions,
Integrations,
)


def test_integration_helper_functions():
# import needed to register the object on the fly
import sparseml.transformers.integration_helper_functions_generative # noqa F401

transformers_gen = IntegrationHelperFunctions.load_from_registry(
Integrations.transformers_generative.value
)
assert transformers_gen.create_model
assert transformers_gen.create_dummy_input
assert transformers_gen.export
assert transformers_gen.graph_optimizations.values() == ["apply_kv_cache_injection"]
assert transformers_gen.create_data_samples
assert set(transformers_gen.deployment_directory_files_mandatory) == {
"model.onnx",
"tokenizer_config.json",
"config.json",
"special_tokens_map.json",
"vocab.json",
"merges.txt",
}
assert set(transformers_gen.deployment_directory_files_optional) == {
"tokenizer.json",
"tokenizer.model",
}

0 comments on commit d6e0894

Please sign in to comment.