Skip to content

Commit

Permalink
update sparseml.pytorch export_onnx to use ONNXToDeepsparse (#1692)
Browse files Browse the repository at this point in the history
* update sparseml.pytorch export_onnx to use ONNXToDeepsparse

* test fix
  • Loading branch information
bfineran committed Jul 28, 2023
1 parent b1d5ea2 commit 9b4221e
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 26 deletions.
5 changes: 4 additions & 1 deletion src/sparseml/exporters/onnx_to_deepsparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Union

import onnx
from onnx import ModelProto

from sparseml.exporters import transforms as sparseml_transforms
from sparseml.exporters.base_exporter import BaseExporter
Expand Down Expand Up @@ -109,7 +110,9 @@ def post_validate(self, model: onnx.ModelProto) -> onnx.ModelProto:
raise TypeError(f"Expected onnx.ModelProto, found {type(model)}")
return model

def export(self, pre_transforms_model: onnx.ModelProto, file_path: str):
def export(self, pre_transforms_model: Union[ModelProto, str], file_path: str):
if not isinstance(pre_transforms_model, ModelProto):
pre_transforms_model = onnx.load(pre_transforms_model)
if self.export_input_model or os.getenv("SAVE_PREQAT_ONNX", False):
save_onnx(pre_transforms_model, file_path.replace(".onnx", ".preqat.onnx"))

Expand Down
33 changes: 8 additions & 25 deletions src/sparseml/pytorch/utils/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader

from sparseml.exporters.onnx_to_deepsparse import ONNXToDeepsparse
from sparseml.onnx.utils import ONNXGraph
from sparseml.pytorch.opset import TORCH_DEFAULT_ONNX_OPSET
from sparseml.pytorch.utils.helpers import (
Expand Down Expand Up @@ -200,7 +201,7 @@ def export_onnx(
compilation.
:param convert_qat: if True and quantization aware training is detected in
the module being exported, the resulting QAT ONNX model will be converted
to a fully quantized ONNX model using `quantize_torch_qat_export`. Default
to a fully quantized ONNX model using `ONNXToDeepsparse`. Default
is False.
:param export_kwargs: kwargs to be passed as is to the torch.onnx.export api
call. Useful to pass in dyanmic_axes, input_names, output_names, etc.
Expand Down Expand Up @@ -444,7 +445,7 @@ def export_onnx(
compilation.
:param convert_qat: if True and quantization aware training is detected in
the module being exported, the resulting QAT ONNX model will be converted
to a fully quantized ONNX model using `quantize_torch_qat_export`. Default
to a fully quantized ONNX model using `ONNXToDeepsparse`. Default
is False.
:param dynamic_axes: dictionary of input or output names to list of dimensions
of those tensors that should be exported as dynamic. May input 'batch'
Expand Down Expand Up @@ -579,34 +580,16 @@ def export_onnx(
save_onnx(onnx_model, file_path)

if convert_qat and is_quant_module:
# overwrite exported model with fully quantized version
# import here to avoid cyclic dependency
from sparseml.pytorch.sparsification.quantization import (
quantize_torch_qat_export,
)

use_qlinearconv = hasattr(module, "export_with_qlinearconv") and (
use_qlinear_conv = hasattr(module, "export_with_qlinearconv") and (
module.export_with_qlinearconv
)

quantize_torch_qat_export(
model=file_path,
output_file_path=file_path,
use_qlinearconv=use_qlinearconv,
exporter = ONNXToDeepsparse(
use_qlinear_conv=use_qlinear_conv,
skip_input_quantize=skip_input_quantize,
)

if skip_input_quantize:
try:
# import here to avoid cyclic dependency
from sparseml.pytorch.sparsification.quantization import (
skip_onnx_input_quantize,
)

skip_onnx_input_quantize(file_path, file_path)
except Exception as e:
_LOGGER.warning(
f"Unable to skip input QuantizeLinear op with exception {e}"
)
exporter.export(pre_transforms_model=file_path, file_path=file_path)


def _copy_file(src: str, target: str):
Expand Down

0 comments on commit 9b4221e

Please sign in to comment.