Skip to content

Commit

Permalink
Merge pull request #171 from neuralmagic/interface-framework-onnx
Browse files Browse the repository at this point in the history
Onnx framework and sparsification implementation for phase 2
  • Loading branch information
markurtz authored Apr 26, 2021
2 parents 1ae34c5 + ba797ac commit cc2ba1d
Show file tree
Hide file tree
Showing 14 changed files with 661 additions and 61 deletions.
7 changes: 5 additions & 2 deletions src/sparseml/onnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@
# limitations under the License.

"""
Code for working with the ONNX framework for creating /
editing models for performance in the Neural Magic System
Functionality for working with and sparsifying Models in the ONNX/ONNXRuntime framework
"""

# flake8: noqa

from .base import *
from .framework import detect_framework, framework_info, is_supported
from .sparsification import sparsification_info
178 changes: 178 additions & 0 deletions src/sparseml/onnx/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import functools
from typing import Optional

from sparseml.base import check_version


try:
import onnx

onnx_err = None
except Exception as err:
onnx = object() # TODO: populate with fake object for necessary imports
onnx_err = err

try:
import onnxruntime

onnxruntime_err = None
except Exception as err:
onnxruntime = object() # TODO: populate with fake object for necessary imports
onnxruntime_err = err


__all__ = [
"onnx",
"onnx_err",
"onnxruntime",
"onnxruntime_err",
"check_onnx_install",
"check_onnxruntime_install",
"require_onnx",
"require_onnxruntime",
]


_ONNX_MIN_VERSION = "1.5.0"
_ORT_MIN_VERSION = "1.0.0"


def check_onnx_install(
min_version: Optional[str] = _ONNX_MIN_VERSION,
max_version: Optional[str] = None,
raise_on_error: bool = True,
) -> bool:
"""
Check that the onnx package is installed.
If raise_on_error, will raise an ImportError if it is not installed or
the required version range, if set, is not installed.
If not raise_on_error, will return True if installed with required version
and False otherwise.
:param min_version: The minimum version for onnx that it must be greater than
or equal to, if unset will require no minimum version
:type min_version: str
:param max_version: The maximum version for onnx that it must be less than
or equal to, if unset will require no maximum version.
:type max_version: str
:param raise_on_error: True to raise any issues such as not installed,
minimum version, or maximum version as ImportError. False to return the result.
:type raise_on_error: bool
:return: If raise_on_error, will return False if onnx is not installed
or the version is outside the accepted bounds and True if everything is correct.
:rtype: bool
"""
if onnx_err is not None:
if raise_on_error:
raise onnx_err
return False

return check_version("onnx", min_version, max_version, raise_on_error)


def check_onnxruntime_install(
min_version: Optional[str] = _ORT_MIN_VERSION,
max_version: Optional[str] = None,
raise_on_error: bool = True,
) -> bool:
"""
Check that the onnxruntime package is installed.
If raise_on_error, will raise an ImportError if it is not installed or
the required version range, if set, is not installed.
If not raise_on_error, will return True if installed with required version
and False otherwise.
:param min_version: The minimum version for onnxruntime that it must be greater than
or equal to, if unset will require no minimum version
:type min_version: str
:param max_version: The maximum version for onnxruntime that it must be less than
or equal to, if unset will require no maximum version.
:type max_version: str
:param raise_on_error: True to raise any issues such as not installed,
minimum version, or maximum version as ImportError. False to return the result.
:type raise_on_error: bool
:return: If raise_on_error, will return False if onnxruntime is not installed
or the version is outside the accepted bounds and True if everything is correct.
:rtype: bool
"""
if onnxruntime_err is not None:
if raise_on_error:
raise onnxruntime_err
return False

return check_version("onnxruntime", min_version, max_version, raise_on_error)


def require_onnx(
min_version: Optional[str] = _ONNX_MIN_VERSION, max_version: Optional[str] = None
):
"""
Decorator function to require use of onnx.
Will check that onnx package is installed and within the bounding
ranges of min_version and max_version if they are set before calling
the wrapped function.
See :func:`check_onnx_install` for more info.
param min_version: The minimum version for onnx that it must be greater than
or equal to, if unset will require no minimum version
:type min_version: str
:param max_version: The maximum version for onnx that it must be less than
or equal to, if unset will require no maximum version.
:type max_version: str
"""

def _decorator(func):
@functools.wraps(func)
def _wrapper(*args, **kwargs):
check_onnx_install(min_version, max_version)

return func(*args, **kwargs)

return _wrapper

return _decorator


def require_onnxruntime(
min_version: Optional[str] = _ORT_MIN_VERSION, max_version: Optional[str] = None
):
"""
Decorator function to require use of onnxruntime.
Will check that onnxruntime package is installed and within the bounding
ranges of min_version and max_version if they are set before calling
the wrapped function.
See :func:`check_onnxruntime_install` for more info.
param min_version: The minimum version for onnxruntime that it must be greater than
or equal to, if unset will require no minimum version
:type min_version: str
:param max_version: The maximum version for onnxruntime that it must be less than
or equal to, if unset will require no maximum version.
:type max_version: str
"""

def _decorator(func):
@functools.wraps(func)
def _wrapper(*args, **kwargs):
check_onnxruntime_install(min_version, max_version)

return func(*args, **kwargs)

return _wrapper

return _decorator
24 changes: 24 additions & 0 deletions src/sparseml/onnx/framework/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa

"""
Functionality related to integrating with, detecting, and getting information for
support and sparsification in the ONNX/ONNXRuntime framework.
"""

# flake8: noqa

from .info import *
157 changes: 157 additions & 0 deletions src/sparseml/onnx/framework/info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Functionality related to detecting and getting information for
support and sparsification in the ONNX/ONNXRuntime framework.
"""

import logging
from typing import Any

from sparseml.base import Framework, get_version
from sparseml.framework import FrameworkInferenceProviderInfo, FrameworkInfo
from sparseml.onnx.base import check_onnx_install, check_onnxruntime_install
from sparseml.onnx.sparsification import sparsification_info
from sparseml.sparsification import SparsificationInfo


__all__ = ["is_supported", "detect_framework", "framework_info"]


_LOGGER = logging.getLogger(__name__)


def is_supported(item: Any) -> bool:
"""
:param item: The item to detect the support for
:type item: Any
:return: True if the item is supported by onnx/onnxruntime, False otherwise
:rtype: bool
"""
framework = detect_framework(item)

return framework == Framework.onnx


def detect_framework(item: Any) -> Framework:
"""
Detect the supported ML framework for a given item specifically for the
onnx/onnxruntime package.
Supported input types are the following:
- A Framework enum
- A string of any case representing the name of the framework
(deepsparse, onnx, keras, pytorch, tensorflow_v1)
- A supported file type within the framework such as model files:
(onnx, pth, h5, pb)
- An object from a supported ML framework such as a model instance
If the framework cannot be determined, will return Framework.unknown
:param item: The item to detect the ML framework for
:type item: Any
:return: The detected framework from the given item
:rtype: Framework
"""
framework = Framework.unknown

if isinstance(item, Framework):
_LOGGER.debug("framework detected from Framework instance")
framework = item
elif isinstance(item, str) and item.lower().strip() in Framework.__members__:
_LOGGER.debug("framework detected from Framework string instance")
framework = Framework[item.lower().strip()]
elif isinstance(item, str) and "onnx" in item.lower().strip():
_LOGGER.debug("framework detected from onnx text")
# string, check if it's a string saying onnx first
framework = Framework.onnx
elif isinstance(item, str) and ".onnx" in item.lower().strip():
_LOGGER.debug("framework detected from .onnx")
# string, check if it's a file url or path that ends with onnx extension
framework = Framework.onnx
elif check_onnx_install(raise_on_error=False):
from onnx import ModelProto

if isinstance(item, ModelProto):
_LOGGER.debug("framework detected from ONNX instance")
# onnx native support
framework = Framework.onnx

return framework


def framework_info() -> FrameworkInfo:
"""
Detect the information for the onnx/onnxruntime framework such as package versions,
availability for core actions such as training and inference,
sparsification support, and inference provider support.
:return: The framework info for onnx/onnxruntime
:rtype: FrameworkInfo
"""
all_providers = []
available_providers = []
if check_onnxruntime_install(raise_on_error=False):
from onnxruntime import get_all_providers, get_available_providers

available_providers = get_available_providers()
all_providers = get_all_providers()

cpu_provider = FrameworkInferenceProviderInfo(
name="cpu",
description="Base CPU provider within ONNXRuntime",
device="cpu",
supported_sparsification=SparsificationInfo(), # TODO: fill in when available
available=(
check_onnx_install(raise_on_error=False)
and check_onnxruntime_install(raise_on_error=False)
and "CPUExecutionProvider" in available_providers
),
properties={},
warnings=[],
)
gpu_provider = FrameworkInferenceProviderInfo(
name="cuda",
description="Base GPU CUDA provider within ONNXRuntime",
device="gpu",
supported_sparsification=SparsificationInfo(), # TODO: fill in when available
available=(
check_onnx_install(raise_on_error=False)
and check_onnxruntime_install(raise_on_error=False)
and "CUDAExecutionProvider" in available_providers
),
properties={},
warnings=[],
)

return FrameworkInfo(
framework=Framework.onnx,
package_versions={
"onnx": get_version(package_name="onnx", raise_on_error=False),
"onnxruntime": (
get_version(package_name="onnxruntime", raise_on_error=False)
),
"sparsezoo": get_version(package_name="sparsezoo", raise_on_error=False),
"sparseml": get_version(package_name="sparseml", raise_on_error=False),
},
sparsification=sparsification_info(),
inference_providers=[cpu_provider, gpu_provider],
properties={
"available_providers": available_providers,
"all_providers": all_providers,
},
training_available=False,
sparsification_available=True,
exporting_onnx_available=True,
inference_available=True,
)
Loading

0 comments on commit cc2ba1d

Please sign in to comment.