Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Export Refactor] End to end testing #1898

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
6d8991c
initial commit
dbogunowicz Dec 5, 2023
059ceab
looking good, time to cleanup
dbogunowicz Dec 6, 2023
b390c2e
Delete src/sparseml/export/helpers.py
dbogunowicz Dec 6, 2023
c2c8444
Delete tests/sparseml/export/test_helpers.py
dbogunowicz Dec 6, 2023
6ce6ba5
ready for review
dbogunowicz Dec 6, 2023
6f3e5e7
Merge branch 'feature/damian/create_model_ic' of github.com:neuralmag…
dbogunowicz Dec 6, 2023
5dfbdcd
improve design
dbogunowicz Dec 6, 2023
042c193
tests pass
dbogunowicz Dec 6, 2023
29cfa1d
reuse _validate_dataset_num_classes
dbogunowicz Dec 6, 2023
ab73aec
initial commit
dbogunowicz Dec 6, 2023
f628532
Update src/sparseml/pytorch/image_classification/integration_helper_f…
dbogunowicz Dec 6, 2023
b93b634
Update src/sparseml/pytorch/image_classification/integration_helper_f…
dbogunowicz Dec 6, 2023
e7606cd
ready for review
dbogunowicz Dec 7, 2023
ea9cb61
Update src/sparseml/export/export.py
dbogunowicz Dec 7, 2023
9572e0b
Update src/sparseml/integration_helper_functions.py
dbogunowicz Dec 7, 2023
0229deb
initial commit
dbogunowicz Dec 7, 2023
a8c1b68
fixes
dbogunowicz Dec 7, 2023
2379354
ready for review
dbogunowicz Dec 7, 2023
741fb12
nit
dbogunowicz Dec 7, 2023
ebdeb9f
add return
dbogunowicz Dec 7, 2023
8b2fca0
initial commit
dbogunowicz Dec 7, 2023
ff52598
initial commit
dbogunowicz Dec 8, 2023
08f42c2
initial commit
dbogunowicz Dec 8, 2023
fc7cf74
initial commit
dbogunowicz Dec 8, 2023
a94dda3
Merge remote-tracking branch 'origin/feature/damian/feature_branch_ex…
dbogunowicz Dec 12, 2023
8110e10
Delete tests/sparseml/test_integration_helper_functions.py
dbogunowicz Dec 12, 2023
3404b89
Merge remote-tracking branch 'origin/feature/damian/feature_branch_ex…
dbogunowicz Dec 12, 2023
578a2a9
ready to merge
dbogunowicz Dec 12, 2023
2ed0d50
add structure validator
dbogunowicz Dec 12, 2023
4d2f11f
ready for review
dbogunowicz Dec 13, 2023
93bc8c2
Delete tests/sparseml/export/model.onnx
dbogunowicz Dec 13, 2023
0f06f92
Delete tests/sparseml/export/image_classification/model.onnx
dbogunowicz Dec 13, 2023
cdb2ce8
Delete tests/sparseml/export/image_classification/conftest.py
dbogunowicz Dec 13, 2023
4e1fa25
Merge remote-tracking branch 'origin/feature/damian/feature_branch_ex…
dbogunowicz Dec 13, 2023
d00e229
PR comments
dbogunowicz Dec 14, 2023
038a9bd
remove onnx
dbogunowicz Dec 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 118 additions & 56 deletions src/sparseml/export/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,21 @@
from pathlib import Path
from typing import Any, List, Optional, Union

from sparseml.export.export_data import export_data_samples
from sparseml.export.helpers import (
AVAILABLE_DEPLOYMENT_TARGETS,
ONNX_MODEL_NAME,
apply_optimizations,
create_deployment_folder,
)
from sparseml.export.validate_correctness import validate_correctness
from sparseml.export_data import export_data_samples
from sparseml.integration_helper_functions import (
from sparseml.export.validators import validate_correctness as validate_correctness_
from sparseml.export.validators import validate_structure as validate_structure_
from sparseml.pytorch.opset import TORCH_DEFAULT_ONNX_OPSET
from sparseml.pytorch.utils.helpers import default_device
from src.sparseml.integration_helper_functions import (
IntegrationHelperFunctions,
infer_integration,
resolve_integration,
)
from sparseml.pytorch.opset import TORCH_DEFAULT_ONNX_OPSET


_LOGGER = logging.getLogger(__name__)
Expand All @@ -37,135 +39,195 @@
def export(
source_path: Union[Path, str],
target_path: Union[Path, str],
model_onnx_name: str = ONNX_MODEL_NAME,
onnx_model_name: str = ONNX_MODEL_NAME,
deployment_target: str = "deepsparse",
integration: Optional[str] = None,
sample_data: Optional[Any] = None,
opset: int = TORCH_DEFAULT_ONNX_OPSET,
batch_size: Optional[int] = None,
single_graph_file: bool = True,
graph_optimizations: Union[str, List[str], None] = "all",
validate_model_correctness: bool = False,
num_export_samples: int = 0,
deployment_directory_name: str = "deployment",
device: str = "auto",
graph_optimizations: Union[str, List[str], None] = "all",
validate_correctness: bool = False,
validate_structure: bool = True,
integration: Optional[str] = None,
sample_data: Optional[Any] = None,
batch_size: Optional[int] = None,
**kwargs,
):
"""
Export a PyTorch model to a deployment target specified by the `deployment_target`.

The functionality follows a set of steps:
1. Create a PyTorch model from the file located in source_path.
2. Create model dummy input.
3. Export the model to the format specified by the `deployment_target`.
4. (Optional) Apply optimizations to the exported model.
5. Export sample inputs and outputs for the exported model (optional).
6. Create a deployment folder for the exported model with the appropriate structure.
7. Validate the correctness of the exported model (optional).
Export a PyTorch model located in source_path, to target_path.
The deployment files will be located at target_path/deployment_directory_name

The exporting logic consists of the following steps:
1. Create the model and validation dataloader (if needed) using the
integration-specific `create_model` function.
2. Export the model to ONNX using the integration-specific `export` function.
3. Apply the graph optimizations to the exported model.
4. Create the deployment folder at target_path/deployment_directory_name
using the integration-specific `create_deployment_folder` function.
5. Optionally, export samples using the integration-specific
`create_data_samples` function.
6. Optionally, validate the correctness of the exported model using
the integration-specific `validate_correctness` function.
7. Optionally, validate the structure of the exported model using
the integration-specific `validate_structure` function.

:param source_path: The path to the PyTorch model to export.
:param target_path: The path to save the exported model to.
:param model_onnx_name: The name of the exported model.
:param onnx_model_name: The name of the exported model.
Defaults to ONNX_MODEL_NAME.
:param deployment_target: The deployment target to export
the model to. Defaults to 'deepsparse'.
:param integration: The name of the integration to use for
exporting the model.Defaults to None, which will infer
the integration from the source_path.
:param sample_data: Optional sample data to use for exporting
the model. If not provided, a dummy input will be created
for the model. Defaults to None.
:param opset: The ONNX opset to use for exporting the model.
Defaults to the latest supported opset.
:param batch_size: The batch size to use for exporting the model.
Defaults to None.
:param single_graph_file: Whether to save the model as a single
file (that contains both the model graph and model weights).
Defaults to True.
:param graph_optimizations: The graph optimizations to apply
to the exported model. Defaults to 'all'.
:param validate_model_correctness: Whether to validate the correctness
of the exported model. Defaults to False.
:param num_export_samples: The number of samples to export for
file. Defaults to True.
:param num_export_samples: The number of samples to create for
the exported model. Defaults to 0.
:param deployment_directory_name: The name of the deployment
directory to create for the exported model. Thus, the exported
model will be saved to `target_path/deployment_directory_name`.
Defaults to 'deployment'.
:param device: The device to use for exporting the model.
Defaults to 'auto'.
:param graph_optimizations: The graph optimizations to apply
to the exported model. Defaults to 'all'.
:param validate_correctness: Whether to validate the correctness
of the exported model. Defaults to False.
:param validate_structure: Whether to validate the structure
of the exporter model (contents of the target_path).
:param integration: The name of the integration to use for
exporting the model.Defaults to None, which will infer
the integration from the source_path.
:param sample_data: Optional sample data to use for exporting
the model. If not provided, a dummy input will be created
for the model. Defaults to None.
:param batch_size: The batch size to use for exporting the data.
Defaults to None.
"""

# create the target path if it doesn't exist
if not Path(target_path).exists():
Path(target_path).mkdir(parents=True, exist_ok=True)

# choose the appropriate device
device = default_device() if device == "auto" else device

# assert the valid deployment target
if deployment_target not in AVAILABLE_DEPLOYMENT_TARGETS:
raise ValueError(
"Argument: deployment_target must be "
f"one of {AVAILABLE_DEPLOYMENT_TARGETS}. "
f"Got {deployment_target} instead."
)

integration = integration or infer_integration(source_path)
integration = resolve_integration(source_path, integration)

_LOGGER.info(f"Starting export for {integration} model...")

helper_functions: IntegrationHelperFunctions = (
IntegrationHelperFunctions.load_from_registry(integration)
)

# for now, this code is not runnable, serves as a blueprint
model, auxiliary_items = helper_functions.create_model(
source_path, **kwargs # noqa: F821
_LOGGER.info("Creating model for the export...")
model, validation_dataloader = helper_functions.create_model(
source_path, batch_size, device, **kwargs
)

if validation_dataloader:
_LOGGER.info("Created validation dataloader for the export")
else:
_LOGGER.warning(
"Failed to create validation dataloader for the export. "
"Will be using the dummy (or user-provided) data instead "
"and will be not able to export samples or validate the model "
"correctness."
)

sample_data = (
helper_functions.create_dummy_input(**auxiliary_items)
helper_functions.create_dummy_input(
validation_dataloader=validation_dataloader, **kwargs
)
if sample_data is None
else sample_data
)
onnx_file_path = helper_functions.export_model(
model, sample_data, target_path, deployment_target, opset

_LOGGER.info(f"Exporting {onnx_model_name} to {target_path}...")
onnx_file_path = helper_functions.export(
model=model,
sample_data=sample_data,
target_path=target_path,
onnx_model_name=onnx_model_name,
deployment_target=deployment_target,
opset=opset,
)
_LOGGER.info(f"Successfully exported {onnx_model_name} to {target_path}...")

_LOGGER.info(
f"Applying optimizations: {graph_optimizations} to the exported model..."
)
apply_optimizations(
onnx_file_path=onnx_file_path,
graph_optimizations=graph_optimizations,
available_graph_optimizations=helper_functions.graph_optimizations,
target_optimizations=graph_optimizations,
available_optimizations=helper_functions.graph_optimizations,
single_graph_file=single_graph_file,
)

if num_export_samples:
data_loader = auxiliary_items.get("validation_loader")
if data_loader is None:
_LOGGER.info(f"Exporting {num_export_samples} samples...")
if not validation_dataloader:
raise ValueError(
"To export sample inputs/outputs a data loader is needed."
"To enable the export, provide a `validation_loader` "
"as a part of `auxiliary_items` output of the `create_model` function."
"To export sample inputs/outputs a data loader is needed. "
"To return a data loader provide the appropriate, integration-specific "
"arguments to `create_model` function"
)
(
input_samples,
output_samples,
label_samples,
) = helper_functions.create_data_samples(
num_samples=num_export_samples, data_loader=data_loader, model=model
num_samples=num_export_samples,
data_loader=validation_dataloader,
model=model,
)
export_data_samples(
input_samples=input_samples,
output_samples=output_samples,
label_samples=label_samples,
target_path=target_path,
as_tar=True,
as_tar=False,
)

_LOGGER.info(
f"Creating deployment folder {deployment_directory_name} "
f"at directory: {target_path}..."
)

deployment_path = create_deployment_folder(
source_path=source_path,
target_path=target_path,
deployment_directory_name=deployment_directory_name,
deployment_directory_files=helper_functions.deployment_directory_structure,
onnx_model_name=model_onnx_name,
onnx_model_name=onnx_model_name,
)
if validate_structure:
_LOGGER.info("Validating model structure...")
validate_structure_(
target_path=target_path,
deployment_directory_name=deployment_directory_name,
onnx_model_name=onnx_model_name,
deployment_directory_files=helper_functions.deployment_directory_structure,
)

if validate_model_correctness:
if validate_correctness:
_LOGGER.info("Validating model correctness...")
if not num_export_samples:
raise ValueError(
"To validate correctness sample inputs/outputs are needed."
"To enable the validation, set `num_export_samples`"
"to True"
)
validate_correctness(deployment_path, model_onnx_name)
validate_correctness_(target_path, deployment_path, onnx_model_name)

_LOGGER.info(
f"Successfully exported model from:\n{target_path}"
Expand Down
34 changes: 23 additions & 11 deletions src/sparseml/export/export_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@
from pathlib import Path
from typing import List, Optional, Tuple, Union

import torch
from tqdm import tqdm

from sparseml.pytorch.utils.helpers import tensors_export, tensors_to_device

__all__ = ["create_data_samples"]

__all__ = ["create_data_samples", "export_data_samples"]

_LOGGER = logging.getLogger(__name__)

Expand All @@ -44,30 +47,37 @@ class InputsNames(Enum):


def create_data_samples(
data_loader: "torch.utils.data.DataLoader", # noqa F821
model: Optional["torch.nn.Module"] = None, # noqa F821
data_loader: torch.utils.data.DataLoader,
model: Optional[torch.nn.Module] = None,
num_samples: int = 1,
) -> Tuple[
List["torch.Tensor"], # noqa F821
Optional[List["torch.Tensor"]], # noqa F821
List["torch.Tensor"], # noqa F821
]:
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]:
"""
Fetch a batch of samples from the data loader and return the inputs and outputs

:param data_loader: The data loader to get a batch of inputs/outputs from.
:param model: The model to run the inputs through to get the outputs.
If None, the outputs will be an empty list.
:param num_samples: The number of samples to generate. Defaults to 1
:return: The inputs and outputs as lists of torch tensors
"""
inputs, outputs, labels = [], [], []
if model is None:
_LOGGER.warning("The model is None. The list of outputs will be empty")
for batch_num, (inputs_, labels_) in tqdm(enumerate(data_loader)):
if batch_num == num_samples:
break
if model:
outputs_ = model(inputs_)
if isinstance(outputs_, tuple):
# outputs_ contains (logits, softmax)
outputs_ = outputs_[0]
outputs.append(outputs_)
inputs.append(inputs_)
labels.append(labels_)
labels.append(
torch.IntTensor([labels_])
if not isinstance(labels_, torch.Tensor)
else labels_
)

return inputs, outputs, labels

Expand Down Expand Up @@ -115,14 +125,16 @@ def export_data_samples(
[InputsNames, OutputsNames, LabelNames],
):
if samples is not None:
_LOGGER.info(f"Exporting {names.basename.value} to {target_path}")
_LOGGER.info(f"Exporting {names.basename.value} to {target_path}...")
export_data_sample(samples, names, target_path, as_tar)
_LOGGER.info(
f"Successfully exported {names.basename.value} to {target_path}!"
)


def export_data_sample(
samples, names: Enum, target_path: Union[Path, str], as_tar: bool = False
):
from sparseml.pytorch.utils.helpers import tensors_export, tensors_to_device

samples = tensors_to_device(samples, "cpu")

Expand Down
13 changes: 7 additions & 6 deletions src/sparseml/export/export_torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from pathlib import Path
from typing import Union

import onnx
import torch

from sparseml.exporters import ExportTargets
Expand Down Expand Up @@ -50,15 +51,15 @@ def export_model(
"""

model.eval()

path_to_exported_model = os.path.join(target_path, onnx_model_name)
exporter = TorchToONNX(sample_batch=sample_data, opset=opset, **kwargs)
exporter.export(model, os.path.join(target_path, onnx_model_name))
exporter.export(model, path_to_exported_model)
if deployment_target == ExportTargets.deepsparse.value:
exporter = ONNXToDeepsparse()
model = exporter.load_model(os.path.join(target_path, onnx_model_name))
exporter.export(model, os.path.join(target_path, onnx_model_name))
model = onnx.load(path_to_exported_model)
exporter.export(model, path_to_exported_model)
return path_to_exported_model
if deployment_target == ExportTargets.onnx.value:
pass
return path_to_exported_model
else:
raise ValueError(f"Unsupported deployment target: {deployment_target}")
return os.path.join(target_path, onnx_model_name)
Loading