Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Onnx framework and sparsification implementation for phase 2 #171

Merged
merged 13 commits into from
Apr 26, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
"sphinx_copybutton",
"sphinx_markdown_tables",
"sphinx_multiversion",
"sphinx-pydantic",
"sphinx_rtd_theme",
"recommonmark",
]
Expand All @@ -60,19 +61,19 @@
templates_path = ["_templates"]

# Whitelist pattern for tags (set to None to ignore all tags)
smv_tag_whitelist = r'^v.*$'
smv_tag_whitelist = r"^v.*$"

# Whitelist pattern for branches (set to None to ignore all branches)
smv_branch_whitelist = r'^main$'
smv_branch_whitelist = r"^main$"

# Whitelist pattern for remotes (set to None to use local branches only)
smv_remote_whitelist = r'^.*$'
smv_remote_whitelist = r"^.*$"

# Pattern for released versions
smv_released_pattern = r'^tags/v.*$'
smv_released_pattern = r"^tags/v.*$"

# Format for versioned output directories inside the build directory
smv_outputdir_format = '{ref.name}'
smv_outputdir_format = "{ref.name}"

# Determines whether remote or local git branches/tags are preferred if their output dirs conflict
smv_prefer_remote_refs = False
Expand Down Expand Up @@ -111,8 +112,8 @@
html_logo = "source/icon-sparseml.png"

html_theme_options = {
'analytics_id': 'UA-128364174-1', # Provided by Google in your dashboard
'analytics_anonymize_ip': False,
"analytics_id": "UA-128364174-1", # Provided by Google in your dashboard
"analytics_anonymize_ip": False,
}

# Add any paths that contain custom static files (such as style sheets) here,
Expand Down Expand Up @@ -153,7 +154,13 @@
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(master_doc, "sparseml.tex", "SparseML Documentation", [author], "manual",),
(
master_doc,
"sparseml.tex",
"SparseML Documentation",
[author],
"manual",
),
]

# -- Options for manual page output ------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ ensure_newline_before_comments = True
force_grid_wrap = 0
include_trailing_comma = True
known_first_party = sparseml,sparsezoo,tests
known_third_party = bs4,requests,packaging,yaml,tqdm,numpy,onnx,onnxruntime,pandas,PIL,psutil,scipy,toposort,pytest,torch,torchvision,keras,tensorflow,merge-args,cv2
known_third_party = bs4,requests,packaging,yaml,pydantic,tqdm,numpy,onnx,onnxruntime,pandas,PIL,psutil,scipy,toposort,pytest,torch,torchvision,keras,tensorflow,merge-args,cv2
sections = FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER

line_length = 88
Expand Down
11 changes: 10 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@
"onnx>=1.5.0,<1.8.0",
"onnxruntime>=1.0.0",
"pandas<1.0.0",
"packaging>=20.0",
"psutil>=5.0.0",
"pydantic>=1.0.0",
"requests>=2.0.0",
"scikit-image>=0.15.0",
"scipy>=1.0.0",
Expand Down Expand Up @@ -80,6 +82,8 @@
"sphinx-copybutton>=0.3.0",
"sphinx-markdown-tables>=0.0.15",
"sphinx-multiversion==0.2.4",
"sphinx-pydantic>=0.1.0",
"sphinx-rtd-theme>=0.5.0",
"wheel>=0.36.2",
"pytest>=6.0.0",
"flaky>=3.0.0",
Expand Down Expand Up @@ -114,7 +118,12 @@ def _setup_extras() -> Dict:


def _setup_entry_points() -> Dict:
return {}
return {
"console_scripts": [
"sparseml.framework=sparseml.framework.info:_main",
"sparseml.sparsification=sparseml.sparsification.info:_main",
]
}


def _setup_long_description() -> Tuple[str, str]:
Expand Down
23 changes: 21 additions & 2 deletions src/sparseml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,27 @@
# flake8: noqa
# isort: skip_file

from .version import *

# be sure to import all logging first and at the root
# this keeps other loggers in nested files creating from the root logger setups
from .log import *
from .version import *

from .base import (
Framework,
check_version,
detect_framework,
execute_in_sparseml_framework,
)
from .framework import (
FrameworkInferenceProviderInfo,
FrameworkInfo,
framework_info,
save_framework_info,
load_framework_info,
)
from .sparsification import (
SparsificationInfo,
sparsification_info,
save_sparsification_info,
load_sparsification_info,
)
214 changes: 214 additions & 0 deletions src/sparseml/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import importlib
import logging
from enum import Enum
from typing import Any, Optional

from packaging import version

import pkg_resources


__all__ = [
"Framework",
"detect_framework",
"execute_in_sparseml_framework",
"get_version",
"check_version",
]


_LOGGER = logging.getLogger(__name__)


class Framework(Enum):
"""
Framework types known of/supported within the sparseml/deepsparse ecosystem
"""

unknown = "unknown"
deepsparse = "deepsparse"
onnx = "onnx"
keras = "keras"
pytorch = "pytorch"
tensorflow_v1 = "tensorflow_v1"


def detect_framework(item: Any) -> Framework:
"""
Detect the supported ML framework for a given item.
Supported input types are the following:
- A Framework enum
- A string of any case representing the name of the framework
(deepsparse, onnx, keras, pytorch, tensorflow_v1)
- A supported file type within the framework such as model files:
(onnx, pth, h5, pb)
- An object from a supported ML framework such as a model instance
If the framework cannot be determined, will return Framework.unknown
:param item: The item to detect the ML framework for
:type item: Any
:return: The detected framework from the given item
:rtype: Framework
"""
_LOGGER.debug("detecting framework for %s", item)
framework = Framework.unknown

if isinstance(item, Framework):
_LOGGER.debug("framework detected from Framework instance")
framework = item
elif isinstance(item, str) and item.lower().strip() in Framework.__members__:
_LOGGER.debug("framework detected from Framework string instance")
framework = Framework[item.lower().strip()]
else:
_LOGGER.debug("detecting framework by calling into supported frameworks")

for test in Framework:
try:
framework = execute_in_sparseml_framework(
test, "detect_framework", item
)
except Exception as err:
# errors are expected if the framework is not installed, log as debug
_LOGGER.debug(f"error while calling detect_framework for {test}: {err}")

if framework != Framework.unknown:
break

_LOGGER.info("detected framework of %s from %s", framework, item)

return framework


def execute_in_sparseml_framework(
framework: Framework, function_name: str, *args, **kwargs
) -> Any:
"""
Execute a general function that is callable from the root of the frameworks
package under SparseML such as sparseml.pytorch.
Useful for benchmarking, analyzing, etc.
Will pass the args and kwargs to the callable function.
:param framework: The ML framework to run the function under in SparseML.
:type framework: Framework
:param function_name: The name of the function in SparseML that should be run
with the given args and kwargs.
:type function_name: str
:param args: Any positional args to be passed into the function.
:param kwargs: Any key word args to be passed into the function.
:return: The return value from the executed function.
:rtype: Any
"""
_LOGGER.debug(
"executing function with name %s for framework %s, args %s, kwargs %s",
function_name,
framework,
args,
kwargs,
)

if not isinstance(framework, Framework):
framework = detect_framework(framework)

if framework == Framework.unknown:
raise ValueError(
f"unknown or unsupported framework {framework}, "
f"cannot call function {function_name}"
)

try:
module = importlib.import_module(f"sparseml.{framework.value}")
function = getattr(module, function_name)
except Exception as err:
raise ValueError(
f"could not find function_name {function_name} in framework {framework}: "
f"{err}"
)

return function(*args, **kwargs)


def get_version(package_name: str, raise_on_error: bool) -> Optional[str]:
"""
:param package_name: The name of the full package, as it would be imported,
to get the version for
:type package_name: str
:param raise_on_error: True to raise an error if package is not installed
or couldn't be imported, False to return None
:return: the version of the desired package if detected, otherwise raises an error
:rtype: str
"""

try:
current_version: str = pkg_resources.get_distribution(package_name).version
except Exception as err:
if raise_on_error:
raise ImportError(
f"error while getting current version for {package_name}: {err}"
)

return None

return current_version


def check_version(
package_name: str,
min_version: Optional[str] = None,
max_version: Optional[str] = None,
raise_on_error: bool = True,
) -> bool:
"""
:param package_name: the name of the package to check the version of
:type package_name: str
:param min_version: The minimum version for the package that it must be greater than
or equal to, if unset will require no minimum version
:type min_version: str
:param max_version: The maximum version for the package that it must be less than
or equal to, if unset will require no maximum version.
:type max_version: str
:param raise_on_error: True to raise any issues such as not installed,
minimum version, or maximum version as ImportError. False to return the result.
:type raise_on_error: bool
:return: If raise_on_error, will return False if the package is not installed
or the version is outside the accepted bounds and True if everything is correct.
:rtype: bool
"""
current_version = get_version(package_name, raise_on_error)

if not current_version:
return False

current_version = version.parse(current_version)
min_version = version.parse(min_version) if min_version else None
max_version = version.parse(max_version) if max_version else None

if min_version and current_version < min_version:
if raise_on_error:
raise ImportError(
f"required min {package_name} version {min_version}, "
f"found {current_version}"
)
return False

if max_version and current_version > max_version:
if raise_on_error:
raise ImportError(
f"required min {package_name} version {min_version}, "
f"found {current_version}"
)
return False

return True
22 changes: 22 additions & 0 deletions src/sparseml/framework/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Functionality related to integrating with, detecting, and getting information for
support and sparsification in ML frameworks.
"""

# flake8: noqa

from .info import *
Loading