Skip to content

Commit

Permalink
Add ONNX export support for ViTMAE and ViTMSN
Browse files Browse the repository at this point in the history
  • Loading branch information
xenova committed Aug 30, 2024
1 parent 844aa66 commit de07c7a
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 1 deletion.
2 changes: 2 additions & 0 deletions docs/source/exporters/onnx/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra
- UniSpeech SAT
- Vision Encoder Decoder
- Vit
- VitMAE
- VitMSN
- Wav2Vec2
- Wav2Vec2 Conformer
- WavLM
Expand Down
12 changes: 12 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,18 @@ class PvtOnnxConfig(ViTOnnxConfig):
DEFAULT_ONNX_OPSET = 11


class VitMAEOnnxConfig(ViTOnnxConfig):
# torch.onnx.errors.UnsupportedOperatorError: Exporting the operator 'aten::scaled_dot_product_attention' to ONNX opset version 11 is not supported.
# Support for this operator was added in version 14, try exporting with this version.
DEFAULT_ONNX_OPSET = 14


class VitMSNOnnxConfig(ViTOnnxConfig):
# torch.onnx.errors.UnsupportedOperatorError: Exporting the operator 'aten::scaled_dot_product_attention' to ONNX opset version 11 is not supported.
# Support for this operator was added in version 14, try exporting with this version.
DEFAULT_ONNX_OPSET = 14


class Dinov2DummyInputGenerator(DummyVisionInputGenerator):
def __init__(
self,
Expand Down
16 changes: 15 additions & 1 deletion optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1135,7 +1135,21 @@ class TasksManager:
onnx="VisionEncoderDecoderOnnxConfig",
),
"vit": supported_tasks_mapping(
"feature-extraction", "image-classification", "masked-im", onnx="ViTOnnxConfig"
"feature-extraction",
"image-classification",
"masked-im",
onnx="ViTOnnxConfig",
),
"vit-mae": supported_tasks_mapping(
"feature-extraction",
"masked-im",
onnx="VitMAEOnnxConfig",
),
"vit-msn": supported_tasks_mapping(
"feature-extraction",
"image-classification",
"masked-im",
onnx="VitMSNOnnxConfig",
),
"vits": supported_tasks_mapping(
"text-to-audio",
Expand Down
2 changes: 2 additions & 0 deletions tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@
"t5": "hf-internal-testing/tiny-random-t5",
"table-transformer": "hf-internal-testing/tiny-random-TableTransformerModel",
"vit": "hf-internal-testing/tiny-random-vit",
"vit-msn": "hf-internal-testing/tiny-random-ViTMSNForImageClassification",
"vits": "echarlaix/tiny-random-vits",
"yolos": "hf-internal-testing/tiny-random-YolosModel",
"whisper": "openai/whisper-tiny.en", # hf-internal-testing ones are broken
Expand Down Expand Up @@ -279,6 +280,7 @@
"t5": "t5-small",
"table-transformer": "microsoft/table-transformer-detection",
"vit": "google/vit-base-patch16-224",
"vit-msn": "facebook/vit-msn-small",
"yolos": "hustvl/yolos-tiny",
"whisper": "openai/whisper-tiny.en",
"hubert": "facebook/hubert-base-ls960",
Expand Down

0 comments on commit de07c7a

Please sign in to comment.