Skip to content

Commit

Permalink
ONNX Export for Refactored Modifiers (#1769)
Browse files Browse the repository at this point in the history
* ONNX export for modifier refactor

* export script runs for new modifier framework

* add cli alias

* style

* reload about QAT adjustment

* style

* rename session alias
  • Loading branch information
Satrat committed Oct 20, 2023
1 parent 3137d56 commit c9c1c21
Show file tree
Hide file tree
Showing 6 changed files with 673 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def main():
from torch.utils.data import DataLoader
from torchvision import transforms

import sparseml.core.session as sml
import sparseml.core.session as session_manager
from sparseml.core.event import EventType
from sparseml.core.framework import Framework
from sparseml.pytorch.utils import (
Expand All @@ -40,8 +40,8 @@ def main():
device = "cuda:0"

# set up SparseML session
sml.create_session()
session = sml.active_session()
session_manager.create_session()
session = session_manager.active_session()

# download model
model = torchvision.models.mobilenet_v2(
Expand Down
7 changes: 5 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,11 @@ def _setup_entry_points() -> Dict:
]
)

entry_points["console_scripts"].append(
"sparseml.transformers.export_onnx=sparseml.transformers.export:main"
entry_points["console_scripts"].extend(
[
"sparseml.transformers.export_onnx=sparseml.transformers.export:main",
"sparseml.transformers.export_onnx_refactor=sparseml.transformers.sparsification.obcq.export:main", # noqa 501
]
)

# image classification integration
Expand Down
5 changes: 1 addition & 4 deletions src/sparseml/modifiers/quantization/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from typing import Any, Dict, List, Optional

from sparseml.core import Event, Modifier, State
from sparseml.core import Event, Modifier


__all__ = ["QuantizationModifier"]
Expand Down Expand Up @@ -136,6 +136,3 @@ def check_should_disable_observer(self, event: Event) -> bool:
if event.current_index >= disable_epoch:
return True
return False

def on_initialize_structure(self, state: State, **kwargs):
pass # nothing needed for this modifier
47 changes: 27 additions & 20 deletions src/sparseml/modifiers/quantization/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ def __init__(self, **kwargs):
self.scheme_overrides, self.scheme
)

def on_initialize_structure(self, state: State, **kwargs):
module = state.model.model
self._enable_module_qat(module)
state.model.model.apply(torch.quantization.disable_observer)

def on_initialize(self, state: State, **kwargs) -> bool:
raise_if_torch_quantization_not_available()
if self.end and self.end != -1:
Expand All @@ -84,6 +89,7 @@ def on_initialize(self, state: State, **kwargs) -> bool:

if self.calculate_start() == -1: # one-shot
self._enable_module_qat(module)
self._calibrate_if_possible(module)
self._disable_quantization_observer(module)

return True
Expand Down Expand Up @@ -122,30 +128,31 @@ def _disable_quantization_observer(self, model: Module):
self.quantization_observer_disabled_ = True

def _enable_module_qat(self, module: Module):
# fuse conv-bn-relu blocks prior to quantization emulation
self._fuse(module)

# add quantization_schemes to target submodules
set_quantization_schemes(
module,
scheme=self.scheme,
scheme_overrides=self.scheme_overrides,
ignore=self.ignore,
strict=self.strict,
)
module.apply(torch.quantization.enable_observer)

# fix for freezing batchnorm statistics when not fusing BN with convs.
# pytorch only supports freezing batchnorm statistics for fused modules.
# this fix wraps BN modules adding with a new module class that supports
# methods related to freezing/unfreezing BN statistics.
configure_module_bn_wrappers(module)
if not self.qat_enabled_:
# fuse conv-bn-relu blocks prior to quantization emulation
self._fuse(module)

# add quantization_schemes to target submodules
set_quantization_schemes(
module,
scheme=self.scheme,
scheme_overrides=self.scheme_overrides,
ignore=self.ignore,
strict=self.strict,
)

# convert target qconfig layers to QAT modules with FakeQuantize
convert_module_qat_from_schemes(module)
# fix for freezing batchnorm statistics when not fusing BN with convs.
# pytorch only supports freezing batchnorm statistics for fused modules.
# this fix wraps BN modules adding with a new module class that supports
# methods related to freezing/unfreezing BN statistics.
configure_module_bn_wrappers(module)

self.qat_enabled_ = True
# convert target qconfig layers to QAT modules with FakeQuantize
convert_module_qat_from_schemes(module)

self._calibrate_if_possible(module)
self.qat_enabled_ = True

def _fuse(self, module: Module):
if self.model_fuse_fn_name in [None, "conv_bn_relus"]:
Expand Down
Loading

0 comments on commit c9c1c21

Please sign in to comment.