Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Export Refactor] Prepare the module to be more general (before including transformers) #1908

Merged
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 38 additions & 36 deletions src/sparseml/export/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import logging
import os
from pathlib import Path
from typing import Any, List, Optional, Union

Expand All @@ -22,6 +23,7 @@
ONNX_MODEL_NAME,
apply_optimizations,
create_deployment_folder,
create_export_kwargs,
)
from sparseml.export.validators import validate_correctness as validate_correctness_
from sparseml.export.validators import validate_structure as validate_structure_
Expand All @@ -44,14 +46,15 @@ def export(
opset: int = TORCH_DEFAULT_ONNX_OPSET,
single_graph_file: bool = True,
num_export_samples: int = 0,
batch_size: int = 1,
deployment_directory_name: str = "deployment",
device: str = "auto",
graph_optimizations: Union[str, List[str], None] = "all",
validate_correctness: bool = False,
validate_structure: bool = True,
integration: Optional[str] = None,
sample_data: Optional[Any] = None,
batch_size: Optional[int] = None,
task: Optional[str] = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -84,6 +87,8 @@ def export(
file. Defaults to True.
:param num_export_samples: The number of samples to create for
the exported model. Defaults to 0.
:param batch_size: The batch size to use for exporting the data.
Defaults to None.
:param deployment_directory_name: The name of the deployment
directory to create for the exported model. Thus, the exported
model will be saved to `target_path/deployment_directory_name`.
Expand All @@ -102,7 +107,7 @@ def export(
:param sample_data: Optional sample data to use for exporting
the model. If not provided, a dummy input will be created
for the model. Defaults to None.
:param batch_size: The batch size to use for exporting the data.
:param task: Optional task to use for exporting the model.
Defaults to None.
"""

Expand All @@ -121,7 +126,7 @@ def export(
f"Got {deployment_target} instead."
)

integration = resolve_integration(source_path, integration)
integration = resolve_integration(source_path, integration, task)

_LOGGER.info(f"Starting export for {integration} model...")

Expand All @@ -130,65 +135,49 @@ def export(
)

_LOGGER.info("Creating model for the export...")
model, validation_dataloader = helper_functions.create_model(
source_path, batch_size, device, **kwargs

# auxiliary_items may include any items
# that are needed for the export
model, auxiliary_items = helper_functions.create_model(
source_path, **dict(device=device, task=task, batch_size=batch_size), **kwargs
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
)

if validation_dataloader:
_LOGGER.info("Created validation dataloader for the export")
else:
_LOGGER.warning(
"Failed to create validation dataloader for the export. "
"Will be using the dummy (or user-provided) data instead "
"and will be not able to export samples or validate the model "
"correctness."
if auxiliary_items:
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
_LOGGER.info(
f"Created auxiliary items for the export: {list(auxiliary_items.keys())}"
)

sample_data = (
helper_functions.create_dummy_input(
validation_dataloader=validation_dataloader, **kwargs
)
helper_functions.create_dummy_input(**auxiliary_items, **kwargs)
if sample_data is None
else sample_data
)

_LOGGER.info(f"Exporting {onnx_model_name} to {target_path}...")

export_kwargs = create_export_kwargs(auxiliary_items)

onnx_file_path = helper_functions.export(
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
model=model,
sample_data=sample_data,
target_path=target_path,
onnx_model_name=onnx_model_name,
deployment_target=deployment_target,
opset=opset,
**export_kwargs,
)
_LOGGER.info(f"Successfully exported {onnx_model_name} to {target_path}...")

_LOGGER.info(
f"Applying optimizations: {graph_optimizations} to the exported model..."
)
apply_optimizations(
onnx_file_path=onnx_file_path,
target_optimizations=graph_optimizations,
available_optimizations=helper_functions.graph_optimizations,
single_graph_file=single_graph_file,
)
_LOGGER.info(f"Successfully exported {onnx_model_name} to {onnx_file_path}...")

if num_export_samples:
_LOGGER.info(f"Exporting {num_export_samples} samples...")
if not validation_dataloader:
raise ValueError(
"To export sample inputs/outputs a data loader is needed. "
"To return a data loader provide the appropriate, integration-specific "
"arguments to `create_model` function"
)
(
input_samples,
output_samples,
label_samples,
) = helper_functions.create_data_samples(
num_samples=num_export_samples,
data_loader=validation_dataloader,
model=model,
**auxiliary_items,
)
export_data_samples(
input_samples=input_samples,
Expand All @@ -207,16 +196,29 @@ def export(
source_path=source_path,
target_path=target_path,
deployment_directory_name=deployment_directory_name,
deployment_directory_files=helper_functions.deployment_directory_structure,
deployment_directory_files_mandatory=helper_functions.deployment_directory_files_mandatory, # noqa: E501
deployment_directory_files_optional=helper_functions.deployment_directory_files_optional, # noqa: E501
onnx_model_name=onnx_model_name,
)

_LOGGER.info(
f"Applying optimizations: {graph_optimizations} to the exported model..."
)
apply_optimizations(
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the purpose of apply_optimizations top level function over helper_functions.graph_optimizations again worried about specificity to onnx

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

graph_optimization contains a mapping that specifies all the possible optimizations for a specific integration. the apply_optimization is a general function that establishes which of those optimizations should be actually applied and applies them. Why worried about onnx specificity? We could always pass the deployment_target as an argument of apply_optimizations - this actually demonstrates the usefulness of this abstraction.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still fail to see how folding this function into each of the IntegrationHelperFunctions is beneficial. It would add to code duplication, and IMO worsen the readability for the user. I'd propose to leave it for now, it's a very small component, can be refactored easily if needed.

onnx_file_path=os.path.join(deployment_path, onnx_model_name),
target_optimizations=graph_optimizations,
available_optimizations=helper_functions.graph_optimizations,
single_graph_file=single_graph_file,
)

if validate_structure:
_LOGGER.info("Validating model structure...")
validate_structure_(
target_path=target_path,
deployment_directory_name=deployment_directory_name,
onnx_model_name=onnx_model_name,
deployment_directory_files=helper_functions.deployment_directory_structure,
deployment_directory_files_mandatory=helper_functions.deployment_directory_files_mandatory, # noqa: E501
deployment_directory_files_optional=helper_functions.deployment_directory_files_optional, # noqa: E501
)

if validate_correctness:
Expand Down
55 changes: 38 additions & 17 deletions src/sparseml/export/export_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import tarfile
from enum import Enum
from pathlib import Path
from typing import List, Optional, Tuple, Union
from typing import Any, List, Optional, Tuple, Union

import torch
from tqdm import tqdm
Expand Down Expand Up @@ -50,7 +50,7 @@ def create_data_samples(
data_loader: torch.utils.data.DataLoader,
model: Optional[torch.nn.Module] = None,
num_samples: int = 1,
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]:
) -> Tuple[List[Any], List[Any], List[Any]]:
"""
Fetch a batch of samples from the data loader and return the inputs and outputs

Expand All @@ -63,30 +63,44 @@ def create_data_samples(
inputs, outputs, labels = [], [], []
if model is None:
_LOGGER.warning("The model is None. The list of outputs will be empty")
for batch_num, (inputs_, labels_) in tqdm(enumerate(data_loader)):

for batch_num, data in tqdm(enumerate(data_loader)):
if batch_num == num_samples:
break
if model:
outputs_ = model(inputs_)

if isinstance(data, dict):
# assume transformers inference
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
from sparseml.transformers.utils.helpers import run_transformers_inference

inputs_, labels_, outputs_ = run_transformers_inference(
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
inputs=data, model=model
)
else:
# assume image classification inference
inputs_, labels_ = data
outputs_ = model(inputs_) if model else None
if isinstance(outputs_, tuple):
# outputs_ contains (logits, softmax)
outputs_ = outputs_[0]
outputs.append(outputs_)

inputs.append(inputs_)
labels.append(
torch.IntTensor([labels_])
if not isinstance(labels_, torch.Tensor)
else labels_
)
if outputs_ is not None:
outputs.append(outputs_)
if labels_ is not None:
labels.append(
torch.IntTensor([labels_])
if not isinstance(labels_, torch.Tensor)
else labels_
)

return inputs, outputs, labels


def export_data_samples(
target_path: Union[Path, str],
input_samples: Optional[List["torch.Tensor"]] = None, # noqa F821
output_samples: Optional[List["torch.Tensor"]] = None, # noqa F821
label_samples: Optional[List["torch.Tensor"]] = None, # noqa F821
input_samples: Optional[List[Any]] = None,
output_samples: Optional[List[Any]] = None,
label_samples: Optional[List[Any]] = None,
as_tar: bool = False,
):
"""
Expand Down Expand Up @@ -116,6 +130,7 @@ def export_data_samples(

:param input_samples: The input samples to save.
:param output_samples: The output samples to save.
:param label_samples: The label samples to save.
:param target_path: The path to save the samples to.
:param as_tar: Whether to save the samples as tar files.
"""
Expand All @@ -124,16 +139,21 @@ def export_data_samples(
[input_samples, output_samples, label_samples],
[InputsNames, OutputsNames, LabelNames],
):
if samples is not None:
if len(samples) > 0:
_LOGGER.info(f"Exporting {names.basename.value} to {target_path}...")
export_data_sample(samples, names, target_path, as_tar)
break_batch = isinstance(samples[0], dict)
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
export_data_sample(samples, names, target_path, as_tar, break_batch)
_LOGGER.info(
f"Successfully exported {names.basename.value} to {target_path}!"
)


def export_data_sample(
samples, names: Enum, target_path: Union[Path, str], as_tar: bool = False
samples,
names: Enum,
target_path: Union[Path, str],
as_tar: bool = False,
break_batch=False,
):

samples = tensors_to_device(samples, "cpu")
Expand All @@ -142,6 +162,7 @@ def export_data_sample(
tensors=samples,
export_dir=os.path.join(target_path, names.basename.value),
name_prefix=names.filename.value,
break_batch=break_batch,
)
if as_tar:
folder_path = os.path.join(target_path, names.basename.value)
Expand Down
Loading