Skip to content

Commit

Permalink
[sparseml.transformers.export_onnx] add --opset flag (#1768)
Browse files Browse the repository at this point in the history
  • Loading branch information
bfineran committed Oct 17, 2023
1 parent 6a2b650 commit 07dbfae
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions src/sparseml/transformers/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
from transformers.tokenization_utils_base import PaddingStrategy

from sparseml.optim import parse_recipe_variables
from sparseml.pytorch.opset import TORCH_DEFAULT_ONNX_OPSET
from sparseml.pytorch.optim import ScheduledModifierManager
from sparseml.pytorch.utils import export_onnx
from sparseml.transformers.sparsification import Trainer
Expand Down Expand Up @@ -245,6 +246,7 @@ def export_transformer_to_onnx(
trust_remote_code: bool = False,
data_args: Optional[Union[Dict[str, Any], str]] = None,
one_shot: Optional[str] = None,
opset: int = TORCH_DEFAULT_ONNX_OPSET,
) -> str:
"""
Exports the saved transformers file to ONNX at batch size 1 using
Expand All @@ -266,6 +268,7 @@ def export_transformer_to_onnx(
:param data_args: additional args to instantiate a `DataTrainingArguments`
instance for exporting samples
:param one_shot: one shot recipe to be applied before exporting model
:param opset: ONNX opset to export with
:return: path to the exported ONNX file
"""
task = task.replace("_", "-").replace(" ", "-")
Expand Down Expand Up @@ -398,6 +401,7 @@ def export_transformer_to_onnx(
inputs,
onnx_file_path,
convert_qat=convert_qat,
opset=opset,
**kwargs,
)
_LOGGER.info(f"ONNX exported to {onnx_file_path}")
Expand Down Expand Up @@ -562,6 +566,12 @@ def _parse_args() -> argparse.Namespace:
action="store_true",
help=("Set flag to allow custom models in HF-transformers"),
)
parser.add_argument(
"--opset",
type=int,
default=TORCH_DEFAULT_ONNX_OPSET,
help=f"ONNX opset to export with, default: {TORCH_DEFAULT_ONNX_OPSET}",
)

return parser.parse_args()

Expand All @@ -577,6 +587,7 @@ def export(
trust_remote_code: bool = False,
data_args: Optional[str] = None,
one_shot: Optional[str] = None,
opset: int = TORCH_DEFAULT_ONNX_OPSET,
):
if os.path.exists(model_path):
# expand to absolute path to support downstream logic
Expand All @@ -592,6 +603,7 @@ def export(
trust_remote_code=trust_remote_code,
data_args=data_args,
one_shot=one_shot,
opset=opset,
)

deployment_folder_dir = create_deployment_folder(
Expand All @@ -616,6 +628,7 @@ def main():
trust_remote_code=args.trust_remote_code,
data_args=args.data_args,
one_shot=args.one_shot,
opset=args.opset,
)


Expand Down

0 comments on commit 07dbfae

Please sign in to comment.