Skip to content

Commit

Permalink
Deprecate the use of QLinearMatMul in favor of MatMultInteger (#1745)
Browse files Browse the repository at this point in the history
* Deprecate the use of QLinearMatMul in favor of MatMultInteger

* Quality fixes
  • Loading branch information
anmarques committed Oct 3, 2023
1 parent 1174f8a commit 38fe044
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 12 deletions.
33 changes: 21 additions & 12 deletions src/sparseml/exporters/onnx_to_deepsparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class ONNXToDeepsparse(BaseExporter):
def __init__(
self,
use_qlinear_conv: bool = False,
use_qlinear_matmul: bool = False,
skip_input_quantize: bool = False,
inplace: bool = True,
export_input_model: bool = False,
Expand All @@ -77,19 +78,27 @@ def __init__(
sparseml_transforms.QuantizeQATEmbedding(),
sparseml_transforms.PropagateEmbeddingQuantization(),
sparseml_transforms.PropagateDequantThroughSplit(),
sparseml_transforms.MatMulToQLinearMatMul(),
sparseml_transforms.MatMulAddToMatMulIntegerAddCastMul(),
sparseml_transforms.MatMulToMatMulIntegerCastMul(),
sparseml_transforms.FoldReLUQuants(),
sparseml_transforms.ConvToQLinearConv()
if use_qlinear_conv
else sparseml_transforms.ConvToConvIntegerAddCastMul(),
sparseml_transforms.GemmToQLinearMatMul(),
sparseml_transforms.GemmToMatMulIntegerAddCastMul(),
sparseml_transforms.QuantizeResiduals(),
sparseml_transforms.RemoveDuplicateQConvWeights(),
sparseml_transforms.RemoveDuplicateQuantizeOps(),
]
if use_qlinear_matmul:
transforms.append(
sparseml_transforms.MatMulToQLinearMatMul(),
)

transforms.extend(
[
sparseml_transforms.MatMulAddToMatMulIntegerAddCastMul(),
sparseml_transforms.MatMulToMatMulIntegerCastMul(),
sparseml_transforms.FoldReLUQuants(),
sparseml_transforms.ConvToQLinearConv()
if use_qlinear_conv
else sparseml_transforms.ConvToConvIntegerAddCastMul(),
sparseml_transforms.GemmToQLinearMatMul(),
sparseml_transforms.GemmToMatMulIntegerAddCastMul(),
sparseml_transforms.QuantizeResiduals(),
sparseml_transforms.RemoveDuplicateQConvWeights(),
sparseml_transforms.RemoveDuplicateQuantizeOps(),
]
)

if skip_input_quantize:
transforms.append(sparseml_transforms.SkipInputQuantize())
Expand Down
5 changes: 5 additions & 0 deletions src/sparseml/pytorch/utils/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,8 +585,13 @@ def export_onnx(
module.export_with_qlinearconv
)

use_qlinear_matmul = hasattr(module, "export_with_qlinearmatmul") and (
module.export_with_qlinearmatmul
)

exporter = ONNXToDeepsparse(
use_qlinear_conv=use_qlinear_conv,
use_qlinear_matmul=use_qlinear_matmul,
skip_input_quantize=skip_input_quantize,
)
exporter.export(pre_transforms_model=file_path, file_path=file_path)
Expand Down

0 comments on commit 38fe044

Please sign in to comment.