-
Notifications
You must be signed in to change notification settings - Fork 144
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #171 from neuralmagic/interface-framework-onnx
Onnx framework and sparsification implementation for phase 2
- Loading branch information
Showing
14 changed files
with
661 additions
and
61 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
import functools | ||
from typing import Optional | ||
|
||
from sparseml.base import check_version | ||
|
||
|
||
try: | ||
import onnx | ||
|
||
onnx_err = None | ||
except Exception as err: | ||
onnx = object() # TODO: populate with fake object for necessary imports | ||
onnx_err = err | ||
|
||
try: | ||
import onnxruntime | ||
|
||
onnxruntime_err = None | ||
except Exception as err: | ||
onnxruntime = object() # TODO: populate with fake object for necessary imports | ||
onnxruntime_err = err | ||
|
||
|
||
__all__ = [ | ||
"onnx", | ||
"onnx_err", | ||
"onnxruntime", | ||
"onnxruntime_err", | ||
"check_onnx_install", | ||
"check_onnxruntime_install", | ||
"require_onnx", | ||
"require_onnxruntime", | ||
] | ||
|
||
|
||
_ONNX_MIN_VERSION = "1.5.0" | ||
_ORT_MIN_VERSION = "1.0.0" | ||
|
||
|
||
def check_onnx_install( | ||
min_version: Optional[str] = _ONNX_MIN_VERSION, | ||
max_version: Optional[str] = None, | ||
raise_on_error: bool = True, | ||
) -> bool: | ||
""" | ||
Check that the onnx package is installed. | ||
If raise_on_error, will raise an ImportError if it is not installed or | ||
the required version range, if set, is not installed. | ||
If not raise_on_error, will return True if installed with required version | ||
and False otherwise. | ||
:param min_version: The minimum version for onnx that it must be greater than | ||
or equal to, if unset will require no minimum version | ||
:type min_version: str | ||
:param max_version: The maximum version for onnx that it must be less than | ||
or equal to, if unset will require no maximum version. | ||
:type max_version: str | ||
:param raise_on_error: True to raise any issues such as not installed, | ||
minimum version, or maximum version as ImportError. False to return the result. | ||
:type raise_on_error: bool | ||
:return: If raise_on_error, will return False if onnx is not installed | ||
or the version is outside the accepted bounds and True if everything is correct. | ||
:rtype: bool | ||
""" | ||
if onnx_err is not None: | ||
if raise_on_error: | ||
raise onnx_err | ||
return False | ||
|
||
return check_version("onnx", min_version, max_version, raise_on_error) | ||
|
||
|
||
def check_onnxruntime_install( | ||
min_version: Optional[str] = _ORT_MIN_VERSION, | ||
max_version: Optional[str] = None, | ||
raise_on_error: bool = True, | ||
) -> bool: | ||
""" | ||
Check that the onnxruntime package is installed. | ||
If raise_on_error, will raise an ImportError if it is not installed or | ||
the required version range, if set, is not installed. | ||
If not raise_on_error, will return True if installed with required version | ||
and False otherwise. | ||
:param min_version: The minimum version for onnxruntime that it must be greater than | ||
or equal to, if unset will require no minimum version | ||
:type min_version: str | ||
:param max_version: The maximum version for onnxruntime that it must be less than | ||
or equal to, if unset will require no maximum version. | ||
:type max_version: str | ||
:param raise_on_error: True to raise any issues such as not installed, | ||
minimum version, or maximum version as ImportError. False to return the result. | ||
:type raise_on_error: bool | ||
:return: If raise_on_error, will return False if onnxruntime is not installed | ||
or the version is outside the accepted bounds and True if everything is correct. | ||
:rtype: bool | ||
""" | ||
if onnxruntime_err is not None: | ||
if raise_on_error: | ||
raise onnxruntime_err | ||
return False | ||
|
||
return check_version("onnxruntime", min_version, max_version, raise_on_error) | ||
|
||
|
||
def require_onnx( | ||
min_version: Optional[str] = _ONNX_MIN_VERSION, max_version: Optional[str] = None | ||
): | ||
""" | ||
Decorator function to require use of onnx. | ||
Will check that onnx package is installed and within the bounding | ||
ranges of min_version and max_version if they are set before calling | ||
the wrapped function. | ||
See :func:`check_onnx_install` for more info. | ||
param min_version: The minimum version for onnx that it must be greater than | ||
or equal to, if unset will require no minimum version | ||
:type min_version: str | ||
:param max_version: The maximum version for onnx that it must be less than | ||
or equal to, if unset will require no maximum version. | ||
:type max_version: str | ||
""" | ||
|
||
def _decorator(func): | ||
@functools.wraps(func) | ||
def _wrapper(*args, **kwargs): | ||
check_onnx_install(min_version, max_version) | ||
|
||
return func(*args, **kwargs) | ||
|
||
return _wrapper | ||
|
||
return _decorator | ||
|
||
|
||
def require_onnxruntime( | ||
min_version: Optional[str] = _ORT_MIN_VERSION, max_version: Optional[str] = None | ||
): | ||
""" | ||
Decorator function to require use of onnxruntime. | ||
Will check that onnxruntime package is installed and within the bounding | ||
ranges of min_version and max_version if they are set before calling | ||
the wrapped function. | ||
See :func:`check_onnxruntime_install` for more info. | ||
param min_version: The minimum version for onnxruntime that it must be greater than | ||
or equal to, if unset will require no minimum version | ||
:type min_version: str | ||
:param max_version: The maximum version for onnxruntime that it must be less than | ||
or equal to, if unset will require no maximum version. | ||
:type max_version: str | ||
""" | ||
|
||
def _decorator(func): | ||
@functools.wraps(func) | ||
def _wrapper(*args, **kwargs): | ||
check_onnxruntime_install(min_version, max_version) | ||
|
||
return func(*args, **kwargs) | ||
|
||
return _wrapper | ||
|
||
return _decorator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# flake8: noqa | ||
|
||
""" | ||
Functionality related to integrating with, detecting, and getting information for | ||
support and sparsification in the ONNX/ONNXRuntime framework. | ||
""" | ||
|
||
# flake8: noqa | ||
|
||
from .info import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
""" | ||
Functionality related to detecting and getting information for | ||
support and sparsification in the ONNX/ONNXRuntime framework. | ||
""" | ||
|
||
import logging | ||
from typing import Any | ||
|
||
from sparseml.base import Framework, get_version | ||
from sparseml.framework import FrameworkInferenceProviderInfo, FrameworkInfo | ||
from sparseml.onnx.base import check_onnx_install, check_onnxruntime_install | ||
from sparseml.onnx.sparsification import sparsification_info | ||
from sparseml.sparsification import SparsificationInfo | ||
|
||
|
||
__all__ = ["is_supported", "detect_framework", "framework_info"] | ||
|
||
|
||
_LOGGER = logging.getLogger(__name__) | ||
|
||
|
||
def is_supported(item: Any) -> bool: | ||
""" | ||
:param item: The item to detect the support for | ||
:type item: Any | ||
:return: True if the item is supported by onnx/onnxruntime, False otherwise | ||
:rtype: bool | ||
""" | ||
framework = detect_framework(item) | ||
|
||
return framework == Framework.onnx | ||
|
||
|
||
def detect_framework(item: Any) -> Framework: | ||
""" | ||
Detect the supported ML framework for a given item specifically for the | ||
onnx/onnxruntime package. | ||
Supported input types are the following: | ||
- A Framework enum | ||
- A string of any case representing the name of the framework | ||
(deepsparse, onnx, keras, pytorch, tensorflow_v1) | ||
- A supported file type within the framework such as model files: | ||
(onnx, pth, h5, pb) | ||
- An object from a supported ML framework such as a model instance | ||
If the framework cannot be determined, will return Framework.unknown | ||
:param item: The item to detect the ML framework for | ||
:type item: Any | ||
:return: The detected framework from the given item | ||
:rtype: Framework | ||
""" | ||
framework = Framework.unknown | ||
|
||
if isinstance(item, Framework): | ||
_LOGGER.debug("framework detected from Framework instance") | ||
framework = item | ||
elif isinstance(item, str) and item.lower().strip() in Framework.__members__: | ||
_LOGGER.debug("framework detected from Framework string instance") | ||
framework = Framework[item.lower().strip()] | ||
elif isinstance(item, str) and "onnx" in item.lower().strip(): | ||
_LOGGER.debug("framework detected from onnx text") | ||
# string, check if it's a string saying onnx first | ||
framework = Framework.onnx | ||
elif isinstance(item, str) and ".onnx" in item.lower().strip(): | ||
_LOGGER.debug("framework detected from .onnx") | ||
# string, check if it's a file url or path that ends with onnx extension | ||
framework = Framework.onnx | ||
elif check_onnx_install(raise_on_error=False): | ||
from onnx import ModelProto | ||
|
||
if isinstance(item, ModelProto): | ||
_LOGGER.debug("framework detected from ONNX instance") | ||
# onnx native support | ||
framework = Framework.onnx | ||
|
||
return framework | ||
|
||
|
||
def framework_info() -> FrameworkInfo: | ||
""" | ||
Detect the information for the onnx/onnxruntime framework such as package versions, | ||
availability for core actions such as training and inference, | ||
sparsification support, and inference provider support. | ||
:return: The framework info for onnx/onnxruntime | ||
:rtype: FrameworkInfo | ||
""" | ||
all_providers = [] | ||
available_providers = [] | ||
if check_onnxruntime_install(raise_on_error=False): | ||
from onnxruntime import get_all_providers, get_available_providers | ||
|
||
available_providers = get_available_providers() | ||
all_providers = get_all_providers() | ||
|
||
cpu_provider = FrameworkInferenceProviderInfo( | ||
name="cpu", | ||
description="Base CPU provider within ONNXRuntime", | ||
device="cpu", | ||
supported_sparsification=SparsificationInfo(), # TODO: fill in when available | ||
available=( | ||
check_onnx_install(raise_on_error=False) | ||
and check_onnxruntime_install(raise_on_error=False) | ||
and "CPUExecutionProvider" in available_providers | ||
), | ||
properties={}, | ||
warnings=[], | ||
) | ||
gpu_provider = FrameworkInferenceProviderInfo( | ||
name="cuda", | ||
description="Base GPU CUDA provider within ONNXRuntime", | ||
device="gpu", | ||
supported_sparsification=SparsificationInfo(), # TODO: fill in when available | ||
available=( | ||
check_onnx_install(raise_on_error=False) | ||
and check_onnxruntime_install(raise_on_error=False) | ||
and "CUDAExecutionProvider" in available_providers | ||
), | ||
properties={}, | ||
warnings=[], | ||
) | ||
|
||
return FrameworkInfo( | ||
framework=Framework.onnx, | ||
package_versions={ | ||
"onnx": get_version(package_name="onnx", raise_on_error=False), | ||
"onnxruntime": ( | ||
get_version(package_name="onnxruntime", raise_on_error=False) | ||
), | ||
"sparsezoo": get_version(package_name="sparsezoo", raise_on_error=False), | ||
"sparseml": get_version(package_name="sparseml", raise_on_error=False), | ||
}, | ||
sparsification=sparsification_info(), | ||
inference_providers=[cpu_provider, gpu_provider], | ||
properties={ | ||
"available_providers": available_providers, | ||
"all_providers": all_providers, | ||
}, | ||
training_available=False, | ||
sparsification_available=True, | ||
exporting_onnx_available=True, | ||
inference_available=True, | ||
) |
Oops, something went wrong.