Skip to content

Commit

Permalink
Enable ONNX export for transformers 4.45 (#2045)
Browse files Browse the repository at this point in the history
* Enable ONNX export for transformers 4.45

* add comment

* update setup
  • Loading branch information
echarlaix authored Oct 9, 2024
1 parent d3c56cd commit 2c0476e
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 11 deletions.
11 changes: 5 additions & 6 deletions optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

import numpy as np
import onnx
import transformers
from transformers.modeling_utils import get_parameter_dtype
from transformers.utils import is_tf_available, is_torch_available

Expand Down Expand Up @@ -531,6 +530,11 @@ def export_pytorch(
logger.info(f"Using framework PyTorch: {torch.__version__}")
FORCE_ONNX_EXTERNAL_DATA = os.getenv("FORCE_ONNX_EXTERNAL_DATA", "0") == "1"

model_kwargs = model_kwargs or {}
# num_logits_to_keep was added in transformers 4.45 and isn't added as inputs when exporting the model
if check_if_transformers_greater("4.44.99") and "num_logits_to_keep" in signature(model.forward).parameters.keys():
model_kwargs["num_logits_to_keep"] = 0

with torch.no_grad():
model.config.return_dict = True
model = model.eval()
Expand Down Expand Up @@ -1001,11 +1005,6 @@ def onnx_export_from_model(
>>> onnx_export_from_model(model, output="gpt2_onnx/")
```
"""
if check_if_transformers_greater("4.44.99"):
raise ImportError(
f"ONNX conversion disabled for now for transformers version greater than v4.45, found {transformers.__version__}"
)

TasksManager.standardize_model_attributes(model)

if hasattr(model.config, "export_model_type"):
Expand Down
7 changes: 2 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
"datasets>=1.2.1",
"evaluate",
"protobuf>=3.20.1",
"transformers<4.45.0",
],
"onnxruntime-gpu": [
"onnx",
Expand All @@ -63,10 +62,9 @@
"evaluate",
"protobuf>=3.20.1",
"accelerate", # ORTTrainer requires it.
"transformers<4.45.0",
],
"exporters": ["onnx", "onnxruntime", "timm", "transformers<4.45.0"],
"exporters-gpu": ["onnx", "onnxruntime-gpu", "timm", "transformers<4.45.0"],
"exporters": ["onnx", "onnxruntime", "timm"],
"exporters-gpu": ["onnx", "onnxruntime-gpu", "timm"],
"exporters-tf": [
"tensorflow>=2.4,<=2.12.1",
"tf2onnx",
Expand All @@ -77,7 +75,6 @@
"numpy<1.24.0",
"datasets<=2.16",
"transformers[sentencepiece]>=4.26,<4.38",
"transformers<4.45.0",
],
"diffusers": ["diffusers"],
"intel": "optimum-intel>=1.18.0",
Expand Down

0 comments on commit 2c0476e

Please sign in to comment.