Skip to content

Commit

Permalink
[Export Refactor] Export transformers (#1909)
Browse files Browse the repository at this point in the history
* cleanup

* Delete src/sparseml/transformers/integration_helper_functions_generative.py

* Delete src/sparseml/transformers/utils/optimizations.py

* Delete tests/sparseml/export/transformers/test_generative_transformers.py

* Delete tests/sparseml/transformers/test_integration_helper_functions_generative.py

* addressing PR reviews

* [Export Refactor] Export generative transformers(#1910)
  • Loading branch information
dbogunowicz committed Dec 18, 2023
1 parent 05a9ee3 commit 450b286
Show file tree
Hide file tree
Showing 19 changed files with 1,298 additions and 128 deletions.
3 changes: 2 additions & 1 deletion src/sparseml/export/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from sparseml.export.validators import validate_correctness as validate_correctness_
from sparseml.export.validators import validate_structure as validate_structure_
from sparseml.pytorch.opset import TORCH_DEFAULT_ONNX_OPSET
from sparseml.pytorch.utils.helpers import default_device
from sparseml.pytorch.utils.helpers import default_device, use_single_gpu
from src.sparseml.integration_helper_functions import (
IntegrationHelperFunctions,
resolve_integration,
Expand Down Expand Up @@ -117,6 +117,7 @@ def export(

# choose the appropriate device
device = default_device() if device == "auto" else device
device = use_single_gpu(device) if "cuda" in device else device

# assert the valid deployment target
if deployment_target not in AVAILABLE_DEPLOYMENT_TARGETS:
Expand Down
84 changes: 49 additions & 35 deletions src/sparseml/integration_helper_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,14 @@ class Integrations(Enum):
"""

image_classification = "image-classification"
transformers = "transformers"
transformers_generative = "transformers-generative"


def resolve_integration(
source_path: Union[Path, str], integration: Optional[str] = None
source_path: Union[Path, str],
integration: Optional[str] = None,
task: Optional[str] = None,
) -> str:
"""
Resolve the integration to use.
Expand All @@ -47,24 +51,46 @@ def resolve_integration(
:param source_path: The path to the PyTorch model to export.
:param integration: Optional name of the integration to use. If not provided,
will attempt to infer it from the source_path.
:param task: Optional name of the task to use.
:return: The name of the integration to use for exporting the model.
"""

if integration is not None:
integration = integration.replace("_", "-")

if task is not None:
task = task.replace("_", "-")

from sparseml.pytorch.image_classification.utils.helpers import (
is_image_classification_model,
)
from sparseml.transformers.utils.helpers import (
TaskNames,
is_transformer_generative_model,
is_transformer_model,
)

if (
integration == Integrations.image_classification.value
or is_image_classification_model(source_path)
):
# import to register the image_classification integration helper functions
import sparseml.pytorch.image_classification.integration_helper_functions # noqa F401

return Integrations.image_classification.value

elif task in TaskNames.text_generation.value or is_transformer_generative_model(
source_path
):
import sparseml.transformers.integration_helper_functions_generative # noqa F401

return Integrations.transformers_generative.value

elif integration == Integrations.transformers.value or is_transformer_model(
source_path
):
import sparseml.transformers.integration_helper_functions # noqa F401

return Integrations.transformers.value
else:
raise ValueError(
f"Could not infer integration from source_path:\n{source_path}\n"
Expand All @@ -81,36 +107,23 @@ class IntegrationHelperFunctions(RegistryMixin, BaseModel):
integration.
"""

create_model: Optional[
Callable[
Tuple[Union[str, Path], Optional[int], str, Optional[Dict[str, Any]]],
Tuple[
"torch.nn.Module", # noqa F821
Optional["torch.utils.data.Dataloader"], # noqa F821
],
]
create_model: Callable[
[Union[str, Path], ...],
Tuple[
"torch.nn.Module", # noqa F821
Optional[Dict[str, Any]],
],
] = Field(
description="A function that takes: "
"- a source path to a PyTorch model "
"- a batch size "
"- a device name "
"- (optionally) a dictionary of additional arguments"
"- (optionally) additional arguments"
"and returns: "
"- a (sparse) PyTorch model "
"- (optionally) a data loader "
"- (optionally) a dictionary of auxiliary items"
)
create_dummy_input: Optional[
Callable[
Tuple[
Optional["torch.utils.data.Dataloader"], # noqa F821
Optional[Dict[str, Any]],
],
"torch.Tensor", # noqa F821
]
] = Field(
create_dummy_input: Callable[..., "torch.Tensor"] = Field( # noqa F821
description="A function that takes: "
"- (optionally) a data loader "
"- (optionally) a dictionary of additional arguments"
"- appropriate arguments "
"and returns: "
"- a dummy input for the model (a torch.Tensor) "
)
Expand All @@ -131,28 +144,29 @@ class IntegrationHelperFunctions(RegistryMixin, BaseModel):
)

create_data_samples: Callable[
Tuple[
Optional["torch.nn.Module"], # noqa F821
"torch.utils.data.DataLoader", # noqa F821
int,
],
Tuple[Optional["torch.nn.Module"], int, Optional[Dict[str, Any]]], # noqa F821
Tuple[
List["torch.Tensor"], # noqa F821
Optional[List["torch.Tensor"]], # noqa F821
List["torch.Tensor"], # noqa F821
Optional[List["torch.Tensor"]], # noqa F821
],
] = Field(
default=create_data_samples_,
description="A function that takes: "
" - (optionally) a (sparse) PyTorch model "
" - a data loader "
" - the number of samples to generate "
" - (optionally) additional auxiliary items "
"and returns: "
" - the inputs, labels and (optionally) outputs as torch tensors ",
" - the inputs, (optionally) labels and (optionally) outputs as torch tensors ",
)

deployment_directory_structure: List[str] = Field(
deployment_directory_files_mandatory: List[str] = Field(
description="A list that describes the "
"expected files of the deployment directory",
"mandatory expected files of the deployment directory",
default=["model.onnx"],
)

deployment_directory_files_optional: Optional[List[str]] = Field(
description="A list that describes the "
"optional expected files of the deployment directory",
)
12 changes: 11 additions & 1 deletion src/sparseml/pytorch/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,15 @@ def default_device() -> str:
return "cuda:{}".format(",".join(device_ids))


def use_single_gpu(device: str) -> str:
"""
return: the first gpu in the device string if multiple are available
"""
if "cuda" not in device:
raise ValueError("use_single_gpu should only be called on cuda devices")
return device.split(",")[0]


def device_of(inputs: Any):
if isinstance(inputs, Tensor):
return inputs.device
Expand Down Expand Up @@ -538,7 +547,8 @@ def _tensors_export_batch(
return

if isinstance(tensors, Iterable):
for index, tens in enumerate(zip(*tensors)):
# TODO: I am breaking something here? - dbogunowicz
for index, tens in enumerate(zip(tensors)):
exported_paths.append(
tensor_export(
tens, export_dir, "{}-{:04d}".format(name_prefix, counter + index)
Expand Down
168 changes: 168 additions & 0 deletions src/sparseml/transformers/integration_helper_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
from pydantic import Field
from transformers import AutoTokenizer

from sparseml.transformers.sparsification.trainer import Trainer
from sparseml.transformers.utils.helpers import (
MANDATORY_DEPLOYMENT_FILES,
OPTIONAL_DEPLOYMENT_FILES,
)
from sparseml.transformers.utils.load_task_dataset import load_task_dataset
from src.sparseml.export.export_data import create_data_samples as create_data_samples_
from src.sparseml.integration_helper_functions import (
IntegrationHelperFunctions,
Integrations,
)
from src.sparseml.transformers.utils.initializers import (
_parse_data_args,
initialize_config,
initialize_model,
initialize_tokenizer,
initialize_trainer,
resolve_sequence_length,
)


_LOGGER = logging.getLogger(__name__)


def create_model(
source_path: Union[Path, str],
device: Optional[str] = None,
task: Optional[str] = None,
**kwargs,
) -> Tuple[torch.nn.Module, Dict[str, Any]]:
"""
A contract to create a model and optional dictionary of
auxiliary items related to the model
:param source_path: The path to the model
:param device: The device to use for the model and dataloader instantiation
:param task: The task to use for the model and dataloader instantiation
:return: A tuple of the
- torch model
- (optionally) a dictionary of auxiliary items
"""
config_args = kwargs.get("config_args", {})
sequence_length = kwargs.get("sequence_length", None)
data_args = kwargs.get("data_args", {})
trust_remote_code = kwargs.get("trust_remote_code", False)

if task is None:
raise ValueError("To create a transformer model, a task must be specified")

if not trust_remote_code:
_LOGGER.warning(
"trust_remote_code is set to False. It is possible, "
"that the model will not be loaded correctly."
)

config = initialize_config(source_path, trust_remote_code, **config_args)
sequence_length = sequence_length or resolve_sequence_length(config)
tokenizer = initialize_tokenizer(source_path, sequence_length, task)
model = initialize_model(
model_path=source_path,
task=task,
config=config,
trust_remote_code=trust_remote_code,
device=device,
)

data_args = _parse_data_args(data_args)

if data_args:
dataset = load_task_dataset(
task=task,
tokenizer=tokenizer,
data_args=data_args,
model=model,
config=config,
)
validation_dataset = dataset.get("validation")

else:
validation_dataset = None

model.train()
trainer = initialize_trainer(model, source_path, validation_dataset)
model.eval()

return model, dict(
trainer=trainer,
tokenizer=tokenizer,
input_names=list(next(trainer._get_fake_dataloader(1, tokenizer)).keys()),
)


def create_dummy_input(
trainer: Optional[Trainer] = None,
tokenizer: Optional[AutoTokenizer] = None,
**kwargs,
) -> torch.Tensor:
if trainer.eval_dataset is not None:
data_loader = trainer.get_eval_dataloader()
else:
if not tokenizer:
raise ValueError(
"Tokenizer is needed to generate "
"fake sample inputs when the trainer is "
"not initialized with an eval dataset"
)
data_loader = trainer._get_fake_dataloader(num_samples=1, tokenizer=tokenizer)
return next(iter(data_loader))


def create_data_samples(
num_samples: int,
trainer: Trainer,
model: Optional["torch.nn.Module"] = None,
**kwargs,
):
if kwargs.get("batch_size"):
_LOGGER.info(
"For exporting samples for transformers integration,"
"batch size is ignored (equal to 1)"
)
if trainer.eval_dataset is None:
raise ValueError(
"Attempting to create data samples without an eval dataloader. "
"Initialize a trainer with an eval dataset"
)

return create_data_samples_(
data_loader=trainer.get_eval_dataloader(), model=model, num_samples=num_samples
)


@IntegrationHelperFunctions.register(name=Integrations.transformers.value)
class Transformers(IntegrationHelperFunctions):
create_model: Callable[..., Tuple[torch.nn.Module, Dict[str, Any]]] = Field(
default=create_model
)
create_dummy_input: Callable[..., torch.Tensor] = Field(default=create_dummy_input)
create_data_samples: Callable = Field(create_data_samples)
deployment_directory_files_mandatory: List[str] = Field(
default=list(MANDATORY_DEPLOYMENT_FILES)
)
deployment_directory_files_optional: List[str] = Field(
default=list(OPTIONAL_DEPLOYMENT_FILES)
)
Loading

0 comments on commit 450b286

Please sign in to comment.