-
Notifications
You must be signed in to change notification settings - Fork 140
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Export Refactor][Image Classification]
apply_optimizations
function (
#1884) * initial commit * looking good, time to cleanup * Delete src/sparseml/export/helpers.py * Delete tests/sparseml/export/test_helpers.py * ready for review * improve design * tests pass * reuse _validate_dataset_num_classes * initial commit * Update src/sparseml/pytorch/image_classification/integration_helper_functions.py * Update src/sparseml/pytorch/image_classification/integration_helper_functions.py * ready for review * Update src/sparseml/export/export.py * Update src/sparseml/integration_helper_functions.py * initial commit * fixes * ready for review * nit * add return * initial commit
- Loading branch information
1 parent
59f3f5a
commit 9096b0d
Showing
7 changed files
with
278 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from collections import OrderedDict | ||
from enum import Enum | ||
from pathlib import Path | ||
from typing import Callable, List, Union | ||
|
||
import onnx | ||
|
||
from sparsezoo.utils.onnx import save_onnx | ||
|
||
|
||
__all__ = ["apply_optimizations"] | ||
|
||
|
||
class GraphOptimizationOptions(Enum): | ||
""" | ||
Holds the string names of the graph optimization options. | ||
""" | ||
|
||
none = "none" | ||
all = "all" | ||
|
||
|
||
def apply_optimizations( | ||
onnx_file_path: Union[str, Path], | ||
available_optimizations: OrderedDict[str, Callable], | ||
target_optimizations: Union[str, List[str]] = GraphOptimizationOptions.all.value, | ||
single_graph_file: bool = True, | ||
): | ||
""" | ||
Apply optimizations to the graph of the ONNX model. | ||
:param onnx_file_path: The path to the ONNX model file. | ||
:param available_optimizations: The graph optimizations available | ||
for the model. It is an ordered mapping from the string names | ||
to functions that alter the model | ||
:param target_optimizations: The name(s) of optimizations to apply. | ||
It can be either a list of string name or a single string option | ||
that specifies the set of optimizations to apply. | ||
If is string, refer to the `GraphOptimizationOptions` enum | ||
for the available options. | ||
:param single_graph_file: Whether to save the optimized graph to a single | ||
file or split it into multiple files. By default, it is True. | ||
""" | ||
optimizations: List[Callable] = resolve_graph_optimizations( | ||
available_optimizations=available_optimizations, | ||
optimizations=target_optimizations, | ||
) | ||
|
||
onnx_model = onnx.load(onnx_file_path) | ||
|
||
for optimization in optimizations: | ||
onnx_model = optimization(onnx_model) | ||
|
||
if single_graph_file: | ||
save_onnx(onnx_model, onnx_file_path) | ||
return | ||
|
||
save_onnx_multiple_files(onnx_model) | ||
|
||
|
||
def resolve_graph_optimizations( | ||
available_optimizations: OrderedDict[str, Callable], | ||
optimizations: Union[str, List[str]], | ||
) -> List[Callable]: | ||
""" | ||
Get the optimization functions to apply to the onnx model. | ||
:param available_optimizations: The graph optimizations available | ||
for the model. It is an ordered mapping from the string names | ||
to functions that alter the model | ||
:param optimizations: The name(s) of optimizations to apply. | ||
It can be either a list of string name or a single string option | ||
that specifies the set of optimizations to apply. | ||
If is string, refer to the `GraphOptimizationOptions` enum | ||
for the available options. | ||
return The list of optimization functions to apply. | ||
""" | ||
if isinstance(optimizations, str): | ||
if optimizations == GraphOptimizationOptions.none.value: | ||
return [] | ||
elif optimizations == GraphOptimizationOptions.all.value: | ||
return list(available_optimizations.values()) | ||
else: | ||
raise KeyError(f"Unknown graph optimization option: {optimizations}") | ||
elif isinstance(optimizations, list): | ||
return [available_optimizations[optimization] for optimization in optimizations] | ||
else: | ||
raise KeyError(f"Unknown graph optimization option: {optimizations}") | ||
|
||
|
||
# TODO: To discuss with @bfineran | ||
def save_onnx_multiple_files(*args, **kwargs): | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import logging | ||
from collections import OrderedDict | ||
|
||
import onnx | ||
import pytest | ||
|
||
from src.sparseml.export.helpers import apply_optimizations | ||
from tests.sparseml.exporters.transforms.test_onnx_transform import ( | ||
_create_model as create_dummy_onnx_file, | ||
) | ||
|
||
|
||
def foo(onnx_model): | ||
logging.debug("foo") | ||
return onnx_model | ||
|
||
|
||
def bar(onnx_model): | ||
logging.debug("bar") | ||
return onnx_model | ||
|
||
|
||
@pytest.fixture() | ||
def available_optimizations(): | ||
return OrderedDict(zip(["bar", "foo"], [bar, foo])) | ||
|
||
|
||
@pytest.fixture() | ||
def available_optimizations_empty(): | ||
return OrderedDict() | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"target_optimizations, should_raise_error", | ||
[("none", False), ("all", False), ("error_name", True), (["error_name"], True)], | ||
) | ||
def test_apply_optimizations_empty( | ||
tmp_path, available_optimizations_empty, target_optimizations, should_raise_error | ||
): | ||
onnx_model = create_dummy_onnx_file() | ||
onnx_file_path = tmp_path / "test.onnx" | ||
onnx.save(onnx_model, onnx_file_path) | ||
|
||
if not should_raise_error: | ||
apply_optimizations( | ||
onnx_file_path=onnx_file_path, | ||
target_optimizations=target_optimizations, | ||
available_optimizations=available_optimizations_empty, | ||
) | ||
else: | ||
with pytest.raises(KeyError): | ||
apply_optimizations( | ||
onnx_file_path=onnx_file_path, | ||
target_optimizations=target_optimizations, | ||
available_optimizations=available_optimizations_empty, | ||
) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"target_optimizations, expected_logs, should_raise_error", | ||
[ | ||
("none", [], False), | ||
("all", ["bar", "foo"], False), | ||
(["foo"], ["foo"], False), | ||
("error_name", [], True), | ||
(["error_name"], [], True), | ||
], | ||
) | ||
def test_apply_optimizations( | ||
caplog, | ||
tmp_path, | ||
available_optimizations, | ||
target_optimizations, | ||
expected_logs, | ||
should_raise_error, | ||
): | ||
onnx_model = create_dummy_onnx_file() | ||
onnx_file_path = tmp_path / "test.onnx" | ||
onnx.save(onnx_model, onnx_file_path) | ||
|
||
if should_raise_error: | ||
with pytest.raises(KeyError): | ||
apply_optimizations( | ||
onnx_file_path=onnx_file_path, | ||
target_optimizations=target_optimizations, | ||
available_optimizations=available_optimizations, | ||
) | ||
return | ||
|
||
with caplog.at_level(logging.DEBUG): | ||
apply_optimizations( | ||
onnx_file_path=onnx_file_path, | ||
target_optimizations=target_optimizations, | ||
available_optimizations=available_optimizations, | ||
) | ||
|
||
assert caplog.messages == expected_logs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters