-
Notifications
You must be signed in to change notification settings - Fork 140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Export Refactor] Export transformers
#1909
Merged
dbogunowicz
merged 7 commits into
feature/damian/export_adapt
from
feature/damian/export_transformers
Dec 18, 2023
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
62c855b
cleanup
dbogunowicz dd42815
Delete src/sparseml/transformers/integration_helper_functions_generat…
dbogunowicz b97aace
Delete src/sparseml/transformers/utils/optimizations.py
dbogunowicz 54bcb2f
Delete tests/sparseml/export/transformers/test_generative_transformer…
dbogunowicz 5f5ad49
Delete tests/sparseml/transformers/test_integration_helper_functions_…
dbogunowicz fd581ea
addressing PR reviews
dbogunowicz d6e0894
[Export Refactor] Export generative transformers(#1910)
dbogunowicz File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,10 +32,14 @@ class Integrations(Enum): | |
""" | ||
|
||
image_classification = "image-classification" | ||
transformers = "transformers" | ||
transformers_generative = "transformers-generative" | ||
|
||
|
||
def resolve_integration( | ||
source_path: Union[Path, str], integration: Optional[str] = None | ||
source_path: Union[Path, str], | ||
integration: Optional[str] = None, | ||
task: Optional[str] = None, | ||
) -> str: | ||
""" | ||
Resolve the integration to use. | ||
|
@@ -47,24 +51,46 @@ def resolve_integration( | |
:param source_path: The path to the PyTorch model to export. | ||
:param integration: Optional name of the integration to use. If not provided, | ||
will attempt to infer it from the source_path. | ||
:param task: Optional name of the task to use. | ||
:return: The name of the integration to use for exporting the model. | ||
""" | ||
|
||
if integration is not None: | ||
integration = integration.replace("_", "-") | ||
|
||
if task is not None: | ||
task = task.replace("_", "-") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shouldn't this be resolved at a more consolidated level? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Like in the |
||
|
||
from sparseml.pytorch.image_classification.utils.helpers import ( | ||
is_image_classification_model, | ||
) | ||
from sparseml.transformers.utils.helpers import ( | ||
TaskNames, | ||
is_transformer_generative_model, | ||
is_transformer_model, | ||
) | ||
|
||
if ( | ||
integration == Integrations.image_classification.value | ||
or is_image_classification_model(source_path) | ||
): | ||
# import to register the image_classification integration helper functions | ||
import sparseml.pytorch.image_classification.integration_helper_functions # noqa F401 | ||
|
||
return Integrations.image_classification.value | ||
|
||
elif task in TaskNames.text_generation.value or is_transformer_generative_model( | ||
source_path | ||
): | ||
import sparseml.transformers.integration_helper_functions_generative # noqa F401 | ||
|
||
return Integrations.transformers_generative.value | ||
|
||
elif integration == Integrations.transformers.value or is_transformer_model( | ||
source_path | ||
): | ||
import sparseml.transformers.integration_helper_functions # noqa F401 | ||
|
||
return Integrations.transformers.value | ||
else: | ||
raise ValueError( | ||
f"Could not infer integration from source_path:\n{source_path}\n" | ||
|
@@ -81,36 +107,23 @@ class IntegrationHelperFunctions(RegistryMixin, BaseModel): | |
integration. | ||
""" | ||
|
||
create_model: Optional[ | ||
Callable[ | ||
Tuple[Union[str, Path], Optional[int], str, Optional[Dict[str, Any]]], | ||
Tuple[ | ||
"torch.nn.Module", # noqa F821 | ||
Optional["torch.utils.data.Dataloader"], # noqa F821 | ||
], | ||
] | ||
create_model: Callable[ | ||
[Union[str, Path], ...], | ||
Tuple[ | ||
"torch.nn.Module", # noqa F821 | ||
Optional[Dict[str, Any]], | ||
], | ||
] = Field( | ||
description="A function that takes: " | ||
"- a source path to a PyTorch model " | ||
"- a batch size " | ||
"- a device name " | ||
"- (optionally) a dictionary of additional arguments" | ||
"- (optionally) additional arguments" | ||
"and returns: " | ||
"- a (sparse) PyTorch model " | ||
"- (optionally) a data loader " | ||
"- (optionally) a dictionary of auxiliary items" | ||
) | ||
create_dummy_input: Optional[ | ||
Callable[ | ||
Tuple[ | ||
Optional["torch.utils.data.Dataloader"], # noqa F821 | ||
Optional[Dict[str, Any]], | ||
], | ||
"torch.Tensor", # noqa F821 | ||
] | ||
] = Field( | ||
create_dummy_input: Callable[..., "torch.Tensor"] = Field( # noqa F821 | ||
description="A function that takes: " | ||
"- (optionally) a data loader " | ||
"- (optionally) a dictionary of additional arguments" | ||
"- appropriate arguments " | ||
"and returns: " | ||
"- a dummy input for the model (a torch.Tensor) " | ||
) | ||
|
@@ -131,28 +144,29 @@ class IntegrationHelperFunctions(RegistryMixin, BaseModel): | |
) | ||
|
||
create_data_samples: Callable[ | ||
Tuple[ | ||
Optional["torch.nn.Module"], # noqa F821 | ||
"torch.utils.data.DataLoader", # noqa F821 | ||
int, | ||
], | ||
Tuple[Optional["torch.nn.Module"], int, Optional[Dict[str, Any]]], # noqa F821 | ||
Tuple[ | ||
List["torch.Tensor"], # noqa F821 | ||
Optional[List["torch.Tensor"]], # noqa F821 | ||
List["torch.Tensor"], # noqa F821 | ||
Optional[List["torch.Tensor"]], # noqa F821 | ||
], | ||
] = Field( | ||
default=create_data_samples_, | ||
description="A function that takes: " | ||
" - (optionally) a (sparse) PyTorch model " | ||
" - a data loader " | ||
" - the number of samples to generate " | ||
" - (optionally) additional auxiliary items " | ||
"and returns: " | ||
" - the inputs, labels and (optionally) outputs as torch tensors ", | ||
" - the inputs, (optionally) labels and (optionally) outputs as torch tensors ", | ||
) | ||
|
||
deployment_directory_structure: List[str] = Field( | ||
deployment_directory_files_mandatory: List[str] = Field( | ||
description="A list that describes the " | ||
"expected files of the deployment directory", | ||
"mandatory expected files of the deployment directory", | ||
default=["model.onnx"], | ||
) | ||
|
||
deployment_directory_files_optional: Optional[List[str]] = Field( | ||
description="A list that describes the " | ||
"optional expected files of the deployment directory", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
168 changes: 168 additions & 0 deletions
168
src/sparseml/transformers/integration_helper_functions.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import logging | ||
from pathlib import Path | ||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union | ||
|
||
import torch | ||
from pydantic import Field | ||
from transformers import AutoTokenizer | ||
|
||
from sparseml.transformers.sparsification.trainer import Trainer | ||
from sparseml.transformers.utils.helpers import ( | ||
MANDATORY_DEPLOYMENT_FILES, | ||
OPTIONAL_DEPLOYMENT_FILES, | ||
) | ||
from sparseml.transformers.utils.load_task_dataset import load_task_dataset | ||
from src.sparseml.export.export_data import create_data_samples as create_data_samples_ | ||
from src.sparseml.integration_helper_functions import ( | ||
IntegrationHelperFunctions, | ||
Integrations, | ||
) | ||
from src.sparseml.transformers.utils.initializers import ( | ||
_parse_data_args, | ||
initialize_config, | ||
initialize_model, | ||
initialize_tokenizer, | ||
initialize_trainer, | ||
resolve_sequence_length, | ||
) | ||
|
||
|
||
_LOGGER = logging.getLogger(__name__) | ||
|
||
|
||
def create_model( | ||
source_path: Union[Path, str], | ||
device: Optional[str] = None, | ||
task: Optional[str] = None, | ||
**kwargs, | ||
) -> Tuple[torch.nn.Module, Dict[str, Any]]: | ||
""" | ||
A contract to create a model and optional dictionary of | ||
auxiliary items related to the model | ||
|
||
:param source_path: The path to the model | ||
:param device: The device to use for the model and dataloader instantiation | ||
:param task: The task to use for the model and dataloader instantiation | ||
|
||
:return: A tuple of the | ||
- torch model | ||
- (optionally) a dictionary of auxiliary items | ||
""" | ||
config_args = kwargs.get("config_args", {}) | ||
sequence_length = kwargs.get("sequence_length", None) | ||
data_args = kwargs.get("data_args", {}) | ||
trust_remote_code = kwargs.get("trust_remote_code", False) | ||
|
||
if task is None: | ||
raise ValueError("To create a transformer model, a task must be specified") | ||
|
||
if not trust_remote_code: | ||
_LOGGER.warning( | ||
"trust_remote_code is set to False. It is possible, " | ||
"that the model will not be loaded correctly." | ||
) | ||
|
||
config = initialize_config(source_path, trust_remote_code, **config_args) | ||
sequence_length = sequence_length or resolve_sequence_length(config) | ||
tokenizer = initialize_tokenizer(source_path, sequence_length, task) | ||
model = initialize_model( | ||
model_path=source_path, | ||
task=task, | ||
config=config, | ||
trust_remote_code=trust_remote_code, | ||
device=device, | ||
) | ||
|
||
data_args = _parse_data_args(data_args) | ||
|
||
if data_args: | ||
dataset = load_task_dataset( | ||
task=task, | ||
tokenizer=tokenizer, | ||
data_args=data_args, | ||
model=model, | ||
config=config, | ||
) | ||
validation_dataset = dataset.get("validation") | ||
|
||
else: | ||
validation_dataset = None | ||
|
||
model.train() | ||
trainer = initialize_trainer(model, source_path, validation_dataset) | ||
model.eval() | ||
|
||
return model, dict( | ||
trainer=trainer, | ||
tokenizer=tokenizer, | ||
input_names=list(next(trainer._get_fake_dataloader(1, tokenizer)).keys()), | ||
) | ||
|
||
|
||
def create_dummy_input( | ||
trainer: Optional[Trainer] = None, | ||
tokenizer: Optional[AutoTokenizer] = None, | ||
**kwargs, | ||
) -> torch.Tensor: | ||
if trainer.eval_dataset is not None: | ||
data_loader = trainer.get_eval_dataloader() | ||
else: | ||
if not tokenizer: | ||
raise ValueError( | ||
"Tokenizer is needed to generate " | ||
"fake sample inputs when the trainer is " | ||
"not initialized with an eval dataset" | ||
) | ||
data_loader = trainer._get_fake_dataloader(num_samples=1, tokenizer=tokenizer) | ||
return next(iter(data_loader)) | ||
|
||
|
||
def create_data_samples( | ||
num_samples: int, | ||
trainer: Trainer, | ||
model: Optional["torch.nn.Module"] = None, | ||
**kwargs, | ||
): | ||
if kwargs.get("batch_size"): | ||
_LOGGER.info( | ||
"For exporting samples for transformers integration," | ||
"batch size is ignored (equal to 1)" | ||
) | ||
if trainer.eval_dataset is None: | ||
raise ValueError( | ||
"Attempting to create data samples without an eval dataloader. " | ||
"Initialize a trainer with an eval dataset" | ||
) | ||
|
||
return create_data_samples_( | ||
data_loader=trainer.get_eval_dataloader(), model=model, num_samples=num_samples | ||
) | ||
|
||
|
||
@IntegrationHelperFunctions.register(name=Integrations.transformers.value) | ||
class Transformers(IntegrationHelperFunctions): | ||
create_model: Callable[..., Tuple[torch.nn.Module, Dict[str, Any]]] = Field( | ||
default=create_model | ||
) | ||
create_dummy_input: Callable[..., torch.Tensor] = Field(default=create_dummy_input) | ||
create_data_samples: Callable = Field(create_data_samples) | ||
deployment_directory_files_mandatory: List[str] = Field( | ||
default=list(MANDATORY_DEPLOYMENT_FILES) | ||
) | ||
deployment_directory_files_optional: List[str] = Field( | ||
default=list(OPTIONAL_DEPLOYMENT_FILES) | ||
) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for top level integration - can we keep it as
transformers
and imply generative based on task name? this seems a bit cumbersome and hidden from UX perspectiveThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's discuss it today.
I agree that it would be more appropriate to "fold" the
transformers_generative
undertransformers
in this hierarchy of argument. The contrary,however, still feels like a good implementation for one important reason: in our context,transformers_generative
are quite distinct from the rest oftransformers
. They have their own optimizations, potentially recipe loading logic (need to look into it), different set of files. There is also high probability that the more we focus on LLMs, and the more popularity they gain, over time they will more-and-more become a separate concept from the "normal" transformers.This is why I feel that the separate
IntegrationHelperFunctions
forgenerative_transformers
makes sense, and from that follows the separate integration name.