Skip to content

Commit

Permalink
Add framework fallback ability to execute in sparseml (#413) (#422)
Browse files Browse the repository at this point in the history
* Add framework fallback ability to execute in sparseml

* remove unused variable

* decrease complexity of falling back on framework execution

* quality and test fixes

* update docs

* fix tests
  • Loading branch information
markurtz committed Oct 19, 2021
1 parent 3f165e6 commit 6e9cc4f
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 55 deletions.
127 changes: 90 additions & 37 deletions src/sparseml/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import importlib
import logging
from collections import OrderedDict
from enum import Enum
from typing import Any, List, Optional

Expand All @@ -25,6 +26,7 @@

__all__ = [
"Framework",
"detect_frameworks",
"detect_framework",
"execute_in_sparseml_framework",
"get_version",
Expand All @@ -48,61 +50,109 @@ class Framework(Enum):
tensorflow_v1 = "tensorflow_v1"


def detect_framework(item: Any) -> Framework:
def _execute_sparseml_package_function(
framework: Framework, function_name: str, *args, **kwargs
):
try:
module = importlib.import_module(f"sparseml.{framework.value}")
function = getattr(module, function_name)
except Exception as err:
raise ValueError(
f"unknown or unsupported framework {framework}, "
f"cannot call function {function_name}: {err}"
)

return function(*args, **kwargs)


def detect_frameworks(item: Any) -> List[Framework]:
"""
Detect the supported ML framework for a given item.
Detects the supported ML frameworks for a given item.
Supported input types are the following:
- A Framework enum
- A string of any case representing the name of the framework
(deepsparse, onnx, keras, pytorch, tensorflow_v1)
- A supported file type within the framework such as model files:
(onnx, pth, h5, pb)
- An object from a supported ML framework such as a model instance
If the framework cannot be determined, will return Framework.unknown
If the framework cannot be determined, an empty list will be returned
:param item: The item to detect the ML framework for
:type item: Any
:return: The detected framework from the given item
:rtype: Framework
:return: The detected ML frameworks from the given item
:rtype: List[Framework]
"""
_LOGGER.debug("detecting framework for %s", item)
framework = Framework.unknown
_LOGGER.debug("detecting frameworks for %s", item)
frameworks = []

if isinstance(item, str) and item.lower().strip() in Framework.__members__:
_LOGGER.debug("framework detected from Framework string instance")
item = Framework[item.lower().strip()]

if isinstance(item, Framework):
_LOGGER.debug("framework detected from Framework instance")
framework = item
elif isinstance(item, str) and item.lower().strip() in Framework.__members__:
_LOGGER.debug("framework detected from Framework string instance")
framework = Framework[item.lower().strip()]

if item != Framework.unknown:
frameworks.append(item)
else:
_LOGGER.debug("detecting framework by calling into supported frameworks")
_LOGGER.debug("detecting frameworks by calling into supported frameworks")
frameworks = []

for test in Framework:
if test == Framework.unknown:
continue

try:
framework = execute_in_sparseml_framework(
detected = _execute_sparseml_package_function(
test, "detect_framework", item
)
frameworks.append(detected)
except Exception as err:
# errors are expected if the framework is not installed, log as debug
_LOGGER.debug(f"error while calling detect_framework for {test}: {err}")
_LOGGER.debug(
"error while calling detect_framework for %s: %s", test, err
)

_LOGGER.info("detected frameworks of %s from %s", frameworks, item)

return frameworks

if framework != Framework.unknown:
break

_LOGGER.info("detected framework of %s from %s", framework, item)
def detect_framework(item: Any) -> Framework:
"""
Detect the supported ML framework for a given item.
Supported input types are the following:
- A Framework enum
- A string of any case representing the name of the framework
(deepsparse, onnx, keras, pytorch, tensorflow_v1)
- A supported file type within the framework such as model files:
(onnx, pth, h5, pb)
- An object from a supported ML framework such as a model instance
If the framework cannot be determined, will return Framework.unknown
:param item: The item to detect the ML framework for
:type item: Any
:return: The detected framework from the given item
:rtype: Framework
"""
_LOGGER.debug("detecting framework for %s", item)
frameworks = detect_frameworks(item)

return framework
return frameworks[0] if len(frameworks) > 0 else Framework.unknown


def execute_in_sparseml_framework(
framework: Framework, function_name: str, *args, **kwargs
framework: Any, function_name: str, *args, **kwargs
) -> Any:
"""
Execute a general function that is callable from the root of the frameworks
package under SparseML such as sparseml.pytorch.
Useful for benchmarking, analyzing, etc.
Will pass the args and kwargs to the callable function.
:param framework: The ML framework to run the function under in SparseML.
:type framework: Framework
:param framework: The item to detect the ML framework for to run the function under,
see detect_frameworks for more details on acceptible inputs
:type framework: Any
:param function_name: The name of the function in SparseML that should be run
with the given args and kwargs.
:type function_name: str
Expand All @@ -119,25 +169,28 @@ def execute_in_sparseml_framework(
kwargs,
)

if not isinstance(framework, Framework):
framework = detect_framework(framework)
framework_errs = OrderedDict()
test_frameworks = detect_frameworks(framework)

if framework == Framework.unknown:
raise ValueError(
f"unknown or unsupported framework {framework}, "
f"cannot call function {function_name}"
)
for test_framework in test_frameworks:
try:
module = importlib.import_module(f"sparseml.{test_framework.value}")
function = getattr(module, function_name)

try:
module = importlib.import_module(f"sparseml.{framework.value}")
function = getattr(module, function_name)
except Exception as err:
raise ValueError(
f"could not find function_name {function_name} in framework {framework}: "
f"{err}"
)
return function(*args, **kwargs)
except Exception as err:
framework_errs[framework] = err

return function(*args, **kwargs)
if len(framework_errs) == 1:
raise list(framework_errs.values())[0]

if len(framework_errs) > 1:
raise RuntimeError(str(framework_errs))

raise ValueError(
f"unknown or unsupported framework {framework}, "
f"cannot call function {function_name}"
)


def get_version(
Expand Down
20 changes: 3 additions & 17 deletions src/sparseml/benchmark/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@

from tqdm import auto

from sparseml.base import Framework, detect_framework, execute_in_sparseml_framework
from sparseml.base import Framework, execute_in_sparseml_framework
from sparseml.benchmark.serialization import (
BatchBenchmarkResult,
BenchmarkConfig,
Expand Down Expand Up @@ -369,13 +369,8 @@ def save_benchmark_results(
pass to the runner
:param show_progress: True to show a tqdm bar when running, False otherwise
"""
if framework is None:
framework = detect_framework(model)
else:
framework = detect_framework(framework)

results = execute_in_sparseml_framework(
framework,
framework if framework is not None else model,
"run_benchmark",
model,
data,
Expand Down Expand Up @@ -442,18 +437,9 @@ def load_and_run_benchmark(
:param save_path: path to save the new benchmark results
"""
_LOGGER.info(f"rerunning benchmark {load}")

info = load_benchmark_info(load)

framework = info.framework

if framework is None:
framework = detect_framework(model)
else:
framework = detect_framework(framework)

save_benchmark_results(
model,
info.framework if info.framework is not None else model,
data,
batch_size=info.config.batch_size,
iterations=info.config.iterations,
Expand Down
2 changes: 1 addition & 1 deletion tests/sparseml/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_execute_in_sparseml_framework():
with pytest.raises(ValueError):
execute_in_sparseml_framework(Framework.unknown, "unknown")

with pytest.raises(ValueError):
with pytest.raises(Exception):
execute_in_sparseml_framework(Framework.onnx, "unknown")

# TODO: fill in with sample functions to execute in frameworks once available
Expand Down

0 comments on commit 6e9cc4f

Please sign in to comment.