diff --git a/src/sparseml/exporters/onnx_to_deepsparse.py b/src/sparseml/exporters/onnx_to_deepsparse.py index fadad423b5c..dbf448ef3e4 100644 --- a/src/sparseml/exporters/onnx_to_deepsparse.py +++ b/src/sparseml/exporters/onnx_to_deepsparse.py @@ -18,6 +18,7 @@ from typing import Union import onnx +from onnx import ModelProto from sparseml.exporters import transforms as sparseml_transforms from sparseml.exporters.base_exporter import BaseExporter @@ -109,7 +110,9 @@ def post_validate(self, model: onnx.ModelProto) -> onnx.ModelProto: raise TypeError(f"Expected onnx.ModelProto, found {type(model)}") return model - def export(self, pre_transforms_model: onnx.ModelProto, file_path: str): + def export(self, pre_transforms_model: Union[ModelProto, str], file_path: str): + if not isinstance(pre_transforms_model, ModelProto): + pre_transforms_model = onnx.load(pre_transforms_model) if self.export_input_model or os.getenv("SAVE_PREQAT_ONNX", False): save_onnx(pre_transforms_model, file_path.replace(".onnx", ".preqat.onnx")) diff --git a/src/sparseml/pytorch/utils/exporter.py b/src/sparseml/pytorch/utils/exporter.py index 6250b889d39..7ac22ef5476 100644 --- a/src/sparseml/pytorch/utils/exporter.py +++ b/src/sparseml/pytorch/utils/exporter.py @@ -34,6 +34,7 @@ from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader +from sparseml.exporters.onnx_to_deepsparse import ONNXToDeepsparse from sparseml.onnx.utils import ONNXGraph from sparseml.pytorch.opset import TORCH_DEFAULT_ONNX_OPSET from sparseml.pytorch.utils.helpers import ( @@ -200,7 +201,7 @@ def export_onnx( compilation. :param convert_qat: if True and quantization aware training is detected in the module being exported, the resulting QAT ONNX model will be converted - to a fully quantized ONNX model using `quantize_torch_qat_export`. Default + to a fully quantized ONNX model using `ONNXToDeepsparse`. Default is False. :param export_kwargs: kwargs to be passed as is to the torch.onnx.export api call. Useful to pass in dyanmic_axes, input_names, output_names, etc. @@ -444,7 +445,7 @@ def export_onnx( compilation. :param convert_qat: if True and quantization aware training is detected in the module being exported, the resulting QAT ONNX model will be converted - to a fully quantized ONNX model using `quantize_torch_qat_export`. Default + to a fully quantized ONNX model using `ONNXToDeepsparse`. Default is False. :param dynamic_axes: dictionary of input or output names to list of dimensions of those tensors that should be exported as dynamic. May input 'batch' @@ -579,34 +580,16 @@ def export_onnx( save_onnx(onnx_model, file_path) if convert_qat and is_quant_module: - # overwrite exported model with fully quantized version - # import here to avoid cyclic dependency - from sparseml.pytorch.sparsification.quantization import ( - quantize_torch_qat_export, - ) - use_qlinearconv = hasattr(module, "export_with_qlinearconv") and ( + use_qlinear_conv = hasattr(module, "export_with_qlinearconv") and ( module.export_with_qlinearconv ) - quantize_torch_qat_export( - model=file_path, - output_file_path=file_path, - use_qlinearconv=use_qlinearconv, + exporter = ONNXToDeepsparse( + use_qlinear_conv=use_qlinear_conv, + skip_input_quantize=skip_input_quantize, ) - - if skip_input_quantize: - try: - # import here to avoid cyclic dependency - from sparseml.pytorch.sparsification.quantization import ( - skip_onnx_input_quantize, - ) - - skip_onnx_input_quantize(file_path, file_path) - except Exception as e: - _LOGGER.warning( - f"Unable to skip input QuantizeLinear op with exception {e}" - ) + exporter.export(pre_transforms_model=file_path, file_path=file_path) def _copy_file(src: str, target: str):