Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update sparseml.pytorch export_onnx to use ONNXToDeepsparse #1692

Merged
merged 2 commits into from
Jul 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/sparseml/exporters/onnx_to_deepsparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Union

import onnx
from onnx import ModelProto

from sparseml.exporters import transforms as sparseml_transforms
from sparseml.exporters.base_exporter import BaseExporter
Expand Down Expand Up @@ -109,7 +110,9 @@ def post_validate(self, model: onnx.ModelProto) -> onnx.ModelProto:
raise TypeError(f"Expected onnx.ModelProto, found {type(model)}")
return model

def export(self, pre_transforms_model: onnx.ModelProto, file_path: str):
def export(self, pre_transforms_model: Union[ModelProto, str], file_path: str):
if not isinstance(pre_transforms_model, ModelProto):
pre_transforms_model = onnx.load(pre_transforms_model)
if self.export_input_model or os.getenv("SAVE_PREQAT_ONNX", False):
save_onnx(pre_transforms_model, file_path.replace(".onnx", ".preqat.onnx"))

Expand Down
33 changes: 8 additions & 25 deletions src/sparseml/pytorch/utils/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader

from sparseml.exporters.onnx_to_deepsparse import ONNXToDeepsparse
from sparseml.onnx.utils import ONNXGraph
from sparseml.pytorch.opset import TORCH_DEFAULT_ONNX_OPSET
from sparseml.pytorch.utils.helpers import (
Expand Down Expand Up @@ -200,7 +201,7 @@ def export_onnx(
compilation.
:param convert_qat: if True and quantization aware training is detected in
the module being exported, the resulting QAT ONNX model will be converted
to a fully quantized ONNX model using `quantize_torch_qat_export`. Default
to a fully quantized ONNX model using `ONNXToDeepsparse`. Default
is False.
:param export_kwargs: kwargs to be passed as is to the torch.onnx.export api
call. Useful to pass in dyanmic_axes, input_names, output_names, etc.
Expand Down Expand Up @@ -444,7 +445,7 @@ def export_onnx(
compilation.
:param convert_qat: if True and quantization aware training is detected in
the module being exported, the resulting QAT ONNX model will be converted
to a fully quantized ONNX model using `quantize_torch_qat_export`. Default
to a fully quantized ONNX model using `ONNXToDeepsparse`. Default
is False.
:param dynamic_axes: dictionary of input or output names to list of dimensions
of those tensors that should be exported as dynamic. May input 'batch'
Expand Down Expand Up @@ -579,34 +580,16 @@ def export_onnx(
save_onnx(onnx_model, file_path)

if convert_qat and is_quant_module:
# overwrite exported model with fully quantized version
# import here to avoid cyclic dependency
from sparseml.pytorch.sparsification.quantization import (
quantize_torch_qat_export,
)

use_qlinearconv = hasattr(module, "export_with_qlinearconv") and (
use_qlinear_conv = hasattr(module, "export_with_qlinearconv") and (
module.export_with_qlinearconv
)

quantize_torch_qat_export(
model=file_path,
output_file_path=file_path,
use_qlinearconv=use_qlinearconv,
exporter = ONNXToDeepsparse(
use_qlinear_conv=use_qlinear_conv,
skip_input_quantize=skip_input_quantize,
)

if skip_input_quantize:
try:
# import here to avoid cyclic dependency
from sparseml.pytorch.sparsification.quantization import (
skip_onnx_input_quantize,
)

skip_onnx_input_quantize(file_path, file_path)
except Exception as e:
_LOGGER.warning(
f"Unable to skip input QuantizeLinear op with exception {e}"
)
exporter.export(pre_transforms_model=file_path, file_path=file_path)


def _copy_file(src: str, target: str):
Expand Down
Loading