Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Export Refactor] Export generative transformers #1910

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import Callable, Dict, List

from pydantic import Field

from sparseml.transformers.integration_helper_functions import Transformers
from sparseml.transformers.utils.helpers import (
MANDATORY_DEPLOYMENT_FILES,
NLG_TOKENIZER_FILES,
)
from sparseml.transformers.utils.optimizations import apply_kv_cache_injection
from src.sparseml.integration_helper_functions import (
IntegrationHelperFunctions,
Integrations,
)


generative_transformers_graph_optimizations = {
"kv_cache_injection": apply_kv_cache_injection
}


@IntegrationHelperFunctions.register(name=Integrations.transformers_generative.value)
class GenerativeTransformers(Transformers):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the generative models - especially ones going through our current OBCQ/finetuning flow need a separate path to load the recipe from the new session framework right? this will be different from the existing automodel.

Should sync with @Satrat on this if needed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, for any model with a recipe we need to call apply_recipe_structure_to_model before exporting

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will discuss it during PP session

graph_optimizations: Dict[str, Callable] = Field(
default=generative_transformers_graph_optimizations
)
deployment_directory_files_mandatory: List[str] = Field(
default=list(MANDATORY_DEPLOYMENT_FILES.union(NLG_TOKENIZER_FILES))
)
41 changes: 41 additions & 0 deletions src/sparseml/transformers/utils/optimizations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
from pathlib import Path
from typing import Union

import onnx

from sparseml.exporters.kv_cache_injector import KeyValueCacheInjector


__all__ = ["apply_kv_cache_injection"]

_LOGGER = logging.getLogger(__name__)


def apply_kv_cache_injection(onnx_model_path: Union[str, Path]) -> bool:
"""
Apply key value cache injection to an ONNX model

:param onnx_model_path: path to the ONNX model to inject
:return: True if successful, False otherwise
"""
onnx_model = onnx.load(onnx_model_path, load_external_data=False)
model_path = os.path.dirname(onnx_model_path)
exporter = KeyValueCacheInjector(model_path=model_path)
exporter.export(onnx_model, onnx_model_path)
return True
79 changes: 79 additions & 0 deletions tests/sparseml/export/transformers/test_generative_transformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import shutil

import pytest
import torch

from huggingface_hub import snapshot_download
from sparseml.export.export import export


@pytest.mark.parametrize(
"stub, task",
[("roneneldan/TinyStories-1M", "text-generation")],
)
class TestEndToEndExport:
@pytest.fixture()
def setup(self, tmp_path, stub, task):
model_path = tmp_path / "model"
target_path = tmp_path / "target"

source_path = snapshot_download(stub, local_dir=model_path)

yield source_path, target_path, task

shutil.rmtree(tmp_path)

def test_export_happy_path(self, setup):
source_path, target_path, task = setup
export(
source_path=source_path,
target_path=target_path,
task=task,
)
assert (target_path / "deployment" / "model.onnx").exists()

def test_export_with_sample_data(self, setup):
source_path, target_path, task = setup

sequence_length = 32
sample_data = dict(
input_ids=torch.ones((10, sequence_length), dtype=torch.long),
attention_mask=torch.ones((10, sequence_length), dtype=torch.long),
)
export(
source_path=source_path,
target_path=target_path,
task=task,
sample_data=sample_data,
)
assert (target_path / "deployment" / "model.onnx").exists()

@pytest.mark.skipif(
reason="skipping since this functionality needs some more attention"
)
def test_export_validate_correctness(self, setup):
source_path, target_path, task = setup

num_samples = 4

export(
source_path=source_path,
target_path=target_path,
task=task,
num_export_samples=num_samples,
validate_correctness=True,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from src.sparseml.integration_helper_functions import (
IntegrationHelperFunctions,
Integrations,
)


def test_integration_helper_functions():
# import needed to register the object on the fly
import sparseml.transformers.integration_helper_functions_generative # noqa F401

transformers_gen = IntegrationHelperFunctions.load_from_registry(
Integrations.transformers_generative.value
)
assert transformers_gen.create_model
assert transformers_gen.create_dummy_input
assert transformers_gen.export
assert transformers_gen.graph_optimizations.values() == ["apply_kv_cache_injection"]
assert transformers_gen.create_data_samples
assert set(transformers_gen.deployment_directory_files_mandatory) == {
"model.onnx",
"tokenizer_config.json",
"config.json",
"special_tokens_map.json",
"vocab.json",
"merges.txt",
}
assert set(transformers_gen.deployment_directory_files_optional) == {
"tokenizer.json",
"tokenizer.model",
}
Loading