Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ONNX Export for Refactored Modifiers #1769

Merged
merged 14 commits into from
Oct 20, 2023
Merged
7 changes: 5 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,11 @@ def _setup_entry_points() -> Dict:
]
)

entry_points["console_scripts"].append(
"sparseml.transformers.export_onnx=sparseml.transformers.export:main"
entry_points["console_scripts"].extend(
[
"sparseml.transformers.export_onnx=sparseml.transformers.export:main",
"sparseml.transformers.export_onnx_refactor=sparseml.transformers.sparsification.obcq.export:main", # noqa 501
]
)

# image classification integration
Expand Down
5 changes: 1 addition & 4 deletions src/sparseml/modifiers/quantization/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from typing import Any, Dict, List, Optional

from sparseml.core import Event, Modifier, State
from sparseml.core import Event, Modifier
from sparseml.modifiers.quantization.utils.quantization_scheme import (
QuantizationScheme,
QuantizationSchemeLoadable,
Expand Down Expand Up @@ -156,9 +156,6 @@ def check_should_disable_observer(self, event: Event) -> bool:
return True
return False

def on_initialize_structure(self, state: State, **kwargs):
pass # nothing needed for this modifier


class _QuantizationSchemesDict(dict):
# wrapper class for dict to override the __str__ method for yaml serialization
Expand Down
47 changes: 27 additions & 20 deletions src/sparseml/modifiers/quantization/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ class QuantizationModifierPyTorch(QuantizationModifier):
quantization_observer_disabled_: bool = False
bn_stats_frozen_: bool = False

def on_initialize_structure(self, state: State, **kwargs):
module = state.model.model
self._enable_module_qat(module)
state.model.model.apply(torch.quantization.disable_observer)

def on_initialize(self, state: State, **kwargs) -> bool:
raise_if_torch_quantization_not_available()
if self.end and self.end != -1:
Expand All @@ -57,6 +62,7 @@ def on_initialize(self, state: State, **kwargs) -> bool:

if self.calculate_start() == -1: # one-shot
self._enable_module_qat(module)
self._calibrate_if_possible(module)
self._disable_quantization_observer(module)

return True
Expand Down Expand Up @@ -95,30 +101,31 @@ def _disable_quantization_observer(self, model: Module):
self.quantization_observer_disabled_ = True

def _enable_module_qat(self, module: Module):
# fuse conv-bn-relu blocks prior to quantization emulation
self._fuse(module)

# add quantization_schemes to target submodules
set_quantization_schemes(
module,
scheme=self.scheme,
scheme_overrides=self.scheme_overrides,
ignore=self.ignore,
strict=self.strict,
)
module.apply(torch.quantization.enable_observer)

# fix for freezing batchnorm statistics when not fusing BN with convs.
# pytorch only supports freezing batchnorm statistics for fused modules.
# this fix wraps BN modules adding with a new module class that supports
# methods related to freezing/unfreezing BN statistics.
configure_module_bn_wrappers(module)
if not self.qat_enabled_:
# fuse conv-bn-relu blocks prior to quantization emulation
self._fuse(module)

# add quantization_schemes to target submodules
set_quantization_schemes(
module,
scheme=self.scheme,
scheme_overrides=self.scheme_overrides,
ignore=self.ignore,
strict=self.strict,
)

# convert target qconfig layers to QAT modules with FakeQuantize
convert_module_qat_from_schemes(module)
# fix for freezing batchnorm statistics when not fusing BN with convs.
# pytorch only supports freezing batchnorm statistics for fused modules.
# this fix wraps BN modules adding with a new module class that supports
# methods related to freezing/unfreezing BN statistics.
configure_module_bn_wrappers(module)

self.qat_enabled_ = True
# convert target qconfig layers to QAT modules with FakeQuantize
convert_module_qat_from_schemes(module)

self._calibrate_if_possible(module)
self.qat_enabled_ = True

def _fuse(self, module: Module):
if self.model_fuse_fn_name in [None, "conv_bn_relus"]:
Expand Down
Loading
Loading