Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Onnx framework and sparsification implementation for phase 2 #171

Merged
merged 13 commits into from
Apr 26, 2021
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/sparseml/onnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@
# limitations under the License.

"""
Code for working with the ONNX framework for creating /
editing models for performance in the Neural Magic System
Functionality for working with and sparsifying Models in the ONNX/ONNXRuntime framework
"""

# flake8: noqa

from .base import *
from .framework import detect_framework, framework_info, is_supported
from .sparsification import sparsification_info
178 changes: 178 additions & 0 deletions src/sparseml/onnx/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import functools
from typing import Optional

from sparseml.base import check_version


try:
import onnx

onnx_err = None
except Exception as err:
onnx = object() # TODO: populate with fake object for necessary imports
onnx_err = err

try:
import onnxruntime

onnxruntime_err = None
except Exception as err:
onnxruntime = object() # TODO: populate with fake object for necessary imports
onnxruntime_err = err


__all__ = [
"onnx",
"onnx_err",
"onnxruntime",
"onnxruntime_err",
"check_onnx_install",
"check_onnxruntime_install",
"require_onnx",
"require_onnxruntime",
]


_ONNX_MIN_VERSION = "1.5.0"
_ORT_MIN_VERSION = "1.0.0"


def check_onnx_install(
min_version: Optional[str] = _ONNX_MIN_VERSION,
max_version: Optional[str] = None,
raise_on_error: bool = True,
) -> bool:
"""
Check that the onnx package is installed.
If raise_on_error, will raise an ImportError if it is not installed or
the required version range, if set, is not installed.
If not raise_on_error, will return True if installed with required version
and False otherwise.

:param min_version: The minimum version for onnx that it must be greater than
or equal to, if unset will require no minimum version
:type min_version: str
:param max_version: The maximum version for onnx that it must be less than
or equal to, if unset will require no maximum version.
:type max_version: str
:param raise_on_error: True to raise any issues such as not installed,
minimum version, or maximum version as ImportError. False to return the result.
:type raise_on_error: bool
:return: If raise_on_error, will return False if onnx is not installed
or the version is outside the accepted bounds and True if everything is correct.
:rtype: bool
"""
if onnx_err is not None:
if raise_on_error:
raise onnx_err
return False

return check_version("onnx", min_version, max_version, raise_on_error)


def check_onnxruntime_install(
min_version: Optional[str] = _ORT_MIN_VERSION,
max_version: Optional[str] = None,
raise_on_error: bool = True,
) -> bool:
"""
Check that the onnxruntime package is installed.
If raise_on_error, will raise an ImportError if it is not installed or
the required version range, if set, is not installed.
If not raise_on_error, will return True if installed with required version
and False otherwise.

:param min_version: The minimum version for onnxruntime that it must be greater than
or equal to, if unset will require no minimum version
:type min_version: str
:param max_version: The maximum version for onnxruntime that it must be less than
or equal to, if unset will require no maximum version.
:type max_version: str
:param raise_on_error: True to raise any issues such as not installed,
minimum version, or maximum version as ImportError. False to return the result.
:type raise_on_error: bool
:return: If raise_on_error, will return False if onnxruntime is not installed
or the version is outside the accepted bounds and True if everything is correct.
:rtype: bool
"""
if onnxruntime_err is not None:
if raise_on_error:
raise onnxruntime_err
return False

return check_version("onnxruntime", min_version, max_version, raise_on_error)


def require_onnx(
min_version: Optional[str] = _ONNX_MIN_VERSION, max_version: Optional[str] = None
):
"""
Decorator function to require use of onnx.
Will check that onnx package is installed and within the bounding
ranges of min_version and max_version if they are set before calling
the wrapped function.
See :func:`check_onnx_install` for more info.

param min_version: The minimum version for onnx that it must be greater than
or equal to, if unset will require no minimum version
:type min_version: str
:param max_version: The maximum version for onnx that it must be less than
or equal to, if unset will require no maximum version.
:type max_version: str
"""

def _decorator(func):
@functools.wraps(func)
def _wrapper(*args, **kwargs):
check_onnx_install(min_version, max_version)

return func(*args, **kwargs)

return _wrapper

return _decorator


def require_onnxruntime(
min_version: Optional[str] = _ORT_MIN_VERSION, max_version: Optional[str] = None
):
"""
Decorator function to require use of onnxruntime.
Will check that onnxruntime package is installed and within the bounding
ranges of min_version and max_version if they are set before calling
the wrapped function.
See :func:`check_onnxruntime_install` for more info.

param min_version: The minimum version for onnxruntime that it must be greater than
or equal to, if unset will require no minimum version
:type min_version: str
:param max_version: The maximum version for onnxruntime that it must be less than
or equal to, if unset will require no maximum version.
:type max_version: str
"""

def _decorator(func):
@functools.wraps(func)
def _wrapper(*args, **kwargs):
check_onnxruntime_install(min_version, max_version)

return func(*args, **kwargs)

return _wrapper

return _decorator
24 changes: 24 additions & 0 deletions src/sparseml/onnx/framework/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa

"""
Functionality related to integrating with, detecting, and getting information for
support and sparsification in the ONNX/ONNXRuntime framework.
"""

# flake8: noqa

from .info import *
157 changes: 157 additions & 0 deletions src/sparseml/onnx/framework/info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Functionality related to detecting and getting information for
support and sparsification in the ONNX/ONNXRuntime framework.
"""

import logging
from typing import Any

from sparseml.base import Framework, get_version
from sparseml.framework import FrameworkInferenceProviderInfo, FrameworkInfo
from sparseml.onnx.base import check_onnx_install, check_onnxruntime_install
from sparseml.onnx.sparsification import sparsification_info
from sparseml.sparsification import SparsificationInfo


__all__ = ["is_supported", "detect_framework", "framework_info"]


_LOGGER = logging.getLogger(__name__)


def is_supported(item: Any) -> bool:
"""
:param item: The item to detect the support for
:type item: Any
:return: True if the item is supported by onnx/onnxruntime, False otherwise
:rtype: bool
"""
framework = detect_framework(item)

return framework == Framework.onnx


def detect_framework(item: Any) -> Framework:
"""
Detect the supported ML framework for a given item specifically for the
onnx/onnxruntime package.
Supported input types are the following:
- A Framework enum
- A string of any case representing the name of the framework
(deepsparse, onnx, keras, pytorch, tensorflow_v1)
- A supported file type within the framework such as model files:
(onnx, pth, h5, pb)
- An object from a supported ML framework such as a model instance
If the framework cannot be determined, will return Framework.unknown

:param item: The item to detect the ML framework for
:type item: Any
:return: The detected framework from the given item
:rtype: Framework
"""
framework = Framework.unknown

if isinstance(item, Framework):
_LOGGER.debug("framework detected from Framework instance")
framework = item
elif isinstance(item, str) and item.lower().strip() in Framework.__members__:
_LOGGER.debug("framework detected from Framework string instance")
framework = Framework[item.lower().strip()]
elif isinstance(item, str) and "onnx" in item.lower().strip():
_LOGGER.debug("framework detected from onnx text")
# string, check if it's a string saying onnx first
framework = Framework.onnx
elif isinstance(item, str) and ".onnx" in item.lower().strip():
_LOGGER.debug("framework detected from .onnx")
# string, check if it's a file url or path that ends with onnx extension
framework = Framework.onnx
elif check_onnx_install(raise_on_error=False):
from onnx import ModelProto

if isinstance(item, ModelProto):
_LOGGER.debug("framework detected from ONNX instance")
# onnx native support
framework = Framework.onnx

return framework


def framework_info() -> FrameworkInfo:
"""
Detect the information for the onnx/onnxruntime framework such as package versions,
availability for core actions such as training and inference,
sparsification support, and inference provider support.

:return: The framework info for onnx/onnxruntime
:rtype: FrameworkInfo
"""
all_providers = []
available_providers = []
if check_onnxruntime_install(raise_on_error=False):
from onnxruntime import get_all_providers, get_available_providers

available_providers = get_available_providers()
all_providers = get_all_providers()

cpu_provider = FrameworkInferenceProviderInfo(
name="cpu",
description="Base CPU provider within ONNXRuntime",
device="cpu",
supported_sparsification=SparsificationInfo(), # TODO: fill in when available
available=(
check_onnx_install(raise_on_error=False)
and check_onnxruntime_install(raise_on_error=False)
and "CPUExecutionProvider" in available_providers
),
properties={},
warnings=[],
)
gpu_provider = FrameworkInferenceProviderInfo(
name="cuda",
description="Base GPU CUDA provider within ONNXRuntime",
device="gpu",
supported_sparsification=SparsificationInfo(), # TODO: fill in when available
available=(
check_onnx_install(raise_on_error=False)
and check_onnxruntime_install(raise_on_error=False)
and "CUDAExecutionProvider" in available_providers
),
properties={},
warnings=[],
)

return FrameworkInfo(
framework=Framework.onnx,
package_versions={
"onnx": get_version(package_name="onnx", raise_on_error=False),
"onnxruntime": (
get_version(package_name="onnxruntime", raise_on_error=False)
),
"sparsezoo": get_version(package_name="sparsezoo", raise_on_error=False),
"sparseml": get_version(package_name="sparseml", raise_on_error=False),
},
sparsification=sparsification_info(),
inference_providers=[cpu_provider, gpu_provider],
properties={
"available_providers": available_providers,
"all_providers": all_providers,
},
training_available=False,
sparsification_available=True,
exporting_onnx_available=True,
inference_available=True,
)
Loading