Skip to content

Commit

Permalink
[Export Refactor][Image Classification] apply_optimizations function (
Browse files Browse the repository at this point in the history
#1884)

* initial commit

* looking good, time to cleanup

* Delete src/sparseml/export/helpers.py

* Delete tests/sparseml/export/test_helpers.py

* ready for review

* improve design

* tests pass

* reuse _validate_dataset_num_classes

* initial commit

* Update src/sparseml/pytorch/image_classification/integration_helper_functions.py

* Update src/sparseml/pytorch/image_classification/integration_helper_functions.py

* ready for review

* Update src/sparseml/export/export.py

* Update src/sparseml/integration_helper_functions.py

* initial commit

* fixes

* ready for review

* nit

* add return

* initial commit
  • Loading branch information
dbogunowicz committed Dec 11, 2023
1 parent 0830372 commit 3f86251
Show file tree
Hide file tree
Showing 7 changed files with 278 additions and 15 deletions.
20 changes: 11 additions & 9 deletions src/sparseml/export/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from pathlib import Path
from typing import Any, List, Optional, Union

from sparseml.export.helpers import apply_optimizations
from sparseml.exporters import ExportTargets
from sparseml.integration_helper_functions import (
IntegrationHelperFunctions,
Expand Down Expand Up @@ -49,11 +50,10 @@ def export(
Export a PyTorch model to a deployment target specified by the `deployment_target`.
The functionality follows a set of steps:
1. Create a PyTorch model from the source_path.
2. Create a dummy input for the model.
3. Export the model, using the precomputed dummy input, to an
ONNX format appropriate for the deployment target.
4. Apply optimizations to the exported model (optional).
1. Create a PyTorch model from the file located in source_path.
2. Create model dummy input.
3. Export the model to the format specified by the `deployment_target`.
4. (Optional) Apply optimizations to the exported model.
5. Export sample inputs and outputs for the exported model (optional).
6. Create a deployment folder for the exported model with the appropriate structure.
7. Validate the correctness of the exported model (optional).
Expand Down Expand Up @@ -113,12 +113,14 @@ def export(
else sample_data
)
onnx_file_path = helper_functions.export_model(
model, sample_data, target_path, deployment_target, opset, single_graph_file
model, sample_data, target_path, deployment_target, opset
)

helper_functions.apply_optimizations(
onnx_file_path,
graph_optimizations,
apply_optimizations(
onnx_file_path=onnx_file_path,
graph_optimizations=graph_optimizations,
available_graph_optimizations=helper_functions.graph_optimizations,
single_graph_file=single_graph_file,
)

if export_sample_inputs_outputs:
Expand Down
107 changes: 107 additions & 0 deletions src/sparseml/export/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import OrderedDict
from enum import Enum
from pathlib import Path
from typing import Callable, List, Union

import onnx

from sparsezoo.utils.onnx import save_onnx


__all__ = ["apply_optimizations"]


class GraphOptimizationOptions(Enum):
"""
Holds the string names of the graph optimization options.
"""

none = "none"
all = "all"


def apply_optimizations(
onnx_file_path: Union[str, Path],
available_optimizations: OrderedDict[str, Callable],
target_optimizations: Union[str, List[str]] = GraphOptimizationOptions.all.value,
single_graph_file: bool = True,
):
"""
Apply optimizations to the graph of the ONNX model.
:param onnx_file_path: The path to the ONNX model file.
:param available_optimizations: The graph optimizations available
for the model. It is an ordered mapping from the string names
to functions that alter the model
:param target_optimizations: The name(s) of optimizations to apply.
It can be either a list of string name or a single string option
that specifies the set of optimizations to apply.
If is string, refer to the `GraphOptimizationOptions` enum
for the available options.
:param single_graph_file: Whether to save the optimized graph to a single
file or split it into multiple files. By default, it is True.
"""
optimizations: List[Callable] = resolve_graph_optimizations(
available_optimizations=available_optimizations,
optimizations=target_optimizations,
)

onnx_model = onnx.load(onnx_file_path)

for optimization in optimizations:
onnx_model = optimization(onnx_model)

if single_graph_file:
save_onnx(onnx_model, onnx_file_path)
return

save_onnx_multiple_files(onnx_model)


def resolve_graph_optimizations(
available_optimizations: OrderedDict[str, Callable],
optimizations: Union[str, List[str]],
) -> List[Callable]:
"""
Get the optimization functions to apply to the onnx model.
:param available_optimizations: The graph optimizations available
for the model. It is an ordered mapping from the string names
to functions that alter the model
:param optimizations: The name(s) of optimizations to apply.
It can be either a list of string name or a single string option
that specifies the set of optimizations to apply.
If is string, refer to the `GraphOptimizationOptions` enum
for the available options.
return The list of optimization functions to apply.
"""
if isinstance(optimizations, str):
if optimizations == GraphOptimizationOptions.none.value:
return []
elif optimizations == GraphOptimizationOptions.all.value:
return list(available_optimizations.values())
else:
raise KeyError(f"Unknown graph optimization option: {optimizations}")
elif isinstance(optimizations, list):
return [available_optimizations[optimization] for optimization in optimizations]
else:
raise KeyError(f"Unknown graph optimization option: {optimizations}")


# TODO: To discuss with @bfineran
def save_onnx_multiple_files(*args, **kwargs):
raise NotImplementedError
10 changes: 4 additions & 6 deletions src/sparseml/integration_helper_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class IntegrationHelperFunctions(RegistryMixin, BaseModel):
"and returns: "
"- a dummy input for the model (a torch.Tensor) "
)
export: Callable[..., str] = Field(
export: Optional[Callable[..., str]] = Field(
description="A function that takes: "
" - a (sparse) PyTorch model "
" - sample input data "
Expand All @@ -70,12 +70,10 @@ class IntegrationHelperFunctions(RegistryMixin, BaseModel):
" - the deployment target to export to "
" - the opset to use for the export "
" - (optionally) a dictionary of additional arguments"
"and returns path to the exported model",
default=export_model,
"and returns nothing"
)
apply_optimizations: Optional[Callable] = Field(
description="A function that takes a set of "
"optimizations and applies them to an ONNX model."
graph_optimizations: Optional[Dict[str, Callable]] = Field(
description="A mapping from names to graph optimization functions "
)
export_sample_inputs_outputs: Optional[Callable] = Field(
description="A function that exports input/output samples given "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch
from pydantic import Field

from sparseml.pytorch.image_classification.utils.helpers import export_model
from src.sparseml.integration_helper_functions import (
IntegrationHelperFunctions,
Integrations,
Expand Down Expand Up @@ -66,7 +67,9 @@ def create_dummy_input(

@IntegrationHelperFunctions.register(name=Integrations.image_classification.value)
class ImageClassification(IntegrationHelperFunctions):

create_model: Callable[..., Tuple[torch.nn.Module, Dict[str, Any]]] = Field(
default=create_model
)
create_dummy_input: Callable[..., torch.Tensor] = Field(default=create_dummy_input)
export: Callable[..., str] = Field(default=export_model)
41 changes: 41 additions & 0 deletions src/sparseml/pytorch/image_classification/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,18 @@
from torch.optim import Optimizer
from torch.utils.data import DataLoader, Dataset

from sparseml.exporters import ExportTargets
from sparseml.exporters.onnx_to_deepsparse import ONNXToDeepsparse
from sparseml.optim.manager import BaseManager
from sparseml.pytorch.datasets import DatasetRegistry
from sparseml.pytorch.datasets.image_classification.ffcv_dataset import (
FFCVCompatibleDataset,
)
from sparseml.pytorch.image_classification.utils.constants import AVAILABLE_DATASETS
from sparseml.pytorch.models import ModelRegistry
from sparseml.pytorch.opset import TORCH_DEFAULT_ONNX_OPSET
from sparseml.pytorch.optim import ScheduledModifierManager
from sparseml.pytorch.torch_to_onnx_exporter import TorchToONNX
from sparseml.pytorch.utils import (
DEFAULT_LOSS_KEY,
CrossEntropyLossWrapper,
Expand Down Expand Up @@ -77,6 +81,43 @@
]


def export_model(
model: torch.nn.Module,
sample_data: torch.Tensor,
target_path: Union[Path, str],
onnx_model_name: str,
deployment_target: str = "deepsparse",
opset: int = TORCH_DEFAULT_ONNX_OPSET,
**kwargs,
) -> str:
"""
Exports the torch model to the deployment target
:param model: The torch model to export
:param sample_data: The sample data to use for the export
:param target_path: The path to export the model to
:param onnx_model_name: The name to save the exported ONNX model as
:param deployment_target: The deployment target to export to. Defaults to deepsparse
:param opset: The opset to use for the export. Defaults to TORCH_DEFAULT_ONNX_OPSET
:param kwargs: Additional kwargs to pass to the TorchToONNX exporter
:return: The path to the exported model
"""

model.eval()

exporter = TorchToONNX(sample_batch=sample_data, opset=opset, **kwargs)
exporter.export(model, os.path.join(target_path, onnx_model_name))
if deployment_target == ExportTargets.deepsparse.value:
exporter = ONNXToDeepsparse()
model = exporter.load_model(os.path.join(target_path, onnx_model_name))
exporter.export(model, os.path.join(target_path, onnx_model_name))
if deployment_target == ExportTargets.onnx.value:
pass
else:
raise ValueError(f"Unsupported deployment target: {deployment_target}")
return os.path.join(target_path, onnx_model_name)


def save_zoo_directory(
output_dir: str, training_outputs_dir: str, logs_path: Optional[str] = None
):
Expand Down
111 changes: 111 additions & 0 deletions tests/sparseml/export/test_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from collections import OrderedDict

import onnx
import pytest

from src.sparseml.export.helpers import apply_optimizations
from tests.sparseml.exporters.transforms.test_onnx_transform import (
_create_model as create_dummy_onnx_file,
)


def foo(onnx_model):
logging.debug("foo")
return onnx_model


def bar(onnx_model):
logging.debug("bar")
return onnx_model


@pytest.fixture()
def available_optimizations():
return OrderedDict(zip(["bar", "foo"], [bar, foo]))


@pytest.fixture()
def available_optimizations_empty():
return OrderedDict()


@pytest.mark.parametrize(
"target_optimizations, should_raise_error",
[("none", False), ("all", False), ("error_name", True), (["error_name"], True)],
)
def test_apply_optimizations_empty(
tmp_path, available_optimizations_empty, target_optimizations, should_raise_error
):
onnx_model = create_dummy_onnx_file()
onnx_file_path = tmp_path / "test.onnx"
onnx.save(onnx_model, onnx_file_path)

if not should_raise_error:
apply_optimizations(
onnx_file_path=onnx_file_path,
target_optimizations=target_optimizations,
available_optimizations=available_optimizations_empty,
)
else:
with pytest.raises(KeyError):
apply_optimizations(
onnx_file_path=onnx_file_path,
target_optimizations=target_optimizations,
available_optimizations=available_optimizations_empty,
)


@pytest.mark.parametrize(
"target_optimizations, expected_logs, should_raise_error",
[
("none", [], False),
("all", ["bar", "foo"], False),
(["foo"], ["foo"], False),
("error_name", [], True),
(["error_name"], [], True),
],
)
def test_apply_optimizations(
caplog,
tmp_path,
available_optimizations,
target_optimizations,
expected_logs,
should_raise_error,
):
onnx_model = create_dummy_onnx_file()
onnx_file_path = tmp_path / "test.onnx"
onnx.save(onnx_model, onnx_file_path)

if should_raise_error:
with pytest.raises(KeyError):
apply_optimizations(
onnx_file_path=onnx_file_path,
target_optimizations=target_optimizations,
available_optimizations=available_optimizations,
)
return

with caplog.at_level(logging.DEBUG):
apply_optimizations(
onnx_file_path=onnx_file_path,
target_optimizations=target_optimizations,
available_optimizations=available_optimizations,
)

assert caplog.messages == expected_logs
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ def test_integration_helper_functions():
assert image_classification.create_model
assert image_classification.create_dummy_input
assert image_classification.export
assert image_classification.graph_optimizations is None

0 comments on commit 3f86251

Please sign in to comment.