Skip to content

Commit

Permalink
[Export Refactor][Image Classification] export_model function (#1883)
Browse files Browse the repository at this point in the history
* initial commit

* looking good, time to cleanup

* Delete src/sparseml/export/helpers.py

* Delete tests/sparseml/export/test_helpers.py

* ready for review

* improve design

* tests pass

* reuse _validate_dataset_num_classes

* initial commit

* Update src/sparseml/pytorch/image_classification/integration_helper_functions.py

* Update src/sparseml/pytorch/image_classification/integration_helper_functions.py

* ready for review

* Update src/sparseml/export/export.py

* Update src/sparseml/integration_helper_functions.py

* initial commit

* fixes

* ready for review

* nit

* add return

* make export function more general
  • Loading branch information
dbogunowicz committed Dec 11, 2023
1 parent bac0802 commit 0830372
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 13 deletions.
14 changes: 11 additions & 3 deletions src/sparseml/export/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from pathlib import Path
from typing import Any, List, Optional, Union

from sparseml.exporters import ExportTargets
from sparseml.integration_helper_functions import (
IntegrationHelperFunctions,
infer_integration,
Expand All @@ -24,12 +25,14 @@


_LOGGER = logging.getLogger(__name__)
AVAILABLE_DEPLOYMENT_TARGETS = ["deepsparse", "onnxruntime"]
AVAILABLE_DEPLOYMENT_TARGETS = [target.value for target in ExportTargets]
ONNX_MODEL_NAME = "model.onnx"


def export(
source_path: Union[Path, str],
target_path: Union[Path, str],
model_onnx_name: str = ONNX_MODEL_NAME,
deployment_target: str = "deepsparse",
integration: Optional[str] = None,
sample_data: Optional[Any] = None,
Expand Down Expand Up @@ -57,6 +60,8 @@ def export(
:param source_path: The path to the PyTorch model to export.
:param target_path: The path to save the exported model to.
:param model_onnx_name: The name of the exported model.
Defaults to ONNX_MODEL_NAME.
:param deployment_target: The deployment target to export
the model to. Defaults to 'deepsparse'.
:param integration: The name of the integration to use for
Expand Down Expand Up @@ -101,7 +106,7 @@ def export(
# for now, this code is not runnable, serves as a blueprint
model, auxiliary_items = helper_functions.create_model(
source_path, **kwargs # noqa: F821
)
)
sample_data = (
helper_functions.create_dummy_input(**auxiliary_items)
if sample_data is None
Expand All @@ -111,7 +116,10 @@ def export(
model, sample_data, target_path, deployment_target, opset, single_graph_file
)

helper_functions.apply_optimizations(onnx_file_path, graph_optimizations)
helper_functions.apply_optimizations(
onnx_file_path,
graph_optimizations,
)

if export_sample_inputs_outputs:
helper_functions.export_sample_inputs_outputs(model, target_path)
Expand Down
64 changes: 64 additions & 0 deletions src/sparseml/export/export_torch_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from pathlib import Path
from typing import Union

import torch

from sparseml.exporters import ExportTargets
from sparseml.exporters.onnx_to_deepsparse import ONNXToDeepsparse
from sparseml.pytorch.opset import TORCH_DEFAULT_ONNX_OPSET
from sparseml.pytorch.torch_to_onnx_exporter import TorchToONNX


__all__ = ["export_model"]


def export_model(
model: torch.nn.Module,
sample_data: torch.Tensor,
target_path: Union[Path, str],
onnx_model_name: str,
deployment_target: str = "deepsparse",
opset: int = TORCH_DEFAULT_ONNX_OPSET,
**kwargs,
) -> str:
"""
Exports the torch model to the deployment target
:param model: The torch model to export
:param sample_data: The sample data to use for the export
:param target_path: The path to export the model to
:param onnx_model_name: The name to save the exported ONNX model as
:param deployment_target: The deployment target to export to. Defaults to deepsparse
:param opset: The opset to use for the export. Defaults to TORCH_DEFAULT_ONNX_OPSET
:param kwargs: Additional kwargs to pass to the TorchToONNX exporter
:return: The path to the exported model
"""

model.eval()

exporter = TorchToONNX(sample_batch=sample_data, opset=opset, **kwargs)
exporter.export(model, os.path.join(target_path, onnx_model_name))
if deployment_target == ExportTargets.deepsparse.value:
exporter = ONNXToDeepsparse()
model = exporter.load_model(os.path.join(target_path, onnx_model_name))
exporter.export(model, os.path.join(target_path, onnx_model_name))
if deployment_target == ExportTargets.onnx.value:
pass
else:
raise ValueError(f"Unsupported deployment target: {deployment_target}")
return os.path.join(target_path, onnx_model_name)
11 changes: 11 additions & 0 deletions src/sparseml/exporters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from enum import Enum


class ExportTargets(Enum):
"""
Holds the names of the supported export targets
"""

deepsparse = "deepsparse"
onnx = "onnx"
25 changes: 16 additions & 9 deletions src/sparseml/integration_helper_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@

from enum import Enum
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Union
from typing import Any, Callable, Dict, Optional, Tuple, Union

from pydantic import BaseModel, Field

from sparseml.export.export_torch_model import export_model
from sparsezoo.utils.registry import RegistryMixin


Expand All @@ -41,9 +42,8 @@ class IntegrationHelperFunctions(RegistryMixin, BaseModel):

create_model: Optional[
Callable[
[Union[str, Path], Optional[Dict[str, Any]]][
"torch.nn.Module", Optional[Dict[str, Any]] # noqa F821
]
Tuple[Union[str, Path], Optional[Dict[str, Any]]],
Tuple["torch.nn.Module", Dict[str, Any]], # noqa F821
]
] = Field(
description="A function that takes: "
Expand All @@ -54,17 +54,24 @@ class IntegrationHelperFunctions(RegistryMixin, BaseModel):
"- (optionally) a dictionary of additional arguments"
)
create_dummy_input: Optional[
Callable[Any]["torch.Tensor"] # noqa F821
Callable[..., "torch.Tensor"] # noqa F821
] = Field( # noqa: F82
description="A function that takes: "
"- a dictionary of arguments"
"and returns: "
"- a dummy input for the model (a torch.Tensor) "
)
export_model: Optional[Callable] = Field(
description="A function that exports a (sparse) PyTorch "
"model to an ONNX format appropriate for a "
"deployment target."
export: Callable[..., str] = Field(
description="A function that takes: "
" - a (sparse) PyTorch model "
" - sample input data "
" - the path to save the exported model to "
" - the name to save the exported ONNX model as "
" - the deployment target to export to "
" - the opset to use for the export "
" - (optionally) a dictionary of additional arguments"
"and returns path to the exported model",
default=export_model,
)
apply_optimizations: Optional[Callable] = Field(
description="A function that takes a set of "
Expand Down
2 changes: 1 addition & 1 deletion src/sparseml/pytorch/torch_to_onnx_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class TorchToONNX(BaseExporter):
inference engine, and other engines, perform batch norm fusing at model
compilation.
:param export_kwargs: kwargs to be passed as is to the torch.onnx.export api
call. Useful to pass in dyanmic_axes, input_names, output_names, etc.
call. Useful to pass in dynamic_axes, input_names, output_names, etc.
See more on the torch.onnx.export api spec in the PyTorch docs:
https://pytorch.org/docs/stable/onnx.html
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@ def test_integration_helper_functions():
)
assert image_classification.create_model
assert image_classification.create_dummy_input
assert image_classification.export

0 comments on commit 0830372

Please sign in to comment.